zitadel/internal/command/oidc_session_model.go
Livio Spring e1b3cda98a
feat(OIDC): support token revocation of V2 tokens (#6203)
This PR adds support for OAuth2 token revocation of V2 tokens.

Unlike with V1 tokens, it's now possible to revoke a token not only from the authorized client / client which the token was issued to, but rather from all trusted clients (audience)
2023-07-17 14:33:37 +02:00

168 lines
5.4 KiB
Go

package command
import (
"time"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
)
type OIDCSessionWriteModel struct {
eventstore.WriteModel
UserID string
SessionID string
ClientID string
Audience []string
Scope []string
AuthMethods []domain.UserAuthMethodType
AuthTime time.Time
State domain.OIDCSessionState
AccessTokenID string
AccessTokenCreation time.Time
AccessTokenExpiration time.Time
RefreshTokenID string
RefreshToken string
RefreshTokenExpiration time.Time
RefreshTokenIdleExpiration time.Time
aggregate *eventstore.Aggregate
}
func NewOIDCSessionWriteModel(id string, resourceOwner string) *OIDCSessionWriteModel {
return &OIDCSessionWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: id,
ResourceOwner: resourceOwner,
},
aggregate: &oidcsession.NewAggregate(id, resourceOwner).Aggregate,
}
}
func (wm *OIDCSessionWriteModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *oidcsession.AddedEvent:
wm.reduceAdded(e)
case *oidcsession.AccessTokenAddedEvent:
wm.reduceAccessTokenAdded(e)
case *oidcsession.AccessTokenRevokedEvent:
wm.reduceAccessTokenRevoked(e)
case *oidcsession.RefreshTokenAddedEvent:
wm.reduceRefreshTokenAdded(e)
case *oidcsession.RefreshTokenRenewedEvent:
wm.reduceRefreshTokenRenewed(e)
case *oidcsession.RefreshTokenRevokedEvent:
wm.reduceRefreshTokenRevoked(e)
}
}
return wm.WriteModel.Reduce()
}
func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(oidcsession.AggregateType).
AggregateIDs(wm.AggregateID).
EventTypes(
oidcsession.AddedType,
oidcsession.AccessTokenAddedType,
oidcsession.RefreshTokenAddedType,
oidcsession.RefreshTokenRenewedType,
oidcsession.RefreshTokenRevokedType,
).
Builder()
if wm.ResourceOwner != "" {
query.ResourceOwner(wm.ResourceOwner)
}
return query
}
func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
wm.UserID = e.UserID
wm.SessionID = e.SessionID
wm.ClientID = e.ClientID
wm.Audience = e.Audience
wm.Scope = e.Scope
wm.AuthMethods = e.AuthMethods
wm.AuthTime = e.AuthTime
wm.State = domain.OIDCSessionStateActive
// the write model might be initialized without resource owner,
// so update the aggregate
if wm.ResourceOwner == "" {
wm.aggregate = &oidcsession.NewAggregate(wm.AggregateID, e.Aggregate().ResourceOwner).Aggregate
}
}
func (wm *OIDCSessionWriteModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
wm.AccessTokenID = e.ID
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
}
func (wm *OIDCSessionWriteModel) reduceAccessTokenRevoked(e *oidcsession.AccessTokenRevokedEvent) {
wm.AccessTokenID = ""
wm.AccessTokenExpiration = e.CreationDate()
}
func (wm *OIDCSessionWriteModel) reduceRefreshTokenAdded(e *oidcsession.RefreshTokenAddedEvent) {
wm.RefreshTokenID = e.ID
wm.RefreshTokenExpiration = e.CreationDate().Add(e.Lifetime)
wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
}
func (wm *OIDCSessionWriteModel) reduceRefreshTokenRenewed(e *oidcsession.RefreshTokenRenewedEvent) {
wm.RefreshTokenID = e.ID
wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime)
}
func (wm *OIDCSessionWriteModel) reduceRefreshTokenRevoked(e *oidcsession.RefreshTokenRevokedEvent) {
wm.RefreshTokenID = ""
wm.RefreshTokenExpiration = e.CreationDate()
wm.RefreshTokenIdleExpiration = e.CreationDate()
wm.AccessTokenID = ""
wm.AccessTokenExpiration = e.CreationDate()
}
func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error {
if wm.State != domain.OIDCSessionStateActive {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid")
}
if wm.RefreshTokenID != refreshTokenID {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid")
}
now := time.Now()
if wm.RefreshTokenExpiration.Before(now) || wm.RefreshTokenIdleExpiration.Before(now) {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid")
}
return nil
}
func (wm *OIDCSessionWriteModel) CheckAccessToken(accessTokenID string) error {
if wm.State != domain.OIDCSessionStateActive {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-KL2pk", "Errors.OIDCSession.Token.Invalid")
}
if wm.AccessTokenID != accessTokenID {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JLKW2", "Errors.OIDCSession.Token.Invalid")
}
if wm.AccessTokenExpiration.Before(time.Now()) {
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3j3md", "Errors.OIDCSession.Token.Invalid")
}
return nil
}
func (wm *OIDCSessionWriteModel) CheckClient(clientID string) error {
for _, aud := range wm.Audience {
if aud == clientID {
return nil
}
}
return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-SKjl3", "Errors.OIDCSession.InvalidClient")
}
func (wm *OIDCSessionWriteModel) OIDCRefreshTokenID(refreshTokenID string) string {
return wm.AggregateID + TokenDelimiter + refreshTokenID
}