diff --git a/internal/api/oidc/amr.go b/internal/api/oidc/amr.go new file mode 100644 index 0000000000..727e0f8889 --- /dev/null +++ b/internal/api/oidc/amr.go @@ -0,0 +1,47 @@ +package oidc + +import "github.com/zitadel/zitadel/internal/domain" + +const ( + // Password states that the users password has been verified + // Deprecated: use `PWD` instead + Password = "password" + // PWD states that the users password has been verified + PWD = "pwd" + // MFA states that multiple factors have been verified (e.g. pwd and otp or passkey) + MFA = "mfa" + // OTP states that a one time password has been verified (e.g. TOTP) + OTP = "otp" + // UserPresence states that the end users presence has been verified (e.g. passkey and u2f) + UserPresence = "user" +) + +// AuthMethodTypesToAMR maps zitadel auth method types to Authentication Method Reference Values +// as defined in [RFC 8176, section 2]. +// +// [RFC 8176, section 2]: https://datatracker.ietf.org/doc/html/rfc8176#section-2 +func AuthMethodTypesToAMR(methodTypes []domain.UserAuthMethodType) []string { + amr := make([]string, 0, 4) + var mfa bool + for _, methodType := range methodTypes { + switch methodType { + case domain.UserAuthMethodTypePassword: + amr = append(amr, PWD) + case domain.UserAuthMethodTypePasswordless: + mfa = true + amr = append(amr, UserPresence) + case domain.UserAuthMethodTypeU2F: + amr = append(amr, UserPresence) + case domain.UserAuthMethodTypeOTP: + amr = append(amr, OTP) + case domain.UserAuthMethodTypeIDP: + // no AMR value according to specification + case domain.UserAuthMethodTypeUnspecified: + // ignore + } + } + if mfa || len(amr) >= 2 { + amr = append(amr, MFA) + } + return amr +} diff --git a/internal/api/oidc/amr/amr.go b/internal/api/oidc/amr/amr.go deleted file mode 100644 index 1791f767c8..0000000000 --- a/internal/api/oidc/amr/amr.go +++ /dev/null @@ -1,43 +0,0 @@ -// Package amr maps zitadel session factors to Authentication Method Reference Values -// as defined in [RFC 8176, section 2]. -// -// [RFC 8176, section 2]: https://datatracker.ietf.org/doc/html/rfc8176#section-2 -package amr - -const ( - // Password states that the users password has been verified - // Deprecated: use `PWD` instead - Password = "password" - // PWD states that the users password has been verified - PWD = "pwd" - // MFA states that multiple factors have been verified (e.g. pwd and otp or passkey) - MFA = "mfa" - // OTP states that a one time password has been verified (e.g. TOTP) - OTP = "otp" - // UserPresence states that the end users presence has been verified (e.g. passkey and u2f) - UserPresence = "user" -) - -type AuthenticationMethodReference interface { - IsPasswordChecked() bool - IsPasskeyChecked() bool - IsU2FChecked() bool - IsOTPChecked() bool -} - -func List(model AuthenticationMethodReference) []string { - amr := make([]string, 0) - if model.IsPasswordChecked() { - amr = append(amr, PWD) - } - if model.IsPasskeyChecked() || model.IsU2FChecked() { - amr = append(amr, UserPresence) - } - if model.IsOTPChecked() { - amr = append(amr, OTP) - } - if model.IsPasskeyChecked() || len(amr) >= 2 { - amr = append(amr, MFA) - } - return amr -} diff --git a/internal/api/oidc/amr/amr_test.go b/internal/api/oidc/amr_test.go similarity index 52% rename from internal/api/oidc/amr/amr_test.go rename to internal/api/oidc/amr_test.go index f2c5189bcf..1861085bc5 100644 --- a/internal/api/oidc/amr/amr_test.go +++ b/internal/api/oidc/amr_test.go @@ -1,14 +1,16 @@ -package amr +package oidc import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/zitadel/zitadel/internal/domain" ) func TestAMR(t *testing.T) { type args struct { - model AuthenticationMethodReference + methodTypes []domain.UserAuthMethodType } tests := []struct { name string @@ -18,76 +20,50 @@ func TestAMR(t *testing.T) { { "no checks, empty", args{ - new(test), + nil, }, []string{}, }, { "pw checked", args{ - &test{pwChecked: true}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, }, []string{PWD}, }, { "passkey checked", args{ - &test{passkeyChecked: true}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePasswordless}, }, []string{UserPresence, MFA}, }, { "u2f checked", args{ - &test{u2fChecked: true}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypeU2F}, }, []string{UserPresence}, }, { "otp checked", args{ - &test{otpChecked: true}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypeOTP}, }, []string{OTP}, }, { "multiple checked", args{ - &test{ - pwChecked: true, - u2fChecked: true, - }, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword, domain.UserAuthMethodTypeU2F}, }, []string{PWD, UserPresence, MFA}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := List(tt.args.model) + got := AuthMethodTypesToAMR(tt.args.methodTypes) assert.Equal(t, tt.want, got) }) } } - -type test struct { - pwChecked bool - passkeyChecked bool - u2fChecked bool - otpChecked bool -} - -func (t test) IsPasswordChecked() bool { - return t.pwChecked -} - -func (t test) IsPasskeyChecked() bool { - return t.passkeyChecked -} - -func (t test) IsU2FChecked() bool { - return t.u2fChecked -} - -func (t test) IsOTPChecked() bool { - return t.otpChecked -} diff --git a/internal/api/oidc/auth_request_converter.go b/internal/api/oidc/auth_request_converter.go index 02fd3273ed..02cbf8cba6 100644 --- a/internal/api/oidc/auth_request_converter.go +++ b/internal/api/oidc/auth_request_converter.go @@ -12,7 +12,6 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" http_utils "github.com/zitadel/zitadel/internal/api/http" - "github.com/zitadel/zitadel/internal/api/oidc/amr" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/user/model" @@ -34,10 +33,10 @@ func (a *AuthRequest) GetACR() string { func (a *AuthRequest) GetAMR() []string { list := make([]string, 0) if a.PasswordVerified { - list = append(list, amr.Password, amr.PWD) + list = append(list, Password, PWD) } if len(a.MFAsVerified) > 0 { - list = append(list, amr.MFA) + list = append(list, MFA) for _, mfa := range a.MFAsVerified { if amrMFA := AMRFromMFAType(mfa); amrMFA != "" { list = append(list, amrMFA) @@ -263,10 +262,10 @@ func CodeChallengeToOIDC(challenge *domain.OIDCCodeChallenge) *oidc.CodeChalleng func AMRFromMFAType(mfaType domain.MFAType) string { switch mfaType { case domain.MFATypeOTP: - return amr.OTP + return OTP case domain.MFATypeU2F, domain.MFATypeU2FUserVerification: - return amr.UserPresence + return UserPresence default: return "" } diff --git a/internal/api/oidc/auth_request_converter_v2.go b/internal/api/oidc/auth_request_converter_v2.go index fd9c5f48ef..9f4c0d0a1f 100644 --- a/internal/api/oidc/auth_request_converter_v2.go +++ b/internal/api/oidc/auth_request_converter_v2.go @@ -21,7 +21,7 @@ func (a *AuthRequestV2) GetACR() string { } func (a *AuthRequestV2) GetAMR() []string { - return a.AMR + return AuthMethodTypesToAMR(a.AuthMethods) } func (a *AuthRequestV2) GetAudience() []string { @@ -78,7 +78,7 @@ type RefreshTokenRequestV2 struct { } func (r *RefreshTokenRequestV2) GetAMR() []string { - return r.AuthMethodsReferences + return AuthMethodTypesToAMR(r.AuthMethods) } func (r *RefreshTokenRequestV2) GetAudience() []string { diff --git a/internal/api/oidc/auth_request_integration_test.go b/internal/api/oidc/auth_request_integration_test.go index d58958256f..1223ff47e7 100644 --- a/internal/api/oidc/auth_request_integration_test.go +++ b/internal/api/oidc/auth_request_integration_test.go @@ -15,7 +15,7 @@ import ( "github.com/zitadel/oidc/v2/pkg/oidc" "golang.org/x/oauth2" - "github.com/zitadel/zitadel/internal/api/oidc/amr" + oidc_api "github.com/zitadel/zitadel/internal/api/oidc" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/integration" oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha" @@ -270,6 +270,6 @@ func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requir func assertTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, sessionStart, sessionChange time.Time) { assert.Equal(t, User.GetUserId(), claims.Subject) - assert.Equal(t, []string{amr.UserPresence, amr.MFA}, claims.AuthenticationMethodsReferences) + assert.Equal(t, []string{oidc_api.UserPresence, oidc_api.MFA}, claims.AuthenticationMethodsReferences) assert.WithinRange(t, claims.AuthTime.AsTime().UTC(), sessionStart.Add(-1*time.Second), sessionChange.Add(1*time.Second)) } diff --git a/internal/command/auth_request.go b/internal/command/auth_request.go index 1ba0a55173..bbcee503c0 100644 --- a/internal/command/auth_request.go +++ b/internal/command/auth_request.go @@ -5,7 +5,6 @@ import ( "time" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/oidc/amr" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/repository/authrequest" @@ -32,10 +31,10 @@ type AuthRequest struct { type CurrentAuthRequest struct { *AuthRequest - SessionID string - UserID string - AMR []string - AuthTime time.Time + SessionID string + UserID string + AuthMethods []domain.UserAuthMethodType + AuthTime time.Time } const IDPrefixV2 = "V2_" @@ -108,7 +107,7 @@ func (c *Commands) LinkSessionToAuthRequest(ctx context.Context, id, sessionID, sessionID, sessionWriteModel.UserID, sessionWriteModel.AuthenticationTime(), - amr.List(sessionWriteModel), + sessionWriteModel.AuthMethodTypes(), )); err != nil { return nil, nil, err } @@ -187,10 +186,10 @@ func authRequestWriteModelToCurrentAuthRequest(writeModel *AuthRequestWriteModel LoginHint: writeModel.LoginHint, HintUserID: writeModel.HintUserID, }, - SessionID: writeModel.SessionID, - UserID: writeModel.UserID, - AMR: writeModel.AMR, - AuthTime: writeModel.AuthTime, + SessionID: writeModel.SessionID, + UserID: writeModel.UserID, + AuthMethods: writeModel.AuthMethods, + AuthTime: writeModel.AuthTime, } } diff --git a/internal/command/auth_request_model.go b/internal/command/auth_request_model.go index 91bbaf955a..d27caf8764 100644 --- a/internal/command/auth_request_model.go +++ b/internal/command/auth_request_model.go @@ -32,7 +32,7 @@ type AuthRequestWriteModel struct { SessionID string UserID string AuthTime time.Time - AMR []string + AuthMethods []domain.UserAuthMethodType AuthRequestState domain.AuthRequestState } @@ -68,7 +68,7 @@ func (m *AuthRequestWriteModel) Reduce() error { m.SessionID = e.SessionID m.UserID = e.UserID m.AuthTime = e.AuthTime - m.AMR = e.AMR + m.AuthMethods = e.AuthMethods case *authrequest.CodeAddedEvent: m.AuthRequestState = domain.AuthRequestStateCodeAdded case *authrequest.FailedEvent: diff --git a/internal/command/auth_request_test.go b/internal/command/auth_request_test.go index ce12a119c1..f54043b368 100644 --- a/internal/command/auth_request_test.go +++ b/internal/command/auth_request_test.go @@ -10,7 +10,6 @@ import ( "github.com/stretchr/testify/require" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/oidc/amr" "github.com/zitadel/zitadel/internal/domain" caos_errs "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore" @@ -463,7 +462,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) { "sessionID", "userID", testNow, - []string{amr.PWD}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, ), )}), ), @@ -492,9 +491,9 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) { Audience: []string{"audience"}, ResponseType: domain.OIDCResponseTypeCode, }, - SessionID: "sessionID", - UserID: "userID", - AMR: []string{amr.PWD}, + SessionID: "sessionID", + UserID: "userID", + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, }, }, }, @@ -542,7 +541,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) { "sessionID", "userID", testNow, - []string{amr.PWD}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, ), )}), ), @@ -572,9 +571,9 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) { Audience: []string{"audience"}, ResponseType: domain.OIDCResponseTypeCode, }, - SessionID: "sessionID", - UserID: "userID", - AMR: []string{amr.PWD}, + SessionID: "sessionID", + UserID: "userID", + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, }, }, }, @@ -798,7 +797,7 @@ func TestCommands_AddAuthRequestCode(t *testing.T) { "sessionID", "userID", testNow, - []string{amr.PWD}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, ), ), ), @@ -930,7 +929,7 @@ func TestCommands_ExchangeAuthCode(t *testing.T) { "sessionID", "userID", testNow, - []string{amr.PWD}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, ), ), eventFromEventPusher( @@ -972,9 +971,9 @@ func TestCommands_ExchangeAuthCode(t *testing.T) { LoginHint: gu.Ptr("loginHint"), HintUserID: gu.Ptr("hintUserID"), }, - SessionID: "sessionID", - UserID: "userID", - AMR: []string{"pwd"}, + SessionID: "sessionID", + UserID: "userID", + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, }, }, }, diff --git a/internal/command/oidc_session.go b/internal/command/oidc_session.go index 6af09d91b0..2505c22b43 100644 --- a/internal/command/oidc_session.go +++ b/internal/command/oidc_session.go @@ -7,7 +7,6 @@ import ( "time" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/oidc/amr" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" caos_errs "github.com/zitadel/zitadel/internal/errors" @@ -195,7 +194,7 @@ func (c *OIDCSessionEvents) AddSession(ctx context.Context) { c.authRequestWriteModel.ClientID, c.authRequestWriteModel.Audience, c.authRequestWriteModel.Scope, - amr.List(c.sessionWriteModel), + c.sessionWriteModel.AuthMethodTypes(), c.sessionWriteModel.AuthenticationTime(), )) } diff --git a/internal/command/oidc_session_model.go b/internal/command/oidc_session_model.go index f1c117f2b2..31eb98b451 100644 --- a/internal/command/oidc_session_model.go +++ b/internal/command/oidc_session_model.go @@ -17,7 +17,7 @@ type OIDCSessionWriteModel struct { ClientID string Audience []string Scope []string - AuthMethodsReferences []string + AuthMethods []domain.UserAuthMethodType AuthTime time.Time State domain.OIDCSessionState AccessTokenExpiration time.Time @@ -79,7 +79,7 @@ func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) { wm.ClientID = e.ClientID wm.Audience = e.Audience wm.Scope = e.Scope - wm.AuthMethodsReferences = e.AuthMethodsReferences + wm.AuthMethods = e.AuthMethods wm.AuthTime = e.AuthTime wm.State = domain.OIDCSessionStateActive } diff --git a/internal/command/oidc_session_test.go b/internal/command/oidc_session_test.go index 2a1dfe08af..3bbc8f5791 100644 --- a/internal/command/oidc_session_test.go +++ b/internal/command/oidc_session_test.go @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/require" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/oidc/amr" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" caos_errs "github.com/zitadel/zitadel/internal/errors" @@ -99,7 +98,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) { "sessionID", "userID", testNow, - []string{amr.PWD}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, ), ), eventFromEventPusher( @@ -151,7 +150,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) { "sessionID", "userID", testNow, - []string{amr.PWD}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, ), ), eventFromEventPusher( @@ -179,7 +178,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) { []*repository.Event{ eventFromEventPusherWithInstanceID("instanceID", oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, - "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid"}, []string{amr.PWD}, testNow), + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), ), eventFromEventPusherWithInstanceID("instanceID", oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, @@ -293,7 +292,7 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) { "sessionID", "userID", testNow, - []string{amr.PWD}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, ), ), eventFromEventPusher( @@ -345,7 +344,7 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) { "sessionID", "userID", testNow, - []string{amr.PWD}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, ), ), eventFromEventPusher( @@ -373,7 +372,7 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) { []*repository.Event{ eventFromEventPusherWithInstanceID("instanceID", oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, - "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, []string{amr.PWD}, testNow), + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), ), eventFromEventPusherWithInstanceID("instanceID", oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, @@ -491,7 +490,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { expectFilter( eventFromEventPusher( oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, - "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), ), eventFromEventPusher( oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, @@ -517,7 +516,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { expectFilter( eventFromEventPusher( oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, - "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), ), eventFromEventPusher( oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, @@ -547,7 +546,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { expectFilter( eventFromEventPusherWithCreationDateNow( oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, - "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), ), eventFromEventPusherWithCreationDateNow( oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, @@ -670,7 +669,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { expectFilter( eventFromEventPusher( oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, - "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), ), eventFromEventPusher( oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, @@ -695,7 +694,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { expectFilter( eventFromEventPusher( oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, - "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), ), eventFromEventPusher( oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, @@ -724,7 +723,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { expectFilter( eventFromEventPusherWithCreationDateNow( oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, - "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow), ), eventFromEventPusherWithCreationDateNow( oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, @@ -753,7 +752,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { ClientID: "clientID", Audience: []string{"audience"}, Scope: []string{"openid", "profile", "offline_access"}, - AuthMethodsReferences: []string{amr.PWD}, + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, AuthTime: testNow, State: domain.OIDCSessionStateActive, RefreshTokenID: "refreshTokenID", @@ -783,7 +782,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { assert.Equal(t, tt.res.model.ClientID, got.ClientID) assert.Equal(t, tt.res.model.Audience, got.Audience) assert.Equal(t, tt.res.model.Scope, got.Scope) - assert.Equal(t, tt.res.model.AuthMethodsReferences, got.AuthMethodsReferences) + assert.Equal(t, tt.res.model.AuthMethods, got.AuthMethods) assert.WithinRange(t, got.AuthTime, tt.res.model.AuthTime.Add(-2*time.Second), tt.res.model.AuthTime.Add(2*time.Second)) assert.Equal(t, tt.res.model.State, got.State) assert.Equal(t, tt.res.model.RefreshTokenID, got.RefreshTokenID) diff --git a/internal/command/session_model.go b/internal/command/session_model.go index 2779c7dc37..261de4cd98 100644 --- a/internal/command/session_model.go +++ b/internal/command/session_model.go @@ -52,24 +52,6 @@ type SessionWriteModel struct { aggregate *eventstore.Aggregate } -func (wm *SessionWriteModel) IsPasswordChecked() bool { - return !wm.PasswordCheckedAt.IsZero() -} - -func (wm *SessionWriteModel) IsPasskeyChecked() bool { - return !wm.PasskeyCheckedAt.IsZero() -} - -func (wm *SessionWriteModel) IsU2FChecked() bool { - // TODO: implement with https://github.com/zitadel/zitadel/issues/5477 - return false -} - -func (wm *SessionWriteModel) IsOTPChecked() bool { - // TODO: implement with https://github.com/zitadel/zitadel/issues/5477 - return false -} - func NewSessionWriteModel(sessionID string, resourceOwner string) *SessionWriteModel { return &SessionWriteModel{ WriteModel: eventstore.WriteModel{ @@ -244,3 +226,27 @@ func (wm *SessionWriteModel) AuthenticationTime() time.Time { } return authTime } + +// AuthMethodTypes returns a list of UserAuthMethodTypes based on succeeded checks +func (wm *SessionWriteModel) AuthMethodTypes() []domain.UserAuthMethodType { + types := make([]domain.UserAuthMethodType, 0, domain.UserAuthMethodTypeIDP) + if !wm.PasswordCheckedAt.IsZero() { + types = append(types, domain.UserAuthMethodTypePassword) + } + if !wm.PasskeyCheckedAt.IsZero() { + types = append(types, domain.UserAuthMethodTypePasswordless) + } + if !wm.IntentCheckedAt.IsZero() { + types = append(types, domain.UserAuthMethodTypeIDP) + } + // TODO: add checks with https://github.com/zitadel/zitadel/issues/5477 + /* + if !wm.TOTPCheckedAt.IsZero() { + types = append(types, domain.UserAuthMethodTypeOTP) + } + if !wm.U2FCheckedAt.IsZero() { + types = append(types, domain.UserAuthMethodTypeU2F) + } + */ + return types +} diff --git a/internal/repository/authrequest/auth_request.go b/internal/repository/authrequest/auth_request.go index 06ef5fea9e..a279f23c92 100644 --- a/internal/repository/authrequest/auth_request.go +++ b/internal/repository/authrequest/auth_request.go @@ -103,10 +103,10 @@ func AddedEventMapper(event *repository.Event) (eventstore.Event, error) { type SessionLinkedEvent struct { eventstore.BaseEvent `json:"-"` - SessionID string `json:"session_id"` - UserID string `json:"user_id"` - AuthTime time.Time `json:"auth_time"` - AMR []string `json:"amr"` + SessionID string `json:"session_id"` + UserID string `json:"user_id"` + AuthTime time.Time `json:"auth_time"` + AuthMethods []domain.UserAuthMethodType `json:"auth_methods"` } func (e *SessionLinkedEvent) Data() interface{} { @@ -122,7 +122,7 @@ func NewSessionLinkedEvent(ctx context.Context, sessionID, userID string, authTime time.Time, - amr []string, + authMethods []domain.UserAuthMethodType, ) *SessionLinkedEvent { return &SessionLinkedEvent{ BaseEvent: *eventstore.NewBaseEventForPush( @@ -130,10 +130,10 @@ func NewSessionLinkedEvent(ctx context.Context, aggregate, SessionLinkedType, ), - SessionID: sessionID, - UserID: userID, - AuthTime: authTime, - AMR: amr, + SessionID: sessionID, + UserID: userID, + AuthTime: authTime, + AuthMethods: authMethods, } } diff --git a/internal/repository/oidcsession/oidc_session.go b/internal/repository/oidcsession/oidc_session.go index 842013b34f..b128c45dc2 100644 --- a/internal/repository/oidcsession/oidc_session.go +++ b/internal/repository/oidcsession/oidc_session.go @@ -5,6 +5,7 @@ import ( "encoding/json" "time" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore/repository" @@ -21,13 +22,13 @@ const ( type AddedEvent struct { eventstore.BaseEvent `json:"-"` - UserID string `json:"userID"` - SessionID string `json:"sessionID"` - ClientID string `json:"clientID"` - Audience []string `json:"audience"` - Scope []string `json:"scope"` - AuthMethodsReferences []string `json:"authMethodsReferences"` - AuthTime time.Time `json:"authTime"` + UserID string `json:"userID"` + SessionID string `json:"sessionID"` + ClientID string `json:"clientID"` + Audience []string `json:"audience"` + Scope []string `json:"scope"` + AuthMethods []domain.UserAuthMethodType `json:"authMethods"` + AuthTime time.Time `json:"authTime"` } func (e *AddedEvent) Data() interface{} { @@ -45,7 +46,7 @@ func NewAddedEvent(ctx context.Context, clientID string, audience, scope []string, - authMethodsReferences []string, + authMethods []domain.UserAuthMethodType, authTime time.Time, ) *AddedEvent { return &AddedEvent{ @@ -54,13 +55,13 @@ func NewAddedEvent(ctx context.Context, aggregate, AddedType, ), - UserID: userID, - SessionID: sessionID, - ClientID: clientID, - Audience: audience, - Scope: scope, - AuthMethodsReferences: authMethodsReferences, - AuthTime: authTime, + UserID: userID, + SessionID: sessionID, + ClientID: clientID, + Audience: audience, + Scope: scope, + AuthMethods: authMethods, + AuthTime: authTime, } }