fix: provide domain in session, passkey and u2f (#6097)

This fix provides a possibility to pass a domain on the session, which
will be used (as rpID) to create a passkey / u2f assertion and
attestation. This is useful in cases where the login UI is served under
a different domain / origin than the ZITADEL API.
This commit is contained in:
Livio Spring 2023-06-27 14:36:07 +02:00 committed by GitHub
parent d0cda1b479
commit bd5defa96a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 287 additions and 123 deletions

View File

@ -47,7 +47,7 @@ func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRe
} }
challengeResponse, cmds := s.challengesToCommand(req.GetChallenges(), checks) challengeResponse, cmds := s.challengesToCommand(req.GetChallenges(), checks)
set, err := s.command.CreateSession(ctx, cmds, metadata) set, err := s.command.CreateSession(ctx, cmds, req.GetDomain(), metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -107,6 +107,7 @@ func sessionToPb(s *query.Session) *session.Session {
Sequence: s.Sequence, Sequence: s.Sequence,
Factors: factorsToPb(s), Factors: factorsToPb(s),
Metadata: s.Metadata, Metadata: s.Metadata,
Domain: s.Domain,
} }
} }

View File

@ -141,6 +141,7 @@ func TestServer_CreateSession(t *testing.T) {
}, },
}, },
Metadata: map[string][]byte{"foo": []byte("bar")}, Metadata: map[string][]byte{"foo": []byte("bar")},
Domain: "domain",
}, },
want: &session.CreateSessionResponse{ want: &session.CreateSessionResponse{
Details: &object.Details{ Details: &object.Details{
@ -169,6 +170,22 @@ func TestServer_CreateSession(t *testing.T) {
}, },
wantErr: true, wantErr: true,
}, },
{
name: "passkey without domain (not registered) error",
req: &session.CreateSessionRequest{
Checks: &session.Checks{
User: &session.CheckUser{
Search: &session.CheckUser_UserId{
UserId: User.GetUserId(),
},
},
},
Challenges: []session.ChallengeKind{
session.ChallengeKind_CHALLENGE_KIND_PASSKEY,
},
},
wantErr: true,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -198,6 +215,7 @@ func TestServer_CreateSession_passkey(t *testing.T) {
Challenges: []session.ChallengeKind{ Challenges: []session.ChallengeKind{
session.ChallengeKind_CHALLENGE_KIND_PASSKEY, session.ChallengeKind_CHALLENGE_KIND_PASSKEY,
}, },
Domain: Tester.Config.ExternalDomain,
}) })
require.NoError(t, err) require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil) verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil)
@ -325,7 +343,7 @@ func TestServer_SetSession_flow(t *testing.T) {
var wantFactors []wantFactor var wantFactors []wantFactor
// create new, empty session // create new, empty session
createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{Domain: Tester.Config.ExternalDomain})
require.NoError(t, err) require.NoError(t, err)
verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, wantFactors...) verifyCurrentSession(t, createResp.GetSessionId(), createResp.GetSessionToken(), createResp.GetDetails().GetSequence(), time.Minute, nil, wantFactors...)
sessionToken := createResp.GetSessionToken() sessionToken := createResp.GetSessionToken()

View File

@ -20,11 +20,11 @@ func (s *Server) RegisterPasskey(ctx context.Context, req *user.RegisterPasskeyR
) )
if code := req.GetCode(); code != nil { if code := req.GetCode(); code != nil {
return passkeyRegistrationDetailsToPb( return passkeyRegistrationDetailsToPb(
s.command.RegisterUserPasskeyWithCode(ctx, req.GetUserId(), resourceOwner, authenticator, code.Id, code.Code, s.userCodeAlg), s.command.RegisterUserPasskeyWithCode(ctx, req.GetUserId(), resourceOwner, authenticator, code.Id, code.Code, req.GetDomain(), s.userCodeAlg),
) )
} }
return passkeyRegistrationDetailsToPb( return passkeyRegistrationDetailsToPb(
s.command.RegisterUserPasskey(ctx, req.GetUserId(), resourceOwner, authenticator), s.command.RegisterUserPasskey(ctx, req.GetUserId(), resourceOwner, req.GetDomain(), authenticator),
) )
} }

View File

@ -12,7 +12,7 @@ import (
func (s *Server) RegisterU2F(ctx context.Context, req *user.RegisterU2FRequest) (*user.RegisterU2FResponse, error) { func (s *Server) RegisterU2F(ctx context.Context, req *user.RegisterU2FRequest) (*user.RegisterU2FResponse, error) {
return u2fRegistrationDetailsToPb( return u2fRegistrationDetailsToPb(
s.command.RegisterUserU2F(ctx, req.GetUserId(), authz.GetCtxData(ctx).ResourceOwner), s.command.RegisterUserU2F(ctx, req.GetUserId(), authz.GetCtxData(ctx).ResourceOwner, req.GetDomain()),
) )
} }

View File

@ -157,7 +157,7 @@ func (s *SessionCommands) commands(ctx context.Context) (string, []eventstore.Co
return token, s.sessionWriteModel.commands, nil return token, s.sessionWriteModel.commands, nil
} }
func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, metadata map[string][]byte) (set *SessionChanged, err error) { func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, sessionDomain string, metadata map[string][]byte) (set *SessionChanged, err error) {
sessionID, err := c.idGenerator.Next() sessionID, err := c.idGenerator.Next()
if err != nil { if err != nil {
return nil, err return nil, err
@ -167,8 +167,8 @@ func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, met
if err != nil { if err != nil {
return nil, err return nil, err
} }
sessionWriteModel.Start(ctx, sessionDomain)
cmd := c.NewSessionCommands(cmds, sessionWriteModel) cmd := c.NewSessionCommands(cmds, sessionWriteModel)
cmd.sessionWriteModel.Start(ctx)
return c.updateSession(ctx, cmd, metadata) return c.updateSession(ctx, cmd, metadata)
} }

View File

@ -16,6 +16,7 @@ type PasskeyChallengeModel struct {
Challenge string Challenge string
AllowedCrentialIDs [][]byte AllowedCrentialIDs [][]byte
UserVerification domain.UserVerificationRequirement UserVerification domain.UserVerificationRequirement
RPID string
} }
func (p *PasskeyChallengeModel) WebAuthNLogin(human *domain.Human, credentialAssertionData []byte) (*domain.WebAuthNLogin, error) { func (p *PasskeyChallengeModel) WebAuthNLogin(human *domain.Human, credentialAssertionData []byte) (*domain.WebAuthNLogin, error) {
@ -28,6 +29,7 @@ func (p *PasskeyChallengeModel) WebAuthNLogin(human *domain.Human, credentialAss
Challenge: p.Challenge, Challenge: p.Challenge,
AllowedCredentialIDs: p.AllowedCrentialIDs, AllowedCredentialIDs: p.AllowedCrentialIDs,
UserVerification: p.UserVerification, UserVerification: p.UserVerification,
RPID: p.RPID,
}, nil }, nil
} }
@ -41,6 +43,7 @@ type SessionWriteModel struct {
IntentCheckedAt time.Time IntentCheckedAt time.Time
PasskeyCheckedAt time.Time PasskeyCheckedAt time.Time
Metadata map[string][]byte Metadata map[string][]byte
Domain string
State domain.SessionState State domain.SessionState
PasskeyChallenge *PasskeyChallengeModel PasskeyChallenge *PasskeyChallengeModel
@ -109,6 +112,7 @@ func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder {
} }
func (wm *SessionWriteModel) reduceAdded(e *session.AddedEvent) { func (wm *SessionWriteModel) reduceAdded(e *session.AddedEvent) {
wm.Domain = e.Domain
wm.State = domain.SessionStateActive wm.State = domain.SessionStateActive
} }
@ -130,6 +134,7 @@ func (wm *SessionWriteModel) reducePasskeyChallenged(e *session.PasskeyChallenge
Challenge: e.Challenge, Challenge: e.Challenge,
AllowedCrentialIDs: e.AllowedCrentialIDs, AllowedCrentialIDs: e.AllowedCrentialIDs,
UserVerification: e.UserVerification, UserVerification: e.UserVerification,
RPID: wm.Domain,
} }
} }
@ -146,8 +151,10 @@ func (wm *SessionWriteModel) reduceTerminate() {
wm.State = domain.SessionStateTerminated wm.State = domain.SessionStateTerminated
} }
func (wm *SessionWriteModel) Start(ctx context.Context) { func (wm *SessionWriteModel) Start(ctx context.Context, domain string) {
wm.commands = append(wm.commands, session.NewAddedEvent(ctx, wm.aggregate)) wm.commands = append(wm.commands, session.NewAddedEvent(ctx, wm.aggregate, domain))
// set the domain so checks can use it
wm.Domain = domain
} }
func (wm *SessionWriteModel) UserChecked(ctx context.Context, userID string, checkedAt time.Time) error { func (wm *SessionWriteModel) UserChecked(ctx context.Context, userID string, checkedAt time.Time) error {

View File

@ -43,7 +43,7 @@ func (c *Commands) CreatePasskeyChallenge(userVerification domain.UserVerificati
if err != nil { if err != nil {
return err return err
} }
webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, humanPasskeys.human, userVerification, humanPasskeys.tokens...) webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, humanPasskeys.human, userVerification, cmd.sessionWriteModel.Domain, humanPasskeys.tokens...)
if err != nil { if err != nil {
return err return err
} }

View File

@ -84,7 +84,7 @@ func TestSessionCommands_getHumanPasskeys(t *testing.T) {
expectFilter(eventFromEventPusher( expectFilter(eventFromEventPusher(
user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush( user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush(
context.Background(), &org.NewAggregate("org1").Aggregate, user.HumanPasswordlessTokenAddedType, context.Background(), &org.NewAggregate("org1").Aggregate, user.HumanPasswordlessTokenAddedType,
), "111", "challenge"), ), "111", "challenge", "rpID"),
)), )),
), ),
sessionWriteModel: &SessionWriteModel{ sessionWriteModel: &SessionWriteModel{
@ -112,6 +112,7 @@ func TestSessionCommands_getHumanPasskeys(t *testing.T) {
WebAuthNTokenID: "111", WebAuthNTokenID: "111",
State: domain.MFAStateNotReady, State: domain.MFAStateNotReady,
Challenge: "challenge", Challenge: "challenge",
RPID: "rpID",
}}, }},
}, },
err: nil, err: nil,

View File

@ -147,6 +147,7 @@ func TestCommands_CreateSession(t *testing.T) {
type args struct { type args struct {
ctx context.Context ctx context.Context
checks []SessionCommand checks []SessionCommand
domain string
metadata map[string][]byte metadata map[string][]byte
} }
type res struct { type res struct {
@ -194,7 +195,7 @@ func TestCommands_CreateSession(t *testing.T) {
expectFilter(), expectFilter(),
expectPush( expectPush(
eventPusherToEvents( eventPusherToEvents(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate), session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, ""),
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID", "tokenID",
), ),
@ -218,6 +219,39 @@ func TestCommands_CreateSession(t *testing.T) {
}, },
}, },
}, },
{
"empty session with domain",
fields{
idGenerator: mock.NewIDGeneratorExpectIDs(t, "sessionID"),
eventstore: eventstoreExpect(t,
expectFilter(),
expectPush(
eventPusherToEvents(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID",
),
),
),
),
tokenCreator: func(sessionID string) (string, string, error) {
return "tokenID",
"token",
nil
},
},
args{
ctx: authz.NewMockContext("", "org1", ""),
domain: "domain.tld",
},
res{
want: &SessionChanged{
ObjectDetails: &domain.ObjectDetails{ResourceOwner: "org1"},
ID: "sessionID",
NewToken: "token",
},
},
},
// the rest is tested in the Test_updateSession // the rest is tested in the Test_updateSession
} }
for _, tt := range tests { for _, tt := range tests {
@ -227,7 +261,7 @@ func TestCommands_CreateSession(t *testing.T) {
idGenerator: tt.fields.idGenerator, idGenerator: tt.fields.idGenerator,
sessionTokenCreator: tt.fields.tokenCreator, sessionTokenCreator: tt.fields.tokenCreator,
} }
got, err := c.CreateSession(tt.args.ctx, tt.args.checks, tt.args.metadata) got, err := c.CreateSession(tt.args.ctx, tt.args.checks, tt.args.domain, tt.args.metadata)
require.ErrorIs(t, err, tt.res.err) require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.want, got) assert.Equal(t, tt.res.want, got)
}) })
@ -276,7 +310,7 @@ func TestCommands_UpdateSession(t *testing.T) {
eventstore: eventstoreExpect(t, eventstore: eventstoreExpect(t,
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
eventFromEventPusher( eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")), "tokenID")),
@ -301,7 +335,7 @@ func TestCommands_UpdateSession(t *testing.T) {
eventstore: eventstoreExpect(t, eventstore: eventstoreExpect(t,
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
eventFromEventPusher( eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")), "tokenID")),
@ -739,7 +773,7 @@ func TestCommands_TerminateSession(t *testing.T) {
eventstore: eventstoreExpect(t, eventstore: eventstoreExpect(t,
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
eventFromEventPusher( eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")), "tokenID")),
@ -764,7 +798,7 @@ func TestCommands_TerminateSession(t *testing.T) {
eventstore: eventstoreExpect(t, eventstore: eventstoreExpect(t,
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
eventFromEventPusher( eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")), "tokenID")),
@ -793,7 +827,7 @@ func TestCommands_TerminateSession(t *testing.T) {
eventstore: eventstoreExpect(t, eventstore: eventstoreExpect(t,
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
eventFromEventPusher( eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID"), "tokenID"),
@ -824,7 +858,7 @@ func TestCommands_TerminateSession(t *testing.T) {
eventstore: eventstoreExpect(t, eventstore: eventstoreExpect(t,
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)), session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
eventFromEventPusher( eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID"), "tokenID"),

View File

@ -140,6 +140,7 @@ func writeModelToWebAuthN(wm *HumanWebAuthNWriteModel) *domain.WebAuthNToken {
SignCount: wm.SignCount, SignCount: wm.SignCount,
WebAuthNTokenName: wm.WebAuthNTokenName, WebAuthNTokenName: wm.WebAuthNTokenName,
State: wm.State, State: wm.State,
RPID: wm.RPID,
} }
} }

View File

@ -27,8 +27,8 @@ func (c *Commands) getHumanU2FTokens(ctx context.Context, userID, resourceowner
return readModelToU2FTokens(tokenReadModel), nil return readModelToU2FTokens(tokenReadModel), nil
} }
func (c *Commands) getHumanPasswordlessTokens(ctx context.Context, userID, resourceowner string) ([]*domain.WebAuthNToken, error) { func (c *Commands) getHumanPasswordlessTokens(ctx context.Context, userID, resourceOwner string) ([]*domain.WebAuthNToken, error) {
tokenReadModel := NewHumanPasswordlessTokensReadModel(userID, resourceowner) tokenReadModel := NewHumanPasswordlessTokensReadModel(userID, resourceOwner)
err := c.eventstore.FilterToQueryReducer(ctx, tokenReadModel) err := c.eventstore.FilterToQueryReducer(ctx, tokenReadModel)
if err != nil { if err != nil {
return nil, err return nil, err
@ -82,12 +82,12 @@ func (c *Commands) HumanAddU2FSetup(ctx context.Context, userID, resourceowner s
if err != nil { if err != nil {
return nil, err return nil, err
} }
addWebAuthN, userAgg, webAuthN, err := c.addHumanWebAuthN(ctx, userID, resourceowner, isLoginUI, u2fTokens, domain.AuthenticatorAttachmentUnspecified, domain.UserVerificationRequirementDiscouraged) addWebAuthN, userAgg, webAuthN, err := c.addHumanWebAuthN(ctx, userID, resourceowner, "", u2fTokens, domain.AuthenticatorAttachmentUnspecified, domain.UserVerificationRequirementDiscouraged)
if err != nil { if err != nil {
return nil, err return nil, err
} }
events, err := c.eventstore.Push(ctx, usr_repo.NewHumanU2FAddedEvent(ctx, userAgg, addWebAuthN.WebauthNTokenID, webAuthN.Challenge)) events, err := c.eventstore.Push(ctx, usr_repo.NewHumanU2FAddedEvent(ctx, userAgg, addWebAuthN.WebauthNTokenID, webAuthN.Challenge, ""))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -108,12 +108,12 @@ func (c *Commands) HumanAddPasswordlessSetup(ctx context.Context, userID, resour
if err != nil { if err != nil {
return nil, err return nil, err
} }
addWebAuthN, userAgg, webAuthN, err := c.addHumanWebAuthN(ctx, userID, resourceowner, isLoginUI, passwordlessTokens, authenticatorPlatform, domain.UserVerificationRequirementRequired) addWebAuthN, userAgg, webAuthN, err := c.addHumanWebAuthN(ctx, userID, resourceowner, "", passwordlessTokens, authenticatorPlatform, domain.UserVerificationRequirementRequired)
if err != nil { if err != nil {
return nil, err return nil, err
} }
events, err := c.eventstore.Push(ctx, usr_repo.NewHumanPasswordlessAddedEvent(ctx, userAgg, addWebAuthN.WebauthNTokenID, webAuthN.Challenge)) events, err := c.eventstore.Push(ctx, usr_repo.NewHumanPasswordlessAddedEvent(ctx, userAgg, addWebAuthN.WebauthNTokenID, webAuthN.Challenge, ""))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -137,7 +137,7 @@ func (c *Commands) HumanAddPasswordlessSetupInitCode(ctx context.Context, userID
return c.HumanAddPasswordlessSetup(ctx, userID, resourceowner, true, preferredPlatformType) return c.HumanAddPasswordlessSetup(ctx, userID, resourceowner, true, preferredPlatformType)
} }
func (c *Commands) addHumanWebAuthN(ctx context.Context, userID, resourceowner string, isLoginUI bool, tokens []*domain.WebAuthNToken, authenticatorPlatform domain.AuthenticatorAttachment, userVerification domain.UserVerificationRequirement) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) { func (c *Commands) addHumanWebAuthN(ctx context.Context, userID, resourceowner, rpID string, tokens []*domain.WebAuthNToken, authenticatorPlatform domain.AuthenticatorAttachment, userVerification domain.UserVerificationRequirement) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) {
if userID == "" { if userID == "" {
return nil, nil, nil, caos_errs.ThrowPreconditionFailed(nil, "COMMAND-3M0od", "Errors.IDMissing") return nil, nil, nil, caos_errs.ThrowPreconditionFailed(nil, "COMMAND-3M0od", "Errors.IDMissing")
} }
@ -157,7 +157,7 @@ func (c *Commands) addHumanWebAuthN(ctx context.Context, userID, resourceowner s
if accountName == "" { if accountName == "" {
accountName = string(user.EmailAddress) accountName = string(user.EmailAddress)
} }
webAuthN, err := c.webauthnConfig.BeginRegistration(ctx, user, accountName, authenticatorPlatform, userVerification, isLoginUI, tokens...) webAuthN, err := c.webauthnConfig.BeginRegistration(ctx, user, accountName, authenticatorPlatform, userVerification, rpID, tokens...)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -343,7 +343,7 @@ func (c *Commands) beginWebAuthNLogin(ctx context.Context, userID, resourceOwner
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, human, userVerification, tokens...) webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, human, userVerification, "", tokens...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -21,6 +21,7 @@ type HumanWebAuthNWriteModel struct {
AAGUID []byte AAGUID []byte
SignCount uint32 SignCount uint32
WebAuthNTokenName string WebAuthNTokenName string
RPID string
State domain.MFAState State domain.MFAState
} }
@ -113,6 +114,7 @@ func (wm *HumanWebAuthNWriteModel) Reduce() error {
func (wm *HumanWebAuthNWriteModel) appendAddedEvent(e *user.HumanWebAuthNAddedEvent) { func (wm *HumanWebAuthNWriteModel) appendAddedEvent(e *user.HumanWebAuthNAddedEvent) {
wm.WebauthNTokenID = e.WebAuthNTokenID wm.WebauthNTokenID = e.WebAuthNTokenID
wm.Challenge = e.Challenge wm.Challenge = e.Challenge
wm.RPID = e.RPID
wm.State = domain.MFAStateNotReady wm.State = domain.MFAStateNotReady
} }

View File

@ -16,22 +16,22 @@ import (
// RegisterUserPasskey creates a passkey registration for the current authenticated user. // RegisterUserPasskey creates a passkey registration for the current authenticated user.
// UserID, usually taken from the request is compared against the user ID in the context. // UserID, usually taken from the request is compared against the user ID in the context.
func (c *Commands) RegisterUserPasskey(ctx context.Context, userID, resourceOwner string, authenticator domain.AuthenticatorAttachment) (*domain.WebAuthNRegistrationDetails, error) { func (c *Commands) RegisterUserPasskey(ctx context.Context, userID, resourceOwner, rpID string, authenticator domain.AuthenticatorAttachment) (*domain.WebAuthNRegistrationDetails, error) {
if err := authz.UserIDInCTX(ctx, userID); err != nil { if err := authz.UserIDInCTX(ctx, userID); err != nil {
return nil, err return nil, err
} }
return c.registerUserPasskey(ctx, userID, resourceOwner, authenticator) return c.registerUserPasskey(ctx, userID, resourceOwner, rpID, authenticator)
} }
// RegisterUserPasskeyWithCode registers a new passkey for a unauthenticated user id. // RegisterUserPasskeyWithCode registers a new passkey for a unauthenticated user id.
// The resource is protected by the code, identified by the codeID. // The resource is protected by the code, identified by the codeID.
func (c *Commands) RegisterUserPasskeyWithCode(ctx context.Context, userID, resourceOwner string, authenticator domain.AuthenticatorAttachment, codeID, code string, alg crypto.EncryptionAlgorithm) (*domain.WebAuthNRegistrationDetails, error) { func (c *Commands) RegisterUserPasskeyWithCode(ctx context.Context, userID, resourceOwner string, authenticator domain.AuthenticatorAttachment, codeID, code, rpID string, alg crypto.EncryptionAlgorithm) (*domain.WebAuthNRegistrationDetails, error) {
event, err := c.verifyUserPasskeyCode(ctx, userID, resourceOwner, codeID, code, alg) event, err := c.verifyUserPasskeyCode(ctx, userID, resourceOwner, codeID, code, alg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c.registerUserPasskey(ctx, userID, resourceOwner, authenticator, event) return c.registerUserPasskey(ctx, userID, resourceOwner, rpID, authenticator, event)
} }
type eventCallback func(context.Context, *eventstore.Aggregate) eventstore.Command type eventCallback func(context.Context, *eventstore.Aggregate) eventstore.Command
@ -63,25 +63,25 @@ func (c *Commands) verifyUserPasskeyCodeFailed(ctx context.Context, wm *HumanPas
logging.WithFields("userID", userAgg.ID).OnError(err).Error("RegisterUserPasskeyWithCode push failed") logging.WithFields("userID", userAgg.ID).OnError(err).Error("RegisterUserPasskeyWithCode push failed")
} }
func (c *Commands) registerUserPasskey(ctx context.Context, userID, resourceOwner string, authenticator domain.AuthenticatorAttachment, events ...eventCallback) (*domain.WebAuthNRegistrationDetails, error) { func (c *Commands) registerUserPasskey(ctx context.Context, userID, resourceOwner, rpID string, authenticator domain.AuthenticatorAttachment, events ...eventCallback) (*domain.WebAuthNRegistrationDetails, error) {
wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, userID, resourceOwner, authenticator) wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, userID, resourceOwner, rpID, authenticator)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c.pushUserPasskey(ctx, wm, userAgg, webAuthN, events...) return c.pushUserPasskey(ctx, wm, userAgg, webAuthN, events...)
} }
func (c *Commands) createUserPasskey(ctx context.Context, userID, resourceOwner string, authenticator domain.AuthenticatorAttachment) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) { func (c *Commands) createUserPasskey(ctx context.Context, userID, resourceOwner, rpID string, authenticator domain.AuthenticatorAttachment) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) {
passwordlessTokens, err := c.getHumanPasswordlessTokens(ctx, userID, resourceOwner) passwordlessTokens, err := c.getHumanPasswordlessTokens(ctx, userID, resourceOwner)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
return c.addHumanWebAuthN(ctx, userID, resourceOwner, false, passwordlessTokens, authenticator, domain.UserVerificationRequirementRequired) return c.addHumanWebAuthN(ctx, userID, resourceOwner, rpID, passwordlessTokens, authenticator, domain.UserVerificationRequirementRequired)
} }
func (c *Commands) pushUserPasskey(ctx context.Context, wm *HumanWebAuthNWriteModel, userAgg *eventstore.Aggregate, webAuthN *domain.WebAuthNToken, events ...eventCallback) (*domain.WebAuthNRegistrationDetails, error) { func (c *Commands) pushUserPasskey(ctx context.Context, wm *HumanWebAuthNWriteModel, userAgg *eventstore.Aggregate, webAuthN *domain.WebAuthNToken, events ...eventCallback) (*domain.WebAuthNRegistrationDetails, error) {
cmds := make([]eventstore.Command, len(events)+1) cmds := make([]eventstore.Command, len(events)+1)
cmds[0] = user.NewHumanPasswordlessAddedEvent(ctx, userAgg, wm.WebauthNTokenID, webAuthN.Challenge) cmds[0] = user.NewHumanPasswordlessAddedEvent(ctx, userAgg, wm.WebauthNTokenID, webAuthN.Challenge, webAuthN.RPID)
for i, event := range events { for i, event := range events {
cmds[i+1] = event(ctx, userAgg) cmds[i+1] = event(ctx, userAgg)
} }

View File

@ -40,6 +40,7 @@ func TestCommands_RegisterUserPasskey(t *testing.T) {
type args struct { type args struct {
userID string userID string
resourceOwner string resourceOwner string
rpID string
authenticator domain.AuthenticatorAttachment authenticator domain.AuthenticatorAttachment
} }
tests := []struct { tests := []struct {
@ -121,7 +122,7 @@ func TestCommands_RegisterUserPasskey(t *testing.T) {
idGenerator: tt.fields.idGenerator, idGenerator: tt.fields.idGenerator,
webauthnConfig: webauthnConfig, webauthnConfig: webauthnConfig,
} }
_, err := c.RegisterUserPasskey(ctx, tt.args.userID, tt.args.resourceOwner, tt.args.authenticator) _, err := c.RegisterUserPasskey(ctx, tt.args.userID, tt.args.resourceOwner, tt.args.rpID, tt.args.authenticator)
require.ErrorIs(t, err, tt.wantErr) require.ErrorIs(t, err, tt.wantErr)
// successful case can't be tested due to random challenge. // successful case can't be tested due to random challenge.
}) })
@ -148,6 +149,7 @@ func TestCommands_RegisterUserPasskeyWithCode(t *testing.T) {
type args struct { type args struct {
userID string userID string
resourceOwner string resourceOwner string
rpID string
authenticator domain.AuthenticatorAttachment authenticator domain.AuthenticatorAttachment
codeID string codeID string
code string code string
@ -222,7 +224,7 @@ func TestCommands_RegisterUserPasskeyWithCode(t *testing.T) {
idGenerator: tt.fields.idGenerator, idGenerator: tt.fields.idGenerator,
webauthnConfig: webauthnConfig, webauthnConfig: webauthnConfig,
} }
_, err := c.RegisterUserPasskeyWithCode(ctx, tt.args.userID, tt.args.resourceOwner, tt.args.authenticator, tt.args.codeID, tt.args.code, alg) _, err := c.RegisterUserPasskeyWithCode(ctx, tt.args.userID, tt.args.resourceOwner, tt.args.authenticator, tt.args.codeID, tt.args.code, tt.args.rpID, alg)
require.ErrorIs(t, err, tt.wantErr) require.ErrorIs(t, err, tt.wantErr)
// successful case can't be tested due to random challenge. // successful case can't be tested due to random challenge.
}) })
@ -376,7 +378,7 @@ func TestCommands_pushUserPasskey(t *testing.T) {
expectFilter(eventFromEventPusher( expectFilter(eventFromEventPusher(
user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush( user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush(
ctx, &org.NewAggregate("org1").Aggregate, user.HumanPasswordlessTokenAddedType, ctx, &org.NewAggregate("org1").Aggregate, user.HumanPasswordlessTokenAddedType,
), "111", "challenge"), ), "111", "challenge", "rpID"),
)), )),
} }
@ -394,7 +396,7 @@ func TestCommands_pushUserPasskey(t *testing.T) {
expectPush: func(challenge string) expect { expectPush: func(challenge string) expect {
return expectPushFailed(io.ErrClosedPipe, []*repository.Event{eventFromEventPusher( return expectPushFailed(io.ErrClosedPipe, []*repository.Event{eventFromEventPusher(
user.NewHumanPasswordlessAddedEvent(ctx, user.NewHumanPasswordlessAddedEvent(ctx,
userAgg, "123", challenge, userAgg, "123", challenge, "rpID",
), ),
)}) )})
}, },
@ -406,7 +408,7 @@ func TestCommands_pushUserPasskey(t *testing.T) {
expectPush: func(challenge string) expect { expectPush: func(challenge string) expect {
return expectPush([]*repository.Event{eventFromEventPusher( return expectPush([]*repository.Event{eventFromEventPusher(
user.NewHumanPasswordlessAddedEvent(ctx, user.NewHumanPasswordlessAddedEvent(ctx,
userAgg, "123", challenge, userAgg, "123", challenge, "rpID",
), ),
)}) )})
}, },
@ -418,7 +420,7 @@ func TestCommands_pushUserPasskey(t *testing.T) {
return expectPush([]*repository.Event{ return expectPush([]*repository.Event{
eventFromEventPusher( eventFromEventPusher(
user.NewHumanPasswordlessAddedEvent(ctx, user.NewHumanPasswordlessAddedEvent(ctx,
userAgg, "123", challenge, userAgg, "123", challenge, "rpID",
), ),
), ),
eventFromEventPusher( eventFromEventPusher(
@ -440,7 +442,7 @@ func TestCommands_pushUserPasskey(t *testing.T) {
webauthnConfig: webauthnConfig, webauthnConfig: webauthnConfig,
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "123"), idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "123"),
} }
wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, "user1", "org1", domain.AuthenticatorAttachmentCrossPlattform) wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, "user1", "org1", "rpID", domain.AuthenticatorAttachmentCrossPlattform)
require.NoError(t, err) require.NoError(t, err)
c.eventstore = eventstoreExpect(t, tt.expectPush(webAuthN.Challenge)) c.eventstore = eventstoreExpect(t, tt.expectPush(webAuthN.Challenge))

View File

@ -9,31 +9,31 @@ import (
"github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/repository/user"
) )
func (c *Commands) RegisterUserU2F(ctx context.Context, userID, resourceOwner string) (*domain.WebAuthNRegistrationDetails, error) { func (c *Commands) RegisterUserU2F(ctx context.Context, userID, resourceOwner, rpID string) (*domain.WebAuthNRegistrationDetails, error) {
if err := authz.UserIDInCTX(ctx, userID); err != nil { if err := authz.UserIDInCTX(ctx, userID); err != nil {
return nil, err return nil, err
} }
return c.registerUserU2F(ctx, userID, resourceOwner) return c.registerUserU2F(ctx, userID, resourceOwner, rpID)
} }
func (c *Commands) registerUserU2F(ctx context.Context, userID, resourceOwner string) (*domain.WebAuthNRegistrationDetails, error) { func (c *Commands) registerUserU2F(ctx context.Context, userID, resourceOwner, rpID string) (*domain.WebAuthNRegistrationDetails, error) {
wm, userAgg, webAuthN, err := c.createUserU2F(ctx, userID, resourceOwner) wm, userAgg, webAuthN, err := c.createUserU2F(ctx, userID, resourceOwner, rpID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c.pushUserU2F(ctx, wm, userAgg, webAuthN) return c.pushUserU2F(ctx, wm, userAgg, webAuthN)
} }
func (c *Commands) createUserU2F(ctx context.Context, userID, resourceOwner string) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) { func (c *Commands) createUserU2F(ctx context.Context, userID, resourceOwner, rpID string) (*HumanWebAuthNWriteModel, *eventstore.Aggregate, *domain.WebAuthNToken, error) {
tokens, err := c.getHumanU2FTokens(ctx, userID, resourceOwner) tokens, err := c.getHumanU2FTokens(ctx, userID, resourceOwner)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
return c.addHumanWebAuthN(ctx, userID, resourceOwner, false, tokens, domain.AuthenticatorAttachmentUnspecified, domain.UserVerificationRequirementRequired) return c.addHumanWebAuthN(ctx, userID, resourceOwner, rpID, tokens, domain.AuthenticatorAttachmentUnspecified, domain.UserVerificationRequirementRequired)
} }
func (c *Commands) pushUserU2F(ctx context.Context, wm *HumanWebAuthNWriteModel, userAgg *eventstore.Aggregate, webAuthN *domain.WebAuthNToken) (*domain.WebAuthNRegistrationDetails, error) { func (c *Commands) pushUserU2F(ctx context.Context, wm *HumanWebAuthNWriteModel, userAgg *eventstore.Aggregate, webAuthN *domain.WebAuthNToken) (*domain.WebAuthNRegistrationDetails, error) {
cmd := user.NewHumanU2FAddedEvent(ctx, userAgg, wm.WebauthNTokenID, webAuthN.Challenge) cmd := user.NewHumanU2FAddedEvent(ctx, userAgg, wm.WebauthNTokenID, webAuthN.Challenge, webAuthN.RPID)
err := c.pushAppendAndReduce(ctx, wm, cmd) err := c.pushAppendAndReduce(ctx, wm, cmd)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -37,6 +37,7 @@ func TestCommands_RegisterUserU2F(t *testing.T) {
type args struct { type args struct {
userID string userID string
resourceOwner string resourceOwner string
rpID string
} }
tests := []struct { tests := []struct {
name string name string
@ -114,7 +115,7 @@ func TestCommands_RegisterUserU2F(t *testing.T) {
idGenerator: tt.fields.idGenerator, idGenerator: tt.fields.idGenerator,
webauthnConfig: webauthnConfig, webauthnConfig: webauthnConfig,
} }
_, err := c.RegisterUserU2F(ctx, tt.args.userID, tt.args.resourceOwner) _, err := c.RegisterUserU2F(ctx, tt.args.userID, tt.args.resourceOwner, tt.args.rpID)
require.ErrorIs(t, err, tt.wantErr) require.ErrorIs(t, err, tt.wantErr)
// successful case can't be tested due to random challenge. // successful case can't be tested due to random challenge.
}) })
@ -160,7 +161,7 @@ func TestCommands_pushUserU2F(t *testing.T) {
expectFilter(eventFromEventPusher( expectFilter(eventFromEventPusher(
user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush( user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush(
ctx, &org.NewAggregate("org1").Aggregate, user.HumanPasswordlessTokenAddedType, ctx, &org.NewAggregate("org1").Aggregate, user.HumanPasswordlessTokenAddedType,
), "111", "challenge"), ), "111", "challenge", "rpID"),
)), )),
} }
@ -174,7 +175,7 @@ func TestCommands_pushUserU2F(t *testing.T) {
expectPush: func(challenge string) expect { expectPush: func(challenge string) expect {
return expectPushFailed(io.ErrClosedPipe, []*repository.Event{eventFromEventPusher( return expectPushFailed(io.ErrClosedPipe, []*repository.Event{eventFromEventPusher(
user.NewHumanU2FAddedEvent(ctx, user.NewHumanU2FAddedEvent(ctx,
userAgg, "123", challenge, userAgg, "123", challenge, "rpID",
), ),
)}) )})
}, },
@ -185,7 +186,7 @@ func TestCommands_pushUserU2F(t *testing.T) {
expectPush: func(challenge string) expect { expectPush: func(challenge string) expect {
return expectPush([]*repository.Event{eventFromEventPusher( return expectPush([]*repository.Event{eventFromEventPusher(
user.NewHumanU2FAddedEvent(ctx, user.NewHumanU2FAddedEvent(ctx,
userAgg, "123", challenge, userAgg, "123", challenge, "rpID",
), ),
)}) )})
}, },
@ -198,7 +199,7 @@ func TestCommands_pushUserU2F(t *testing.T) {
webauthnConfig: webauthnConfig, webauthnConfig: webauthnConfig,
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "123"), idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "123"),
} }
wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, "user1", "org1", domain.AuthenticatorAttachmentCrossPlattform) wm, userAgg, webAuthN, err := c.createUserPasskey(ctx, "user1", "org1", "rpID", domain.AuthenticatorAttachmentCrossPlattform)
require.NoError(t, err) require.NoError(t, err)
c.eventstore = eventstoreExpect(t, tt.expectPush(webAuthN.Challenge)) c.eventstore = eventstoreExpect(t, tt.expectPush(webAuthN.Challenge))

View File

@ -23,6 +23,7 @@ type WebAuthNToken struct {
AAGUID []byte AAGUID []byte
SignCount uint32 SignCount uint32
WebAuthNTokenName string WebAuthNTokenName string
RPID string
} }
type WebAuthNLogin struct { type WebAuthNLogin struct {
@ -32,6 +33,7 @@ type WebAuthNLogin struct {
Challenge string Challenge string
AllowedCredentialIDs [][]byte AllowedCredentialIDs [][]byte
UserVerification UserVerificationRequirement UserVerification UserVerificationRequirement
RPID string
} }
type UserVerificationRequirement int32 type UserVerificationRequirement int32

View File

@ -89,6 +89,7 @@ func (s *Tester) RegisterUserPasskey(ctx context.Context, userID string) {
pkr, err := s.Client.UserV2.RegisterPasskey(ctx, &user.RegisterPasskeyRequest{ pkr, err := s.Client.UserV2.RegisterPasskey(ctx, &user.RegisterPasskeyRequest{
UserId: userID, UserId: userID,
Code: reg.GetCode(), Code: reg.GetCode(),
Domain: s.Config.ExternalDomain,
}) })
logging.OnError(err).Fatal("create user passkey") logging.OnError(err).Fatal("create user passkey")
attestationResponse, err := s.WebAuthN.CreateAttestationResponse(pkr.GetPublicKeyCredentialCreationOptions()) attestationResponse, err := s.WebAuthN.CreateAttestationResponse(pkr.GetPublicKeyCredentialCreationOptions())

View File

@ -14,7 +14,7 @@ import (
) )
const ( const (
SessionsProjectionTable = "projections.sessions2" SessionsProjectionTable = "projections.sessions3"
SessionColumnID = "id" SessionColumnID = "id"
SessionColumnCreationDate = "creation_date" SessionColumnCreationDate = "creation_date"
@ -22,6 +22,7 @@ const (
SessionColumnSequence = "sequence" SessionColumnSequence = "sequence"
SessionColumnState = "state" SessionColumnState = "state"
SessionColumnResourceOwner = "resource_owner" SessionColumnResourceOwner = "resource_owner"
SessionColumnDomain = "domain"
SessionColumnInstanceID = "instance_id" SessionColumnInstanceID = "instance_id"
SessionColumnCreator = "creator" SessionColumnCreator = "creator"
SessionColumnUserID = "user_id" SessionColumnUserID = "user_id"
@ -49,6 +50,7 @@ func newSessionProjection(ctx context.Context, config crdb.StatementHandlerConfi
crdb.NewColumn(SessionColumnSequence, crdb.ColumnTypeInt64), crdb.NewColumn(SessionColumnSequence, crdb.ColumnTypeInt64),
crdb.NewColumn(SessionColumnState, crdb.ColumnTypeEnum), crdb.NewColumn(SessionColumnState, crdb.ColumnTypeEnum),
crdb.NewColumn(SessionColumnResourceOwner, crdb.ColumnTypeText), crdb.NewColumn(SessionColumnResourceOwner, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnDomain, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnInstanceID, crdb.ColumnTypeText), crdb.NewColumn(SessionColumnInstanceID, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnCreator, crdb.ColumnTypeText), crdb.NewColumn(SessionColumnCreator, crdb.ColumnTypeText),
crdb.NewColumn(SessionColumnUserID, crdb.ColumnTypeText, crdb.Nullable()), crdb.NewColumn(SessionColumnUserID, crdb.ColumnTypeText, crdb.Nullable()),
@ -140,6 +142,7 @@ func (p *sessionProjection) reduceSessionAdded(event eventstore.Event) (*handler
handler.NewCol(SessionColumnCreationDate, e.CreationDate()), handler.NewCol(SessionColumnCreationDate, e.CreationDate()),
handler.NewCol(SessionColumnChangeDate, e.CreationDate()), handler.NewCol(SessionColumnChangeDate, e.CreationDate()),
handler.NewCol(SessionColumnResourceOwner, e.Aggregate().ResourceOwner), handler.NewCol(SessionColumnResourceOwner, e.Aggregate().ResourceOwner),
handler.NewCol(SessionColumnDomain, e.Domain),
handler.NewCol(SessionColumnState, domain.SessionStateActive), handler.NewCol(SessionColumnState, domain.SessionStateActive),
handler.NewCol(SessionColumnSequence, e.Sequence()), handler.NewCol(SessionColumnSequence, e.Sequence()),
handler.NewCol(SessionColumnCreator, e.User), handler.NewCol(SessionColumnCreator, e.User),

View File

@ -30,7 +30,9 @@ func TestSessionProjection_reduces(t *testing.T) {
event: getEvent(testEvent( event: getEvent(testEvent(
session.AddedType, session.AddedType,
session.AggregateType, session.AggregateType,
[]byte(`{}`), []byte(`{
"domain": "domain"
}`),
), session.AddedEventMapper), ), session.AddedEventMapper),
}, },
reduce: (&sessionProjection{}).reduceSessionAdded, reduce: (&sessionProjection{}).reduceSessionAdded,
@ -41,13 +43,14 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "INSERT INTO projections.sessions2 (id, instance_id, creation_date, change_date, resource_owner, state, sequence, creator) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", expectedStmt: "INSERT INTO projections.sessions3 (id, instance_id, creation_date, change_date, resource_owner, domain, state, sequence, creator) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
"agg-id", "agg-id",
"instance-id", "instance-id",
anyArg{}, anyArg{},
anyArg{}, anyArg{},
"ro-id", "ro-id",
"domain",
domain.SessionStateActive, domain.SessionStateActive,
uint64(15), uint64(15),
"editor-user", "editor-user",
@ -77,7 +80,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "UPDATE projections.sessions2 SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)", expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
anyArg{}, anyArg{},
anyArg{}, anyArg{},
@ -110,7 +113,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "UPDATE projections.sessions2 SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
anyArg{}, anyArg{},
anyArg{}, anyArg{},
@ -142,7 +145,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "UPDATE projections.sessions2 SET (change_date, sequence, intent_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, intent_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
anyArg{}, anyArg{},
anyArg{}, anyArg{},
@ -174,7 +177,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "UPDATE projections.sessions2 SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
anyArg{}, anyArg{},
anyArg{}, anyArg{},
@ -208,7 +211,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "UPDATE projections.sessions2 SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedStmt: "UPDATE projections.sessions3 SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
anyArg{}, anyArg{},
anyArg{}, anyArg{},
@ -240,7 +243,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "DELETE FROM projections.sessions2 WHERE (id = $1) AND (instance_id = $2)", expectedStmt: "DELETE FROM projections.sessions3 WHERE (id = $1) AND (instance_id = $2)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
"agg-id", "agg-id",
"instance-id", "instance-id",
@ -267,7 +270,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "DELETE FROM projections.sessions2 WHERE (instance_id = $1)", expectedStmt: "DELETE FROM projections.sessions3 WHERE (instance_id = $1)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
"agg-id", "agg-id",
}, },
@ -298,7 +301,7 @@ func TestSessionProjection_reduces(t *testing.T) {
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "UPDATE projections.sessions2 SET password_checked_at = $1 WHERE (user_id = $2) AND (password_checked_at < $3)", expectedStmt: "UPDATE projections.sessions3 SET password_checked_at = $1 WHERE (user_id = $2) AND (password_checked_at < $3)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
nil, nil,
"agg-id", "agg-id",

View File

@ -29,6 +29,7 @@ type Session struct {
Sequence uint64 Sequence uint64
State domain.SessionState State domain.SessionState
ResourceOwner string ResourceOwner string
Domain string
Creator string Creator string
UserFactor SessionUserFactor UserFactor SessionUserFactor
PasswordFactor SessionPasswordFactor PasswordFactor SessionPasswordFactor
@ -98,6 +99,10 @@ var (
name: projection.SessionColumnResourceOwner, name: projection.SessionColumnResourceOwner,
table: sessionsTable, table: sessionsTable,
} }
SessionColumnDomain = Column{
name: projection.SessionColumnDomain,
table: sessionsTable,
}
SessionColumnInstanceID = Column{ SessionColumnInstanceID = Column{
name: projection.SessionColumnInstanceID, name: projection.SessionColumnInstanceID,
table: sessionsTable, table: sessionsTable,
@ -211,6 +216,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
SessionColumnState.identifier(), SessionColumnState.identifier(),
SessionColumnResourceOwner.identifier(), SessionColumnResourceOwner.identifier(),
SessionColumnCreator.identifier(), SessionColumnCreator.identifier(),
SessionColumnDomain.identifier(),
SessionColumnUserID.identifier(), SessionColumnUserID.identifier(),
SessionColumnUserCheckedAt.identifier(), SessionColumnUserCheckedAt.identifier(),
LoginNameNameCol.identifier(), LoginNameNameCol.identifier(),
@ -236,6 +242,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
passkeyCheckedAt sql.NullTime passkeyCheckedAt sql.NullTime
metadata database.Map[[]byte] metadata database.Map[[]byte]
token sql.NullString token sql.NullString
sessionDomain sql.NullString
) )
err := row.Scan( err := row.Scan(
@ -246,6 +253,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
&session.State, &session.State,
&session.ResourceOwner, &session.ResourceOwner,
&session.Creator, &session.Creator,
&sessionDomain,
&userID, &userID,
&userCheckedAt, &userCheckedAt,
&loginName, &loginName,
@ -264,6 +272,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
return nil, "", errors.ThrowInternal(err, "QUERY-SAder", "Errors.Internal") return nil, "", errors.ThrowInternal(err, "QUERY-SAder", "Errors.Internal")
} }
session.Domain = sessionDomain.String
session.UserFactor.UserID = userID.String session.UserFactor.UserID = userID.String
session.UserFactor.UserCheckedAt = userCheckedAt.Time session.UserFactor.UserCheckedAt = userCheckedAt.Time
session.UserFactor.LoginName = loginName.String session.UserFactor.LoginName = loginName.String
@ -286,6 +295,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
SessionColumnState.identifier(), SessionColumnState.identifier(),
SessionColumnResourceOwner.identifier(), SessionColumnResourceOwner.identifier(),
SessionColumnCreator.identifier(), SessionColumnCreator.identifier(),
SessionColumnDomain.identifier(),
SessionColumnUserID.identifier(), SessionColumnUserID.identifier(),
SessionColumnUserCheckedAt.identifier(), SessionColumnUserCheckedAt.identifier(),
LoginNameNameCol.identifier(), LoginNameNameCol.identifier(),
@ -313,6 +323,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
intentCheckedAt sql.NullTime intentCheckedAt sql.NullTime
passkeyCheckedAt sql.NullTime passkeyCheckedAt sql.NullTime
metadata database.Map[[]byte] metadata database.Map[[]byte]
sessionDomain sql.NullString
) )
err := rows.Scan( err := rows.Scan(
@ -323,6 +334,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
&session.State, &session.State,
&session.ResourceOwner, &session.ResourceOwner,
&session.Creator, &session.Creator,
&sessionDomain,
&userID, &userID,
&userCheckedAt, &userCheckedAt,
&loginName, &loginName,
@ -337,6 +349,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-SAfeg", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-SAfeg", "Errors.Internal")
} }
session.Domain = sessionDomain.String
session.UserFactor.UserID = userID.String session.UserFactor.UserID = userID.String
session.UserFactor.UserCheckedAt = userCheckedAt.Time session.UserFactor.UserCheckedAt = userCheckedAt.Time
session.UserFactor.LoginName = loginName.String session.UserFactor.LoginName = loginName.String

View File

@ -17,45 +17,47 @@ import (
) )
var ( var (
expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions2.id,` + expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions3.id,` +
` projections.sessions2.creation_date,` + ` projections.sessions3.creation_date,` +
` projections.sessions2.change_date,` + ` projections.sessions3.change_date,` +
` projections.sessions2.sequence,` + ` projections.sessions3.sequence,` +
` projections.sessions2.state,` + ` projections.sessions3.state,` +
` projections.sessions2.resource_owner,` + ` projections.sessions3.resource_owner,` +
` projections.sessions2.creator,` + ` projections.sessions3.creator,` +
` projections.sessions2.user_id,` + ` projections.sessions3.domain,` +
` projections.sessions2.user_checked_at,` + ` projections.sessions3.user_id,` +
` projections.sessions3.user_checked_at,` +
` projections.login_names2.login_name,` + ` projections.login_names2.login_name,` +
` projections.users8_humans.display_name,` + ` projections.users8_humans.display_name,` +
` projections.sessions2.password_checked_at,` + ` projections.sessions3.password_checked_at,` +
` projections.sessions2.intent_checked_at,` + ` projections.sessions3.intent_checked_at,` +
` projections.sessions2.passkey_checked_at,` + ` projections.sessions3.passkey_checked_at,` +
` projections.sessions2.metadata,` + ` projections.sessions3.metadata,` +
` projections.sessions2.token_id` + ` projections.sessions3.token_id` +
` FROM projections.sessions2` + ` FROM projections.sessions3` +
` LEFT JOIN projections.login_names2 ON projections.sessions2.user_id = projections.login_names2.user_id AND projections.sessions2.instance_id = projections.login_names2.instance_id` + ` LEFT JOIN projections.login_names2 ON projections.sessions3.user_id = projections.login_names2.user_id AND projections.sessions3.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions2.user_id = projections.users8_humans.user_id AND projections.sessions2.instance_id = projections.users8_humans.instance_id` + ` LEFT JOIN projections.users8_humans ON projections.sessions3.user_id = projections.users8_humans.user_id AND projections.sessions3.instance_id = projections.users8_humans.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`) ` AS OF SYSTEM TIME '-1 ms'`)
expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions2.id,` + expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions3.id,` +
` projections.sessions2.creation_date,` + ` projections.sessions3.creation_date,` +
` projections.sessions2.change_date,` + ` projections.sessions3.change_date,` +
` projections.sessions2.sequence,` + ` projections.sessions3.sequence,` +
` projections.sessions2.state,` + ` projections.sessions3.state,` +
` projections.sessions2.resource_owner,` + ` projections.sessions3.resource_owner,` +
` projections.sessions2.creator,` + ` projections.sessions3.creator,` +
` projections.sessions2.user_id,` + ` projections.sessions3.domain,` +
` projections.sessions2.user_checked_at,` + ` projections.sessions3.user_id,` +
` projections.sessions3.user_checked_at,` +
` projections.login_names2.login_name,` + ` projections.login_names2.login_name,` +
` projections.users8_humans.display_name,` + ` projections.users8_humans.display_name,` +
` projections.sessions2.password_checked_at,` + ` projections.sessions3.password_checked_at,` +
` projections.sessions2.intent_checked_at,` + ` projections.sessions3.intent_checked_at,` +
` projections.sessions2.passkey_checked_at,` + ` projections.sessions3.passkey_checked_at,` +
` projections.sessions2.metadata,` + ` projections.sessions3.metadata,` +
` COUNT(*) OVER ()` + ` COUNT(*) OVER ()` +
` FROM projections.sessions2` + ` FROM projections.sessions3` +
` LEFT JOIN projections.login_names2 ON projections.sessions2.user_id = projections.login_names2.user_id AND projections.sessions2.instance_id = projections.login_names2.instance_id` + ` LEFT JOIN projections.login_names2 ON projections.sessions3.user_id = projections.login_names2.user_id AND projections.sessions3.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions2.user_id = projections.users8_humans.user_id AND projections.sessions2.instance_id = projections.users8_humans.instance_id` + ` LEFT JOIN projections.users8_humans ON projections.sessions3.user_id = projections.users8_humans.user_id AND projections.sessions3.instance_id = projections.users8_humans.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`) ` AS OF SYSTEM TIME '-1 ms'`)
sessionCols = []string{ sessionCols = []string{
@ -66,6 +68,7 @@ var (
"state", "state",
"resource_owner", "resource_owner",
"creator", "creator",
"domain",
"user_id", "user_id",
"user_checked_at", "user_checked_at",
"login_name", "login_name",
@ -85,6 +88,7 @@ var (
"state", "state",
"resource_owner", "resource_owner",
"creator", "creator",
"domain",
"user_id", "user_id",
"user_checked_at", "user_checked_at",
"login_name", "login_name",
@ -136,6 +140,7 @@ func Test_SessionsPrepare(t *testing.T) {
domain.SessionStateActive, domain.SessionStateActive,
"ro", "ro",
"creator", "creator",
"domain",
"user-id", "user-id",
testNow, testNow,
"login-name", "login-name",
@ -161,6 +166,7 @@ func Test_SessionsPrepare(t *testing.T) {
State: domain.SessionStateActive, State: domain.SessionStateActive,
ResourceOwner: "ro", ResourceOwner: "ro",
Creator: "creator", Creator: "creator",
Domain: "domain",
UserFactor: SessionUserFactor{ UserFactor: SessionUserFactor{
UserID: "user-id", UserID: "user-id",
UserCheckedAt: testNow, UserCheckedAt: testNow,
@ -199,6 +205,7 @@ func Test_SessionsPrepare(t *testing.T) {
domain.SessionStateActive, domain.SessionStateActive,
"ro", "ro",
"creator", "creator",
"domain",
"user-id", "user-id",
testNow, testNow,
"login-name", "login-name",
@ -216,6 +223,7 @@ func Test_SessionsPrepare(t *testing.T) {
domain.SessionStateActive, domain.SessionStateActive,
"ro", "ro",
"creator2", "creator2",
"domain",
"user-id2", "user-id2",
testNow, testNow,
"login-name2", "login-name2",
@ -241,6 +249,7 @@ func Test_SessionsPrepare(t *testing.T) {
State: domain.SessionStateActive, State: domain.SessionStateActive,
ResourceOwner: "ro", ResourceOwner: "ro",
Creator: "creator", Creator: "creator",
Domain: "domain",
UserFactor: SessionUserFactor{ UserFactor: SessionUserFactor{
UserID: "user-id", UserID: "user-id",
UserCheckedAt: testNow, UserCheckedAt: testNow,
@ -268,6 +277,7 @@ func Test_SessionsPrepare(t *testing.T) {
State: domain.SessionStateActive, State: domain.SessionStateActive,
ResourceOwner: "ro", ResourceOwner: "ro",
Creator: "creator2", Creator: "creator2",
Domain: "domain",
UserFactor: SessionUserFactor{ UserFactor: SessionUserFactor{
UserID: "user-id2", UserID: "user-id2",
UserCheckedAt: testNow, UserCheckedAt: testNow,
@ -359,6 +369,7 @@ func Test_SessionPrepare(t *testing.T) {
domain.SessionStateActive, domain.SessionStateActive,
"ro", "ro",
"creator", "creator",
"domain",
"user-id", "user-id",
testNow, testNow,
"login-name", "login-name",
@ -379,6 +390,7 @@ func Test_SessionPrepare(t *testing.T) {
State: domain.SessionStateActive, State: domain.SessionStateActive,
ResourceOwner: "ro", ResourceOwner: "ro",
Creator: "creator", Creator: "creator",
Domain: "domain",
UserFactor: SessionUserFactor{ UserFactor: SessionUserFactor{
UserID: "user-id", UserID: "user-id",
UserCheckedAt: testNow, UserCheckedAt: testNow,

View File

@ -26,6 +26,8 @@ const (
type AddedEvent struct { type AddedEvent struct {
eventstore.BaseEvent `json:"-"` eventstore.BaseEvent `json:"-"`
Domain string `json:"domain,omitempty"`
} }
func (e *AddedEvent) Data() interface{} { func (e *AddedEvent) Data() interface{} {
@ -38,6 +40,7 @@ func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
func NewAddedEvent(ctx context.Context, func NewAddedEvent(ctx context.Context,
aggregate *eventstore.Aggregate, aggregate *eventstore.Aggregate,
domain string,
) *AddedEvent { ) *AddedEvent {
return &AddedEvent{ return &AddedEvent{
BaseEvent: *eventstore.NewBaseEventForPush( BaseEvent: *eventstore.NewBaseEventForPush(
@ -45,6 +48,7 @@ func NewAddedEvent(ctx context.Context,
aggregate, aggregate,
AddedType, AddedType,
), ),
Domain: domain,
} }
} }

View File

@ -39,6 +39,7 @@ func NewHumanPasswordlessAddedEvent(
aggregate *eventstore.Aggregate, aggregate *eventstore.Aggregate,
webAuthNTokenID, webAuthNTokenID,
challenge string, challenge string,
rpID string,
) *HumanPasswordlessAddedEvent { ) *HumanPasswordlessAddedEvent {
return &HumanPasswordlessAddedEvent{ return &HumanPasswordlessAddedEvent{
HumanWebAuthNAddedEvent: *NewHumanWebAuthNAddedEvent( HumanWebAuthNAddedEvent: *NewHumanWebAuthNAddedEvent(
@ -49,6 +50,7 @@ func NewHumanPasswordlessAddedEvent(
), ),
webAuthNTokenID, webAuthNTokenID,
challenge, challenge,
rpID,
), ),
} }
} }

View File

@ -28,6 +28,7 @@ func NewHumanU2FAddedEvent(
aggregate *eventstore.Aggregate, aggregate *eventstore.Aggregate,
webAuthNTokenID, webAuthNTokenID,
challenge string, challenge string,
rpID string,
) *HumanU2FAddedEvent { ) *HumanU2FAddedEvent {
return &HumanU2FAddedEvent{ return &HumanU2FAddedEvent{
HumanWebAuthNAddedEvent: *NewHumanWebAuthNAddedEvent( HumanWebAuthNAddedEvent: *NewHumanWebAuthNAddedEvent(
@ -38,6 +39,7 @@ func NewHumanU2FAddedEvent(
), ),
webAuthNTokenID, webAuthNTokenID,
challenge, challenge,
rpID,
), ),
} }
} }

View File

@ -14,6 +14,7 @@ type HumanWebAuthNAddedEvent struct {
WebAuthNTokenID string `json:"webAuthNTokenId"` WebAuthNTokenID string `json:"webAuthNTokenId"`
Challenge string `json:"challenge"` Challenge string `json:"challenge"`
RPID string `json:"rpID,omitempty"`
} }
func (e *HumanWebAuthNAddedEvent) Data() interface{} { func (e *HumanWebAuthNAddedEvent) Data() interface{} {
@ -28,11 +29,13 @@ func NewHumanWebAuthNAddedEvent(
base *eventstore.BaseEvent, base *eventstore.BaseEvent,
webAuthNTokenID, webAuthNTokenID,
challenge string, challenge string,
rpID string,
) *HumanWebAuthNAddedEvent { ) *HumanWebAuthNAddedEvent {
return &HumanWebAuthNAddedEvent{ return &HumanWebAuthNAddedEvent{
BaseEvent: *base, BaseEvent: *base,
WebAuthNTokenID: webAuthNTokenID, WebAuthNTokenID: webAuthNTokenID,
Challenge: challenge, Challenge: challenge,
RPID: rpID,
} }
} }

View File

@ -7,10 +7,10 @@ import (
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
) )
func WebAuthNsToCredentials(webAuthNs []*domain.WebAuthNToken) []webauthn.Credential { func WebAuthNsToCredentials(webAuthNs []*domain.WebAuthNToken, rpID string) []webauthn.Credential {
creds := make([]webauthn.Credential, 0) creds := make([]webauthn.Credential, 0)
for _, webAuthN := range webAuthNs { for _, webAuthN := range webAuthNs {
if webAuthN.State == domain.MFAStateReady { if webAuthN.State == domain.MFAStateReady && webAuthN.RPID == rpID {
creds = append(creds, webauthn.Credential{ creds = append(creds, webauthn.Credential{
ID: webAuthN.KeyID, ID: webAuthN.KeyID,
PublicKey: webAuthN.PublicKey, PublicKey: webAuthN.PublicKey,

View File

@ -52,12 +52,12 @@ func (u *webUser) WebAuthnCredentials() []webauthn.Credential {
return u.credentials return u.credentials
} }
func (w *Config) BeginRegistration(ctx context.Context, user *domain.Human, accountName string, authType domain.AuthenticatorAttachment, userVerification domain.UserVerificationRequirement, isLoginUI bool, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNToken, error) { func (w *Config) BeginRegistration(ctx context.Context, user *domain.Human, accountName string, authType domain.AuthenticatorAttachment, userVerification domain.UserVerificationRequirement, rpID string, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNToken, error) {
webAuthNServer, err := w.serverFromContext(ctx) webAuthNServer, err := w.serverFromContext(ctx, rpID, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
creds := WebAuthNsToCredentials(webAuthNs) creds := WebAuthNsToCredentials(webAuthNs, rpID)
existing := make([]protocol.CredentialDescriptor, len(creds)) existing := make([]protocol.CredentialDescriptor, len(creds))
for i, cred := range creds { for i, cred := range creds {
existing[i] = protocol.CredentialDescriptor{ existing[i] = protocol.CredentialDescriptor{
@ -90,6 +90,7 @@ func (w *Config) BeginRegistration(ctx context.Context, user *domain.Human, acco
CredentialCreationData: cred, CredentialCreationData: cred,
AllowedCredentialIDs: sessionData.AllowedCredentialIDs, AllowedCredentialIDs: sessionData.AllowedCredentialIDs,
UserVerification: UserVerificationToDomain(sessionData.UserVerification), UserVerification: UserVerificationToDomain(sessionData.UserVerification),
RPID: webAuthNServer.Config.RPID,
}, nil }, nil
} }
@ -104,7 +105,7 @@ func (w *Config) FinishRegistration(ctx context.Context, user *domain.Human, web
return nil, caos_errs.ThrowInternal(err, "WEBAU-sEr8c", "Errors.User.WebAuthN.ErrorOnParseCredential") return nil, caos_errs.ThrowInternal(err, "WEBAU-sEr8c", "Errors.User.WebAuthN.ErrorOnParseCredential")
} }
sessionData := WebAuthNToSessionData(webAuthN) sessionData := WebAuthNToSessionData(webAuthN)
webAuthNServer, err := w.serverFromContext(ctx) webAuthNServer, err := w.serverFromContext(ctx, webAuthN.RPID, credentialData.Response.CollectedClientData.Origin)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -124,17 +125,18 @@ func (w *Config) FinishRegistration(ctx context.Context, user *domain.Human, web
webAuthN.AAGUID = credential.Authenticator.AAGUID webAuthN.AAGUID = credential.Authenticator.AAGUID
webAuthN.SignCount = credential.Authenticator.SignCount webAuthN.SignCount = credential.Authenticator.SignCount
webAuthN.WebAuthNTokenName = tokenName webAuthN.WebAuthNTokenName = tokenName
webAuthN.RPID = webAuthNServer.Config.RPID
return webAuthN, nil return webAuthN, nil
} }
func (w *Config) BeginLogin(ctx context.Context, user *domain.Human, userVerification domain.UserVerificationRequirement, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNLogin, error) { func (w *Config) BeginLogin(ctx context.Context, user *domain.Human, userVerification domain.UserVerificationRequirement, rpID string, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNLogin, error) {
webAuthNServer, err := w.serverFromContext(ctx) webAuthNServer, err := w.serverFromContext(ctx, rpID, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
assertion, sessionData, err := webAuthNServer.BeginLogin(&webUser{ assertion, sessionData, err := webAuthNServer.BeginLogin(&webUser{
Human: user, Human: user,
credentials: WebAuthNsToCredentials(webAuthNs), credentials: WebAuthNsToCredentials(webAuthNs, rpID),
}, webauthn.WithUserVerification(UserVerificationFromDomain(userVerification))) }, webauthn.WithUserVerification(UserVerificationFromDomain(userVerification)))
if err != nil { if err != nil {
return nil, caos_errs.ThrowInternal(err, "WEBAU-4G8sw", "Errors.User.WebAuthN.BeginLoginFailed") return nil, caos_errs.ThrowInternal(err, "WEBAU-4G8sw", "Errors.User.WebAuthN.BeginLoginFailed")
@ -148,6 +150,7 @@ func (w *Config) BeginLogin(ctx context.Context, user *domain.Human, userVerific
CredentialAssertionData: cred, CredentialAssertionData: cred,
AllowedCredentialIDs: sessionData.AllowedCredentialIDs, AllowedCredentialIDs: sessionData.AllowedCredentialIDs,
UserVerification: userVerification, UserVerification: userVerification,
RPID: webAuthNServer.Config.RPID,
}, nil }, nil
} }
@ -158,9 +161,9 @@ func (w *Config) FinishLogin(ctx context.Context, user *domain.Human, webAuthN *
} }
webUser := &webUser{ webUser := &webUser{
Human: user, Human: user,
credentials: WebAuthNsToCredentials(webAuthNs), credentials: WebAuthNsToCredentials(webAuthNs, webAuthN.RPID),
} }
webAuthNServer, err := w.serverFromContext(ctx) webAuthNServer, err := w.serverFromContext(ctx, webAuthN.RPID, assertionData.Response.CollectedClientData.Origin)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -175,15 +178,31 @@ func (w *Config) FinishLogin(ctx context.Context, user *domain.Human, webAuthN *
return credential.ID, credential.Authenticator.SignCount, nil return credential.ID, credential.Authenticator.SignCount, nil
} }
func (w *Config) serverFromContext(ctx context.Context) (*webauthn.WebAuthn, error) { func (w *Config) serverFromContext(ctx context.Context, id, origin string) (*webauthn.WebAuthn, error) {
instance := authz.GetInstance(ctx) config := w.config(id, origin)
webAuthn, err := webauthn.New(&webauthn.Config{ if id == "" {
RPDisplayName: w.DisplayName, config = w.configFromContext(ctx)
RPID: instance.RequestedDomain(), }
RPOrigins: []string{http.BuildOrigin(instance.RequestedHost(), w.ExternalSecure)}, webAuthn, err := webauthn.New(config)
})
if err != nil { if err != nil {
return nil, caos_errs.ThrowInternal(err, "WEBAU-UX9ta", "Errors.User.WebAuthN.ServerConfig") return nil, caos_errs.ThrowInternal(err, "WEBAU-UX9ta", "Errors.User.WebAuthN.ServerConfig")
} }
return webAuthn, nil return webAuthn, nil
} }
func (w *Config) configFromContext(ctx context.Context) *webauthn.Config {
instance := authz.GetInstance(ctx)
return &webauthn.Config{
RPDisplayName: w.DisplayName,
RPID: instance.RequestedDomain(),
RPOrigins: []string{http.BuildOrigin(instance.RequestedHost(), w.ExternalSecure)},
}
}
func (w *Config) config(id, origin string) *webauthn.Config {
return &webauthn.Config{
RPDisplayName: w.DisplayName,
RPID: id,
RPOrigins: []string{origin},
}
}

View File

@ -14,7 +14,9 @@ import (
func TestConfig_serverFromContext(t *testing.T) { func TestConfig_serverFromContext(t *testing.T) {
type args struct { type args struct {
ctx context.Context ctx context.Context
id string
origin string
} }
tests := []struct { tests := []struct {
name string name string
@ -24,12 +26,12 @@ func TestConfig_serverFromContext(t *testing.T) {
}{ }{
{ {
name: "webauthn error", name: "webauthn error",
args: args{context.Background()}, args: args{context.Background(), "", ""},
wantErr: caos_errs.ThrowInternal(nil, "WEBAU-UX9ta", "Errors.User.WebAuthN.ServerConfig"), wantErr: caos_errs.ThrowInternal(nil, "WEBAU-UX9ta", "Errors.User.WebAuthN.ServerConfig"),
}, },
{ {
name: "success", name: "success from ctx",
args: args{authz.WithRequestedDomain(context.Background(), "example.com")}, args: args{authz.WithRequestedDomain(context.Background(), "example.com"), "", ""},
want: &webauthn.WebAuthn{ want: &webauthn.WebAuthn{
Config: &webauthn.Config{ Config: &webauthn.Config{
RPDisplayName: "DisplayName", RPDisplayName: "DisplayName",
@ -38,6 +40,17 @@ func TestConfig_serverFromContext(t *testing.T) {
}, },
}, },
}, },
{
name: "success from id",
args: args{authz.WithRequestedDomain(context.Background(), "example.com"), "external.com", "https://external.com"},
want: &webauthn.WebAuthn{
Config: &webauthn.Config{
RPDisplayName: "DisplayName",
RPID: "external.com",
RPOrigins: []string{"https://external.com"},
},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -45,7 +58,7 @@ func TestConfig_serverFromContext(t *testing.T) {
DisplayName: "DisplayName", DisplayName: "DisplayName",
ExternalSecure: true, ExternalSecure: true,
} }
got, err := w.serverFromContext(tt.args.ctx) got, err := w.serverFromContext(tt.args.ctx, tt.args.id, tt.args.origin)
require.ErrorIs(t, err, tt.wantErr) require.ErrorIs(t, err, tt.wantErr)
if tt.want != nil { if tt.want != nil {
require.NotNil(t, got) require.NotNil(t, got)

View File

@ -39,6 +39,11 @@ message Session {
description: "\"custom key value list\""; description: "\"custom key value list\"";
} }
]; ];
string domain = 7 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "\"domain on which the session was created\"";
}
];
} }
message Factors { message Factors {

View File

@ -245,6 +245,11 @@ message CreateSessionRequest{
} }
]; ];
repeated ChallengeKind challenges = 3; repeated ChallengeKind challenges = 3;
string domain = 4 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "\"Domain on which the session was created. Will be used for Passkey and U2F challenges.\"";
}
];
} }
message CreateSessionResponse{ message CreateSessionResponse{

View File

@ -587,6 +587,11 @@ message RegisterPasskeyRequest{
description: "\"Optionally specify the authenticator type of the passkey device (platform or cross-platform). If none is provided, both values are allowed.\""; description: "\"Optionally specify the authenticator type of the passkey device (platform or cross-platform). If none is provided, both values are allowed.\"";
} }
]; ];
string domain = 4 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "\"Domain on which the user is authenticated.\"";
}
];
} }
message RegisterPasskeyResponse{ message RegisterPasskeyResponse{
@ -658,6 +663,11 @@ message RegisterU2FRequest{
example: "\"d654e6ba-70a3-48ef-a95d-37c8d8a7901a\""; example: "\"d654e6ba-70a3-48ef-a95d-37c8d8a7901a\"";
} }
]; ];
string domain = 2 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "\"Domain on which the user is authenticated.\"";
}
];
} }
message RegisterU2FResponse{ message RegisterU2FResponse{