zitadel/internal/command/oidc_session.go

282 lines
11 KiB
Go
Raw Normal View History

package command
import (
"context"
"encoding/base64"
"strings"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
)
// AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration.
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) {
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
if err != nil {
return "", time.Time{}, err
}
cmd.AddSession(ctx)
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil {
return "", time.Time{}, err
}
cmd.SetAuthRequestSuccessful(ctx)
accessTokenID, _, accessTokenExpiration, err := cmd.PushEvents(ctx)
return accessTokenID, accessTokenExpiration, err
}
// AddOIDCSessionRefreshAndAccessToken creates a new OIDC Session, creates an access token and refresh token.
// It returns the access token id, expiration and the refresh token.
// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged.
func (c *Commands) AddOIDCSessionRefreshAndAccessToken(ctx context.Context, authRequestID string) (tokenID, refreshToken string, tokenExpiration time.Time, err error) {
cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID)
if err != nil {
return "", "", time.Time{}, err
}
cmd.AddSession(ctx)
if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil {
return "", "", time.Time{}, err
}
if err = cmd.AddRefreshToken(ctx); err != nil {
return "", "", time.Time{}, err
}
cmd.SetAuthRequestSuccessful(ctx)
return cmd.PushEvents(ctx)
}
// ExchangeOIDCSessionRefreshAndAccessToken updates an existing OIDC Session, creates a new access and refresh token.
// It returns the access token id and expiration and the new refresh token.
func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, oidcSessionID, refreshToken string, scope []string) (tokenID, newRefreshToken string, tokenExpiration time.Time, err error) {
cmd, err := c.newOIDCSessionUpdateEvents(ctx, oidcSessionID, refreshToken)
if err != nil {
return "", "", time.Time{}, err
}
if err = cmd.AddAccessToken(ctx, scope); err != nil {
return "", "", time.Time{}, err
}
if err = cmd.RenewRefreshToken(ctx); err != nil {
return "", "", time.Time{}, err
}
return cmd.PushEvents(ctx)
}
// OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant).
// If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned.
func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) {
split := strings.Split(refreshToken, ":")
if len(split) != 2 {
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")
}
writeModel := NewOIDCSessionWriteModel(split[0], "")
err := c.eventstore.FilterToQueryReducer(ctx, writeModel)
if err != nil {
return nil, caos_errs.ThrowPreconditionFailed(err, "OIDCS-SAF31", "Errors.OIDCSession.RefreshTokenInvalid")
}
if err = writeModel.CheckRefreshToken(split[1]); err != nil {
return nil, err
}
return writeModel, nil
}
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) {
authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID)
if err != nil {
return nil, err
}
if err = authRequestWriteModel.CheckAuthenticated(); err != nil {
return nil, err
}
sessionWriteModel := NewSessionWriteModel(authRequestWriteModel.SessionID, authz.GetCtxData(ctx).OrgID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
if err != nil {
return nil, err
}
if sessionWriteModel.State != domain.SessionStateActive {
return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated")
}
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
if err != nil {
return nil, err
}
sessionID, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
sessionID = IDPrefixV2 + sessionID
return &OIDCSessionEvents{
eventstore: c.eventstore,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID()),
sessionWriteModel: sessionWriteModel,
authRequestWriteModel: authRequestWriteModel,
accessTokenLifetime: accessTokenLifetime,
refreshTokenLifeTime: refreshTokenLifeTime,
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
}, nil
}
func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID string, err error) {
decoded, err := base64.RawURLEncoding.DecodeString(refreshToken)
if err != nil {
return "", err
}
decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID())
if err != nil {
return "", err
}
split := strings.Split(decrypted, ":")
if len(split) != 2 {
return "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid")
}
return split[1], nil
}
func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) {
refreshTokenID, err := c.decryptRefreshToken(refreshToken)
if err != nil {
return nil, err
}
sessionWriteModel := NewOIDCSessionWriteModel(oidcSessionID, authz.GetInstance(ctx).InstanceID())
if err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel); err != nil {
return nil, err
}
if err = sessionWriteModel.CheckRefreshToken(refreshTokenID); err != nil {
return nil, err
}
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
if err != nil {
return nil, err
}
return &OIDCSessionEvents{
eventstore: c.eventstore,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
oidcSessionWriteModel: sessionWriteModel,
accessTokenLifetime: accessTokenLifetime,
refreshTokenLifeTime: refreshTokenLifeTime,
refreshTokenIdleLifetime: refreshTokenIdleLifetime,
}, nil
}
type OIDCSessionEvents struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
encryptionAlg crypto.EncryptionAlgorithm
events []eventstore.Command
oidcSessionWriteModel *OIDCSessionWriteModel
sessionWriteModel *SessionWriteModel
authRequestWriteModel *AuthRequestWriteModel
accessTokenLifetime time.Duration
refreshTokenLifeTime time.Duration
refreshTokenIdleLifetime time.Duration
// accessTokenID is set by the command
accessTokenID string
// refreshToken is set by the command
refreshToken string
}
func (c *OIDCSessionEvents) AddSession(ctx context.Context) {
c.events = append(c.events, oidcsession.NewAddedEvent(
ctx,
c.oidcSessionWriteModel.aggregate,
c.sessionWriteModel.UserID,
c.sessionWriteModel.AggregateID,
c.authRequestWriteModel.ClientID,
c.authRequestWriteModel.Audience,
c.authRequestWriteModel.Scope,
amr.List(c.sessionWriteModel),
c.sessionWriteModel.AuthenticationTime(),
))
}
func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) {
c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate))
}
func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) (err error) {
c.accessTokenID, err = c.idGenerator.Next()
if err != nil {
return err
}
c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime))
return nil
}
func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) {
var refreshTokenID string
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
if err != nil {
return err
}
c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime))
return nil
}
func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) {
var refreshTokenID string
refreshTokenID, c.refreshToken, err = c.generateRefreshToken()
if err != nil {
return err
}
c.events = append(c.events, oidcsession.NewRefreshTokenRenewedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenIdleLifetime))
return nil
}
func (c *OIDCSessionEvents) generateRefreshToken() (refreshTokenID, refreshToken string, err error) {
refreshTokenID, err = c.idGenerator.Next()
if err != nil {
return "", "", err
}
token, err := c.encryptionAlg.Encrypt([]byte(c.oidcSessionWriteModel.AggregateID + ":" + refreshTokenID))
if err != nil {
return "", "", err
}
return refreshTokenID, base64.RawURLEncoding.EncodeToString(token), nil
}
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID string, refreshToken string, accessTokenExpiration time.Time, err error) {
pushedEvents, err := c.eventstore.Push(ctx, c.events...)
if err != nil {
return "", "", time.Time{}, err
}
err = AppendAndReduce(c.oidcSessionWriteModel, pushedEvents...)
if err != nil {
return "", "", time.Time{}, err
}
// prefix the returned id with the oidcSessionID so that we can retrieve it later on
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
return c.oidcSessionWriteModel.AggregateID + "-" + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil
}
func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) {
oidcSettings := NewInstanceOIDCSettingsWriteModel(ctx)
err = c.eventstore.FilterToQueryReducer(ctx, oidcSettings)
if err != nil {
return 0, 0, 0, err
}
accessTokenLifetime = c.defaultAccessTokenLifetime
refreshTokenLifetime = c.defaultRefreshTokenLifetime
refreshTokenIdleLifetime = c.defaultRefreshTokenIdleLifetime
if oidcSettings.AccessTokenLifetime > 0 {
accessTokenLifetime = oidcSettings.AccessTokenLifetime
}
if oidcSettings.RefreshTokenExpiration > 0 {
refreshTokenLifetime = oidcSettings.RefreshTokenExpiration
}
if oidcSettings.RefreshTokenIdleExpiration > 0 {
refreshTokenIdleLifetime = oidcSettings.RefreshTokenIdleExpiration
}
return accessTokenLifetime, refreshTokenLifetime, refreshTokenIdleLifetime, nil
}