From bd5defa96a734365116456b63ae55422a472a45f Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Tue, 27 Jun 2023 14:36:07 +0200 Subject: [PATCH] fix: provide domain in session, passkey and u2f (#6097) This fix provides a possibility to pass a domain on the session, which will be used (as rpID) to create a passkey / u2f assertion and attestation. This is useful in cases where the login UI is served under a different domain / origin than the ZITADEL API. --- internal/api/grpc/session/v2/session.go | 3 +- .../session/v2/session_integration_test.go | 20 ++++- internal/api/grpc/user/v2/passkey.go | 4 +- internal/api/grpc/user/v2/u2f.go | 2 +- internal/command/session.go | 4 +- internal/command/session_model.go | 11 ++- internal/command/session_passkey.go | 2 +- internal/command/session_passkeys_test.go | 3 +- internal/command/session_test.go | 50 ++++++++++-- internal/command/user_converter.go | 1 + internal/command/user_human_webauthn.go | 18 ++--- internal/command/user_human_webauthn_model.go | 2 + internal/command/user_v2_passkey.go | 18 ++--- internal/command/user_v2_passkey_test.go | 16 ++-- internal/command/user_v2_u2f.go | 14 ++-- internal/command/user_v2_u2f_test.go | 11 +-- internal/domain/human_web_auth_n.go | 2 + internal/integration/client.go | 1 + internal/query/projection/session.go | 5 +- internal/query/projection/session_test.go | 23 +++--- internal/query/session.go | 13 ++++ internal/query/sessions_test.go | 78 +++++++++++-------- internal/repository/session/session.go | 4 + .../repository/user/human_mfa_passwordless.go | 2 + internal/repository/user/human_mfa_u2f.go | 2 + .../repository/user/human_mfa_web_auth_n.go | 3 + internal/webauthn/converter.go | 4 +- internal/webauthn/webauthn.go | 51 ++++++++---- internal/webauthn/webauthn_test.go | 23 ++++-- proto/zitadel/session/v2alpha/session.proto | 5 ++ .../session/v2alpha/session_service.proto | 5 ++ proto/zitadel/user/v2alpha/user_service.proto | 10 +++ 32 files changed, 287 insertions(+), 123 deletions(-) diff --git a/internal/api/grpc/session/v2/session.go b/internal/api/grpc/session/v2/session.go index 6c31fbeab8..8ac29a4880 100644 --- a/internal/api/grpc/session/v2/session.go +++ b/internal/api/grpc/session/v2/session.go @@ -47,7 +47,7 @@ func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRe } challengeResponse, cmds := s.challengesToCommand(req.GetChallenges(), checks) - set, err := s.command.CreateSession(ctx, cmds, metadata) + set, err := s.command.CreateSession(ctx, cmds, req.GetDomain(), metadata) if err != nil { return nil, err } @@ -107,6 +107,7 @@ func sessionToPb(s *query.Session) *session.Session { Sequence: s.Sequence, Factors: factorsToPb(s), Metadata: s.Metadata, + Domain: s.Domain, } } diff --git a/internal/api/grpc/session/v2/session_integration_test.go b/internal/api/grpc/session/v2/session_integration_test.go index 097bba7bea..6c3c2720dd 100644 --- a/internal/api/grpc/session/v2/session_integration_test.go +++ b/internal/api/grpc/session/v2/session_integration_test.go @@ -141,6 +141,7 @@ func TestServer_CreateSession(t *testing.T) { }, }, Metadata: map[string][]byte{"foo": []byte("bar")}, + Domain: "domain", }, want: &session.CreateSessionResponse{ Details: &object.Details{ @@ -169,6 +170,22 @@ func TestServer_CreateSession(t *testing.T) { }, wantErr: true, }, + { + name: "passkey without domain (not registered) error", + req: &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: User.GetUserId(), + }, + }, + }, + Challenges: []session.ChallengeKind{ + session.ChallengeKind_CHALLENGE_KIND_PASSKEY, + }, + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -198,6 +215,7 @@ func TestServer_CreateSession_passkey(t *testing.T) { Challenges: []session.ChallengeKind{ session.ChallengeKind_CHALLENGE_KIND_PASSKEY, }, + Domain: Tester.Config.ExternalDomain, }) require.NoError(t, err) verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil) @@ -325,7 +343,7 @@ func TestServer_SetSession_flow(t *testing.T) { var wantFactors []wantFactor // create new, empty session - createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) + createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{Domain: Tester.Config.ExternalDomain}) require.NoError(t, err) verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, wantFactors...) sessionToken := createResp.GetSessionToken() diff --git a/internal/api/grpc/user/v2/passkey.go b/internal/api/grpc/user/v2/passkey.go index 6f62be3dec..1b4414474c 100644 --- a/internal/api/grpc/user/v2/passkey.go +++ b/internal/api/grpc/user/v2/passkey.go @@ -20,11 +20,11 @@ func (s *Server) RegisterPasskey(ctx context.Context, req *user.RegisterPasskeyR ) if code := req.GetCode(); code != nil { return passkeyRegistrationDetailsToPb( - s.command.RegisterUserPasskeyWithCode(ctx, req.GetUserId(), resourceOwner, authenticator, code.Id, code.Code, s.userCodeAlg), + s.command.RegisterUserPasskeyWithCode(ctx, req.GetUserId(), resourceOwner, authenticator, code.Id, code.Code, req.GetDomain(), s.userCodeAlg), ) } return passkeyRegistrationDetailsToPb( - s.command.RegisterUserPasskey(ctx, req.GetUserId(), resourceOwner, authenticator), + s.command.RegisterUserPasskey(ctx, req.GetUserId(), resourceOwner, req.GetDomain(), authenticator), ) } diff --git a/internal/api/grpc/user/v2/u2f.go b/internal/api/grpc/user/v2/u2f.go index ae1daf8443..5178125f7a 100644 --- a/internal/api/grpc/user/v2/u2f.go +++ b/internal/api/grpc/user/v2/u2f.go @@ -12,7 +12,7 @@ import ( func (s *Server) RegisterU2F(ctx context.Context, req *user.RegisterU2FRequest) (*user.RegisterU2FResponse, error) { return u2fRegistrationDetailsToPb( - s.command.RegisterUserU2F(ctx, req.GetUserId(), authz.GetCtxData(ctx).ResourceOwner), + s.command.RegisterUserU2F(ctx, req.GetUserId(), authz.GetCtxData(ctx).ResourceOwner, req.GetDomain()), ) } diff --git a/internal/command/session.go b/internal/command/session.go index 71e9a69826..fb1dd42d04 100644 --- a/internal/command/session.go +++ b/internal/command/session.go @@ -157,7 +157,7 @@ func (s *SessionCommands) commands(ctx context.Context) (string, []eventstore.Co return token, s.sessionWriteModel.commands, nil } -func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, metadata map[string][]byte) (set *SessionChanged, err error) { +func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, sessionDomain string, metadata map[string][]byte) (set *SessionChanged, err error) { sessionID, err := c.idGenerator.Next() if err != nil { return nil, err @@ -167,8 +167,8 @@ func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, met if err != nil { return nil, err } + sessionWriteModel.Start(ctx, sessionDomain) cmd := c.NewSessionCommands(cmds, sessionWriteModel) - cmd.sessionWriteModel.Start(ctx) return c.updateSession(ctx, cmd, metadata) } diff --git a/internal/command/session_model.go b/internal/command/session_model.go index 99fd9f36e4..aed5426976 100644 --- a/internal/command/session_model.go +++ b/internal/command/session_model.go @@ -16,6 +16,7 @@ type PasskeyChallengeModel struct { Challenge string AllowedCrentialIDs [][]byte UserVerification domain.UserVerificationRequirement + RPID string } func (p *PasskeyChallengeModel) WebAuthNLogin(human *domain.Human, credentialAssertionData []byte) (*domain.WebAuthNLogin, error) { @@ -28,6 +29,7 @@ func (p *PasskeyChallengeModel) WebAuthNLogin(human *domain.Human, credentialAss Challenge: p.Challenge, AllowedCredentialIDs: p.AllowedCrentialIDs, UserVerification: p.UserVerification, + RPID: p.RPID, }, nil } @@ -41,6 +43,7 @@ type SessionWriteModel struct { IntentCheckedAt time.Time PasskeyCheckedAt time.Time Metadata map[string][]byte + Domain string State domain.SessionState PasskeyChallenge *PasskeyChallengeModel @@ -109,6 +112,7 @@ func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder { } func (wm *SessionWriteModel) reduceAdded(e *session.AddedEvent) { + wm.Domain = e.Domain wm.State = domain.SessionStateActive } @@ -130,6 +134,7 @@ func (wm *SessionWriteModel) reducePasskeyChallenged(e *session.PasskeyChallenge Challenge: e.Challenge, AllowedCrentialIDs: e.AllowedCrentialIDs, UserVerification: e.UserVerification, + RPID: wm.Domain, } } @@ -146,8 +151,10 @@ func (wm *SessionWriteModel) reduceTerminate() { wm.State = domain.SessionStateTerminated } -func (wm *SessionWriteModel) Start(ctx context.Context) { - wm.commands = append(wm.commands, session.NewAddedEvent(ctx, wm.aggregate)) +func (wm *SessionWriteModel) Start(ctx context.Context, domain string) { + wm.commands = append(wm.commands, session.NewAddedEvent(ctx, wm.aggregate, domain)) + // set the domain so checks can use it + wm.Domain = domain } func (wm *SessionWriteModel) UserChecked(ctx context.Context, userID string, checkedAt time.Time) error { diff --git a/internal/command/session_passkey.go b/internal/command/session_passkey.go index d4664ca2fe..6de850b23f 100644 --- a/internal/command/session_passkey.go +++ b/internal/command/session_passkey.go @@ -43,7 +43,7 @@ func (c *Commands) CreatePasskeyChallenge(userVerification domain.UserVerificati if err != nil { return err } - webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, humanPasskeys.human, userVerification, humanPasskeys.tokens...) + webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, humanPasskeys.human, userVerification, cmd.sessionWriteModel.Domain, humanPasskeys.tokens...) if err != nil { return err } diff --git a/internal/command/session_passkeys_test.go b/internal/command/session_passkeys_test.go index fdc3245374..a934c83d77 100644 --- a/internal/command/session_passkeys_test.go +++ b/internal/command/session_passkeys_test.go @@ -84,7 +84,7 @@ func TestSessionCommands_getHumanPasskeys(t *testing.T) { expectFilter(eventFromEventPusher( user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush( context.Background(), &org.NewAggregate("org1").Aggregate, user.HumanPasswordlessTokenAddedType, - ), "111", "challenge"), + ), "111", "challenge", "rpID"), )), ), sessionWriteModel: &SessionWriteModel{ @@ -112,6 +112,7 @@ func TestSessionCommands_getHumanPasskeys(t *testing.T) { WebAuthNTokenID: "111", State: domain.MFAStateNotReady, Challenge: "challenge", + RPID: "rpID", }}, }, err: nil, diff --git a/internal/command/session_test.go b/internal/command/session_test.go index f9287a7723..92d19ac243 100644 --- a/internal/command/session_test.go +++ b/internal/command/session_test.go @@ -147,6 +147,7 @@ func TestCommands_CreateSession(t *testing.T) { type args struct { ctx context.Context checks []SessionCommand + domain string metadata map[string][]byte } type res struct { @@ -194,7 +195,7 @@ func TestCommands_CreateSession(t *testing.T) { expectFilter(), expectPush( eventPusherToEvents( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate), + session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, ""), session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID", ), @@ -218,6 +219,39 @@ func TestCommands_CreateSession(t *testing.T) { }, }, }, + { + "empty session with domain", + fields{ + idGenerator: mock.NewIDGeneratorExpectIDs(t, "sessionID"), + eventstore: eventstoreExpect(t, + expectFilter(), + expectPush( + eventPusherToEvents( + session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"), + session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, + "tokenID", + ), + ), + ), + ), + tokenCreator: func(sessionID string) (string, string, error) { + return "tokenID", + "token", + nil + }, + }, + args{ + ctx: authz.NewMockContext("", "org1", ""), + domain: "domain.tld", + }, + res{ + want: &SessionChanged{ + ObjectDetails: &domain.ObjectDetails{ResourceOwner: "org1"}, + ID: "sessionID", + NewToken: "token", + }, + }, + }, // the rest is tested in the Test_updateSession } for _, tt := range tests { @@ -227,7 +261,7 @@ func TestCommands_CreateSession(t *testing.T) { idGenerator: tt.fields.idGenerator, sessionTokenCreator: tt.fields.tokenCreator, } - got, err := c.CreateSession(tt.args.ctx, tt.args.checks, tt.args.metadata) + got, err := c.CreateSession(tt.args.ctx, tt.args.checks, tt.args.domain, tt.args.metadata) require.ErrorIs(t, err, tt.res.err) assert.Equal(t, tt.res.want, got) }) @@ -276,7 +310,7 @@ func TestCommands_UpdateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID")), @@ -301,7 +335,7 @@ func TestCommands_UpdateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID")), @@ -739,7 +773,7 @@ func TestCommands_TerminateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID")), @@ -764,7 +798,7 @@ func TestCommands_TerminateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID")), @@ -793,7 +827,7 @@ func TestCommands_TerminateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID"), @@ -824,7 +858,7 @@ func TestCommands_TerminateSession(t *testing.T) { eventstore: eventstoreExpect(t, expectFilter( eventFromEventPusher( - session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), + session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")), eventFromEventPusher( session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "tokenID"), diff --git a/internal/command/user_converter.go b/internal/command/user_converter.go index 599671e1f6..62cee0bedd 100644 --- a/internal/command/user_converter.go +++ b/internal/command/user_converter.go @@ -140,6 +140,7 @@ func writeModelToWebAuthN(wm *HumanWebAuthNWriteModel) *domain.WebAuthNToken { SignCount: wm.SignCount, WebAuthNTokenName: wm.WebAuthNTokenName, State: wm.State, + RPID: wm.RPID, } } diff --git a/internal/command/user_human_webauthn.go b/internal/command/user_human_webauthn.go index 5156b5f8f3..c5e03bf3d4 100644 --- a/internal/command/user_human_webauthn.go +++ b/internal/command/user_human_webauthn.go @@ -27,8 +27,8 @@ func (c *Commands) getHumanU2FTokens(ctx context.Context, userID, resourceowner return readModelToU2FTokens(tokenReadModel), nil } -func (c *Commands) getHumanPasswordlessTokens(ctx context.Context, userID, resourceowner string) ([]*domain.WebAuthNToken, error) { - tokenReadModel := NewHumanPasswordlessTokensReadModel(userID, resourceowner) +func (c *Commands) getHumanPasswordlessTokens(ctx context.Context, userID, resourceOwner string) ([]*domain.WebAuthNToken, error) { + tokenReadModel := NewHumanPasswordlessTokensReadModel(userID, resourceOwner) err := c.eventstore.FilterToQueryReducer(ctx, tokenReadModel) if err != nil { return nil, err @@ -82,12 +82,12 @@ func (c *Commands) HumanAddU2FSetup(ctx context.Context, userID, resourceowner s if err != nil { return nil, err } - addWebAuthN, userAgg, webAuthN, err := c.addHumanWebAuthN(ctx, userID, resourceowner, isLoginUI, u2fTokens, domain.AuthenticatorAttachmentUnspecified, domain.UserVerificationRequirementDiscouraged) + addWebAuthN, userAgg, webAuthN, err := c.addHumanWebAuthN(ctx, userID, resourceowner, "", u2fTokens, domain.AuthenticatorAttachmentUnspecified, domain.UserVerificationRequirementDiscouraged) if err != nil { return nil, err } - events, err := c.eventstore.Push(ctx, usr_repo.NewHumanU2FAddedEvent(ctx, userAgg, addWebAuthN.WebauthNTokenID, webAuthN.Challenge)) + events, err := c.eventstore.Push(ctx, usr_repo.NewHumanU2FAddedEvent(ctx, userAgg, addWebAuthN.WebauthNTokenID, webAuthN.Challenge, "")) if err != nil { return nil, err } @@ -108,12 +108,12 @@ func (c *Commands) HumanAddPasswordlessSetup(ctx context.Context, userID, resour if err != nil { return nil, err } - addWebAuthN, userAgg, webAuthN, err := c.addHumanWebAuthN(ctx, userID, resourceowner, isLoginUI, passwordlessTokens, authenticatorPlatform, domain.UserVerificationRequirementRequired) + addWebAuthN, userAgg, webAuthN, err := c.addHumanWebAuthN(ctx, userID, resourceowner, "", passwordlessTokens, authenticatorPlatform, domain.UserVerificationRequirementRequired) if err != nil { return nil, err } - events, err := c.eventstore.Push(ctx, usr_repo.NewHumanPasswordlessAddedEvent(ctx, userAgg, addWebAuthN.WebauthNTokenID, webAuthN.Challenge)) + events, err := c.eventstore.Push(ctx, usr_repo.NewHumanPasswordlessAddedEvent(ctx, userAgg, addWebAuthN.WebauthNTokenID, webAuthN.Challenge, "")) if err != nil { return nil, err } @@ -137,7 +137,7 @@ func (c *Commands) HumanAddPasswordlessSetupInitCode(ctx context.Context, userID return c.HumanAddPasswordlessSetup(ctx, userID, resourceowner, true, preferredPlatformType) } -func (c *Commands) addHumanWebAuthN(ctx context.Context, userID, resourceowner string, isLoginUI bool, tokens []*domain.WebAuthNToken, authenticatorPlatform domain.AuthenticatorAttachment, userVerification domain.UserVerificationRequirement) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) { +func (c *Commands) addHumanWebAuthN(ctx context.Context, userID, resourceowner, rpID string, tokens []*domain.WebAuthNToken, authenticatorPlatform domain.AuthenticatorAttachment, userVerification domain.UserVerificationRequirement) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) { if userID == "" { return nil, nil, nil, caos_errs.ThrowPreconditionFailed(nil, "COMMAND-3M0od", "Errors.IDMissing") } @@ -157,7 +157,7 @@ func (c *Commands) addHumanWebAuthN(ctx context.Context, userID, resourceowner s if accountName == "" { accountName = string(user.EmailAddress) } - webAuthN, err := c.webauthnConfig.BeginRegistration(ctx, user, accountName, authenticatorPlatform, userVerification, isLoginUI, tokens...) + webAuthN, err := c.webauthnConfig.BeginRegistration(ctx, user, accountName, authenticatorPlatform, userVerification, rpID, tokens...) if err != nil { return nil, nil, nil, err } @@ -343,7 +343,7 @@ func (c *Commands) beginWebAuthNLogin(ctx context.Context, userID, resourceOwner if err != nil { return nil, nil, err } - webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, human, userVerification, tokens...) + webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, human, userVerification, "", tokens...) if err != nil { return nil, nil, err } diff --git a/internal/command/user_human_webauthn_model.go b/internal/command/user_human_webauthn_model.go index b92e35dbd0..daf9e880d4 100644 --- a/internal/command/user_human_webauthn_model.go +++ b/internal/command/user_human_webauthn_model.go @@ -21,6 +21,7 @@ type HumanWebAuthNWriteModel struct { AAGUID []byte SignCount uint32 WebAuthNTokenName string + RPID string State domain.MFAState } @@ -113,6 +114,7 @@ func (wm *HumanWebAuthNWriteModel) Reduce() error { func (wm *HumanWebAuthNWriteModel) appendAddedEvent(e *user.HumanWebAuthNAddedEvent) { wm.WebauthNTokenID = e.WebAuthNTokenID wm.Challenge = e.Challenge + wm.RPID = e.RPID wm.State = domain.MFAStateNotReady } diff --git a/internal/command/user_v2_passkey.go b/internal/command/user_v2_passkey.go index 19b06d97f9..f08a0632ca 100644 --- a/internal/command/user_v2_passkey.go +++ b/internal/command/user_v2_passkey.go @@ -16,22 +16,22 @@ import ( // RegisterUserPasskey creates a passkey registration for the current authenticated user. // UserID, usually taken from the request is compared against the user ID in the context. -func (c *Commands) RegisterUserPasskey(ctx context.Context, userID, resourceOwner string, authenticator domain.AuthenticatorAttachment) (*domain.WebAuthNRegistrationDetails, error) { +func (c *Commands) RegisterUserPasskey(ctx context.Context, userID, resourceOwner, rpID string, authenticator domain.AuthenticatorAttachment) (*domain.WebAuthNRegistrationDetails, error) { if err := authz.UserIDInCTX(ctx, userID); err != nil { return nil, err } - return c.registerUserPasskey(ctx, userID, resourceOwner, authenticator) + return c.registerUserPasskey(ctx, userID, resourceOwner, rpID, authenticator) } // RegisterUserPasskeyWithCode registers a new passkey for a unauthenticated user id. // The resource is protected by the code, identified by the codeID. -func (c *Commands) RegisterUserPasskeyWithCode(ctx context.Context, userID, resourceOwner string, authenticator domain.AuthenticatorAttachment, codeID, code string, alg crypto.EncryptionAlgorithm) (*domain.WebAuthNRegistrationDetails, error) { +func (c *Commands) RegisterUserPasskeyWithCode(ctx context.Context, userID, resourceOwner string, authenticator domain.AuthenticatorAttachment, codeID, code, rpID string, alg crypto.EncryptionAlgorithm) (*domain.WebAuthNRegistrationDetails, error) { event, err := c.verifyUserPasskeyCode(ctx, userID, resourceOwner, codeID, code, alg) if err != nil { return nil, err } - return c.registerUserPasskey(ctx, userID, resourceOwner, authenticator, event) + return c.registerUserPasskey(ctx, userID, resourceOwner, rpID, authenticator, event) } type eventCallback func(context.Context, *eventstore.Aggregate) eventstore.Command @@ -63,25 +63,25 @@ func (c *Commands) verifyUserPasskeyCodeFailed(ctx context.Context, wm *HumanPas logging.WithFields("userID", userAgg.ID).OnError(err).Error("RegisterUserPasskeyWithCode push failed") } -func (c *Commands) registerUserPasskey(ctx context.Context, userID, resourceOwner string, authenticator domain.AuthenticatorAttachment, events ...eventCallback) (*domain.WebAuthNRegistrationDetails, error) { - wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, userID, resourceOwner, authenticator) +func (c *Commands) registerUserPasskey(ctx context.Context, userID, resourceOwner, rpID string, authenticator domain.AuthenticatorAttachment, events ...eventCallback) (*domain.WebAuthNRegistrationDetails, error) { + wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, userID, resourceOwner, rpID, authenticator) if err != nil { return nil, err } return c.pushUserPasskey(ctx, wm, userAgg, webAuthN, events...) } -func (c *Commands) createUserPasskey(ctx context.Context, userID, resourceOwner string, authenticator domain.AuthenticatorAttachment) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) { +func (c *Commands) createUserPasskey(ctx context.Context, userID, resourceOwner, rpID string, authenticator domain.AuthenticatorAttachment) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) { passwordlessTokens, err := c.getHumanPasswordlessTokens(ctx, userID, resourceOwner) if err != nil { return nil, nil, nil, err } - return c.addHumanWebAuthN(ctx, userID, resourceOwner, false, passwordlessTokens, authenticator, domain.UserVerificationRequirementRequired) + return c.addHumanWebAuthN(ctx, userID, resourceOwner, rpID, passwordlessTokens, authenticator, domain.UserVerificationRequirementRequired) } func (c *Commands) pushUserPasskey(ctx context.Context, wm *HumanWebAuthNWriteModel, userAgg *eventstore.Aggregate, webAuthN *domain.WebAuthNToken, events ...eventCallback) (*domain.WebAuthNRegistrationDetails, error) { cmds := make([]eventstore.Command, len(events)+1) - cmds[0] = user.NewHumanPasswordlessAddedEvent(ctx, userAgg, wm.WebauthNTokenID, webAuthN.Challenge) + cmds[0] = user.NewHumanPasswordlessAddedEvent(ctx, userAgg, wm.WebauthNTokenID, webAuthN.Challenge, webAuthN.RPID) for i, event := range events { cmds[i+1] = event(ctx, userAgg) } diff --git a/internal/command/user_v2_passkey_test.go b/internal/command/user_v2_passkey_test.go index d8b0d0a188..1885f6ab27 100644 --- a/internal/command/user_v2_passkey_test.go +++ b/internal/command/user_v2_passkey_test.go @@ -40,6 +40,7 @@ func TestCommands_RegisterUserPasskey(t *testing.T) { type args struct { userID string resourceOwner string + rpID string authenticator domain.AuthenticatorAttachment } tests := []struct { @@ -121,7 +122,7 @@ func TestCommands_RegisterUserPasskey(t *testing.T) { idGenerator: tt.fields.idGenerator, webauthnConfig: webauthnConfig, } - _, err := c.RegisterUserPasskey(ctx, tt.args.userID, tt.args.resourceOwner, tt.args.authenticator) + _, err := c.RegisterUserPasskey(ctx, tt.args.userID, tt.args.resourceOwner, tt.args.rpID, tt.args.authenticator) require.ErrorIs(t, err, tt.wantErr) // successful case can't be tested due to random challenge. }) @@ -148,6 +149,7 @@ func TestCommands_RegisterUserPasskeyWithCode(t *testing.T) { type args struct { userID string resourceOwner string + rpID string authenticator domain.AuthenticatorAttachment codeID string code string @@ -222,7 +224,7 @@ func TestCommands_RegisterUserPasskeyWithCode(t *testing.T) { idGenerator: tt.fields.idGenerator, webauthnConfig: webauthnConfig, } - _, err := c.RegisterUserPasskeyWithCode(ctx, tt.args.userID, tt.args.resourceOwner, tt.args.authenticator, tt.args.codeID, tt.args.code, alg) + _, err := c.RegisterUserPasskeyWithCode(ctx, tt.args.userID, tt.args.resourceOwner, tt.args.authenticator, tt.args.codeID, tt.args.code, tt.args.rpID, alg) require.ErrorIs(t, err, tt.wantErr) // successful case can't be tested due to random challenge. }) @@ -376,7 +378,7 @@ func TestCommands_pushUserPasskey(t *testing.T) { expectFilter(eventFromEventPusher( user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush( ctx, &org.NewAggregate("org1").Aggregate, user.HumanPasswordlessTokenAddedType, - ), "111", "challenge"), + ), "111", "challenge", "rpID"), )), } @@ -394,7 +396,7 @@ func TestCommands_pushUserPasskey(t *testing.T) { expectPush: func(challenge string) expect { return expectPushFailed(io.ErrClosedPipe, []*repository.Event{eventFromEventPusher( user.NewHumanPasswordlessAddedEvent(ctx, - userAgg, "123", challenge, + userAgg, "123", challenge, "rpID", ), )}) }, @@ -406,7 +408,7 @@ func TestCommands_pushUserPasskey(t *testing.T) { expectPush: func(challenge string) expect { return expectPush([]*repository.Event{eventFromEventPusher( user.NewHumanPasswordlessAddedEvent(ctx, - userAgg, "123", challenge, + userAgg, "123", challenge, "rpID", ), )}) }, @@ -418,7 +420,7 @@ func TestCommands_pushUserPasskey(t *testing.T) { return expectPush([]*repository.Event{ eventFromEventPusher( user.NewHumanPasswordlessAddedEvent(ctx, - userAgg, "123", challenge, + userAgg, "123", challenge, "rpID", ), ), eventFromEventPusher( @@ -440,7 +442,7 @@ func TestCommands_pushUserPasskey(t *testing.T) { webauthnConfig: webauthnConfig, idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "123"), } - wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, "user1", "org1", domain.AuthenticatorAttachmentCrossPlattform) + wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, "user1", "org1", "rpID", domain.AuthenticatorAttachmentCrossPlattform) require.NoError(t, err) c.eventstore = eventstoreExpect(t, tt.expectPush(webAuthN.Challenge)) diff --git a/internal/command/user_v2_u2f.go b/internal/command/user_v2_u2f.go index fcda1de6b3..50c16aa30f 100644 --- a/internal/command/user_v2_u2f.go +++ b/internal/command/user_v2_u2f.go @@ -9,31 +9,31 @@ import ( "github.com/zitadel/zitadel/internal/repository/user" ) -func (c *Commands) RegisterUserU2F(ctx context.Context, userID, resourceOwner string) (*domain.WebAuthNRegistrationDetails, error) { +func (c *Commands) RegisterUserU2F(ctx context.Context, userID, resourceOwner, rpID string) (*domain.WebAuthNRegistrationDetails, error) { if err := authz.UserIDInCTX(ctx, userID); err != nil { return nil, err } - return c.registerUserU2F(ctx, userID, resourceOwner) + return c.registerUserU2F(ctx, userID, resourceOwner, rpID) } -func (c *Commands) registerUserU2F(ctx context.Context, userID, resourceOwner string) (*domain.WebAuthNRegistrationDetails, error) { - wm, userAgg, webAuthN, err := c.createUserU2F(ctx, userID, resourceOwner) +func (c *Commands) registerUserU2F(ctx context.Context, userID, resourceOwner, rpID string) (*domain.WebAuthNRegistrationDetails, error) { + wm, userAgg, webAuthN, err := c.createUserU2F(ctx, userID, resourceOwner, rpID) if err != nil { return nil, err } return c.pushUserU2F(ctx, wm, userAgg, webAuthN) } -func (c *Commands) createUserU2F(ctx context.Context, userID, resourceOwner string) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) { +func (c *Commands) createUserU2F(ctx context.Context, userID, resourceOwner, rpID string) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) { tokens, err := c.getHumanU2FTokens(ctx, userID, resourceOwner) if err != nil { return nil, nil, nil, err } - return c.addHumanWebAuthN(ctx, userID, resourceOwner, false, tokens, domain.AuthenticatorAttachmentUnspecified, domain.UserVerificationRequirementRequired) + return c.addHumanWebAuthN(ctx, userID, resourceOwner, rpID, tokens, domain.AuthenticatorAttachmentUnspecified, domain.UserVerificationRequirementRequired) } func (c *Commands) pushUserU2F(ctx context.Context, wm *HumanWebAuthNWriteModel, userAgg *eventstore.Aggregate, webAuthN *domain.WebAuthNToken) (*domain.WebAuthNRegistrationDetails, error) { - cmd := user.NewHumanU2FAddedEvent(ctx, userAgg, wm.WebauthNTokenID, webAuthN.Challenge) + cmd := user.NewHumanU2FAddedEvent(ctx, userAgg, wm.WebauthNTokenID, webAuthN.Challenge, webAuthN.RPID) err := c.pushAppendAndReduce(ctx, wm, cmd) if err != nil { return nil, err diff --git a/internal/command/user_v2_u2f_test.go b/internal/command/user_v2_u2f_test.go index aecb0ad79a..ebab0a4ced 100644 --- a/internal/command/user_v2_u2f_test.go +++ b/internal/command/user_v2_u2f_test.go @@ -37,6 +37,7 @@ func TestCommands_RegisterUserU2F(t *testing.T) { type args struct { userID string resourceOwner string + rpID string } tests := []struct { name string @@ -114,7 +115,7 @@ func TestCommands_RegisterUserU2F(t *testing.T) { idGenerator: tt.fields.idGenerator, webauthnConfig: webauthnConfig, } - _, err := c.RegisterUserU2F(ctx, tt.args.userID, tt.args.resourceOwner) + _, err := c.RegisterUserU2F(ctx, tt.args.userID, tt.args.resourceOwner, tt.args.rpID) require.ErrorIs(t, err, tt.wantErr) // successful case can't be tested due to random challenge. }) @@ -160,7 +161,7 @@ func TestCommands_pushUserU2F(t *testing.T) { expectFilter(eventFromEventPusher( user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush( ctx, &org.NewAggregate("org1").Aggregate, user.HumanPasswordlessTokenAddedType, - ), "111", "challenge"), + ), "111", "challenge", "rpID"), )), } @@ -174,7 +175,7 @@ func TestCommands_pushUserU2F(t *testing.T) { expectPush: func(challenge string) expect { return expectPushFailed(io.ErrClosedPipe, []*repository.Event{eventFromEventPusher( user.NewHumanU2FAddedEvent(ctx, - userAgg, "123", challenge, + userAgg, "123", challenge, "rpID", ), )}) }, @@ -185,7 +186,7 @@ func TestCommands_pushUserU2F(t *testing.T) { expectPush: func(challenge string) expect { return expectPush([]*repository.Event{eventFromEventPusher( user.NewHumanU2FAddedEvent(ctx, - userAgg, "123", challenge, + userAgg, "123", challenge, "rpID", ), )}) }, @@ -198,7 +199,7 @@ func TestCommands_pushUserU2F(t *testing.T) { webauthnConfig: webauthnConfig, idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "123"), } - wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, "user1", "org1", domain.AuthenticatorAttachmentCrossPlattform) + wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, "user1", "org1", "rpID", domain.AuthenticatorAttachmentCrossPlattform) require.NoError(t, err) c.eventstore = eventstoreExpect(t, tt.expectPush(webAuthN.Challenge)) diff --git a/internal/domain/human_web_auth_n.go b/internal/domain/human_web_auth_n.go index 8b1a7134b8..16590d43ca 100644 --- a/internal/domain/human_web_auth_n.go +++ b/internal/domain/human_web_auth_n.go @@ -23,6 +23,7 @@ type WebAuthNToken struct { AAGUID []byte SignCount uint32 WebAuthNTokenName string + RPID string } type WebAuthNLogin struct { @@ -32,6 +33,7 @@ type WebAuthNLogin struct { Challenge string AllowedCredentialIDs [][]byte UserVerification UserVerificationRequirement + RPID string } type UserVerificationRequirement int32 diff --git a/internal/integration/client.go b/internal/integration/client.go index f66db5b2aa..86b78bb87e 100644 --- a/internal/integration/client.go +++ b/internal/integration/client.go @@ -89,6 +89,7 @@ func (s *Tester) RegisterUserPasskey(ctx context.Context, userID string) { pkr, err := s.Client.UserV2.RegisterPasskey(ctx, &user.RegisterPasskeyRequest{ UserId: userID, Code: reg.GetCode(), + Domain: s.Config.ExternalDomain, }) logging.OnError(err).Fatal("create user passkey") attestationResponse, err := s.WebAuthN.CreateAttestationResponse(pkr.GetPublicKeyCredentialCreationOptions()) diff --git a/internal/query/projection/session.go b/internal/query/projection/session.go index dc31fd37a1..2fceb22439 100644 --- a/internal/query/projection/session.go +++ b/internal/query/projection/session.go @@ -14,7 +14,7 @@ import ( ) const ( - SessionsProjectionTable = "projections.sessions2" + SessionsProjectionTable = "projections.sessions3" SessionColumnID = "id" SessionColumnCreationDate = "creation_date" @@ -22,6 +22,7 @@ const ( SessionColumnSequence = "sequence" SessionColumnState = "state" SessionColumnResourceOwner = "resource_owner" + SessionColumnDomain = "domain" SessionColumnInstanceID = "instance_id" SessionColumnCreator = "creator" SessionColumnUserID = "user_id" @@ -49,6 +50,7 @@ func newSessionProjection(ctx context.Context, config crdb.StatementHandlerConfi crdb.NewColumn(SessionColumnSequence, crdb.ColumnTypeInt64), crdb.NewColumn(SessionColumnState, crdb.ColumnTypeEnum), crdb.NewColumn(SessionColumnResourceOwner, crdb.ColumnTypeText), + crdb.NewColumn(SessionColumnDomain, crdb.ColumnTypeText), crdb.NewColumn(SessionColumnInstanceID, crdb.ColumnTypeText), crdb.NewColumn(SessionColumnCreator, crdb.ColumnTypeText), crdb.NewColumn(SessionColumnUserID, crdb.ColumnTypeText, crdb.Nullable()), @@ -140,6 +142,7 @@ func (p *sessionProjection) reduceSessionAdded(event eventstore.Event) (*handler handler.NewCol(SessionColumnCreationDate, e.CreationDate()), handler.NewCol(SessionColumnChangeDate, e.CreationDate()), handler.NewCol(SessionColumnResourceOwner, e.Aggregate().ResourceOwner), + handler.NewCol(SessionColumnDomain, e.Domain), handler.NewCol(SessionColumnState, domain.SessionStateActive), handler.NewCol(SessionColumnSequence, e.Sequence()), handler.NewCol(SessionColumnCreator, e.User), diff --git a/internal/query/projection/session_test.go b/internal/query/projection/session_test.go index 976eab8b7f..ae19e247ec 100644 --- a/internal/query/projection/session_test.go +++ b/internal/query/projection/session_test.go @@ -30,7 +30,9 @@ func TestSessionProjection_reduces(t *testing.T) { event: getEvent(testEvent( session.AddedType, session.AggregateType, - []byte(`{}`), + []byte(`{ + "domain": "domain" + }`), ), session.AddedEventMapper), }, reduce: (&sessionProjection{}).reduceSessionAdded, @@ -41,13 +43,14 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "INSERT INTO projections.sessions2 (id, instance_id, creation_date, change_date, resource_owner, state, sequence, creator) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", + expectedStmt: "INSERT INTO projections.sessions3 (id, instance_id, creation_date, change_date, resource_owner, domain, state, sequence, creator) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)", expectedArgs: []interface{}{ "agg-id", "instance-id", anyArg{}, anyArg{}, "ro-id", + "domain", domain.SessionStateActive, uint64(15), "editor-user", @@ -77,7 +80,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions2 SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)", + expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -110,7 +113,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions2 SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -142,7 +145,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions2 SET (change_date, sequence, intent_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, intent_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -174,7 +177,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions2 SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -208,7 +211,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions2 SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -240,7 +243,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "DELETE FROM projections.sessions2 WHERE (id = $1) AND (instance_id = $2)", + expectedStmt: "DELETE FROM projections.sessions3 WHERE (id = $1) AND (instance_id = $2)", expectedArgs: []interface{}{ "agg-id", "instance-id", @@ -267,7 +270,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "DELETE FROM projections.sessions2 WHERE (instance_id = $1)", + expectedStmt: "DELETE FROM projections.sessions3 WHERE (instance_id = $1)", expectedArgs: []interface{}{ "agg-id", }, @@ -298,7 +301,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions2 SET password_checked_at = $1 WHERE (user_id = $2) AND (password_checked_at < $3)", + expectedStmt: "UPDATE projections.sessions3 SET password_checked_at = $1 WHERE (user_id = $2) AND (password_checked_at < $3)", expectedArgs: []interface{}{ nil, "agg-id", diff --git a/internal/query/session.go b/internal/query/session.go index 427fc90f2c..e9706410ac 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -29,6 +29,7 @@ type Session struct { Sequence uint64 State domain.SessionState ResourceOwner string + Domain string Creator string UserFactor SessionUserFactor PasswordFactor SessionPasswordFactor @@ -98,6 +99,10 @@ var ( name: projection.SessionColumnResourceOwner, table: sessionsTable, } + SessionColumnDomain = Column{ + name: projection.SessionColumnDomain, + table: sessionsTable, + } SessionColumnInstanceID = Column{ name: projection.SessionColumnInstanceID, table: sessionsTable, @@ -211,6 +216,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil SessionColumnState.identifier(), SessionColumnResourceOwner.identifier(), SessionColumnCreator.identifier(), + SessionColumnDomain.identifier(), SessionColumnUserID.identifier(), SessionColumnUserCheckedAt.identifier(), LoginNameNameCol.identifier(), @@ -236,6 +242,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil passkeyCheckedAt sql.NullTime metadata database.Map[[]byte] token sql.NullString + sessionDomain sql.NullString ) err := row.Scan( @@ -246,6 +253,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil &session.State, &session.ResourceOwner, &session.Creator, + &sessionDomain, &userID, &userCheckedAt, &loginName, @@ -264,6 +272,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil return nil, "", errors.ThrowInternal(err, "QUERY-SAder", "Errors.Internal") } + session.Domain = sessionDomain.String session.UserFactor.UserID = userID.String session.UserFactor.UserCheckedAt = userCheckedAt.Time session.UserFactor.LoginName = loginName.String @@ -286,6 +295,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui SessionColumnState.identifier(), SessionColumnResourceOwner.identifier(), SessionColumnCreator.identifier(), + SessionColumnDomain.identifier(), SessionColumnUserID.identifier(), SessionColumnUserCheckedAt.identifier(), LoginNameNameCol.identifier(), @@ -313,6 +323,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui intentCheckedAt sql.NullTime passkeyCheckedAt sql.NullTime metadata database.Map[[]byte] + sessionDomain sql.NullString ) err := rows.Scan( @@ -323,6 +334,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui &session.State, &session.ResourceOwner, &session.Creator, + &sessionDomain, &userID, &userCheckedAt, &loginName, @@ -337,6 +349,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui if err != nil { return nil, errors.ThrowInternal(err, "QUERY-SAfeg", "Errors.Internal") } + session.Domain = sessionDomain.String session.UserFactor.UserID = userID.String session.UserFactor.UserCheckedAt = userCheckedAt.Time session.UserFactor.LoginName = loginName.String diff --git a/internal/query/sessions_test.go b/internal/query/sessions_test.go index 84990e2e14..38662ead77 100644 --- a/internal/query/sessions_test.go +++ b/internal/query/sessions_test.go @@ -17,45 +17,47 @@ import ( ) var ( - expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions2.id,` + - ` projections.sessions2.creation_date,` + - ` projections.sessions2.change_date,` + - ` projections.sessions2.sequence,` + - ` projections.sessions2.state,` + - ` projections.sessions2.resource_owner,` + - ` projections.sessions2.creator,` + - ` projections.sessions2.user_id,` + - ` projections.sessions2.user_checked_at,` + + expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions3.id,` + + ` projections.sessions3.creation_date,` + + ` projections.sessions3.change_date,` + + ` projections.sessions3.sequence,` + + ` projections.sessions3.state,` + + ` projections.sessions3.resource_owner,` + + ` projections.sessions3.creator,` + + ` projections.sessions3.domain,` + + ` projections.sessions3.user_id,` + + ` projections.sessions3.user_checked_at,` + ` projections.login_names2.login_name,` + ` projections.users8_humans.display_name,` + - ` projections.sessions2.password_checked_at,` + - ` projections.sessions2.intent_checked_at,` + - ` projections.sessions2.passkey_checked_at,` + - ` projections.sessions2.metadata,` + - ` projections.sessions2.token_id` + - ` FROM projections.sessions2` + - ` LEFT JOIN projections.login_names2 ON projections.sessions2.user_id = projections.login_names2.user_id AND projections.sessions2.instance_id = projections.login_names2.instance_id` + - ` LEFT JOIN projections.users8_humans ON projections.sessions2.user_id = projections.users8_humans.user_id AND projections.sessions2.instance_id = projections.users8_humans.instance_id` + + ` projections.sessions3.password_checked_at,` + + ` projections.sessions3.intent_checked_at,` + + ` projections.sessions3.passkey_checked_at,` + + ` projections.sessions3.metadata,` + + ` projections.sessions3.token_id` + + ` FROM projections.sessions3` + + ` LEFT JOIN projections.login_names2 ON projections.sessions3.user_id = projections.login_names2.user_id AND projections.sessions3.instance_id = projections.login_names2.instance_id` + + ` LEFT JOIN projections.users8_humans ON projections.sessions3.user_id = projections.users8_humans.user_id AND projections.sessions3.instance_id = projections.users8_humans.instance_id` + ` AS OF SYSTEM TIME '-1 ms'`) - expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions2.id,` + - ` projections.sessions2.creation_date,` + - ` projections.sessions2.change_date,` + - ` projections.sessions2.sequence,` + - ` projections.sessions2.state,` + - ` projections.sessions2.resource_owner,` + - ` projections.sessions2.creator,` + - ` projections.sessions2.user_id,` + - ` projections.sessions2.user_checked_at,` + + expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions3.id,` + + ` projections.sessions3.creation_date,` + + ` projections.sessions3.change_date,` + + ` projections.sessions3.sequence,` + + ` projections.sessions3.state,` + + ` projections.sessions3.resource_owner,` + + ` projections.sessions3.creator,` + + ` projections.sessions3.domain,` + + ` projections.sessions3.user_id,` + + ` projections.sessions3.user_checked_at,` + ` projections.login_names2.login_name,` + ` projections.users8_humans.display_name,` + - ` projections.sessions2.password_checked_at,` + - ` projections.sessions2.intent_checked_at,` + - ` projections.sessions2.passkey_checked_at,` + - ` projections.sessions2.metadata,` + + ` projections.sessions3.password_checked_at,` + + ` projections.sessions3.intent_checked_at,` + + ` projections.sessions3.passkey_checked_at,` + + ` projections.sessions3.metadata,` + ` COUNT(*) OVER ()` + - ` FROM projections.sessions2` + - ` LEFT JOIN projections.login_names2 ON projections.sessions2.user_id = projections.login_names2.user_id AND projections.sessions2.instance_id = projections.login_names2.instance_id` + - ` LEFT JOIN projections.users8_humans ON projections.sessions2.user_id = projections.users8_humans.user_id AND projections.sessions2.instance_id = projections.users8_humans.instance_id` + + ` FROM projections.sessions3` + + ` LEFT JOIN projections.login_names2 ON projections.sessions3.user_id = projections.login_names2.user_id AND projections.sessions3.instance_id = projections.login_names2.instance_id` + + ` LEFT JOIN projections.users8_humans ON projections.sessions3.user_id = projections.users8_humans.user_id AND projections.sessions3.instance_id = projections.users8_humans.instance_id` + ` AS OF SYSTEM TIME '-1 ms'`) sessionCols = []string{ @@ -66,6 +68,7 @@ var ( "state", "resource_owner", "creator", + "domain", "user_id", "user_checked_at", "login_name", @@ -85,6 +88,7 @@ var ( "state", "resource_owner", "creator", + "domain", "user_id", "user_checked_at", "login_name", @@ -136,6 +140,7 @@ func Test_SessionsPrepare(t *testing.T) { domain.SessionStateActive, "ro", "creator", + "domain", "user-id", testNow, "login-name", @@ -161,6 +166,7 @@ func Test_SessionsPrepare(t *testing.T) { State: domain.SessionStateActive, ResourceOwner: "ro", Creator: "creator", + Domain: "domain", UserFactor: SessionUserFactor{ UserID: "user-id", UserCheckedAt: testNow, @@ -199,6 +205,7 @@ func Test_SessionsPrepare(t *testing.T) { domain.SessionStateActive, "ro", "creator", + "domain", "user-id", testNow, "login-name", @@ -216,6 +223,7 @@ func Test_SessionsPrepare(t *testing.T) { domain.SessionStateActive, "ro", "creator2", + "domain", "user-id2", testNow, "login-name2", @@ -241,6 +249,7 @@ func Test_SessionsPrepare(t *testing.T) { State: domain.SessionStateActive, ResourceOwner: "ro", Creator: "creator", + Domain: "domain", UserFactor: SessionUserFactor{ UserID: "user-id", UserCheckedAt: testNow, @@ -268,6 +277,7 @@ func Test_SessionsPrepare(t *testing.T) { State: domain.SessionStateActive, ResourceOwner: "ro", Creator: "creator2", + Domain: "domain", UserFactor: SessionUserFactor{ UserID: "user-id2", UserCheckedAt: testNow, @@ -359,6 +369,7 @@ func Test_SessionPrepare(t *testing.T) { domain.SessionStateActive, "ro", "creator", + "domain", "user-id", testNow, "login-name", @@ -379,6 +390,7 @@ func Test_SessionPrepare(t *testing.T) { State: domain.SessionStateActive, ResourceOwner: "ro", Creator: "creator", + Domain: "domain", UserFactor: SessionUserFactor{ UserID: "user-id", UserCheckedAt: testNow, diff --git a/internal/repository/session/session.go b/internal/repository/session/session.go index c11a6eb458..9b48543ecb 100644 --- a/internal/repository/session/session.go +++ b/internal/repository/session/session.go @@ -26,6 +26,8 @@ const ( type AddedEvent struct { eventstore.BaseEvent `json:"-"` + + Domain string `json:"domain,omitempty"` } func (e *AddedEvent) Data() interface{} { @@ -38,6 +40,7 @@ func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { func NewAddedEvent(ctx context.Context, aggregate *eventstore.Aggregate, + domain string, ) *AddedEvent { return &AddedEvent{ BaseEvent: *eventstore.NewBaseEventForPush( @@ -45,6 +48,7 @@ func NewAddedEvent(ctx context.Context, aggregate, AddedType, ), + Domain: domain, } } diff --git a/internal/repository/user/human_mfa_passwordless.go b/internal/repository/user/human_mfa_passwordless.go index f5e68b6394..1ffa14af3f 100644 --- a/internal/repository/user/human_mfa_passwordless.go +++ b/internal/repository/user/human_mfa_passwordless.go @@ -39,6 +39,7 @@ func NewHumanPasswordlessAddedEvent( aggregate *eventstore.Aggregate, webAuthNTokenID, challenge string, + rpID string, ) *HumanPasswordlessAddedEvent { return &HumanPasswordlessAddedEvent{ HumanWebAuthNAddedEvent: *NewHumanWebAuthNAddedEvent( @@ -49,6 +50,7 @@ func NewHumanPasswordlessAddedEvent( ), webAuthNTokenID, challenge, + rpID, ), } } diff --git a/internal/repository/user/human_mfa_u2f.go b/internal/repository/user/human_mfa_u2f.go index 27fe6be3fc..096026b872 100644 --- a/internal/repository/user/human_mfa_u2f.go +++ b/internal/repository/user/human_mfa_u2f.go @@ -28,6 +28,7 @@ func NewHumanU2FAddedEvent( aggregate *eventstore.Aggregate, webAuthNTokenID, challenge string, + rpID string, ) *HumanU2FAddedEvent { return &HumanU2FAddedEvent{ HumanWebAuthNAddedEvent: *NewHumanWebAuthNAddedEvent( @@ -38,6 +39,7 @@ func NewHumanU2FAddedEvent( ), webAuthNTokenID, challenge, + rpID, ), } } diff --git a/internal/repository/user/human_mfa_web_auth_n.go b/internal/repository/user/human_mfa_web_auth_n.go index faeb3124c6..77f27fb64b 100644 --- a/internal/repository/user/human_mfa_web_auth_n.go +++ b/internal/repository/user/human_mfa_web_auth_n.go @@ -14,6 +14,7 @@ type HumanWebAuthNAddedEvent struct { WebAuthNTokenID string `json:"webAuthNTokenId"` Challenge string `json:"challenge"` + RPID string `json:"rpID,omitempty"` } func (e *HumanWebAuthNAddedEvent) Data() interface{} { @@ -28,11 +29,13 @@ func NewHumanWebAuthNAddedEvent( base *eventstore.BaseEvent, webAuthNTokenID, challenge string, + rpID string, ) *HumanWebAuthNAddedEvent { return &HumanWebAuthNAddedEvent{ BaseEvent: *base, WebAuthNTokenID: webAuthNTokenID, Challenge: challenge, + RPID: rpID, } } diff --git a/internal/webauthn/converter.go b/internal/webauthn/converter.go index 9b49ec0163..36799ee3dc 100644 --- a/internal/webauthn/converter.go +++ b/internal/webauthn/converter.go @@ -7,10 +7,10 @@ import ( "github.com/zitadel/zitadel/internal/domain" ) -func WebAuthNsToCredentials(webAuthNs []*domain.WebAuthNToken) []webauthn.Credential { +func WebAuthNsToCredentials(webAuthNs []*domain.WebAuthNToken, rpID string) []webauthn.Credential { creds := make([]webauthn.Credential, 0) for _, webAuthN := range webAuthNs { - if webAuthN.State == domain.MFAStateReady { + if webAuthN.State == domain.MFAStateReady && webAuthN.RPID == rpID { creds = append(creds, webauthn.Credential{ ID: webAuthN.KeyID, PublicKey: webAuthN.PublicKey, diff --git a/internal/webauthn/webauthn.go b/internal/webauthn/webauthn.go index 5d2517bc47..bced4f220f 100644 --- a/internal/webauthn/webauthn.go +++ b/internal/webauthn/webauthn.go @@ -52,12 +52,12 @@ func (u *webUser) WebAuthnCredentials() []webauthn.Credential { return u.credentials } -func (w *Config) BeginRegistration(ctx context.Context, user *domain.Human, accountName string, authType domain.AuthenticatorAttachment, userVerification domain.UserVerificationRequirement, isLoginUI bool, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNToken, error) { - webAuthNServer, err := w.serverFromContext(ctx) +func (w *Config) BeginRegistration(ctx context.Context, user *domain.Human, accountName string, authType domain.AuthenticatorAttachment, userVerification domain.UserVerificationRequirement, rpID string, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNToken, error) { + webAuthNServer, err := w.serverFromContext(ctx, rpID, "") if err != nil { return nil, err } - creds := WebAuthNsToCredentials(webAuthNs) + creds := WebAuthNsToCredentials(webAuthNs, rpID) existing := make([]protocol.CredentialDescriptor, len(creds)) for i, cred := range creds { existing[i] = protocol.CredentialDescriptor{ @@ -90,6 +90,7 @@ func (w *Config) BeginRegistration(ctx context.Context, user *domain.Human, acco CredentialCreationData: cred, AllowedCredentialIDs: sessionData.AllowedCredentialIDs, UserVerification: UserVerificationToDomain(sessionData.UserVerification), + RPID: webAuthNServer.Config.RPID, }, nil } @@ -104,7 +105,7 @@ func (w *Config) FinishRegistration(ctx context.Context, user *domain.Human, web return nil, caos_errs.ThrowInternal(err, "WEBAU-sEr8c", "Errors.User.WebAuthN.ErrorOnParseCredential") } sessionData := WebAuthNToSessionData(webAuthN) - webAuthNServer, err := w.serverFromContext(ctx) + webAuthNServer, err := w.serverFromContext(ctx, webAuthN.RPID, credentialData.Response.CollectedClientData.Origin) if err != nil { return nil, err } @@ -124,17 +125,18 @@ func (w *Config) FinishRegistration(ctx context.Context, user *domain.Human, web webAuthN.AAGUID = credential.Authenticator.AAGUID webAuthN.SignCount = credential.Authenticator.SignCount webAuthN.WebAuthNTokenName = tokenName + webAuthN.RPID = webAuthNServer.Config.RPID return webAuthN, nil } -func (w *Config) BeginLogin(ctx context.Context, user *domain.Human, userVerification domain.UserVerificationRequirement, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNLogin, error) { - webAuthNServer, err := w.serverFromContext(ctx) +func (w *Config) BeginLogin(ctx context.Context, user *domain.Human, userVerification domain.UserVerificationRequirement, rpID string, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNLogin, error) { + webAuthNServer, err := w.serverFromContext(ctx, rpID, "") if err != nil { return nil, err } assertion, sessionData, err := webAuthNServer.BeginLogin(&webUser{ Human: user, - credentials: WebAuthNsToCredentials(webAuthNs), + credentials: WebAuthNsToCredentials(webAuthNs, rpID), }, webauthn.WithUserVerification(UserVerificationFromDomain(userVerification))) if err != nil { return nil, caos_errs.ThrowInternal(err, "WEBAU-4G8sw", "Errors.User.WebAuthN.BeginLoginFailed") @@ -148,6 +150,7 @@ func (w *Config) BeginLogin(ctx context.Context, user *domain.Human, userVerific CredentialAssertionData: cred, AllowedCredentialIDs: sessionData.AllowedCredentialIDs, UserVerification: userVerification, + RPID: webAuthNServer.Config.RPID, }, nil } @@ -158,9 +161,9 @@ func (w *Config) FinishLogin(ctx context.Context, user *domain.Human, webAuthN * } webUser := &webUser{ Human: user, - credentials: WebAuthNsToCredentials(webAuthNs), + credentials: WebAuthNsToCredentials(webAuthNs, webAuthN.RPID), } - webAuthNServer, err := w.serverFromContext(ctx) + webAuthNServer, err := w.serverFromContext(ctx, webAuthN.RPID, assertionData.Response.CollectedClientData.Origin) if err != nil { return nil, 0, err } @@ -175,15 +178,31 @@ func (w *Config) FinishLogin(ctx context.Context, user *domain.Human, webAuthN * return credential.ID, credential.Authenticator.SignCount, nil } -func (w *Config) serverFromContext(ctx context.Context) (*webauthn.WebAuthn, error) { - instance := authz.GetInstance(ctx) - webAuthn, err := webauthn.New(&webauthn.Config{ - RPDisplayName: w.DisplayName, - RPID: instance.RequestedDomain(), - RPOrigins: []string{http.BuildOrigin(instance.RequestedHost(), w.ExternalSecure)}, - }) +func (w *Config) serverFromContext(ctx context.Context, id, origin string) (*webauthn.WebAuthn, error) { + config := w.config(id, origin) + if id == "" { + config = w.configFromContext(ctx) + } + webAuthn, err := webauthn.New(config) if err != nil { return nil, caos_errs.ThrowInternal(err, "WEBAU-UX9ta", "Errors.User.WebAuthN.ServerConfig") } return webAuthn, nil } + +func (w *Config) configFromContext(ctx context.Context) *webauthn.Config { + instance := authz.GetInstance(ctx) + return &webauthn.Config{ + RPDisplayName: w.DisplayName, + RPID: instance.RequestedDomain(), + RPOrigins: []string{http.BuildOrigin(instance.RequestedHost(), w.ExternalSecure)}, + } +} + +func (w *Config) config(id, origin string) *webauthn.Config { + return &webauthn.Config{ + RPDisplayName: w.DisplayName, + RPID: id, + RPOrigins: []string{origin}, + } +} diff --git a/internal/webauthn/webauthn_test.go b/internal/webauthn/webauthn_test.go index df05d34e92..7d8052048a 100644 --- a/internal/webauthn/webauthn_test.go +++ b/internal/webauthn/webauthn_test.go @@ -14,7 +14,9 @@ import ( func TestConfig_serverFromContext(t *testing.T) { type args struct { - ctx context.Context + ctx context.Context + id string + origin string } tests := []struct { name string @@ -24,12 +26,12 @@ func TestConfig_serverFromContext(t *testing.T) { }{ { name: "webauthn error", - args: args{context.Background()}, + args: args{context.Background(), "", ""}, wantErr: caos_errs.ThrowInternal(nil, "WEBAU-UX9ta", "Errors.User.WebAuthN.ServerConfig"), }, { - name: "success", - args: args{authz.WithRequestedDomain(context.Background(), "example.com")}, + name: "success from ctx", + args: args{authz.WithRequestedDomain(context.Background(), "example.com"), "", ""}, want: &webauthn.WebAuthn{ Config: &webauthn.Config{ RPDisplayName: "DisplayName", @@ -38,6 +40,17 @@ func TestConfig_serverFromContext(t *testing.T) { }, }, }, + { + name: "success from id", + args: args{authz.WithRequestedDomain(context.Background(), "example.com"), "external.com", "https://external.com"}, + want: &webauthn.WebAuthn{ + Config: &webauthn.Config{ + RPDisplayName: "DisplayName", + RPID: "external.com", + RPOrigins: []string{"https://external.com"}, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -45,7 +58,7 @@ func TestConfig_serverFromContext(t *testing.T) { DisplayName: "DisplayName", ExternalSecure: true, } - got, err := w.serverFromContext(tt.args.ctx) + got, err := w.serverFromContext(tt.args.ctx, tt.args.id, tt.args.origin) require.ErrorIs(t, err, tt.wantErr) if tt.want != nil { require.NotNil(t, got) diff --git a/proto/zitadel/session/v2alpha/session.proto b/proto/zitadel/session/v2alpha/session.proto index a73eb8a79e..b3e11a4435 100644 --- a/proto/zitadel/session/v2alpha/session.proto +++ b/proto/zitadel/session/v2alpha/session.proto @@ -39,6 +39,11 @@ message Session { description: "\"custom key value list\""; } ]; + string domain = 7 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "\"domain on which the session was created\""; + } + ]; } message Factors { diff --git a/proto/zitadel/session/v2alpha/session_service.proto b/proto/zitadel/session/v2alpha/session_service.proto index 6ae18fd88d..2703176262 100644 --- a/proto/zitadel/session/v2alpha/session_service.proto +++ b/proto/zitadel/session/v2alpha/session_service.proto @@ -245,6 +245,11 @@ message CreateSessionRequest{ } ]; repeated ChallengeKind challenges = 3; + string domain = 4 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "\"Domain on which the session was created. Will be used for Passkey and U2F challenges.\""; + } + ]; } message CreateSessionResponse{ diff --git a/proto/zitadel/user/v2alpha/user_service.proto b/proto/zitadel/user/v2alpha/user_service.proto index 5c548f9178..1534a4d74c 100644 --- a/proto/zitadel/user/v2alpha/user_service.proto +++ b/proto/zitadel/user/v2alpha/user_service.proto @@ -587,6 +587,11 @@ message RegisterPasskeyRequest{ description: "\"Optionally specify the authenticator type of the passkey device (platform or cross-platform). If none is provided, both values are allowed.\""; } ]; + string domain = 4 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "\"Domain on which the user is authenticated.\""; + } + ]; } message RegisterPasskeyResponse{ @@ -658,6 +663,11 @@ message RegisterU2FRequest{ example: "\"d654e6ba-70a3-48ef-a95d-37c8d8a7901a\""; } ]; + string domain = 2 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "\"Domain on which the user is authenticated.\""; + } + ]; } message RegisterU2FResponse{