feat(OIDC): support token revocation of V2 tokens (#6203)

This PR adds support for OAuth2 token revocation of V2 tokens.

Unlike with V1 tokens, it's now possible to revoke a token not only from the authorized client / client which the token was issued to, but rather from all trusted clients (audience)
This commit is contained in:
Livio Spring 2023-07-17 14:33:37 +02:00 committed by GitHub
parent ecf9835cb8
commit e1b3cda98a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 689 additions and 102 deletions

View File

@ -265,12 +265,12 @@ func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
plainCode, err := o.decryptGrant(refreshToken) plainToken, err := o.decryptGrant(refreshToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if strings.HasPrefix(plainCode, command.IDPrefixV2) { if strings.HasPrefix(plainToken, command.IDPrefixV2) {
oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainCode) oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -308,7 +308,25 @@ func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID strin
return err return err
} }
func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID string) *oidc.Error { func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID string) (err *oidc.Error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if strings.HasPrefix(token, command.IDPrefixV2) {
err := o.command.RevokeOIDCSessionToken(ctx, token, clientID)
if err == nil {
return nil
}
if errors.IsPreconditionFailed(err) {
return oidc.ErrInvalidClient().WithDescription("token was not issued for this client")
}
return oidc.ErrServerError().WithParent(err)
}
return o.revokeTokenV1(ctx, token, userID, clientID)
}
func (o *OPStorage) revokeTokenV1(ctx context.Context, token, userID, clientID string) *oidc.Error {
refreshToken, err := o.repo.RefreshTokenByID(ctx, token, userID) refreshToken, err := o.repo.RefreshTokenByID(ctx, token, userID)
if err == nil { if err == nil {
if refreshToken.ClientID != clientID { if refreshToken.ClientID != clientID {
@ -338,6 +356,17 @@ func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID str
} }
func (o *OPStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) { func (o *OPStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) {
plainToken, err := o.decryptGrant(token)
if err != nil {
return "", "", op.ErrInvalidRefreshToken
}
if strings.HasPrefix(plainToken, command.IDPrefixV2) {
oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainToken)
if err != nil {
return "", "", op.ErrInvalidRefreshToken
}
return oidcSession.UserID, oidcSession.OIDCRefreshTokenID(oidcSession.RefreshTokenID), nil
}
refreshToken, err := o.repo.RefreshTokenByToken(ctx, token) refreshToken, err := o.repo.RefreshTokenByToken(ctx, token)
if err != nil { if err != nil {
return "", "", op.ErrInvalidRefreshToken return "", "", op.ErrInvalidRefreshToken

View File

@ -184,6 +184,196 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
require.Error(t, err) require.Error(t, err)
} }
func TestOPStorage_RevokeToken_access_token(t *testing.T) {
clientID := createClient(t)
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
// revoke access token
err = rp.RevokeToken(provider, tokens.AccessToken, "access_token")
require.NoError(t, err)
// userinfo must fail
_, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
require.Error(t, err)
// refresh grant must still work
_, err = refreshTokens(t, clientID, tokens.RefreshToken)
require.NoError(t, err)
// revocation with the same access token must not fail (with or without hint)
err = rp.RevokeToken(provider, tokens.AccessToken, "access_token")
require.NoError(t, err)
err = rp.RevokeToken(provider, tokens.AccessToken, "")
require.NoError(t, err)
}
func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T) {
clientID := createClient(t)
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
// revoke access token
err = rp.RevokeToken(provider, tokens.AccessToken, "refresh_token")
require.NoError(t, err)
// userinfo must fail
_, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
require.Error(t, err)
// refresh grant must still work
_, err = refreshTokens(t, clientID, tokens.RefreshToken)
require.NoError(t, err)
}
func TestOPStorage_RevokeToken_refresh_token(t *testing.T) {
clientID := createClient(t)
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
// revoke refresh token -> invalidates also access token
err = rp.RevokeToken(provider, tokens.RefreshToken, "refresh_token")
require.NoError(t, err)
// userinfo must fail
_, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
require.Error(t, err)
// refresh must fail
_, err = refreshTokens(t, clientID, tokens.RefreshToken)
require.Error(t, err)
// revocation with the same refresh token must not fail (with or without hint)
err = rp.RevokeToken(provider, tokens.RefreshToken, "refresh_token")
require.NoError(t, err)
err = rp.RevokeToken(provider, tokens.RefreshToken, "")
require.NoError(t, err)
}
func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing.T) {
clientID := createClient(t)
provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
require.NoError(t, err)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
// revoke refresh token even with a wrong hint
err = rp.RevokeToken(provider, tokens.RefreshToken, "access_token")
require.NoError(t, err)
// userinfo must fail
_, err = rp.Userinfo(tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
require.Error(t, err)
// refresh must fail
_, err = refreshTokens(t, clientID, tokens.RefreshToken)
require.Error(t, err)
}
func TestOPStorage_RevokeToken_invalid_client(t *testing.T) {
clientID := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId())
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
// simulate second client (not part of the audience) trying to revoke the token
otherClientID := createClient(t)
provider, err := Tester.CreateRelyingParty(otherClientID, redirectURI)
require.NoError(t, err)
err = rp.RevokeToken(provider, tokens.AccessToken, "")
require.Error(t, err)
}
func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) { func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
provider, err := Tester.CreateRelyingParty(clientID, redirectURI) provider, err := Tester.CreateRelyingParty(clientID, redirectURI)
require.NoError(t, err) require.NoError(t, err)

View File

@ -3,9 +3,12 @@ package command
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"fmt"
"strings" "strings"
"time" "time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
@ -17,6 +20,14 @@ import (
"github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/repository/user"
) )
const (
TokenDelimiter = "-"
AccessTokenPrefix = "at_"
RefreshTokenPrefix = "rt_"
oidcTokenSubjectDelimiter = ":"
oidcTokenFormat = "%s" + oidcTokenSubjectDelimiter + "%s"
)
// AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration. // AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration.
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged. // If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) { func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) {
@ -71,21 +82,70 @@ func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context,
// OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant). // OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant).
// If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned. // If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned.
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) { func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) {
split := strings.Split(refreshToken, ":") oidcSessionID, refreshTokenID, err := parseRefreshToken(refreshToken)
if len(split) != 2 { if err != nil {
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid") return nil, err
} }
writeModel := NewOIDCSessionWriteModel(split[0], "") writeModel := NewOIDCSessionWriteModel(oidcSessionID, "")
err := c.eventstore.FilterToQueryReducer(ctx, writeModel) err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
if err != nil { if err != nil {
return nil, caos_errs.ThrowPreconditionFailed(err, "OIDCS-SAF31", "Errors.OIDCSession.RefreshTokenInvalid") return nil, caos_errs.ThrowPreconditionFailed(err, "OIDCS-SAF31", "Errors.OIDCSession.RefreshTokenInvalid")
} }
if err = writeModel.CheckRefreshToken(split[1]); err != nil { if err = writeModel.CheckRefreshToken(refreshTokenID); err != nil {
return nil, err return nil, err
} }
return writeModel, nil return writeModel, nil
} }
func oidcSessionTokenIDsFromToken(token string) (oidcSessionID, refreshTokenID, accessTokenID string, err error) {
split := strings.Split(token, TokenDelimiter)
if len(split) != 2 {
return "", "", "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-S87kl", "Errors.OIDCSession.Token.Invalid")
}
if strings.HasPrefix(split[1], RefreshTokenPrefix) {
return split[0], split[1], "", nil
}
if strings.HasPrefix(split[1], AccessTokenPrefix) {
return split[0], "", split[1], nil
}
return "", "", "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-S87kl", "Errors.OIDCSession.Token.Invalid")
}
// RevokeOIDCSessionToken revokes an access_token or refresh_token
// if the OIDCSession cannot be retrieved by the provided token, is not active or if the token is already revoked,
// then no error will be returned.
// The only possible error (except db connection or other internal errors) occurs if a client tries to revoke a token,
// which was not part of the audience.
func (c *Commands) RevokeOIDCSessionToken(ctx context.Context, token, clientID string) (err error) {
oidcSessionID, refreshTokenID, accessTokenID, err := oidcSessionTokenIDsFromToken(token)
if err != nil {
logging.WithError(err).Info("token revocation with invalid token format")
return nil
}
writeModel := NewOIDCSessionWriteModel(oidcSessionID, "")
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
if err != nil {
return caos_errs.ThrowInternal(err, "OIDCS-NB3t2", "Errors.Internal")
}
if err = writeModel.CheckClient(clientID); err != nil {
return err
}
if refreshTokenID != "" {
if err = writeModel.CheckRefreshToken(refreshTokenID); err != nil {
logging.WithFields("oidcSessionID", oidcSessionID, "refreshTokenID", refreshTokenID).WithError(err).
Info("refresh token revocation with invalid token")
return nil
}
return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewRefreshTokenRevokedEvent(ctx, writeModel.aggregate))
}
if err = writeModel.CheckAccessToken(accessTokenID); err != nil {
logging.WithFields("oidcSessionID", oidcSessionID, "accessTokenID", accessTokenID).WithError(err).
Info("access token revocation with invalid token")
return nil
}
return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewAccessTokenRevokedEvent(ctx, writeModel.aggregate))
}
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) { func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) {
authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID) authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
if err != nil { if err != nil {
@ -153,11 +213,18 @@ func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID stri
if err != nil { if err != nil {
return "", err return "", err
} }
split := strings.Split(decrypted, ":") _, refreshTokenID, err = parseRefreshToken(decrypted)
if len(split) != 2 { return refreshTokenID, err
return "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid") }
func parseRefreshToken(refreshToken string) (oidcSessionID, refreshTokenID string, err error) {
split := strings.Split(refreshToken, TokenDelimiter)
if len(split) < 2 || !strings.HasPrefix(split[1], RefreshTokenPrefix) {
return "", "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")
} }
return split[1], nil // the oidc library requires that every token has the format of <tokenID>:<userID>
// the V2 tokens don't use the userID anymore, so let's just remove it
return split[0], strings.Split(split[1], oidcTokenSubjectDelimiter)[0], nil
} }
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) { func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) {
@ -224,18 +291,19 @@ func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) {
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate)) c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate))
} }
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) (err error) { func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) error {
c.accessTokenID, err = c.idGenerator.Next() accessTokenID, err := c.idGenerator.Next()
if err != nil { if err != nil {
return err return err
} }
c.accessTokenID = AccessTokenPrefix + accessTokenID
c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime)) c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime))
return nil return nil
} }
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) { func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
var refreshTokenID string var refreshTokenID string
refreshTokenID, c.refreshToken, err = c.generateRefreshToken() refreshTokenID, c.refreshToken, err = c.generateRefreshToken(c.sessionWriteModel.UserID)
if err != nil { if err != nil {
return err return err
} }
@ -245,7 +313,7 @@ func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) { func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
var refreshTokenID string var refreshTokenID string
refreshTokenID, c.refreshToken, err = c.generateRefreshToken() refreshTokenID, c.refreshToken, err = c.generateRefreshToken(c.oidcSessionWriteModel.UserID)
if err != nil { if err != nil {
return err return err
} }
@ -253,12 +321,13 @@ func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
return nil return nil
} }
func (c *OIDCSessionEvents) generateRefreshToken() (refreshTokenID, refreshToken string, err error) { func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID, refreshToken string, err error) {
refreshTokenID, err = c.idGenerator.Next() refreshTokenID, err = c.idGenerator.Next()
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
token, err := c.encryptionAlg.Encrypt([]byte(c.oidcSessionWriteModel.AggregateID + ":" + refreshTokenID)) refreshTokenID = RefreshTokenPrefix + refreshTokenID
token, err := c.encryptionAlg.Encrypt([]byte(fmt.Sprintf(oidcTokenFormat, c.oidcSessionWriteModel.OIDCRefreshTokenID(refreshTokenID), userID)))
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -276,7 +345,7 @@ func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID strin
} }
// prefix the returned id with the oidcSessionID so that we can retrieve it later on // prefix the returned id with the oidcSessionID so that we can retrieve it later on
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts // we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
return c.oidcSessionWriteModel.AggregateID + "-" + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil return c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil
} }
func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) { func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) {

View File

@ -20,9 +20,11 @@ type OIDCSessionWriteModel struct {
AuthMethods []domain.UserAuthMethodType AuthMethods []domain.UserAuthMethodType
AuthTime time.Time AuthTime time.Time
State domain.OIDCSessionState State domain.OIDCSessionState
AccessTokenID string
AccessTokenCreation time.Time AccessTokenCreation time.Time
AccessTokenExpiration time.Time AccessTokenExpiration time.Time
RefreshTokenID string RefreshTokenID string
RefreshToken string
RefreshTokenExpiration time.Time RefreshTokenExpiration time.Time
RefreshTokenIdleExpiration time.Time RefreshTokenIdleExpiration time.Time
@ -46,10 +48,14 @@ func (wm *OIDCSessionWriteModel) Reduce() error {
wm.reduceAdded(e) wm.reduceAdded(e)
case *oidcsession.AccessTokenAddedEvent: case *oidcsession.AccessTokenAddedEvent:
wm.reduceAccessTokenAdded(e) wm.reduceAccessTokenAdded(e)
case *oidcsession.AccessTokenRevokedEvent:
wm.reduceAccessTokenRevoked(e)
case *oidcsession.RefreshTokenAddedEvent: case *oidcsession.RefreshTokenAddedEvent:
wm.reduceRefreshTokenAdded(e) wm.reduceRefreshTokenAdded(e)
case *oidcsession.RefreshTokenRenewedEvent: case *oidcsession.RefreshTokenRenewedEvent:
wm.reduceRefreshTokenRenewed(e) wm.reduceRefreshTokenRenewed(e)
case *oidcsession.RefreshTokenRevokedEvent:
wm.reduceRefreshTokenRevoked(e)
} }
} }
return wm.WriteModel.Reduce() return wm.WriteModel.Reduce()
@ -65,6 +71,7 @@ func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder {
oidcsession.AccessTokenAddedType, oidcsession.AccessTokenAddedType,
oidcsession.RefreshTokenAddedType, oidcsession.RefreshTokenAddedType,
oidcsession.RefreshTokenRenewedType, oidcsession.RefreshTokenRenewedType,
oidcsession.RefreshTokenRevokedType,
). ).
Builder() Builder()
@ -91,9 +98,15 @@ func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
} }
func (wm *OIDCSessionWriteModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) { func (wm *OIDCSessionWriteModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
wm.AccessTokenID = e.ID
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime) wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
} }
func (wm *OIDCSessionWriteModel) reduceAccessTokenRevoked(e *oidcsession.AccessTokenRevokedEvent) {
wm.AccessTokenID = ""
wm.AccessTokenExpiration = e.CreationDate()
}
func (wm *OIDCSessionWriteModel) reduceRefreshTokenAdded(e *oidcsession.RefreshTokenAddedEvent) { func (wm *OIDCSessionWriteModel) reduceRefreshTokenAdded(e *oidcsession.RefreshTokenAddedEvent) {
wm.RefreshTokenID = e.ID wm.RefreshTokenID = e.ID
wm.RefreshTokenExpiration = e.CreationDate().Add(e.Lifetime) wm.RefreshTokenExpiration = e.CreationDate().Add(e.Lifetime)
@ -105,6 +118,14 @@ func (wm *OIDCSessionWriteModel) reduceRefreshTokenRenewed(e *oidcsession.Refres
wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime) wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
} }
func (wm *OIDCSessionWriteModel) reduceRefreshTokenRevoked(e *oidcsession.RefreshTokenRevokedEvent) {
wm.RefreshTokenID = ""
wm.RefreshTokenExpiration = e.CreationDate()
wm.RefreshTokenIdleExpiration = e.CreationDate()
wm.AccessTokenID = ""
wm.AccessTokenExpiration = e.CreationDate()
}
func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error { func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error {
if wm.State != domain.OIDCSessionStateActive { if wm.State != domain.OIDCSessionStateActive {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid") return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid")
@ -118,3 +139,29 @@ func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error
} }
return nil return nil
} }
func (wm *OIDCSessionWriteModel) CheckAccessToken(accessTokenID string) error {
if wm.State != domain.OIDCSessionStateActive {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-KL2pk", "Errors.OIDCSession.Token.Invalid")
}
if wm.AccessTokenID != accessTokenID {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JLKW2", "Errors.OIDCSession.Token.Invalid")
}
if wm.AccessTokenExpiration.Before(time.Now()) {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3j3md", "Errors.OIDCSession.Token.Invalid")
}
return nil
}
func (wm *OIDCSessionWriteModel) CheckClient(clientID string) error {
for _, aud := range wm.Audience {
if aud == clientID {
return nil
}
}
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-SKjl3", "Errors.OIDCSession.InvalidClient")
}
func (wm *OIDCSessionWriteModel) OIDCRefreshTokenID(refreshTokenID string) string {
return wm.AggregateID + TokenDelimiter + refreshTokenID
}

View File

@ -191,7 +191,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
), ),
eventFromEventPusherWithInstanceID("instanceID", eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"accessTokenID", []string{"openid"}, time.Hour), "at_accessTokenID", []string{"openid"}, time.Hour),
), ),
eventFromEventPusherWithInstanceID("instanceID", eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
@ -207,7 +207,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
authRequestID: "V2_authRequestID", authRequestID: "V2_authRequestID",
}, },
res{ res{
id: "V2_oidcSessionID-accessTokenID", id: "V2_oidcSessionID-at_accessTokenID",
expiration: tokenCreationNow.Add(time.Hour), expiration: tokenCreationNow.Add(time.Hour),
}, },
}, },
@ -392,11 +392,11 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
), ),
eventFromEventPusherWithInstanceID("instanceID", eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"accessTokenID", []string{"openid", "offline_access"}, time.Hour), "at_accessTokenID", []string{"openid", "offline_access"}, time.Hour),
), ),
eventFromEventPusherWithInstanceID("instanceID", eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour), "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour),
), ),
eventFromEventPusherWithInstanceID("instanceID", eventFromEventPusherWithInstanceID("instanceID",
authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
@ -415,8 +415,8 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
authRequestID: "V2_authRequestID", authRequestID: "V2_authRequestID",
}, },
res{ res{
id: "V2_oidcSessionID-accessTokenID", id: "V2_oidcSessionID-at_accessTokenID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", //V2_oidcSessionID:refreshTokenID refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID-rt_refreshTokenID:userID
expiration: tokenCreationNow.Add(time.Hour), expiration: tokenCreationNow.Add(time.Hour),
}, },
}, },
@ -476,10 +476,10 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
args{ args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID", oidcSessionID: "V2_oidcSessionID",
refreshToken: "aW52YWxpZA", refreshToken: "aW52YWxpZA", // invalid
}, },
res{ res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid"), err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid"),
}, },
}, },
{ {
@ -493,7 +493,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
args{ args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID", oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID:rt_refreshTokenID:userID
}, },
res{ res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"), err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"),
@ -519,7 +519,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
args{ args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID", oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID:rt_refreshTokenID:userID
}, },
res{ res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"), err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"),
@ -536,11 +536,11 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
), ),
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
), ),
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour), "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour),
), ),
), ),
), ),
@ -549,7 +549,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
args{ args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID", oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID:rt_refreshTokenID:userID
}, },
res{ res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"), err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"),
@ -566,11 +566,11 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
), ),
eventFromEventPusherWithCreationDateNow( eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
), ),
eventFromEventPusherWithCreationDateNow( eventFromEventPusherWithCreationDateNow(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour), "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour),
), ),
), ),
expectFilter(), // token lifetime expectFilter(), // token lifetime
@ -578,11 +578,11 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
[]*repository.Event{ []*repository.Event{
eventFromEventPusherWithInstanceID("instanceID", eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"accessTokenID", []string{"openid", "offline_access"}, time.Hour), "at_accessTokenID", []string{"openid", "offline_access"}, time.Hour),
), ),
eventFromEventPusherWithInstanceID("instanceID", eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewRefreshTokenRenewedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewRefreshTokenRenewedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"refreshTokenID2", 24*time.Hour), "rt_refreshTokenID2", 24*time.Hour),
), ),
}, },
), ),
@ -596,12 +596,12 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
args{ args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcSessionID: "V2_oidcSessionID", oidcSessionID: "V2_oidcSessionID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID:rt_refreshTokenID:userID
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
}, },
res{ res{
id: "V2_oidcSessionID-accessTokenID", id: "V2_oidcSessionID-at_accessTokenID",
refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRDI", refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDI6dXNlcklE", // V2_oidcSessionID-rt_refreshTokenID2:userID%
expiration: time.Time{}.Add(time.Hour), expiration: time.Time{}.Add(time.Hour),
}, },
}, },
@ -672,7 +672,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
}, },
args{ args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID", refreshToken: "V2_oidcSessionID-rt_refreshTokenID:userID",
}, },
res{ res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"), err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"),
@ -689,7 +689,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
), ),
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
), ),
), ),
), ),
@ -697,7 +697,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
}, },
args{ args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID", refreshToken: "V2_oidcSessionID-rt_refreshTokenID:userID",
}, },
res{ res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"), err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"),
@ -714,11 +714,11 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
), ),
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
), ),
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour), "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour),
), ),
), ),
), ),
@ -726,7 +726,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
}, },
args{ args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID", refreshToken: "V2_oidcSessionID-rt_refreshTokenID:userID",
}, },
res{ res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"), err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"),
@ -743,11 +743,11 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
), ),
eventFromEventPusherWithCreationDateNow( eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), "at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
), ),
eventFromEventPusherWithCreationDateNow( eventFromEventPusherWithCreationDateNow(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"refreshTokenID", 7*24*time.Hour, 24*time.Hour), "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour),
), ),
), ),
), ),
@ -755,7 +755,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
}, },
args{ args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "V2_oidcSessionID:refreshTokenID", refreshToken: "V2_oidcSessionID-rt_refreshTokenID:userID",
}, },
res{ res{
model: &OIDCSessionWriteModel{ model: &OIDCSessionWriteModel{
@ -771,7 +771,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
AuthTime: testNow, AuthTime: testNow,
State: domain.OIDCSessionStateActive, State: domain.OIDCSessionStateActive,
RefreshTokenID: "refreshTokenID", RefreshTokenID: "rt_refreshTokenID",
RefreshTokenExpiration: testNow.Add(7 * 24 * time.Hour), RefreshTokenExpiration: testNow.Add(7 * 24 * time.Hour),
RefreshTokenIdleExpiration: testNow.Add(24 * time.Hour), RefreshTokenIdleExpiration: testNow.Add(24 * time.Hour),
}, },
@ -808,3 +808,207 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
}) })
} }
} }
func TestCommands_RevokeOIDCSessionToken(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
keyAlgorithm crypto.EncryptionAlgorithm
}
type args struct {
ctx context.Context
token string
clientID string
}
type res struct {
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"invalid token",
fields{
eventstore: eventstoreExpect(t),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
token: "invalid",
},
res{
err: nil,
},
},
{
"refresh_token inactive",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"userID", "sessionID", "clientID", []string{"clientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
token: "V2_oidcSessionID-rt_refreshTokenID",
clientID: "clientID",
},
res{
err: nil,
},
},
{
"refresh_token invalid client",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"userID", "sessionID", "otherClientID", []string{"otherClientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
token: "V2_oidcSessionID-rt_refreshTokenID",
clientID: "clientID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-SKjl3", "Errors.OIDCSession.InvalidClient"),
},
},
{
"refresh_token revoked",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"userID", "sessionID", "clientID", []string{"clientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
),
expectPush([]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewRefreshTokenRevokedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate),
),
}),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
token: "V2_oidcSessionID-rt_refreshTokenID",
clientID: "clientID",
},
res{
err: nil,
},
},
{
"access_token inactive session",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"userID", "sessionID", "clientID", []string{"clientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
token: "V2_oidcSessionID-at_accessTokenID",
clientID: "clientID",
},
res{
err: nil,
},
},
{
"access_token invalid client",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"userID", "sessionID", "otherClientID", []string{"otherClientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
token: "V2_oidcSessionID-at_accessTokenID",
clientID: "clientID",
},
res{
err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-SKjl3", "Errors.OIDCSession.InvalidClient"),
},
},
{
"access_token revoked",
fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"userID", "sessionID", "clientID", []string{"clientID"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
),
expectPush([]*repository.Event{
eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenRevokedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate),
),
}),
),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
token: "V2_oidcSessionID-at_accessTokenID",
clientID: "clientID",
},
res{
err: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
keyAlgorithm: tt.fields.keyAlgorithm,
}
err := c.RevokeOIDCSessionToken(tt.args.ctx, tt.args.token, tt.args.clientID)
require.ErrorIs(t, err, tt.res.err)
})
}
}

View File

@ -43,6 +43,9 @@ func (wm *OIDCSessionAccessTokenReadModel) Reduce() error {
wm.reduceAdded(e) wm.reduceAdded(e)
case *oidcsession.AccessTokenAddedEvent: case *oidcsession.AccessTokenAddedEvent:
wm.reduceAccessTokenAdded(e) wm.reduceAccessTokenAdded(e)
case *oidcsession.AccessTokenRevokedEvent,
*oidcsession.RefreshTokenRevokedEvent:
wm.reduceTokenRevoked(event)
} }
} }
return wm.WriteModel.Reduce() return wm.WriteModel.Reduce()
@ -57,6 +60,8 @@ func (wm *OIDCSessionAccessTokenReadModel) Query() *eventstore.SearchQueryBuilde
EventTypes( EventTypes(
oidcsession.AddedType, oidcsession.AddedType,
oidcsession.AccessTokenAddedType, oidcsession.AccessTokenAddedType,
oidcsession.AccessTokenRevokedType,
oidcsession.RefreshTokenRevokedType,
). ).
Builder() Builder()
} }
@ -78,6 +83,11 @@ func (wm *OIDCSessionAccessTokenReadModel) reduceAccessTokenAdded(e *oidcsession
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime) wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
} }
func (wm *OIDCSessionAccessTokenReadModel) reduceTokenRevoked(e eventstore.Event) {
wm.AccessTokenID = ""
wm.AccessTokenExpiration = e.CreationDate()
}
// ActiveAccessTokenByToken will check if the token is active by retrieving the OIDCSession events from the eventstore. // ActiveAccessTokenByToken will check if the token is active by retrieving the OIDCSession events from the eventstore.
// refreshed or expired tokens will return an error // refreshed or expired tokens will return an error
func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (model *OIDCSessionAccessTokenReadModel, err error) { func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (model *OIDCSessionAccessTokenReadModel, err error) {

View File

@ -1,11 +1,15 @@
package oidcsession package oidcsession
import "github.com/zitadel/zitadel/internal/eventstore" import (
"github.com/zitadel/zitadel/internal/eventstore"
)
func RegisterEventMappers(es *eventstore.Eventstore) { func RegisterEventMappers(es *eventstore.Eventstore) {
es.RegisterFilterEventMapper(AggregateType, AddedType, AddedEventMapper). es.RegisterFilterEventMapper(AggregateType, AddedType, eventstore.GenericEventMapper[AddedEvent]).
RegisterFilterEventMapper(AggregateType, AccessTokenAddedType, AccessTokenAddedEventMapper). RegisterFilterEventMapper(AggregateType, AccessTokenAddedType, eventstore.GenericEventMapper[AccessTokenAddedEvent]).
RegisterFilterEventMapper(AggregateType, RefreshTokenAddedType, RefreshTokenAddedEventMapper). RegisterFilterEventMapper(AggregateType, AccessTokenRevokedType, eventstore.GenericEventMapper[AccessTokenRevokedEvent]).
RegisterFilterEventMapper(AggregateType, RefreshTokenRenewedType, RefreshTokenRenewedEventMapper) RegisterFilterEventMapper(AggregateType, RefreshTokenAddedType, eventstore.GenericEventMapper[RefreshTokenAddedEvent]).
RegisterFilterEventMapper(AggregateType, RefreshTokenRenewedType, eventstore.GenericEventMapper[RefreshTokenRenewedEvent]).
RegisterFilterEventMapper(AggregateType, RefreshTokenRevokedType, eventstore.GenericEventMapper[RefreshTokenRevokedEvent])
} }

View File

@ -2,21 +2,20 @@ package oidcsession
import ( import (
"context" "context"
"encoding/json"
"time" "time"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
) )
const ( const (
oidcSessionEventPrefix = "oidc_session." oidcSessionEventPrefix = "oidc_session."
AddedType = oidcSessionEventPrefix + "added" AddedType = oidcSessionEventPrefix + "added"
AccessTokenAddedType = oidcSessionEventPrefix + "access_token.added" AccessTokenAddedType = oidcSessionEventPrefix + "access_token.added"
AccessTokenRevokedType = oidcSessionEventPrefix + "access_token.revoked"
RefreshTokenAddedType = oidcSessionEventPrefix + "refresh_token.added" RefreshTokenAddedType = oidcSessionEventPrefix + "refresh_token.added"
RefreshTokenRenewedType = oidcSessionEventPrefix + "refresh_token.renewed" RefreshTokenRenewedType = oidcSessionEventPrefix + "refresh_token.renewed"
RefreshTokenRevokedType = oidcSessionEventPrefix + "refresh_token.revoked"
) )
type AddedEvent struct { type AddedEvent struct {
@ -39,6 +38,10 @@ func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil return nil
} }
func (e *AddedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = *event
}
func NewAddedEvent(ctx context.Context, func NewAddedEvent(ctx context.Context,
aggregate *eventstore.Aggregate, aggregate *eventstore.Aggregate,
userID, userID,
@ -65,18 +68,6 @@ func NewAddedEvent(ctx context.Context,
} }
} }
func AddedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &AddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "OIDCS-DG4gn", "unable to unmarshal oidc session added")
}
return added, nil
}
type AccessTokenAddedEvent struct { type AccessTokenAddedEvent struct {
eventstore.BaseEvent `json:"-"` eventstore.BaseEvent `json:"-"`
@ -93,6 +84,10 @@ func (e *AccessTokenAddedEvent) UniqueConstraints() []*eventstore.EventUniqueCon
return nil return nil
} }
func (e *AccessTokenAddedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = *event
}
func NewAccessTokenAddedEvent( func NewAccessTokenAddedEvent(
ctx context.Context, ctx context.Context,
aggregate *eventstore.Aggregate, aggregate *eventstore.Aggregate,
@ -112,16 +107,33 @@ func NewAccessTokenAddedEvent(
} }
} }
func AccessTokenAddedEventMapper(event *repository.Event) (eventstore.Event, error) { type AccessTokenRevokedEvent struct {
added := &AccessTokenAddedEvent{ eventstore.BaseEvent `json:"-"`
BaseEvent: *eventstore.BaseEventFromRepo(event), }
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "OIDCS-DSGn5", "unable to unmarshal access token added")
}
return added, nil func (e *AccessTokenRevokedEvent) Data() interface{} {
return e
}
func (e *AccessTokenRevokedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return nil
}
func (e *AccessTokenRevokedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = *event
}
func NewAccessTokenRevokedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
) *AccessTokenAddedEvent {
return &AccessTokenAddedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
AccessTokenRevokedType,
),
}
} }
type RefreshTokenAddedEvent struct { type RefreshTokenAddedEvent struct {
@ -140,6 +152,10 @@ func (e *RefreshTokenAddedEvent) UniqueConstraints() []*eventstore.EventUniqueCo
return nil return nil
} }
func (e *RefreshTokenAddedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = *event
}
func NewRefreshTokenAddedEvent( func NewRefreshTokenAddedEvent(
ctx context.Context, ctx context.Context,
aggregate *eventstore.Aggregate, aggregate *eventstore.Aggregate,
@ -159,18 +175,6 @@ func NewRefreshTokenAddedEvent(
} }
} }
func RefreshTokenAddedEventMapper(event *repository.Event) (eventstore.Event, error) {
added := &RefreshTokenAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := json.Unmarshal(event.Data, added)
if err != nil {
return nil, errors.ThrowInternal(err, "OIDCS-aW3gqq", "unable to unmarshal refresh token added")
}
return added, nil
}
type RefreshTokenRenewedEvent struct { type RefreshTokenRenewedEvent struct {
eventstore.BaseEvent `json:"-"` eventstore.BaseEvent `json:"-"`
@ -186,6 +190,10 @@ func (e *RefreshTokenRenewedEvent) UniqueConstraints() []*eventstore.EventUnique
return nil return nil
} }
func (e *RefreshTokenRenewedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = *event
}
func NewRefreshTokenRenewedEvent( func NewRefreshTokenRenewedEvent(
ctx context.Context, ctx context.Context,
aggregate *eventstore.Aggregate, aggregate *eventstore.Aggregate,
@ -203,14 +211,31 @@ func NewRefreshTokenRenewedEvent(
} }
} }
func RefreshTokenRenewedEventMapper(event *repository.Event) (eventstore.Event, error) { type RefreshTokenRevokedEvent struct {
added := &RefreshTokenRenewedEvent{ eventstore.BaseEvent `json:"-"`
BaseEvent: *eventstore.BaseEventFromRepo(event), }
}
err := json.Unmarshal(event.Data, added) func (e *RefreshTokenRevokedEvent) Data() interface{} {
if err != nil { return e
return nil, errors.ThrowInternal(err, "OIDCS-SF3fc", "unable to unmarshal refresh token renewed") }
}
func (e *RefreshTokenRevokedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
return added, nil return nil
}
func (e *RefreshTokenRevokedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = *event
}
func NewRefreshTokenRevokedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
) *RefreshTokenRevokedEvent {
return &RefreshTokenRevokedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
RefreshTokenRevokedType,
),
}
} }

View File

@ -499,6 +499,7 @@ Errors:
Token: Token:
Invalid: Token ist ungültig Invalid: Token ist ungültig
Expired: Token ist abgelaufen Expired: Token ist abgelaufen
InvalidClient: Token wurde nicht für diesen Client ausgestellt
AggregateTypes: AggregateTypes:
action: Action action: Action

View File

@ -499,6 +499,7 @@ Errors:
Token: Token:
Invalid: Token is invalid Invalid: Token is invalid
Expired: Token is expired Expired: Token is expired
InvalidClient: Token was not issued for this client
AggregateTypes: AggregateTypes:
action: Action action: Action

View File

@ -499,6 +499,7 @@ Errors:
Token: Token:
Invalid: El token no es válido Invalid: El token no es válido
Expired: El token ha caducado Expired: El token ha caducado
InvalidClient: El token no ha sido emitido para este cliente
AggregateTypes: AggregateTypes:
action: Acción action: Acción

View File

@ -499,6 +499,7 @@ Errors:
Token: Token:
Invalid: Le jeton n'est pas valide Invalid: Le jeton n'est pas valide
Expired: Le jeton est expiré Expired: Le jeton est expiré
InvalidClient: Le token n'a pas été émis pour ce client
AggregateTypes: AggregateTypes:
action: Action action: Action

View File

@ -499,6 +499,7 @@ Errors:
Token: Token:
Invalid: Token non è valido Invalid: Token non è valido
Expired: Token è scaduto Expired: Token è scaduto
InvalidClient: Il token non è stato emesso per questo cliente
AggregateTypes: AggregateTypes:
action: Azione action: Azione

View File

@ -488,6 +488,7 @@ Errors:
Token: Token:
Invalid: トークンが無効です Invalid: トークンが無効です
Expired: トークンの有効期限が切れている Expired: トークンの有効期限が切れている
InvalidClient: トークンが発行されていません
AggregateTypes: AggregateTypes:
action: アクション action: アクション

View File

@ -499,6 +499,7 @@ Errors:
Token: Token:
Invalid: токенот е неважечки Invalid: токенот е неважечки
Expired: токенот е истечен Expired: токенот е истечен
InvalidClient: Токен не беше издаден на овој клиент
AggregateTypes: AggregateTypes:
action: Акција action: Акција

View File

@ -499,6 +499,7 @@ Errors:
Token: Token:
Invalid: Token jest nieprawidłowy Invalid: Token jest nieprawidłowy
Expired: Token wygasł Expired: Token wygasł
InvalidClient: Token nie został wydany dla tego klienta
AggregateTypes: AggregateTypes:
action: Działanie action: Działanie

View File

@ -499,6 +499,7 @@ Errors:
Token: Token:
Invalid: 令牌无效 Invalid: 令牌无效
Expired: 令牌已过期 Expired: 令牌已过期
InvalidClient: 没有为该客户发放令牌
AggregateTypes: AggregateTypes:
action: 动作 action: 动作