zitadel/internal/query/access_token.go
Livio Spring 3289698d4c
fix: return 401 instead of 403 on expired tokens (#8476)
# Which Problems Are Solved

The access token verifier returned a permission denied (HTTP 403 / GRPC
7) instead of a unauthenticated (HTTP 401 / GRPC 16) error.

# How the Problems Are Solved

Return the correct error type.

# Additional Changes

None

# Additional Context

close #8392

(cherry picked from commit cbbd44c303c6a06a5ef3d6c8fecd6fca63ec8705)
2024-09-04 13:03:24 +02:00

214 lines
6.5 KiB
Go

package query
import (
"context"
"strings"
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
type OIDCSessionAccessTokenReadModel struct {
eventstore.ReadModel
UserID string
SessionID string
ClientID string
Audience []string
Scope []string
AuthMethods []domain.UserAuthMethodType
AuthTime time.Time
Nonce string
State domain.OIDCSessionState
AccessTokenID string
AccessTokenCreation time.Time
AccessTokenExpiration time.Time
PreferredLanguage *language.Tag
UserAgent *domain.UserAgent
Reason domain.TokenReason
Actor *domain.TokenActor
}
func newOIDCSessionAccessTokenReadModel(id string) *OIDCSessionAccessTokenReadModel {
return &OIDCSessionAccessTokenReadModel{
ReadModel: eventstore.ReadModel{
AggregateID: id,
},
}
}
func (wm *OIDCSessionAccessTokenReadModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *oidcsession.AddedEvent:
wm.reduceAdded(e)
case *oidcsession.AccessTokenAddedEvent:
wm.reduceAccessTokenAdded(e)
case *oidcsession.AccessTokenRevokedEvent,
*oidcsession.RefreshTokenRevokedEvent:
wm.reduceTokenRevoked(event)
}
}
return wm.ReadModel.Reduce()
}
func (wm *OIDCSessionAccessTokenReadModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(oidcsession.AggregateType).
AggregateIDs(wm.AggregateID).
EventTypes(
oidcsession.AddedType,
oidcsession.AccessTokenAddedType,
oidcsession.AccessTokenRevokedType,
oidcsession.RefreshTokenRevokedType,
).
Builder()
}
func (wm *OIDCSessionAccessTokenReadModel) reduceAdded(e *oidcsession.AddedEvent) {
wm.UserID = e.UserID
wm.SessionID = e.SessionID
wm.ClientID = e.ClientID
wm.Audience = e.Audience
wm.Scope = e.Scope
wm.AuthMethods = e.AuthMethods
wm.AuthTime = e.AuthTime
wm.Nonce = e.Nonce
wm.PreferredLanguage = e.PreferredLanguage
wm.UserAgent = e.UserAgent
wm.State = domain.OIDCSessionStateActive
}
func (wm *OIDCSessionAccessTokenReadModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
wm.AccessTokenID = e.ID
wm.AccessTokenCreation = e.CreationDate()
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
wm.Reason = e.Reason
wm.Actor = e.Actor
}
func (wm *OIDCSessionAccessTokenReadModel) reduceTokenRevoked(e eventstore.Event) {
wm.AccessTokenID = ""
wm.AccessTokenExpiration = e.CreatedAt()
}
// ActiveAccessTokenByToken will check if the token is active by retrieving the OIDCSession events from the eventstore.
// Refreshed or expired tokens will return an error as well as if the underlying sessions has been terminated.
func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (model *OIDCSessionAccessTokenReadModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
split := strings.Split(token, "-")
if len(split) != 2 {
return nil, zerrors.ThrowUnauthenticated(nil, "QUERY-LJK2W", "Errors.OIDCSession.Token.Invalid")
}
model, err = q.accessTokenByOIDCSessionAndTokenID(ctx, split[0], split[1])
if err != nil {
return nil, err
}
if !model.AccessTokenExpiration.After(time.Now()) {
return nil, zerrors.ThrowUnauthenticated(nil, "QUERY-SAF3rf", "Errors.OIDCSession.Token.Expired")
}
if err = q.checkSessionNotTerminatedAfter(ctx, model.SessionID, model.UserID, model.Position, model.UserAgent.GetFingerprintID()); err != nil {
return nil, err
}
return model, nil
}
func (q *Queries) accessTokenByOIDCSessionAndTokenID(ctx context.Context, oidcSessionID, tokenID string) (model *OIDCSessionAccessTokenReadModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
model = newOIDCSessionAccessTokenReadModel(oidcSessionID)
if err = q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, zerrors.ThrowUnauthenticated(err, "QUERY-ASfe2", "Errors.OIDCSession.Token.Invalid")
}
if model.AccessTokenID != tokenID {
return nil, zerrors.ThrowUnauthenticated(nil, "QUERY-M2u9w", "Errors.OIDCSession.Token.Invalid")
}
return model, nil
}
// checkSessionNotTerminatedAfter checks if a [session.TerminateType] event (or user events leading to a session termination)
// occurred after a certain time and will return an error if so.
func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID, userID string, position float64, fingerprintID string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
model := &sessionTerminatedModel{
sessionID: sessionID,
position: position,
userID: userID,
fingerPrintID: fingerprintID,
}
err = q.eventstore.FilterToQueryReducer(ctx, model)
if err != nil {
return zerrors.ThrowUnauthenticated(err, "QUERY-SJ642", "Errors.Internal")
}
if model.terminated {
return zerrors.ThrowUnauthenticated(nil, "QUERY-IJL3H", "Errors.OIDCSession.Token.Invalid")
}
return nil
}
type sessionTerminatedModel struct {
position float64
sessionID string
userID string
fingerPrintID string
events int
terminated bool
}
func (s *sessionTerminatedModel) Reduce() error {
s.terminated = s.events > 0
return nil
}
func (s *sessionTerminatedModel) AppendEvents(events ...eventstore.Event) {
s.events += len(events)
}
func (s *sessionTerminatedModel) Query() *eventstore.SearchQueryBuilder {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
PositionAfter(s.position).
AddQuery().
AggregateTypes(session.AggregateType).
AggregateIDs(s.sessionID).
EventTypes(
session.TerminateType,
).
Builder()
if s.userID == "" {
return query
}
return query.
AddQuery().
AggregateTypes(user.AggregateType).
AggregateIDs(s.userID).
EventTypes(
user.UserDeactivatedType,
user.UserLockedType,
user.UserRemovedType,
).
Or(). // for specific logout on v1 sessions from the same user agent
AggregateTypes(user.AggregateType).
AggregateIDs(s.userID).
EventTypes(
user.HumanSignedOutType,
).
EventData(map[string]interface{}{"userAgentID": s.fingerPrintID}).
Builder()
}