fix: store auth methods instead of AMR in auth request linking and OIDC Session (#6192)

This PR changes the information stored on the SessionLinkedEvent and (OIDC Session) AddedEvent from OIDC AMR strings to domain.UserAuthMethodTypes, so no information is lost in the process (e.g. authentication with an IDP)
This commit is contained in:
Livio Spring 2023-07-12 14:24:01 +02:00 committed by GitHub
parent a3a1e245ad
commit ee26f99ebf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 156 additions and 174 deletions

47
internal/api/oidc/amr.go Normal file
View File

@ -0,0 +1,47 @@
package oidc
import "github.com/zitadel/zitadel/internal/domain"
const (
// Password states that the users password has been verified
// Deprecated: use `PWD` instead
Password = "password"
// PWD states that the users password has been verified
PWD = "pwd"
// MFA states that multiple factors have been verified (e.g. pwd and otp or passkey)
MFA = "mfa"
// OTP states that a one time password has been verified (e.g. TOTP)
OTP = "otp"
// UserPresence states that the end users presence has been verified (e.g. passkey and u2f)
UserPresence = "user"
)
// AuthMethodTypesToAMR maps zitadel auth method types to Authentication Method Reference Values
// as defined in [RFC 8176, section 2].
//
// [RFC 8176, section 2]: https://datatracker.ietf.org/doc/html/rfc8176#section-2
func AuthMethodTypesToAMR(methodTypes []domain.UserAuthMethodType) []string {
amr := make([]string, 0, 4)
var mfa bool
for _, methodType := range methodTypes {
switch methodType {
case domain.UserAuthMethodTypePassword:
amr = append(amr, PWD)
case domain.UserAuthMethodTypePasswordless:
mfa = true
amr = append(amr, UserPresence)
case domain.UserAuthMethodTypeU2F:
amr = append(amr, UserPresence)
case domain.UserAuthMethodTypeOTP:
amr = append(amr, OTP)
case domain.UserAuthMethodTypeIDP:
// no AMR value according to specification
case domain.UserAuthMethodTypeUnspecified:
// ignore
}
}
if mfa || len(amr) >= 2 {
amr = append(amr, MFA)
}
return amr
}

View File

@ -1,43 +0,0 @@
// Package amr maps zitadel session factors to Authentication Method Reference Values
// as defined in [RFC 8176, section 2].
//
// [RFC 8176, section 2]: https://datatracker.ietf.org/doc/html/rfc8176#section-2
package amr
const (
// Password states that the users password has been verified
// Deprecated: use `PWD` instead
Password = "password"
// PWD states that the users password has been verified
PWD = "pwd"
// MFA states that multiple factors have been verified (e.g. pwd and otp or passkey)
MFA = "mfa"
// OTP states that a one time password has been verified (e.g. TOTP)
OTP = "otp"
// UserPresence states that the end users presence has been verified (e.g. passkey and u2f)
UserPresence = "user"
)
type AuthenticationMethodReference interface {
IsPasswordChecked() bool
IsPasskeyChecked() bool
IsU2FChecked() bool
IsOTPChecked() bool
}
func List(model AuthenticationMethodReference) []string {
amr := make([]string, 0)
if model.IsPasswordChecked() {
amr = append(amr, PWD)
}
if model.IsPasskeyChecked() || model.IsU2FChecked() {
amr = append(amr, UserPresence)
}
if model.IsOTPChecked() {
amr = append(amr, OTP)
}
if model.IsPasskeyChecked() || len(amr) >= 2 {
amr = append(amr, MFA)
}
return amr
}

View File

@ -1,14 +1,16 @@
package amr package oidc
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/domain"
) )
func TestAMR(t *testing.T) { func TestAMR(t *testing.T) {
type args struct { type args struct {
model AuthenticationMethodReference methodTypes []domain.UserAuthMethodType
} }
tests := []struct { tests := []struct {
name string name string
@ -18,76 +20,50 @@ func TestAMR(t *testing.T) {
{ {
"no checks, empty", "no checks, empty",
args{ args{
new(test), nil,
}, },
[]string{}, []string{},
}, },
{ {
"pw checked", "pw checked",
args{ args{
&test{pwChecked: true}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
}, },
[]string{PWD}, []string{PWD},
}, },
{ {
"passkey checked", "passkey checked",
args{ args{
&test{passkeyChecked: true}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePasswordless},
}, },
[]string{UserPresence, MFA}, []string{UserPresence, MFA},
}, },
{ {
"u2f checked", "u2f checked",
args{ args{
&test{u2fChecked: true}, []domain.UserAuthMethodType{domain.UserAuthMethodTypeU2F},
}, },
[]string{UserPresence}, []string{UserPresence},
}, },
{ {
"otp checked", "otp checked",
args{ args{
&test{otpChecked: true}, []domain.UserAuthMethodType{domain.UserAuthMethodTypeOTP},
}, },
[]string{OTP}, []string{OTP},
}, },
{ {
"multiple checked", "multiple checked",
args{ args{
&test{ []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword, domain.UserAuthMethodTypeU2F},
pwChecked: true,
u2fChecked: true,
},
}, },
[]string{PWD, UserPresence, MFA}, []string{PWD, UserPresence, MFA},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := List(tt.args.model) got := AuthMethodTypesToAMR(tt.args.methodTypes)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
}) })
} }
} }
type test struct {
pwChecked bool
passkeyChecked bool
u2fChecked bool
otpChecked bool
}
func (t test) IsPasswordChecked() bool {
return t.pwChecked
}
func (t test) IsPasskeyChecked() bool {
return t.passkeyChecked
}
func (t test) IsU2FChecked() bool {
return t.u2fChecked
}
func (t test) IsOTPChecked() bool {
return t.otpChecked
}

View File

@ -12,7 +12,6 @@ import (
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http" http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/user/model" "github.com/zitadel/zitadel/internal/user/model"
@ -34,10 +33,10 @@ func (a *AuthRequest) GetACR() string {
func (a *AuthRequest) GetAMR() []string { func (a *AuthRequest) GetAMR() []string {
list := make([]string, 0) list := make([]string, 0)
if a.PasswordVerified { if a.PasswordVerified {
list = append(list, amr.Password, amr.PWD) list = append(list, Password, PWD)
} }
if len(a.MFAsVerified) > 0 { if len(a.MFAsVerified) > 0 {
list = append(list, amr.MFA) list = append(list, MFA)
for _, mfa := range a.MFAsVerified { for _, mfa := range a.MFAsVerified {
if amrMFA := AMRFromMFAType(mfa); amrMFA != "" { if amrMFA := AMRFromMFAType(mfa); amrMFA != "" {
list = append(list, amrMFA) list = append(list, amrMFA)
@ -263,10 +262,10 @@ func CodeChallengeToOIDC(challenge *domain.OIDCCodeChallenge) *oidc.CodeChalleng
func AMRFromMFAType(mfaType domain.MFAType) string { func AMRFromMFAType(mfaType domain.MFAType) string {
switch mfaType { switch mfaType {
case domain.MFATypeOTP: case domain.MFATypeOTP:
return amr.OTP return OTP
case domain.MFATypeU2F, case domain.MFATypeU2F,
domain.MFATypeU2FUserVerification: domain.MFATypeU2FUserVerification:
return amr.UserPresence return UserPresence
default: default:
return "" return ""
} }

View File

@ -21,7 +21,7 @@ func (a *AuthRequestV2) GetACR() string {
} }
func (a *AuthRequestV2) GetAMR() []string { func (a *AuthRequestV2) GetAMR() []string {
return a.AMR return AuthMethodTypesToAMR(a.AuthMethods)
} }
func (a *AuthRequestV2) GetAudience() []string { func (a *AuthRequestV2) GetAudience() []string {
@ -78,7 +78,7 @@ type RefreshTokenRequestV2 struct {
} }
func (r *RefreshTokenRequestV2) GetAMR() []string { func (r *RefreshTokenRequestV2) GetAMR() []string {
return r.AuthMethodsReferences return AuthMethodTypesToAMR(r.AuthMethods)
} }
func (r *RefreshTokenRequestV2) GetAudience() []string { func (r *RefreshTokenRequestV2) GetAudience() []string {

View File

@ -15,7 +15,7 @@ import (
"github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/api/oidc/amr" oidc_api "github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/integration" "github.com/zitadel/zitadel/internal/integration"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha" oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha"
@ -270,6 +270,6 @@ func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requir
func assertTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, sessionStart, sessionChange time.Time) { func assertTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, sessionStart, sessionChange time.Time) {
assert.Equal(t, User.GetUserId(), claims.Subject) assert.Equal(t, User.GetUserId(), claims.Subject)
assert.Equal(t, []string{amr.UserPresence, amr.MFA}, claims.AuthenticationMethodsReferences) assert.Equal(t, []string{oidc_api.UserPresence, oidc_api.MFA}, claims.AuthenticationMethodsReferences)
assert.WithinRange(t, claims.AuthTime.AsTime().UTC(), sessionStart.Add(-1*time.Second), sessionChange.Add(1*time.Second)) assert.WithinRange(t, claims.AuthTime.AsTime().UTC(), sessionStart.Add(-1*time.Second), sessionChange.Add(1*time.Second))
} }

View File

@ -5,7 +5,6 @@ import (
"time" "time"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/repository/authrequest" "github.com/zitadel/zitadel/internal/repository/authrequest"
@ -32,10 +31,10 @@ type AuthRequest struct {
type CurrentAuthRequest struct { type CurrentAuthRequest struct {
*AuthRequest *AuthRequest
SessionID string SessionID string
UserID string UserID string
AMR []string AuthMethods []domain.UserAuthMethodType
AuthTime time.Time AuthTime time.Time
} }
const IDPrefixV2 = "V2_" const IDPrefixV2 = "V2_"
@ -108,7 +107,7 @@ func (c *Commands) LinkSessionToAuthRequest(ctx context.Context, id, sessionID,
sessionID, sessionID,
sessionWriteModel.UserID, sessionWriteModel.UserID,
sessionWriteModel.AuthenticationTime(), sessionWriteModel.AuthenticationTime(),
amr.List(sessionWriteModel), sessionWriteModel.AuthMethodTypes(),
)); err != nil { )); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -187,10 +186,10 @@ func authRequestWriteModelToCurrentAuthRequest(writeModel *AuthRequestWriteModel
LoginHint: writeModel.LoginHint, LoginHint: writeModel.LoginHint,
HintUserID: writeModel.HintUserID, HintUserID: writeModel.HintUserID,
}, },
SessionID: writeModel.SessionID, SessionID: writeModel.SessionID,
UserID: writeModel.UserID, UserID: writeModel.UserID,
AMR: writeModel.AMR, AuthMethods: writeModel.AuthMethods,
AuthTime: writeModel.AuthTime, AuthTime: writeModel.AuthTime,
} }
} }

View File

@ -32,7 +32,7 @@ type AuthRequestWriteModel struct {
SessionID string SessionID string
UserID string UserID string
AuthTime time.Time AuthTime time.Time
AMR []string AuthMethods []domain.UserAuthMethodType
AuthRequestState domain.AuthRequestState AuthRequestState domain.AuthRequestState
} }
@ -68,7 +68,7 @@ func (m *AuthRequestWriteModel) Reduce() error {
m.SessionID = e.SessionID m.SessionID = e.SessionID
m.UserID = e.UserID m.UserID = e.UserID
m.AuthTime = e.AuthTime m.AuthTime = e.AuthTime
m.AMR = e.AMR m.AuthMethods = e.AuthMethods
case *authrequest.CodeAddedEvent: case *authrequest.CodeAddedEvent:
m.AuthRequestState = domain.AuthRequestStateCodeAdded m.AuthRequestState = domain.AuthRequestStateCodeAdded
case *authrequest.FailedEvent: case *authrequest.FailedEvent:

View File

@ -10,7 +10,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors" caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
@ -463,7 +462,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
"sessionID", "sessionID",
"userID", "userID",
testNow, testNow,
[]string{amr.PWD}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
), ),
)}), )}),
), ),
@ -492,9 +491,9 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
Audience: []string{"audience"}, Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode, ResponseType: domain.OIDCResponseTypeCode,
}, },
SessionID: "sessionID", SessionID: "sessionID",
UserID: "userID", UserID: "userID",
AMR: []string{amr.PWD}, AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
}, },
}, },
}, },
@ -542,7 +541,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
"sessionID", "sessionID",
"userID", "userID",
testNow, testNow,
[]string{amr.PWD}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
), ),
)}), )}),
), ),
@ -572,9 +571,9 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
Audience: []string{"audience"}, Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode, ResponseType: domain.OIDCResponseTypeCode,
}, },
SessionID: "sessionID", SessionID: "sessionID",
UserID: "userID", UserID: "userID",
AMR: []string{amr.PWD}, AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
}, },
}, },
}, },
@ -798,7 +797,7 @@ func TestCommands_AddAuthRequestCode(t *testing.T) {
"sessionID", "sessionID",
"userID", "userID",
testNow, testNow,
[]string{amr.PWD}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
), ),
), ),
), ),
@ -930,7 +929,7 @@ func TestCommands_ExchangeAuthCode(t *testing.T) {
"sessionID", "sessionID",
"userID", "userID",
testNow, testNow,
[]string{amr.PWD}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
), ),
), ),
eventFromEventPusher( eventFromEventPusher(
@ -972,9 +971,9 @@ func TestCommands_ExchangeAuthCode(t *testing.T) {
LoginHint: gu.Ptr("loginHint"), LoginHint: gu.Ptr("loginHint"),
HintUserID: gu.Ptr("hintUserID"), HintUserID: gu.Ptr("hintUserID"),
}, },
SessionID: "sessionID", SessionID: "sessionID",
UserID: "userID", UserID: "userID",
AMR: []string{"pwd"}, AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
}, },
}, },
}, },

View File

@ -7,7 +7,6 @@ import (
"time" "time"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors" caos_errs "github.com/zitadel/zitadel/internal/errors"
@ -195,7 +194,7 @@ func (c *OIDCSessionEvents) AddSession(ctx context.Context) {
c.authRequestWriteModel.ClientID, c.authRequestWriteModel.ClientID,
c.authRequestWriteModel.Audience, c.authRequestWriteModel.Audience,
c.authRequestWriteModel.Scope, c.authRequestWriteModel.Scope,
amr.List(c.sessionWriteModel), c.sessionWriteModel.AuthMethodTypes(),
c.sessionWriteModel.AuthenticationTime(), c.sessionWriteModel.AuthenticationTime(),
)) ))
} }

View File

@ -17,7 +17,7 @@ type OIDCSessionWriteModel struct {
ClientID string ClientID string
Audience []string Audience []string
Scope []string Scope []string
AuthMethodsReferences []string AuthMethods []domain.UserAuthMethodType
AuthTime time.Time AuthTime time.Time
State domain.OIDCSessionState State domain.OIDCSessionState
AccessTokenExpiration time.Time AccessTokenExpiration time.Time
@ -79,7 +79,7 @@ func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) {
wm.ClientID = e.ClientID wm.ClientID = e.ClientID
wm.Audience = e.Audience wm.Audience = e.Audience
wm.Scope = e.Scope wm.Scope = e.Scope
wm.AuthMethodsReferences = e.AuthMethodsReferences wm.AuthMethods = e.AuthMethods
wm.AuthTime = e.AuthTime wm.AuthTime = e.AuthTime
wm.State = domain.OIDCSessionStateActive wm.State = domain.OIDCSessionStateActive
} }

View File

@ -11,7 +11,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/oidc/amr"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors" caos_errs "github.com/zitadel/zitadel/internal/errors"
@ -99,7 +98,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
"sessionID", "sessionID",
"userID", "userID",
testNow, testNow,
[]string{amr.PWD}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
), ),
), ),
eventFromEventPusher( eventFromEventPusher(
@ -151,7 +150,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
"sessionID", "sessionID",
"userID", "userID",
testNow, testNow,
[]string{amr.PWD}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
), ),
), ),
eventFromEventPusher( eventFromEventPusher(
@ -179,7 +178,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
[]*repository.Event{ []*repository.Event{
eventFromEventPusherWithInstanceID("instanceID", eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid"}, []string{amr.PWD}, testNow), "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
), ),
eventFromEventPusherWithInstanceID("instanceID", eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@ -293,7 +292,7 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
"sessionID", "sessionID",
"userID", "userID",
testNow, testNow,
[]string{amr.PWD}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
), ),
), ),
eventFromEventPusher( eventFromEventPusher(
@ -345,7 +344,7 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
"sessionID", "sessionID",
"userID", "userID",
testNow, testNow,
[]string{amr.PWD}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
), ),
), ),
eventFromEventPusher( eventFromEventPusher(
@ -373,7 +372,7 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
[]*repository.Event{ []*repository.Event{
eventFromEventPusherWithInstanceID("instanceID", eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, []string{amr.PWD}, testNow), "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
), ),
eventFromEventPusherWithInstanceID("instanceID", eventFromEventPusherWithInstanceID("instanceID",
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@ -491,7 +490,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
), ),
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@ -517,7 +516,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
), ),
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@ -547,7 +546,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
expectFilter( expectFilter(
eventFromEventPusherWithCreationDateNow( eventFromEventPusherWithCreationDateNow(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
), ),
eventFromEventPusherWithCreationDateNow( eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@ -670,7 +669,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
), ),
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@ -695,7 +694,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
), ),
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@ -724,7 +723,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
expectFilter( expectFilter(
eventFromEventPusherWithCreationDateNow( eventFromEventPusherWithCreationDateNow(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
"userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow),
), ),
eventFromEventPusherWithCreationDateNow( eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate,
@ -753,7 +752,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
ClientID: "clientID", ClientID: "clientID",
Audience: []string{"audience"}, Audience: []string{"audience"},
Scope: []string{"openid", "profile", "offline_access"}, Scope: []string{"openid", "profile", "offline_access"},
AuthMethodsReferences: []string{amr.PWD}, AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
AuthTime: testNow, AuthTime: testNow,
State: domain.OIDCSessionStateActive, State: domain.OIDCSessionStateActive,
RefreshTokenID: "refreshTokenID", RefreshTokenID: "refreshTokenID",
@ -783,7 +782,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
assert.Equal(t, tt.res.model.ClientID, got.ClientID) assert.Equal(t, tt.res.model.ClientID, got.ClientID)
assert.Equal(t, tt.res.model.Audience, got.Audience) assert.Equal(t, tt.res.model.Audience, got.Audience)
assert.Equal(t, tt.res.model.Scope, got.Scope) assert.Equal(t, tt.res.model.Scope, got.Scope)
assert.Equal(t, tt.res.model.AuthMethodsReferences, got.AuthMethodsReferences) assert.Equal(t, tt.res.model.AuthMethods, got.AuthMethods)
assert.WithinRange(t, got.AuthTime, tt.res.model.AuthTime.Add(-2*time.Second), tt.res.model.AuthTime.Add(2*time.Second)) assert.WithinRange(t, got.AuthTime, tt.res.model.AuthTime.Add(-2*time.Second), tt.res.model.AuthTime.Add(2*time.Second))
assert.Equal(t, tt.res.model.State, got.State) assert.Equal(t, tt.res.model.State, got.State)
assert.Equal(t, tt.res.model.RefreshTokenID, got.RefreshTokenID) assert.Equal(t, tt.res.model.RefreshTokenID, got.RefreshTokenID)

View File

@ -52,24 +52,6 @@ type SessionWriteModel struct {
aggregate *eventstore.Aggregate aggregate *eventstore.Aggregate
} }
func (wm *SessionWriteModel) IsPasswordChecked() bool {
return !wm.PasswordCheckedAt.IsZero()
}
func (wm *SessionWriteModel) IsPasskeyChecked() bool {
return !wm.PasskeyCheckedAt.IsZero()
}
func (wm *SessionWriteModel) IsU2FChecked() bool {
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
return false
}
func (wm *SessionWriteModel) IsOTPChecked() bool {
// TODO: implement with https://github.com/zitadel/zitadel/issues/5477
return false
}
func NewSessionWriteModel(sessionID string, resourceOwner string) *SessionWriteModel { func NewSessionWriteModel(sessionID string, resourceOwner string) *SessionWriteModel {
return &SessionWriteModel{ return &SessionWriteModel{
WriteModel: eventstore.WriteModel{ WriteModel: eventstore.WriteModel{
@ -244,3 +226,27 @@ func (wm *SessionWriteModel) AuthenticationTime() time.Time {
} }
return authTime return authTime
} }
// AuthMethodTypes returns a list of UserAuthMethodTypes based on succeeded checks
func (wm *SessionWriteModel) AuthMethodTypes() []domain.UserAuthMethodType {
types := make([]domain.UserAuthMethodType, 0, domain.UserAuthMethodTypeIDP)
if !wm.PasswordCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypePassword)
}
if !wm.PasskeyCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypePasswordless)
}
if !wm.IntentCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypeIDP)
}
// TODO: add checks with https://github.com/zitadel/zitadel/issues/5477
/*
if !wm.TOTPCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypeOTP)
}
if !wm.U2FCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypeU2F)
}
*/
return types
}

View File

@ -103,10 +103,10 @@ func AddedEventMapper(event *repository.Event) (eventstore.Event, error) {
type SessionLinkedEvent struct { type SessionLinkedEvent struct {
eventstore.BaseEvent `json:"-"` eventstore.BaseEvent `json:"-"`
SessionID string `json:"session_id"` SessionID string `json:"session_id"`
UserID string `json:"user_id"` UserID string `json:"user_id"`
AuthTime time.Time `json:"auth_time"` AuthTime time.Time `json:"auth_time"`
AMR []string `json:"amr"` AuthMethods []domain.UserAuthMethodType `json:"auth_methods"`
} }
func (e *SessionLinkedEvent) Data() interface{} { func (e *SessionLinkedEvent) Data() interface{} {
@ -122,7 +122,7 @@ func NewSessionLinkedEvent(ctx context.Context,
sessionID, sessionID,
userID string, userID string,
authTime time.Time, authTime time.Time,
amr []string, authMethods []domain.UserAuthMethodType,
) *SessionLinkedEvent { ) *SessionLinkedEvent {
return &SessionLinkedEvent{ return &SessionLinkedEvent{
BaseEvent: *eventstore.NewBaseEventForPush( BaseEvent: *eventstore.NewBaseEventForPush(
@ -130,10 +130,10 @@ func NewSessionLinkedEvent(ctx context.Context,
aggregate, aggregate,
SessionLinkedType, SessionLinkedType,
), ),
SessionID: sessionID, SessionID: sessionID,
UserID: userID, UserID: userID,
AuthTime: authTime, AuthTime: authTime,
AMR: amr, AuthMethods: authMethods,
} }
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository" "github.com/zitadel/zitadel/internal/eventstore/repository"
@ -21,13 +22,13 @@ const (
type AddedEvent struct { type AddedEvent struct {
eventstore.BaseEvent `json:"-"` eventstore.BaseEvent `json:"-"`
UserID string `json:"userID"` UserID string `json:"userID"`
SessionID string `json:"sessionID"` SessionID string `json:"sessionID"`
ClientID string `json:"clientID"` ClientID string `json:"clientID"`
Audience []string `json:"audience"` Audience []string `json:"audience"`
Scope []string `json:"scope"` Scope []string `json:"scope"`
AuthMethodsReferences []string `json:"authMethodsReferences"` AuthMethods []domain.UserAuthMethodType `json:"authMethods"`
AuthTime time.Time `json:"authTime"` AuthTime time.Time `json:"authTime"`
} }
func (e *AddedEvent) Data() interface{} { func (e *AddedEvent) Data() interface{} {
@ -45,7 +46,7 @@ func NewAddedEvent(ctx context.Context,
clientID string, clientID string,
audience, audience,
scope []string, scope []string,
authMethodsReferences []string, authMethods []domain.UserAuthMethodType,
authTime time.Time, authTime time.Time,
) *AddedEvent { ) *AddedEvent {
return &AddedEvent{ return &AddedEvent{
@ -54,13 +55,13 @@ func NewAddedEvent(ctx context.Context,
aggregate, aggregate,
AddedType, AddedType,
), ),
UserID: userID, UserID: userID,
SessionID: sessionID, SessionID: sessionID,
ClientID: clientID, ClientID: clientID,
Audience: audience, Audience: audience,
Scope: scope, Scope: scope,
AuthMethodsReferences: authMethodsReferences, AuthMethods: authMethods,
AuthTime: authTime, AuthTime: authTime,
} }
} }