package command import ( "context" "encoding/base64" "fmt" "strings" "time" "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/id" "github.com/zitadel/zitadel/internal/repository/authrequest" "github.com/zitadel/zitadel/internal/repository/oidcsession" "github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/zerrors" ) const ( TokenDelimiter = "-" AccessTokenPrefix = "at_" RefreshTokenPrefix = "rt_" oidcTokenSubjectDelimiter = ":" oidcTokenFormat = "%s" + oidcTokenSubjectDelimiter + "%s" ) // AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration. // If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged. func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) { cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID) if err != nil { return "", time.Time{}, err } cmd.AddSession(ctx) if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope, domain.TokenReasonAuthRequest, nil); err != nil { return "", time.Time{}, err } cmd.SetAuthRequestSuccessful(ctx) accessTokenID, _, accessTokenExpiration, err := cmd.PushEvents(ctx) return accessTokenID, accessTokenExpiration, err } // AddOIDCSessionRefreshAndAccessToken creates a new OIDC Session, creates an access token and refresh token. // It returns the access token id, expiration and the refresh token. // If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged. func (c *Commands) AddOIDCSessionRefreshAndAccessToken(ctx context.Context, authRequestID string) (tokenID, refreshToken string, tokenExpiration time.Time, err error) { cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID) if err != nil { return "", "", time.Time{}, err } cmd.AddSession(ctx) if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope, domain.TokenReasonAuthRequest, nil); err != nil { return "", "", time.Time{}, err } if err = cmd.AddRefreshToken(ctx); err != nil { return "", "", time.Time{}, err } cmd.SetAuthRequestSuccessful(ctx) return cmd.PushEvents(ctx) } // ExchangeOIDCSessionRefreshAndAccessToken updates an existing OIDC Session, creates a new access and refresh token. // It returns the access token id and expiration and the new refresh token. func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, oidcSessionID, refreshToken string, scope []string) (tokenID, newRefreshToken string, tokenExpiration time.Time, err error) { cmd, err := c.newOIDCSessionUpdateEvents(ctx, oidcSessionID, refreshToken) if err != nil { return "", "", time.Time{}, err } if err = cmd.AddAccessToken(ctx, scope, domain.TokenReasonRefresh, nil); err != nil { return "", "", time.Time{}, err } if err = cmd.RenewRefreshToken(ctx); err != nil { return "", "", time.Time{}, err } return cmd.PushEvents(ctx) } // OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant). // If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned. func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) { oidcSessionID, refreshTokenID, err := parseRefreshToken(refreshToken) if err != nil { return nil, err } writeModel := NewOIDCSessionWriteModel(oidcSessionID, "") err = c.eventstore.FilterToQueryReducer(ctx, writeModel) if err != nil { return nil, zerrors.ThrowPreconditionFailed(err, "OIDCS-SAF31", "Errors.OIDCSession.RefreshTokenInvalid") } if err = writeModel.CheckRefreshToken(refreshTokenID); err != nil { return nil, err } return writeModel, nil } func oidcSessionTokenIDsFromToken(token string) (oidcSessionID, refreshTokenID, accessTokenID string, err error) { split := strings.Split(token, TokenDelimiter) if len(split) != 2 { return "", "", "", zerrors.ThrowPreconditionFailed(nil, "OIDCS-S87kl", "Errors.OIDCSession.Token.Invalid") } if strings.HasPrefix(split[1], RefreshTokenPrefix) { return split[0], split[1], "", nil } if strings.HasPrefix(split[1], AccessTokenPrefix) { return split[0], "", split[1], nil } return "", "", "", zerrors.ThrowPreconditionFailed(nil, "OIDCS-S87kl", "Errors.OIDCSession.Token.Invalid") } // RevokeOIDCSessionToken revokes an access_token or refresh_token // if the OIDCSession cannot be retrieved by the provided token, is not active or if the token is already revoked, // then no error will be returned. // The only possible error (except db connection or other internal errors) occurs if a client tries to revoke a token, // which was not part of the audience. func (c *Commands) RevokeOIDCSessionToken(ctx context.Context, token, clientID string) (err error) { oidcSessionID, refreshTokenID, accessTokenID, err := oidcSessionTokenIDsFromToken(token) if err != nil { logging.WithError(err).Info("token revocation with invalid token format") return nil } writeModel := NewOIDCSessionWriteModel(oidcSessionID, "") err = c.eventstore.FilterToQueryReducer(ctx, writeModel) if err != nil { return zerrors.ThrowInternal(err, "OIDCS-NB3t2", "Errors.Internal") } if err = writeModel.CheckClient(clientID); err != nil { return err } if refreshTokenID != "" { if err = writeModel.CheckRefreshToken(refreshTokenID); err != nil { logging.WithFields("oidcSessionID", oidcSessionID, "refreshTokenID", refreshTokenID).WithError(err). Info("refresh token revocation with invalid token") return nil } return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewRefreshTokenRevokedEvent(ctx, writeModel.aggregate)) } if err = writeModel.CheckAccessToken(accessTokenID); err != nil { logging.WithFields("oidcSessionID", oidcSessionID, "accessTokenID", accessTokenID).WithError(err). Info("access token revocation with invalid token") return nil } return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewAccessTokenRevokedEvent(ctx, writeModel.aggregate)) } func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) { authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID) if err != nil { return nil, err } if err = authRequestWriteModel.CheckAuthenticated(); err != nil { return nil, err } sessionWriteModel := NewSessionWriteModel(authRequestWriteModel.SessionID, authz.GetInstance(ctx).InstanceID()) err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel) if err != nil { return nil, err } if err = sessionWriteModel.CheckIsActive(); err != nil { return nil, err } resourceOwner, err := c.getResourceOwnerOfSessionUser(ctx, sessionWriteModel.UserID, sessionWriteModel.InstanceID) if err != nil { return nil, err } accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx) if err != nil { return nil, err } sessionID, err := c.idGenerator.Next() if err != nil { return nil, err } sessionID = IDPrefixV2 + sessionID return &OIDCSessionEvents{ eventstore: c.eventstore, idGenerator: c.idGenerator, encryptionAlg: c.keyAlgorithm, oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, resourceOwner), sessionWriteModel: sessionWriteModel, authRequestWriteModel: authRequestWriteModel, accessTokenLifetime: accessTokenLifetime, refreshTokenLifeTime: refreshTokenLifeTime, refreshTokenIdleLifetime: refreshTokenIdleLifetime, }, nil } func (c *Commands) getResourceOwnerOfSessionUser(ctx context.Context, userID, instanceID string) (string, error) { events, err := c.eventstore.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). InstanceID(instanceID). AllowTimeTravel(). OrderAsc(). Limit(1). AddQuery(). AggregateTypes(user.AggregateType). AggregateIDs(userID). Builder()) if err != nil || len(events) != 1 { return "", zerrors.ThrowInternal(err, "OIDCS-sferh", "Errors.Internal") } return events[0].Aggregate().ResourceOwner, nil } func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID string, err error) { decoded, err := base64.RawURLEncoding.DecodeString(refreshToken) if err != nil { return "", zerrors.ThrowInvalidArgument(err, "OIDCS-Cux9a", "Errors.User.RefreshToken.Invalid") } decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID()) if err != nil { return "", err } _, refreshTokenID, err = parseRefreshToken(decrypted) return refreshTokenID, err } func parseRefreshToken(refreshToken string) (oidcSessionID, refreshTokenID string, err error) { split := strings.Split(refreshToken, TokenDelimiter) if len(split) < 2 || !strings.HasPrefix(split[1], RefreshTokenPrefix) { return "", "", zerrors.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid") } // the oidc library requires that every token has the format of : // the V2 tokens don't use the userID anymore, so let's just remove it return split[0], strings.Split(split[1], oidcTokenSubjectDelimiter)[0], nil } func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) { refreshTokenID, err := c.decryptRefreshToken(refreshToken) if err != nil { return nil, err } sessionWriteModel := NewOIDCSessionWriteModel(oidcSessionID, "") if err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel); err != nil { return nil, err } if err = sessionWriteModel.CheckRefreshToken(refreshTokenID); err != nil { return nil, err } accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx) if err != nil { return nil, err } return &OIDCSessionEvents{ eventstore: c.eventstore, idGenerator: c.idGenerator, encryptionAlg: c.keyAlgorithm, oidcSessionWriteModel: sessionWriteModel, accessTokenLifetime: accessTokenLifetime, refreshTokenLifeTime: refreshTokenLifeTime, refreshTokenIdleLifetime: refreshTokenIdleLifetime, }, nil } type OIDCSessionEvents struct { eventstore *eventstore.Eventstore idGenerator id.Generator encryptionAlg crypto.EncryptionAlgorithm events []eventstore.Command oidcSessionWriteModel *OIDCSessionWriteModel sessionWriteModel *SessionWriteModel authRequestWriteModel *AuthRequestWriteModel accessTokenLifetime time.Duration refreshTokenLifeTime time.Duration refreshTokenIdleLifetime time.Duration // accessTokenID is set by the command accessTokenID string // refreshToken is set by the command refreshToken string } func (c *OIDCSessionEvents) AddSession(ctx context.Context) { c.events = append(c.events, oidcsession.NewAddedEvent( ctx, c.oidcSessionWriteModel.aggregate, c.sessionWriteModel.UserID, c.sessionWriteModel.AggregateID, c.authRequestWriteModel.ClientID, c.authRequestWriteModel.Audience, c.authRequestWriteModel.Scope, c.sessionWriteModel.AuthMethodTypes(), c.sessionWriteModel.AuthenticationTime(), )) } func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) { c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate)) } func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string, reason domain.TokenReason, actor *domain.TokenActor) error { accessTokenID, err := c.idGenerator.Next() if err != nil { return err } c.accessTokenID = AccessTokenPrefix + accessTokenID c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime, reason, actor)) return nil } func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) { var refreshTokenID string refreshTokenID, c.refreshToken, err = c.generateRefreshToken(c.sessionWriteModel.UserID) if err != nil { return err } c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime)) return nil } func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) { var refreshTokenID string refreshTokenID, c.refreshToken, err = c.generateRefreshToken(c.oidcSessionWriteModel.UserID) if err != nil { return err } c.events = append(c.events, oidcsession.NewRefreshTokenRenewedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenIdleLifetime)) return nil } func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID, refreshToken string, err error) { refreshTokenID, err = c.idGenerator.Next() if err != nil { return "", "", err } refreshTokenID = RefreshTokenPrefix + refreshTokenID token, err := c.encryptionAlg.Encrypt([]byte(fmt.Sprintf(oidcTokenFormat, c.oidcSessionWriteModel.OIDCRefreshTokenID(refreshTokenID), userID))) if err != nil { return "", "", err } return refreshTokenID, base64.RawURLEncoding.EncodeToString(token), nil } func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID string, refreshToken string, accessTokenExpiration time.Time, err error) { pushedEvents, err := c.eventstore.Push(ctx, c.events...) if err != nil { return "", "", time.Time{}, err } err = AppendAndReduce(c.oidcSessionWriteModel, pushedEvents...) if err != nil { return "", "", time.Time{}, err } // prefix the returned id with the oidcSessionID so that we can retrieve it later on // we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts return c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil } func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) { oidcSettings := NewInstanceOIDCSettingsWriteModel(ctx) err = c.eventstore.FilterToQueryReducer(ctx, oidcSettings) if err != nil { return 0, 0, 0, err } accessTokenLifetime = c.defaultAccessTokenLifetime refreshTokenLifetime = c.defaultRefreshTokenLifetime refreshTokenIdleLifetime = c.defaultRefreshTokenIdleLifetime if oidcSettings.AccessTokenLifetime > 0 { accessTokenLifetime = oidcSettings.AccessTokenLifetime } if oidcSettings.RefreshTokenExpiration > 0 { refreshTokenLifetime = oidcSettings.RefreshTokenExpiration } if oidcSettings.RefreshTokenIdleExpiration > 0 { refreshTokenIdleLifetime = oidcSettings.RefreshTokenIdleExpiration } return accessTokenLifetime, refreshTokenLifetime, refreshTokenIdleLifetime, nil }