From e1b3cda98ac55abda57a550e10d61d6bc17f16e9 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Mon, 17 Jul 2023 14:33:37 +0200 Subject: [PATCH] feat(OIDC): support token revocation of V2 tokens (#6203) This PR adds support for OAuth2 token revocation of V2 tokens. Unlike with V1 tokens, it's now possible to revoke a token not only from the authorized client / client which the token was issued to, but rather from all trusted clients (audience) --- internal/api/oidc/auth_request.go | 37 ++- .../api/oidc/auth_request_integration_test.go | 190 +++++++++++++ internal/command/oidc_session.go | 103 +++++-- internal/command/oidc_session_model.go | 47 ++++ internal/command/oidc_session_test.go | 264 ++++++++++++++++-- internal/query/access_token.go | 10 + internal/repository/oidcsession/eventstore.go | 14 +- .../repository/oidcsession/oidc_session.go | 117 +++++--- internal/static/i18n/de.yaml | 1 + internal/static/i18n/en.yaml | 1 + internal/static/i18n/es.yaml | 1 + internal/static/i18n/fr.yaml | 1 + internal/static/i18n/it.yaml | 1 + internal/static/i18n/ja.yaml | 1 + internal/static/i18n/mk.yaml | 1 + internal/static/i18n/pl.yaml | 1 + internal/static/i18n/zh.yaml | 1 + 17 files changed, 689 insertions(+), 102 deletions(-) diff --git a/internal/api/oidc/auth_request.go b/internal/api/oidc/auth_request.go index e2ccc8f86d..93c79f07dd 100644 --- a/internal/api/oidc/auth_request.go +++ b/internal/api/oidc/auth_request.go @@ -265,12 +265,12 @@ func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - plainCode, err := o.decryptGrant(refreshToken) + plainToken, err := o.decryptGrant(refreshToken) if err != nil { return nil, err } - if strings.HasPrefix(plainCode, command.IDPrefixV2) { - oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainCode) + if strings.HasPrefix(plainToken, command.IDPrefixV2) { + oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainToken) if err != nil { return nil, err } @@ -308,7 +308,25 @@ func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID strin return err } -func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID string) *oidc.Error { +func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID string) (err *oidc.Error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + if strings.HasPrefix(token, command.IDPrefixV2) { + err := o.command.RevokeOIDCSessionToken(ctx, token, clientID) + if err == nil { + return nil + } + if errors.IsPreconditionFailed(err) { + return oidc.ErrInvalidClient().WithDescription("token was not issued for this client") + } + return oidc.ErrServerError().WithParent(err) + } + + return o.revokeTokenV1(ctx, token, userID, clientID) +} + +func (o *OPStorage) revokeTokenV1(ctx context.Context, token, userID, clientID string) *oidc.Error { refreshToken, err := o.repo.RefreshTokenByID(ctx, token, userID) if err == nil { if refreshToken.ClientID != clientID { @@ -338,6 +356,17 @@ func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID str } func (o *OPStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) { + plainToken, err := o.decryptGrant(token) + if err != nil { + return "", "", op.ErrInvalidRefreshToken + } + if strings.HasPrefix(plainToken, command.IDPrefixV2) { + oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainToken) + if err != nil { + return "", "", op.ErrInvalidRefreshToken + } + return oidcSession.UserID, oidcSession.OIDCRefreshTokenID(oidcSession.RefreshTokenID), nil + } refreshToken, err := o.repo.RefreshTokenByToken(ctx, token) if err != nil { return "", "", op.ErrInvalidRefreshToken diff --git a/internal/api/oidc/auth_request_integration_test.go b/internal/api/oidc/auth_request_integration_test.go index 9d6b4c701a..42b3121ffd 100644 --- a/internal/api/oidc/auth_request_integration_test.go +++ b/internal/api/oidc/auth_request_integration_test.go @@ -184,6 +184,196 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { require.Error(t, err) } +func TestOPStorage_RevokeToken_access_token(t *testing.T) { + clientID := createClient(t) + provider, err := Tester.CreateRelyingParty(clientID, redirectURI) + require.NoError(t, err) + authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId()) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code) + require.NoError(t, err) + assertTokens(t, tokens, true) + assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + + // revoke access token + err = rp.RevokeToken(provider, tokens.AccessToken, "access_token") + require.NoError(t, err) + + // userinfo must fail + _, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) + require.Error(t, err) + + // refresh grant must still work + _, err = refreshTokens(t, clientID, tokens.RefreshToken) + require.NoError(t, err) + + // revocation with the same access token must not fail (with or without hint) + err = rp.RevokeToken(provider, tokens.AccessToken, "access_token") + require.NoError(t, err) + err = rp.RevokeToken(provider, tokens.AccessToken, "") + require.NoError(t, err) +} + +func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T) { + clientID := createClient(t) + provider, err := Tester.CreateRelyingParty(clientID, redirectURI) + require.NoError(t, err) + authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId()) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code) + require.NoError(t, err) + assertTokens(t, tokens, true) + assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + + // revoke access token + err = rp.RevokeToken(provider, tokens.AccessToken, "refresh_token") + require.NoError(t, err) + + // userinfo must fail + _, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) + require.Error(t, err) + + // refresh grant must still work + _, err = refreshTokens(t, clientID, tokens.RefreshToken) + require.NoError(t, err) +} + +func TestOPStorage_RevokeToken_refresh_token(t *testing.T) { + clientID := createClient(t) + provider, err := Tester.CreateRelyingParty(clientID, redirectURI) + require.NoError(t, err) + authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId()) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code) + require.NoError(t, err) + assertTokens(t, tokens, true) + assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + + // revoke refresh token -> invalidates also access token + err = rp.RevokeToken(provider, tokens.RefreshToken, "refresh_token") + require.NoError(t, err) + + // userinfo must fail + _, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) + require.Error(t, err) + + // refresh must fail + _, err = refreshTokens(t, clientID, tokens.RefreshToken) + require.Error(t, err) + + // revocation with the same refresh token must not fail (with or without hint) + err = rp.RevokeToken(provider, tokens.RefreshToken, "refresh_token") + require.NoError(t, err) + err = rp.RevokeToken(provider, tokens.RefreshToken, "") + require.NoError(t, err) +} + +func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing.T) { + clientID := createClient(t) + provider, err := Tester.CreateRelyingParty(clientID, redirectURI) + require.NoError(t, err) + authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId()) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code) + require.NoError(t, err) + assertTokens(t, tokens, true) + assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + + // revoke refresh token even with a wrong hint + err = rp.RevokeToken(provider, tokens.RefreshToken, "access_token") + require.NoError(t, err) + + // userinfo must fail + _, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) + require.Error(t, err) + + // refresh must fail + _, err = refreshTokens(t, clientID, tokens.RefreshToken) + require.Error(t, err) +} + +func TestOPStorage_RevokeToken_invalid_client(t *testing.T) { + clientID := createClient(t) + authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId()) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code) + require.NoError(t, err) + assertTokens(t, tokens, true) + assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + + // simulate second client (not part of the audience) trying to revoke the token + otherClientID := createClient(t) + provider, err := Tester.CreateRelyingParty(otherClientID, redirectURI) + require.NoError(t, err) + err = rp.RevokeToken(provider, tokens.AccessToken, "") + require.Error(t, err) +} + func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) { provider, err := Tester.CreateRelyingParty(clientID, redirectURI) require.NoError(t, err) diff --git a/internal/command/oidc_session.go b/internal/command/oidc_session.go index 7a5f5dfda0..d990d0053a 100644 --- a/internal/command/oidc_session.go +++ b/internal/command/oidc_session.go @@ -3,9 +3,12 @@ package command import ( "context" "encoding/base64" + "fmt" "strings" "time" + "github.com/zitadel/logging" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" @@ -17,6 +20,14 @@ import ( "github.com/zitadel/zitadel/internal/repository/user" ) +const ( + TokenDelimiter = "-" + AccessTokenPrefix = "at_" + RefreshTokenPrefix = "rt_" + oidcTokenSubjectDelimiter = ":" + oidcTokenFormat = "%s" + oidcTokenSubjectDelimiter + "%s" +) + // AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration. // If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged. func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) { @@ -71,21 +82,70 @@ func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, // OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant). // If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned. func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) { - split := strings.Split(refreshToken, ":") - if len(split) != 2 { - return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid") + oidcSessionID, refreshTokenID, err := parseRefreshToken(refreshToken) + if err != nil { + return nil, err } - writeModel := NewOIDCSessionWriteModel(split[0], "") - err := c.eventstore.FilterToQueryReducer(ctx, writeModel) + writeModel := NewOIDCSessionWriteModel(oidcSessionID, "") + err = c.eventstore.FilterToQueryReducer(ctx, writeModel) if err != nil { return nil, caos_errs.ThrowPreconditionFailed(err, "OIDCS-SAF31", "Errors.OIDCSession.RefreshTokenInvalid") } - if err = writeModel.CheckRefreshToken(split[1]); err != nil { + if err = writeModel.CheckRefreshToken(refreshTokenID); err != nil { return nil, err } return writeModel, nil } +func oidcSessionTokenIDsFromToken(token string) (oidcSessionID, refreshTokenID, accessTokenID string, err error) { + split := strings.Split(token, TokenDelimiter) + if len(split) != 2 { + return "", "", "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-S87kl", "Errors.OIDCSession.Token.Invalid") + } + if strings.HasPrefix(split[1], RefreshTokenPrefix) { + return split[0], split[1], "", nil + } + if strings.HasPrefix(split[1], AccessTokenPrefix) { + return split[0], "", split[1], nil + } + return "", "", "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-S87kl", "Errors.OIDCSession.Token.Invalid") +} + +// RevokeOIDCSessionToken revokes an access_token or refresh_token +// if the OIDCSession cannot be retrieved by the provided token, is not active or if the token is already revoked, +// then no error will be returned. +// The only possible error (except db connection or other internal errors) occurs if a client tries to revoke a token, +// which was not part of the audience. +func (c *Commands) RevokeOIDCSessionToken(ctx context.Context, token, clientID string) (err error) { + oidcSessionID, refreshTokenID, accessTokenID, err := oidcSessionTokenIDsFromToken(token) + if err != nil { + logging.WithError(err).Info("token revocation with invalid token format") + return nil + } + writeModel := NewOIDCSessionWriteModel(oidcSessionID, "") + err = c.eventstore.FilterToQueryReducer(ctx, writeModel) + if err != nil { + return caos_errs.ThrowInternal(err, "OIDCS-NB3t2", "Errors.Internal") + } + if err = writeModel.CheckClient(clientID); err != nil { + return err + } + if refreshTokenID != "" { + if err = writeModel.CheckRefreshToken(refreshTokenID); err != nil { + logging.WithFields("oidcSessionID", oidcSessionID, "refreshTokenID", refreshTokenID).WithError(err). + Info("refresh token revocation with invalid token") + return nil + } + return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewRefreshTokenRevokedEvent(ctx, writeModel.aggregate)) + } + if err = writeModel.CheckAccessToken(accessTokenID); err != nil { + logging.WithFields("oidcSessionID", oidcSessionID, "accessTokenID", accessTokenID).WithError(err). + Info("access token revocation with invalid token") + return nil + } + return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewAccessTokenRevokedEvent(ctx, writeModel.aggregate)) +} + func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) { authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID) if err != nil { @@ -153,11 +213,18 @@ func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID stri if err != nil { return "", err } - split := strings.Split(decrypted, ":") - if len(split) != 2 { - return "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid") + _, refreshTokenID, err = parseRefreshToken(decrypted) + return refreshTokenID, err +} + +func parseRefreshToken(refreshToken string) (oidcSessionID, refreshTokenID string, err error) { + split := strings.Split(refreshToken, TokenDelimiter) + if len(split) < 2 || !strings.HasPrefix(split[1], RefreshTokenPrefix) { + return "", "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid") } - return split[1], nil + // the oidc library requires that every token has the format of : + // the V2 tokens don't use the userID anymore, so let's just remove it + return split[0], strings.Split(split[1], oidcTokenSubjectDelimiter)[0], nil } func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) { @@ -224,18 +291,19 @@ func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) { c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate)) } -func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) (err error) { - c.accessTokenID, err = c.idGenerator.Next() +func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) error { + accessTokenID, err := c.idGenerator.Next() if err != nil { return err } + c.accessTokenID = AccessTokenPrefix + accessTokenID c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime)) return nil } func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) { var refreshTokenID string - refreshTokenID, c.refreshToken, err = c.generateRefreshToken() + refreshTokenID, c.refreshToken, err = c.generateRefreshToken(c.sessionWriteModel.UserID) if err != nil { return err } @@ -245,7 +313,7 @@ func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) { func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) { var refreshTokenID string - refreshTokenID, c.refreshToken, err = c.generateRefreshToken() + refreshTokenID, c.refreshToken, err = c.generateRefreshToken(c.oidcSessionWriteModel.UserID) if err != nil { return err } @@ -253,12 +321,13 @@ func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) { return nil } -func (c *OIDCSessionEvents) generateRefreshToken() (refreshTokenID, refreshToken string, err error) { +func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID, refreshToken string, err error) { refreshTokenID, err = c.idGenerator.Next() if err != nil { return "", "", err } - token, err := c.encryptionAlg.Encrypt([]byte(c.oidcSessionWriteModel.AggregateID + ":" + refreshTokenID)) + refreshTokenID = RefreshTokenPrefix + refreshTokenID + token, err := c.encryptionAlg.Encrypt([]byte(fmt.Sprintf(oidcTokenFormat, c.oidcSessionWriteModel.OIDCRefreshTokenID(refreshTokenID), userID))) if err != nil { return "", "", err } @@ -276,7 +345,7 @@ func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID strin } // prefix the returned id with the oidcSessionID so that we can retrieve it later on // we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts - return c.oidcSessionWriteModel.AggregateID + "-" + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil + return c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil } func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) { diff --git a/internal/command/oidc_session_model.go b/internal/command/oidc_session_model.go index 4a25bf6f41..94725d933e 100644 --- a/internal/command/oidc_session_model.go +++ b/internal/command/oidc_session_model.go @@ -20,9 +20,11 @@ type OIDCSessionWriteModel struct { AuthMethods []domain.UserAuthMethodType AuthTime time.Time State domain.OIDCSessionState + AccessTokenID string AccessTokenCreation time.Time AccessTokenExpiration time.Time RefreshTokenID string + RefreshToken string RefreshTokenExpiration time.Time RefreshTokenIdleExpiration time.Time @@ -46,10 +48,14 @@ func (wm *OIDCSessionWriteModel) Reduce() error { wm.reduceAdded(e) case *oidcsession.AccessTokenAddedEvent: wm.reduceAccessTokenAdded(e) + case *oidcsession.AccessTokenRevokedEvent: + wm.reduceAccessTokenRevoked(e) case *oidcsession.RefreshTokenAddedEvent: wm.reduceRefreshTokenAdded(e) case *oidcsession.RefreshTokenRenewedEvent: wm.reduceRefreshTokenRenewed(e) + case *oidcsession.RefreshTokenRevokedEvent: + wm.reduceRefreshTokenRevoked(e) } } return wm.WriteModel.Reduce() @@ -65,6 +71,7 @@ func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder { oidcsession.AccessTokenAddedType, oidcsession.RefreshTokenAddedType, oidcsession.RefreshTokenRenewedType, + oidcsession.RefreshTokenRevokedType, ). Builder() @@ -91,9 +98,15 @@ func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) { } func (wm *OIDCSessionWriteModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) { + wm.AccessTokenID = e.ID wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime) } +func (wm *OIDCSessionWriteModel) reduceAccessTokenRevoked(e *oidcsession.AccessTokenRevokedEvent) { + wm.AccessTokenID = "" + wm.AccessTokenExpiration = e.CreationDate() +} + func (wm *OIDCSessionWriteModel) reduceRefreshTokenAdded(e *oidcsession.RefreshTokenAddedEvent) { wm.RefreshTokenID = e.ID wm.RefreshTokenExpiration = e.CreationDate().Add(e.Lifetime) @@ -105,6 +118,14 @@ func (wm *OIDCSessionWriteModel) reduceRefreshTokenRenewed(e *oidcsession.Refres wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime) } +func (wm *OIDCSessionWriteModel) reduceRefreshTokenRevoked(e *oidcsession.RefreshTokenRevokedEvent) { + wm.RefreshTokenID = "" + wm.RefreshTokenExpiration = e.CreationDate() + wm.RefreshTokenIdleExpiration = e.CreationDate() + wm.AccessTokenID = "" + wm.AccessTokenExpiration = e.CreationDate() +} + func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error { if wm.State != domain.OIDCSessionStateActive { return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid") @@ -118,3 +139,29 @@ func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error } return nil } + +func (wm *OIDCSessionWriteModel) CheckAccessToken(accessTokenID string) error { + if wm.State != domain.OIDCSessionStateActive { + return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-KL2pk", "Errors.OIDCSession.Token.Invalid") + } + if wm.AccessTokenID != accessTokenID { + return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JLKW2", "Errors.OIDCSession.Token.Invalid") + } + if wm.AccessTokenExpiration.Before(time.Now()) { + return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3j3md", "Errors.OIDCSession.Token.Invalid") + } + return nil +} + +func (wm *OIDCSessionWriteModel) CheckClient(clientID string) error { + for _, aud := range wm.Audience { + if aud == clientID { + return nil + } + } + return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-SKjl3", "Errors.OIDCSession.InvalidClient") +} + +func (wm *OIDCSessionWriteModel) OIDCRefreshTokenID(refreshTokenID string) string { + return wm.AggregateID + TokenDelimiter + refreshTokenID +} diff --git a/internal/command/oidc_session_test.go b/internal/command/oidc_session_test.go index 9f0d305f96..b8a3a462ea 100644 --- a/internal/command/oidc_session_test.go +++ b/internal/command/oidc_session_test.go @@ -191,7 +191,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) { ), eventFromEventPusherWithInstanceID("instanceID", oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "accessTokenID", []string{"openid"}, time.Hour), + "at_accessTokenID", []string{"openid"}, time.Hour), ), eventFromEventPusherWithInstanceID("instanceID", authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), @@ -207,7 +207,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) { authRequestID: "V2_authRequestID", }, res{ - id: "V2_oidcSessionID-accessTokenID", + id: "V2_oidcSessionID-at_accessTokenID", expiration: tokenCreationNow.Add(time.Hour), }, }, @@ -392,11 +392,11 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) { ), eventFromEventPusherWithInstanceID("instanceID", oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "accessTokenID", []string{"openid", "offline_access"}, time.Hour), + "at_accessTokenID", []string{"openid", "offline_access"}, time.Hour), ), eventFromEventPusherWithInstanceID("instanceID", oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "refreshTokenID", 7*24*time.Hour, 24*time.Hour), + "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour), ), eventFromEventPusherWithInstanceID("instanceID", authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), @@ -415,8 +415,8 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) { authRequestID: "V2_authRequestID", }, res{ - id: "V2_oidcSessionID-accessTokenID", - refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", //V2_oidcSessionID:refreshTokenID + id: "V2_oidcSessionID-at_accessTokenID", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID-rt_refreshTokenID:userID expiration: tokenCreationNow.Add(time.Hour), }, }, @@ -476,10 +476,10 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { args{ ctx: authz.WithInstanceID(context.Background(), "instanceID"), oidcSessionID: "V2_oidcSessionID", - refreshToken: "aW52YWxpZA", + refreshToken: "aW52YWxpZA", // invalid }, res{ - err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid"), + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid"), }, }, { @@ -493,7 +493,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { args{ ctx: authz.WithInstanceID(context.Background(), "instanceID"), oidcSessionID: "V2_oidcSessionID", - refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID:rt_refreshTokenID:userID }, res{ err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"), @@ -519,7 +519,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { args{ ctx: authz.WithInstanceID(context.Background(), "instanceID"), oidcSessionID: "V2_oidcSessionID", - refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID:rt_refreshTokenID:userID }, res{ err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"), @@ -536,11 +536,11 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { ), eventFromEventPusher( oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), ), eventFromEventPusher( oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "refreshTokenID", 7*24*time.Hour, 24*time.Hour), + "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour), ), ), ), @@ -549,7 +549,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { args{ ctx: authz.WithInstanceID(context.Background(), "instanceID"), oidcSessionID: "V2_oidcSessionID", - refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID:rt_refreshTokenID:userID }, res{ err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"), @@ -566,11 +566,11 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { ), eventFromEventPusherWithCreationDateNow( oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), ), eventFromEventPusherWithCreationDateNow( oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "refreshTokenID", 7*24*time.Hour, 24*time.Hour), + "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour), ), ), expectFilter(), // token lifetime @@ -578,11 +578,11 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { []*repository.Event{ eventFromEventPusherWithInstanceID("instanceID", oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "accessTokenID", []string{"openid", "offline_access"}, time.Hour), + "at_accessTokenID", []string{"openid", "offline_access"}, time.Hour), ), eventFromEventPusherWithInstanceID("instanceID", oidcsession.NewRefreshTokenRenewedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "refreshTokenID2", 24*time.Hour), + "rt_refreshTokenID2", 24*time.Hour), ), }, ), @@ -596,12 +596,12 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { args{ ctx: authz.WithInstanceID(context.Background(), "instanceID"), oidcSessionID: "V2_oidcSessionID", - refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID:rt_refreshTokenID:userID scope: []string{"openid", "offline_access"}, }, res{ - id: "V2_oidcSessionID-accessTokenID", - refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRDI", + id: "V2_oidcSessionID-at_accessTokenID", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDI6dXNlcklE", // V2_oidcSessionID-rt_refreshTokenID2:userID% expiration: time.Time{}.Add(time.Hour), }, }, @@ -672,7 +672,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { }, args{ ctx: authz.WithInstanceID(context.Background(), "instanceID"), - refreshToken: "V2_oidcSessionID:refreshTokenID", + refreshToken: "V2_oidcSessionID-rt_refreshTokenID:userID", }, res{ err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"), @@ -689,7 +689,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { ), eventFromEventPusher( oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), ), ), ), @@ -697,7 +697,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { }, args{ ctx: authz.WithInstanceID(context.Background(), "instanceID"), - refreshToken: "V2_oidcSessionID:refreshTokenID", + refreshToken: "V2_oidcSessionID-rt_refreshTokenID:userID", }, res{ err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"), @@ -714,11 +714,11 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { ), eventFromEventPusher( oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), ), eventFromEventPusher( oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "refreshTokenID", 7*24*time.Hour, 24*time.Hour), + "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour), ), ), ), @@ -726,7 +726,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { }, args{ ctx: authz.WithInstanceID(context.Background(), "instanceID"), - refreshToken: "V2_oidcSessionID:refreshTokenID", + refreshToken: "V2_oidcSessionID-rt_refreshTokenID:userID", }, res{ err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"), @@ -743,11 +743,11 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { ), eventFromEventPusherWithCreationDateNow( oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), ), eventFromEventPusherWithCreationDateNow( oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, - "refreshTokenID", 7*24*time.Hour, 24*time.Hour), + "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour), ), ), ), @@ -755,7 +755,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { }, args{ ctx: authz.WithInstanceID(context.Background(), "instanceID"), - refreshToken: "V2_oidcSessionID:refreshTokenID", + refreshToken: "V2_oidcSessionID-rt_refreshTokenID:userID", }, res{ model: &OIDCSessionWriteModel{ @@ -771,7 +771,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, AuthTime: testNow, State: domain.OIDCSessionStateActive, - RefreshTokenID: "refreshTokenID", + RefreshTokenID: "rt_refreshTokenID", RefreshTokenExpiration: testNow.Add(7 * 24 * time.Hour), RefreshTokenIdleExpiration: testNow.Add(24 * time.Hour), }, @@ -808,3 +808,207 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { }) } } + +func TestCommands_RevokeOIDCSessionToken(t *testing.T) { + type fields struct { + eventstore *eventstore.Eventstore + keyAlgorithm crypto.EncryptionAlgorithm + } + type args struct { + ctx context.Context + token string + clientID string + } + type res struct { + err error + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "invalid token", + fields{ + eventstore: eventstoreExpect(t), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + token: "invalid", + }, + res{ + err: nil, + }, + }, + { + "refresh_token inactive", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "userID", "sessionID", "clientID", []string{"clientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), + ), + ), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + token: "V2_oidcSessionID-rt_refreshTokenID", + clientID: "clientID", + }, + res{ + err: nil, + }, + }, + { + "refresh_token invalid client", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "userID", "sessionID", "otherClientID", []string{"otherClientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), + ), + ), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + token: "V2_oidcSessionID-rt_refreshTokenID", + clientID: "clientID", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-SKjl3", "Errors.OIDCSession.InvalidClient"), + }, + }, + { + "refresh_token revoked", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "userID", "sessionID", "clientID", []string{"clientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), + ), + eventFromEventPusherWithCreationDateNow( + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + ), + eventFromEventPusherWithCreationDateNow( + oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour), + ), + ), + expectPush([]*repository.Event{ + eventFromEventPusherWithInstanceID("instanceID", + oidcsession.NewRefreshTokenRevokedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate), + ), + }), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + token: "V2_oidcSessionID-rt_refreshTokenID", + clientID: "clientID", + }, + res{ + err: nil, + }, + }, + { + "access_token inactive session", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "userID", "sessionID", "clientID", []string{"clientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), + ), + ), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + token: "V2_oidcSessionID-at_accessTokenID", + clientID: "clientID", + }, + res{ + err: nil, + }, + }, + { + "access_token invalid client", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "userID", "sessionID", "otherClientID", []string{"otherClientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), + ), + ), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + token: "V2_oidcSessionID-at_accessTokenID", + clientID: "clientID", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-SKjl3", "Errors.OIDCSession.InvalidClient"), + }, + }, + { + "access_token revoked", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "userID", "sessionID", "clientID", []string{"clientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), + ), + eventFromEventPusherWithCreationDateNow( + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + ), + eventFromEventPusherWithCreationDateNow( + oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour), + ), + ), + expectPush([]*repository.Event{ + eventFromEventPusherWithInstanceID("instanceID", + oidcsession.NewAccessTokenRevokedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate), + ), + }), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + token: "V2_oidcSessionID-at_accessTokenID", + clientID: "clientID", + }, + res{ + err: nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + keyAlgorithm: tt.fields.keyAlgorithm, + } + err := c.RevokeOIDCSessionToken(tt.args.ctx, tt.args.token, tt.args.clientID) + require.ErrorIs(t, err, tt.res.err) + }) + } +} diff --git a/internal/query/access_token.go b/internal/query/access_token.go index 617c3c0623..78c7778268 100644 --- a/internal/query/access_token.go +++ b/internal/query/access_token.go @@ -43,6 +43,9 @@ func (wm *OIDCSessionAccessTokenReadModel) Reduce() error { wm.reduceAdded(e) case *oidcsession.AccessTokenAddedEvent: wm.reduceAccessTokenAdded(e) + case *oidcsession.AccessTokenRevokedEvent, + *oidcsession.RefreshTokenRevokedEvent: + wm.reduceTokenRevoked(event) } } return wm.WriteModel.Reduce() @@ -57,6 +60,8 @@ func (wm *OIDCSessionAccessTokenReadModel) Query() *eventstore.SearchQueryBuilde EventTypes( oidcsession.AddedType, oidcsession.AccessTokenAddedType, + oidcsession.AccessTokenRevokedType, + oidcsession.RefreshTokenRevokedType, ). Builder() } @@ -78,6 +83,11 @@ func (wm *OIDCSessionAccessTokenReadModel) reduceAccessTokenAdded(e *oidcsession wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime) } +func (wm *OIDCSessionAccessTokenReadModel) reduceTokenRevoked(e eventstore.Event) { + wm.AccessTokenID = "" + wm.AccessTokenExpiration = e.CreationDate() +} + // ActiveAccessTokenByToken will check if the token is active by retrieving the OIDCSession events from the eventstore. // refreshed or expired tokens will return an error func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (model *OIDCSessionAccessTokenReadModel, err error) { diff --git a/internal/repository/oidcsession/eventstore.go b/internal/repository/oidcsession/eventstore.go index 88c78f4593..091bcf5d44 100644 --- a/internal/repository/oidcsession/eventstore.go +++ b/internal/repository/oidcsession/eventstore.go @@ -1,11 +1,15 @@ package oidcsession -import "github.com/zitadel/zitadel/internal/eventstore" +import ( + "github.com/zitadel/zitadel/internal/eventstore" +) func RegisterEventMappers(es *eventstore.Eventstore) { - es.RegisterFilterEventMapper(AggregateType, AddedType, AddedEventMapper). - RegisterFilterEventMapper(AggregateType, AccessTokenAddedType, AccessTokenAddedEventMapper). - RegisterFilterEventMapper(AggregateType, RefreshTokenAddedType, RefreshTokenAddedEventMapper). - RegisterFilterEventMapper(AggregateType, RefreshTokenRenewedType, RefreshTokenRenewedEventMapper) + es.RegisterFilterEventMapper(AggregateType, AddedType, eventstore.GenericEventMapper[AddedEvent]). + RegisterFilterEventMapper(AggregateType, AccessTokenAddedType, eventstore.GenericEventMapper[AccessTokenAddedEvent]). + RegisterFilterEventMapper(AggregateType, AccessTokenRevokedType, eventstore.GenericEventMapper[AccessTokenRevokedEvent]). + RegisterFilterEventMapper(AggregateType, RefreshTokenAddedType, eventstore.GenericEventMapper[RefreshTokenAddedEvent]). + RegisterFilterEventMapper(AggregateType, RefreshTokenRenewedType, eventstore.GenericEventMapper[RefreshTokenRenewedEvent]). + RegisterFilterEventMapper(AggregateType, RefreshTokenRevokedType, eventstore.GenericEventMapper[RefreshTokenRevokedEvent]) } diff --git a/internal/repository/oidcsession/oidc_session.go b/internal/repository/oidcsession/oidc_session.go index b128c45dc2..8887b61926 100644 --- a/internal/repository/oidcsession/oidc_session.go +++ b/internal/repository/oidcsession/oidc_session.go @@ -2,21 +2,20 @@ package oidcsession import ( "context" - "encoding/json" "time" "github.com/zitadel/zitadel/internal/domain" - "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore" - "github.com/zitadel/zitadel/internal/eventstore/repository" ) const ( oidcSessionEventPrefix = "oidc_session." AddedType = oidcSessionEventPrefix + "added" AccessTokenAddedType = oidcSessionEventPrefix + "access_token.added" + AccessTokenRevokedType = oidcSessionEventPrefix + "access_token.revoked" RefreshTokenAddedType = oidcSessionEventPrefix + "refresh_token.added" RefreshTokenRenewedType = oidcSessionEventPrefix + "refresh_token.renewed" + RefreshTokenRevokedType = oidcSessionEventPrefix + "refresh_token.revoked" ) type AddedEvent struct { @@ -39,6 +38,10 @@ func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { return nil } +func (e *AddedEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = *event +} + func NewAddedEvent(ctx context.Context, aggregate *eventstore.Aggregate, userID, @@ -65,18 +68,6 @@ func NewAddedEvent(ctx context.Context, } } -func AddedEventMapper(event *repository.Event) (eventstore.Event, error) { - added := &AddedEvent{ - BaseEvent: *eventstore.BaseEventFromRepo(event), - } - err := json.Unmarshal(event.Data, added) - if err != nil { - return nil, errors.ThrowInternal(err, "OIDCS-DG4gn", "unable to unmarshal oidc session added") - } - - return added, nil -} - type AccessTokenAddedEvent struct { eventstore.BaseEvent `json:"-"` @@ -93,6 +84,10 @@ func (e *AccessTokenAddedEvent) UniqueConstraints() []*eventstore.EventUniqueCon return nil } +func (e *AccessTokenAddedEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = *event +} + func NewAccessTokenAddedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -112,16 +107,33 @@ func NewAccessTokenAddedEvent( } } -func AccessTokenAddedEventMapper(event *repository.Event) (eventstore.Event, error) { - added := &AccessTokenAddedEvent{ - BaseEvent: *eventstore.BaseEventFromRepo(event), - } - err := json.Unmarshal(event.Data, added) - if err != nil { - return nil, errors.ThrowInternal(err, "OIDCS-DSGn5", "unable to unmarshal access token added") - } +type AccessTokenRevokedEvent struct { + eventstore.BaseEvent `json:"-"` +} - return added, nil +func (e *AccessTokenRevokedEvent) Data() interface{} { + return e +} + +func (e *AccessTokenRevokedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func (e *AccessTokenRevokedEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = *event +} + +func NewAccessTokenRevokedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, +) *AccessTokenAddedEvent { + return &AccessTokenAddedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + AccessTokenRevokedType, + ), + } } type RefreshTokenAddedEvent struct { @@ -140,6 +152,10 @@ func (e *RefreshTokenAddedEvent) UniqueConstraints() []*eventstore.EventUniqueCo return nil } +func (e *RefreshTokenAddedEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = *event +} + func NewRefreshTokenAddedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -159,18 +175,6 @@ func NewRefreshTokenAddedEvent( } } -func RefreshTokenAddedEventMapper(event *repository.Event) (eventstore.Event, error) { - added := &RefreshTokenAddedEvent{ - BaseEvent: *eventstore.BaseEventFromRepo(event), - } - err := json.Unmarshal(event.Data, added) - if err != nil { - return nil, errors.ThrowInternal(err, "OIDCS-aW3gqq", "unable to unmarshal refresh token added") - } - - return added, nil -} - type RefreshTokenRenewedEvent struct { eventstore.BaseEvent `json:"-"` @@ -186,6 +190,10 @@ func (e *RefreshTokenRenewedEvent) UniqueConstraints() []*eventstore.EventUnique return nil } +func (e *RefreshTokenRenewedEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = *event +} + func NewRefreshTokenRenewedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -203,14 +211,31 @@ func NewRefreshTokenRenewedEvent( } } -func RefreshTokenRenewedEventMapper(event *repository.Event) (eventstore.Event, error) { - added := &RefreshTokenRenewedEvent{ - BaseEvent: *eventstore.BaseEventFromRepo(event), - } - err := json.Unmarshal(event.Data, added) - if err != nil { - return nil, errors.ThrowInternal(err, "OIDCS-SF3fc", "unable to unmarshal refresh token renewed") - } - - return added, nil +type RefreshTokenRevokedEvent struct { + eventstore.BaseEvent `json:"-"` +} + +func (e *RefreshTokenRevokedEvent) Data() interface{} { + return e +} + +func (e *RefreshTokenRevokedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func (e *RefreshTokenRevokedEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = *event +} + +func NewRefreshTokenRevokedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, +) *RefreshTokenRevokedEvent { + return &RefreshTokenRevokedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + RefreshTokenRevokedType, + ), + } } diff --git a/internal/static/i18n/de.yaml b/internal/static/i18n/de.yaml index dbf851c299..929ce1db6d 100644 --- a/internal/static/i18n/de.yaml +++ b/internal/static/i18n/de.yaml @@ -499,6 +499,7 @@ Errors: Token: Invalid: Token ist ungültig Expired: Token ist abgelaufen + InvalidClient: Token wurde nicht für diesen Client ausgestellt AggregateTypes: action: Action diff --git a/internal/static/i18n/en.yaml b/internal/static/i18n/en.yaml index 15b60ae4a9..f623dc2133 100644 --- a/internal/static/i18n/en.yaml +++ b/internal/static/i18n/en.yaml @@ -499,6 +499,7 @@ Errors: Token: Invalid: Token is invalid Expired: Token is expired + InvalidClient: Token was not issued for this client AggregateTypes: action: Action diff --git a/internal/static/i18n/es.yaml b/internal/static/i18n/es.yaml index eaf7ff8a18..097bd9e7c1 100644 --- a/internal/static/i18n/es.yaml +++ b/internal/static/i18n/es.yaml @@ -499,6 +499,7 @@ Errors: Token: Invalid: El token no es válido Expired: El token ha caducado + InvalidClient: El token no ha sido emitido para este cliente AggregateTypes: action: Acción diff --git a/internal/static/i18n/fr.yaml b/internal/static/i18n/fr.yaml index 640c9c7866..7fba4deaf2 100644 --- a/internal/static/i18n/fr.yaml +++ b/internal/static/i18n/fr.yaml @@ -499,6 +499,7 @@ Errors: Token: Invalid: Le jeton n'est pas valide Expired: Le jeton est expiré + InvalidClient: Le token n'a pas été émis pour ce client AggregateTypes: action: Action diff --git a/internal/static/i18n/it.yaml b/internal/static/i18n/it.yaml index 4e1e13c614..4981dc3070 100644 --- a/internal/static/i18n/it.yaml +++ b/internal/static/i18n/it.yaml @@ -499,6 +499,7 @@ Errors: Token: Invalid: Token non è valido Expired: Token è scaduto + InvalidClient: Il token non è stato emesso per questo cliente AggregateTypes: action: Azione diff --git a/internal/static/i18n/ja.yaml b/internal/static/i18n/ja.yaml index 659d805f3b..d0ba74e274 100644 --- a/internal/static/i18n/ja.yaml +++ b/internal/static/i18n/ja.yaml @@ -488,6 +488,7 @@ Errors: Token: Invalid: トークンが無効です Expired: トークンの有効期限が切れている + InvalidClient: トークンが発行されていません AggregateTypes: action: アクション diff --git a/internal/static/i18n/mk.yaml b/internal/static/i18n/mk.yaml index 2ea8933bcd..a02b79dc91 100644 --- a/internal/static/i18n/mk.yaml +++ b/internal/static/i18n/mk.yaml @@ -499,6 +499,7 @@ Errors: Token: Invalid: токенот е неважечки Expired: токенот е истечен + InvalidClient: Токен не беше издаден на овој клиент AggregateTypes: action: Акција diff --git a/internal/static/i18n/pl.yaml b/internal/static/i18n/pl.yaml index 4501cc6844..3e48f45825 100644 --- a/internal/static/i18n/pl.yaml +++ b/internal/static/i18n/pl.yaml @@ -499,6 +499,7 @@ Errors: Token: Invalid: Token jest nieprawidłowy Expired: Token wygasł + InvalidClient: Token nie został wydany dla tego klienta AggregateTypes: action: Działanie diff --git a/internal/static/i18n/zh.yaml b/internal/static/i18n/zh.yaml index 21ba558c4e..1209fd5159 100644 --- a/internal/static/i18n/zh.yaml +++ b/internal/static/i18n/zh.yaml @@ -499,6 +499,7 @@ Errors: Token: Invalid: 令牌无效 Expired: 令牌已过期 + InvalidClient: 没有为该客户发放令牌 AggregateTypes: action: 动作