mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 19:07:30 +00:00
feat(OIDC): support token revocation of V2 tokens (#6203)
This PR adds support for OAuth2 token revocation of V2 tokens. Unlike with V1 tokens, it's now possible to revoke a token not only from the authorized client / client which the token was issued to, but rather from all trusted clients (audience)
This commit is contained in:
@@ -3,9 +3,12 @@ package command
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@@ -17,6 +20,14 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenDelimiter = "-"
|
||||
AccessTokenPrefix = "at_"
|
||||
RefreshTokenPrefix = "rt_"
|
||||
oidcTokenSubjectDelimiter = ":"
|
||||
oidcTokenFormat = "%s" + oidcTokenSubjectDelimiter + "%s"
|
||||
)
|
||||
|
||||
// AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration.
|
||||
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
|
||||
func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) {
|
||||
@@ -71,21 +82,70 @@ func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context,
|
||||
// OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant).
|
||||
// If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned.
|
||||
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) {
|
||||
split := strings.Split(refreshToken, ":")
|
||||
if len(split) != 2 {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
oidcSessionID, refreshTokenID, err := parseRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writeModel := NewOIDCSessionWriteModel(split[0], "")
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||
writeModel := NewOIDCSessionWriteModel(oidcSessionID, "")
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(err, "OIDCS-SAF31", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
if err = writeModel.CheckRefreshToken(split[1]); err != nil {
|
||||
if err = writeModel.CheckRefreshToken(refreshTokenID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return writeModel, nil
|
||||
}
|
||||
|
||||
func oidcSessionTokenIDsFromToken(token string) (oidcSessionID, refreshTokenID, accessTokenID string, err error) {
|
||||
split := strings.Split(token, TokenDelimiter)
|
||||
if len(split) != 2 {
|
||||
return "", "", "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-S87kl", "Errors.OIDCSession.Token.Invalid")
|
||||
}
|
||||
if strings.HasPrefix(split[1], RefreshTokenPrefix) {
|
||||
return split[0], split[1], "", nil
|
||||
}
|
||||
if strings.HasPrefix(split[1], AccessTokenPrefix) {
|
||||
return split[0], "", split[1], nil
|
||||
}
|
||||
return "", "", "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-S87kl", "Errors.OIDCSession.Token.Invalid")
|
||||
}
|
||||
|
||||
// RevokeOIDCSessionToken revokes an access_token or refresh_token
|
||||
// if the OIDCSession cannot be retrieved by the provided token, is not active or if the token is already revoked,
|
||||
// then no error will be returned.
|
||||
// The only possible error (except db connection or other internal errors) occurs if a client tries to revoke a token,
|
||||
// which was not part of the audience.
|
||||
func (c *Commands) RevokeOIDCSessionToken(ctx context.Context, token, clientID string) (err error) {
|
||||
oidcSessionID, refreshTokenID, accessTokenID, err := oidcSessionTokenIDsFromToken(token)
|
||||
if err != nil {
|
||||
logging.WithError(err).Info("token revocation with invalid token format")
|
||||
return nil
|
||||
}
|
||||
writeModel := NewOIDCSessionWriteModel(oidcSessionID, "")
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||
if err != nil {
|
||||
return caos_errs.ThrowInternal(err, "OIDCS-NB3t2", "Errors.Internal")
|
||||
}
|
||||
if err = writeModel.CheckClient(clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
if refreshTokenID != "" {
|
||||
if err = writeModel.CheckRefreshToken(refreshTokenID); err != nil {
|
||||
logging.WithFields("oidcSessionID", oidcSessionID, "refreshTokenID", refreshTokenID).WithError(err).
|
||||
Info("refresh token revocation with invalid token")
|
||||
return nil
|
||||
}
|
||||
return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewRefreshTokenRevokedEvent(ctx, writeModel.aggregate))
|
||||
}
|
||||
if err = writeModel.CheckAccessToken(accessTokenID); err != nil {
|
||||
logging.WithFields("oidcSessionID", oidcSessionID, "accessTokenID", accessTokenID).WithError(err).
|
||||
Info("access token revocation with invalid token")
|
||||
return nil
|
||||
}
|
||||
return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewAccessTokenRevokedEvent(ctx, writeModel.aggregate))
|
||||
}
|
||||
|
||||
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) {
|
||||
authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
|
||||
if err != nil {
|
||||
@@ -153,11 +213,18 @@ func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID stri
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
split := strings.Split(decrypted, ":")
|
||||
if len(split) != 2 {
|
||||
return "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
_, refreshTokenID, err = parseRefreshToken(decrypted)
|
||||
return refreshTokenID, err
|
||||
}
|
||||
|
||||
func parseRefreshToken(refreshToken string) (oidcSessionID, refreshTokenID string, err error) {
|
||||
split := strings.Split(refreshToken, TokenDelimiter)
|
||||
if len(split) < 2 || !strings.HasPrefix(split[1], RefreshTokenPrefix) {
|
||||
return "", "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")
|
||||
}
|
||||
return split[1], nil
|
||||
// the oidc library requires that every token has the format of <tokenID>:<userID>
|
||||
// the V2 tokens don't use the userID anymore, so let's just remove it
|
||||
return split[0], strings.Split(split[1], oidcTokenSubjectDelimiter)[0], nil
|
||||
}
|
||||
|
||||
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) {
|
||||
@@ -224,18 +291,19 @@ func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) {
|
||||
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate))
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) (err error) {
|
||||
c.accessTokenID, err = c.idGenerator.Next()
|
||||
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) error {
|
||||
accessTokenID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.accessTokenID = AccessTokenPrefix + accessTokenID
|
||||
c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
|
||||
var refreshTokenID string
|
||||
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
|
||||
refreshTokenID, c.refreshToken, err = c.generateRefreshToken(c.sessionWriteModel.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -245,7 +313,7 @@ func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
|
||||
|
||||
func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
|
||||
var refreshTokenID string
|
||||
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
|
||||
refreshTokenID, c.refreshToken, err = c.generateRefreshToken(c.oidcSessionWriteModel.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -253,12 +321,13 @@ func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) generateRefreshToken() (refreshTokenID, refreshToken string, err error) {
|
||||
func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID, refreshToken string, err error) {
|
||||
refreshTokenID, err = c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
token, err := c.encryptionAlg.Encrypt([]byte(c.oidcSessionWriteModel.AggregateID + ":" + refreshTokenID))
|
||||
refreshTokenID = RefreshTokenPrefix + refreshTokenID
|
||||
token, err := c.encryptionAlg.Encrypt([]byte(fmt.Sprintf(oidcTokenFormat, c.oidcSessionWriteModel.OIDCRefreshTokenID(refreshTokenID), userID)))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
@@ -276,7 +345,7 @@ func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID strin
|
||||
}
|
||||
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
|
||||
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
|
||||
return c.oidcSessionWriteModel.AggregateID + "-" + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil
|
||||
return c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil
|
||||
}
|
||||
|
||||
func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) {
|
||||
|
Reference in New Issue
Block a user