feat: refresh token (#1728)

* begin refresh tokens

* refresh tokens

* list and revoke refresh tokens

* handle remove

* tests for refresh tokens

* uniqueness and default expiration

* rename oidc token methods

* cleanup

* migration version

* Update internal/static/i18n/en.yaml

Co-authored-by: Fabi <38692350+fgerschwiler@users.noreply.github.com>

* fixes

* feat: update oidc pkg for refresh tokens

Co-authored-by: Fabi <38692350+fgerschwiler@users.noreply.github.com>
This commit is contained in:
Livio Amstutz
2021-05-20 13:33:35 +02:00
committed by GitHub
parent bc21eeb114
commit ec5020bebc
36 changed files with 2732 additions and 55 deletions

View File

@@ -0,0 +1,94 @@
package eventstore
import (
"context"
"time"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/crypto"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/eventstore/v1"
"github.com/caos/zitadel/internal/eventstore/v1/models"
usr_view "github.com/caos/zitadel/internal/user/repository/view"
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/telemetry/tracing"
usr_model "github.com/caos/zitadel/internal/user/model"
"github.com/caos/zitadel/internal/user/repository/view/model"
)
type RefreshTokenRepo struct {
Eventstore v1.Eventstore
View *view.View
SearchLimit uint64
KeyAlgorithm crypto.EncryptionAlgorithm
}
func (r *RefreshTokenRepo) RefreshTokenByID(ctx context.Context, refreshToken string) (*usr_model.RefreshTokenView, error) {
userID, tokenID, token, err := domain.FromRefreshToken(refreshToken, r.KeyAlgorithm)
if err != nil {
return nil, err
}
tokenView, viewErr := r.View.RefreshTokenByID(tokenID)
if viewErr != nil && !errors.IsNotFound(viewErr) {
return nil, viewErr
}
if errors.IsNotFound(viewErr) {
tokenView = new(model.RefreshTokenView)
tokenView.ID = tokenID
tokenView.UserID = userID
}
events, esErr := r.getUserEvents(ctx, userID, tokenView.Sequence)
if errors.IsNotFound(viewErr) && len(events) == 0 {
return nil, errors.ThrowNotFound(nil, "EVENT-BHB52", "Errors.User.RefreshToken.Invalid")
}
if esErr != nil {
logging.Log("EVENT-AE462").WithError(viewErr).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Debug("error retrieving new events")
return model.RefreshTokenViewToModel(tokenView), nil
}
viewToken := *tokenView
for _, event := range events {
err := tokenView.AppendEventIfMyRefreshToken(event)
if err != nil {
return model.RefreshTokenViewToModel(&viewToken), nil
}
}
if !tokenView.Expiration.After(time.Now()) || tokenView.Token != token {
return nil, errors.ThrowNotFound(nil, "EVENT-5Bm9s", "Errors.User.RefreshToken.Invalid")
}
return model.RefreshTokenViewToModel(tokenView), nil
}
func (r *RefreshTokenRepo) SearchMyRefreshTokens(ctx context.Context, userID string, request *usr_model.RefreshTokenSearchRequest) (*usr_model.RefreshTokenSearchResponse, error) {
err := request.EnsureLimit(r.SearchLimit)
if err != nil {
return nil, err
}
sequence, err := r.View.GetLatestRefreshTokenSequence()
logging.Log("EVENT-GBdn4").OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Warn("could not read latest refresh token sequence")
request.Queries = append(request.Queries, &usr_model.RefreshTokenSearchQuery{Key: usr_model.RefreshTokenSearchKeyUserID, Method: domain.SearchMethodEquals, Value: userID})
tokens, count, err := r.View.SearchRefreshTokens(request)
if err != nil {
return nil, err
}
return &usr_model.RefreshTokenSearchResponse{
Offset: request.Offset,
Limit: request.Limit,
TotalResult: count,
Sequence: sequence.CurrentSequence,
Timestamp: sequence.LastSuccessfulSpoolerRun,
Result: model.RefreshTokenViewsToModel(tokens),
}, nil
}
func (r *RefreshTokenRepo) getUserEvents(ctx context.Context, userID string, sequence uint64) ([]*models.Event, error) {
query, err := usr_view.UserByIDQuery(userID, sequence)
if err != nil {
return nil, err
}
return r.Eventstore.FilterEvents(ctx, query)
}

View File

@@ -69,6 +69,7 @@ func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es
newProjectRole(handler{view, bulkLimit, configs.cycleDuration("ProjectRole"), errorCount, es}),
newLabelPolicy(handler{view, bulkLimit, configs.cycleDuration("LabelPolicy"), errorCount, es}),
newFeatures(handler{view, bulkLimit, configs.cycleDuration("Features"), errorCount, es}),
newRefreshToken(handler{view, bulkLimit, configs.cycleDuration("RefreshToken"), errorCount, es}),
}
}

View File

@@ -0,0 +1,123 @@
package handler
import (
"encoding/json"
"github.com/caos/logging"
caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/eventstore/v1"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/eventstore/v1/query"
"github.com/caos/zitadel/internal/eventstore/v1/spooler"
project_es_model "github.com/caos/zitadel/internal/project/repository/eventsourcing/model"
user_repo "github.com/caos/zitadel/internal/repository/user"
user_es_model "github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
view_model "github.com/caos/zitadel/internal/user/repository/view/model"
)
const (
refreshTokenTable = "auth.refresh_tokens"
)
type RefreshToken struct {
handler
subscription *v1.Subscription
}
func newRefreshToken(
handler handler,
) *RefreshToken {
h := &RefreshToken{
handler: handler,
}
h.subscribe()
return h
}
func (t *RefreshToken) subscribe() {
t.subscription = t.es.Subscribe(t.AggregateTypes()...)
go func() {
for event := range t.subscription.Events {
query.ReduceEvent(t, event)
}
}()
}
func (t *RefreshToken) ViewModel() string {
return refreshTokenTable
}
func (t *RefreshToken) AggregateTypes() []es_models.AggregateType {
return []es_models.AggregateType{user_es_model.UserAggregate, project_es_model.ProjectAggregate}
}
func (t *RefreshToken) CurrentSequence() (uint64, error) {
sequence, err := t.view.GetLatestRefreshTokenSequence()
if err != nil {
return 0, err
}
return sequence.CurrentSequence, nil
}
func (t *RefreshToken) EventQuery() (*es_models.SearchQuery, error) {
sequence, err := t.view.GetLatestRefreshTokenSequence()
if err != nil {
return nil, err
}
return es_models.NewSearchQuery().
AggregateTypeFilter(user_es_model.UserAggregate, project_es_model.ProjectAggregate).
LatestSequenceFilter(sequence.CurrentSequence), nil
}
func (t *RefreshToken) Reduce(event *es_models.Event) (err error) {
switch eventstore.EventType(event.Type) {
case user_repo.HumanRefreshTokenAddedType:
token := new(view_model.RefreshTokenView)
err := token.AppendEvent(event)
if err != nil {
return err
}
return t.view.PutRefreshToken(token, event)
case user_repo.HumanRefreshTokenRenewedType:
e := new(user_repo.HumanRefreshTokenRenewedEvent)
if err := json.Unmarshal(event.Data, e); err != nil {
logging.Log("EVEN-DBbn4").WithError(err).Error("could not unmarshal event data")
return caos_errs.ThrowInternal(nil, "MODEL-BHn75", "could not unmarshal data")
}
token, err := t.view.RefreshTokenByID(e.TokenID)
if err != nil {
return err
}
err = token.AppendEvent(event)
if err != nil {
return err
}
return t.view.PutRefreshToken(token, event)
case user_repo.HumanRefreshTokenRemovedType:
e := new(user_repo.HumanRefreshTokenRemovedEvent)
if err := json.Unmarshal(event.Data, e); err != nil {
logging.Log("EVEN-BDbh3").WithError(err).Error("could not unmarshal event data")
return caos_errs.ThrowInternal(nil, "MODEL-Bz653", "could not unmarshal data")
}
return t.view.DeleteRefreshToken(e.TokenID, event)
case user_repo.UserLockedType,
user_repo.UserDeactivatedType,
user_repo.UserRemovedType:
return t.view.DeleteUserRefreshTokens(event.AggregateID, event)
default:
return t.view.ProcessedRefreshTokenSequence(event)
}
}
func (t *RefreshToken) OnError(event *es_models.Event, err error) error {
logging.LogWithFields("SPOOL-3jkl4", "id", event.AggregateID).WithError(err).Warn("something went wrong in token handler")
return spooler.HandleError(event, err, t.view.GetLatestTokenFailedEvent, t.view.ProcessedTokenFailedEvent, t.view.ProcessedTokenSequence, t.errorCountUntilSkip)
}
func (t *RefreshToken) OnSuccess() error {
return spooler.HandleSuccess(t.view.UpdateTokenSpoolerRunTimestamp)
}

View File

@@ -36,6 +36,7 @@ type EsRepository struct {
eventstore.UserRepo
eventstore.AuthRequestRepo
eventstore.TokenRepo
eventstore.RefreshTokenRepo
eventstore.KeyRepository
eventstore.ApplicationRepo
eventstore.UserSessionRepo
@@ -110,6 +111,12 @@ func Start(conf Config, authZ authz.Config, systemDefaults sd.SystemDefaults, co
View: view,
Eventstore: es,
},
eventstore.RefreshTokenRepo{
View: view,
Eventstore: es,
SearchLimit: conf.SearchLimit,
KeyAlgorithm: keyAlgorithm,
},
eventstore.KeyRepository{
View: view,
Commands: command,

View File

@@ -0,0 +1,86 @@
package view
import (
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1/models"
user_model "github.com/caos/zitadel/internal/user/model"
usr_view "github.com/caos/zitadel/internal/user/repository/view"
"github.com/caos/zitadel/internal/user/repository/view/model"
"github.com/caos/zitadel/internal/view/repository"
)
const (
refreshTokenTable = "auth.refresh_tokens"
)
func (v *View) RefreshTokenByID(tokenID string) (*model.RefreshTokenView, error) {
return usr_view.RefreshTokenByID(v.Db, refreshTokenTable, tokenID)
}
func (v *View) RefreshTokensByUserID(userID string) ([]*model.RefreshTokenView, error) {
return usr_view.RefreshTokensByUserID(v.Db, refreshTokenTable, userID)
}
func (v *View) SearchRefreshTokens(request *user_model.RefreshTokenSearchRequest) ([]*model.RefreshTokenView, uint64, error) {
return usr_view.SearchRefreshTokens(v.Db, refreshTokenTable, request)
}
func (v *View) PutRefreshToken(token *model.RefreshTokenView, event *models.Event) error {
err := usr_view.PutRefreshToken(v.Db, refreshTokenTable, token)
if err != nil {
return err
}
return v.ProcessedTokenSequence(event)
}
func (v *View) PutRefreshTokens(token []*model.RefreshTokenView, event *models.Event) error {
err := usr_view.PutRefreshTokens(v.Db, refreshTokenTable, token...)
if err != nil {
return err
}
return v.ProcessedRefreshTokenSequence(event)
}
func (v *View) DeleteRefreshToken(tokenID string, event *models.Event) error {
err := usr_view.DeleteRefreshToken(v.Db, refreshTokenTable, tokenID)
if err != nil && !errors.IsNotFound(err) {
return err
}
return v.ProcessedRefreshTokenSequence(event)
}
func (v *View) DeleteUserRefreshTokens(userID string, event *models.Event) error {
err := usr_view.DeleteUserRefreshTokens(v.Db, refreshTokenTable, userID)
if err != nil && !errors.IsNotFound(err) {
return err
}
return v.ProcessedRefreshTokenSequence(event)
}
func (v *View) DeleteApplicationRefreshTokens(event *models.Event, ids ...string) error {
err := usr_view.DeleteApplicationTokens(v.Db, refreshTokenTable, ids)
if err != nil && !errors.IsNotFound(err) {
return err
}
return v.ProcessedRefreshTokenSequence(event)
}
func (v *View) GetLatestRefreshTokenSequence() (*repository.CurrentSequence, error) {
return v.latestSequence(refreshTokenTable)
}
func (v *View) ProcessedRefreshTokenSequence(event *models.Event) error {
return v.saveCurrentSequence(refreshTokenTable, event)
}
func (v *View) UpdateRefreshTokenSpoolerRunTimestamp() error {
return v.updateSpoolerRunSequence(refreshTokenTable)
}
func (v *View) GetLatestRefreshTokenFailedEvent(sequence uint64) (*repository.FailedEvent, error) {
return v.latestFailedEvent(refreshTokenTable, sequence)
}
func (v *View) ProcessedRefreshTokenFailedEvent(failedEvent *repository.FailedEvent) error {
return v.saveFailedEvent(failedEvent)
}

View File

@@ -0,0 +1,12 @@
package repository
import (
"context"
"github.com/caos/zitadel/internal/user/model"
)
type RefreshTokenRepository interface {
RefreshTokenByID(ctx context.Context, refreshToken string) (*model.RefreshTokenView, error)
SearchMyRefreshTokens(ctx context.Context, userID string, request *model.RefreshTokenSearchRequest) (*model.RefreshTokenSearchResponse, error)
}

View File

@@ -17,4 +17,5 @@ type Repository interface {
OrgRepository
IAMRepository
FeaturesRepository
RefreshTokenRepository
}