diff --git a/internal/webauthn/converter.go b/internal/webauthn/converter.go index 36799ee3dc..c914bb8bf9 100644 --- a/internal/webauthn/converter.go +++ b/internal/webauthn/converter.go @@ -1,16 +1,26 @@ package webauthn import ( + "context" + "strings" + "github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/webauthn" + "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/domain" ) -func WebAuthNsToCredentials(webAuthNs []*domain.WebAuthNToken, rpID string) []webauthn.Credential { +func WebAuthNsToCredentials(ctx context.Context, webAuthNs []*domain.WebAuthNToken, rpID string) []webauthn.Credential { creds := make([]webauthn.Credential, 0) for _, webAuthN := range webAuthNs { - if webAuthN.State == domain.MFAStateReady && webAuthN.RPID == rpID { + // only add credentials that are ready and + // either match the rpID or + // if they were added through Console / old login UI, there is no stored rpID set; + // then we check if the requested rpID matches the instance domain + if webAuthN.State == domain.MFAStateReady && + (webAuthN.RPID == rpID || + (webAuthN.RPID == "" && rpID == strings.Split(http.DomainContext(ctx).InstanceHost, ":")[0])) { creds = append(creds, webauthn.Credential{ ID: webAuthN.KeyID, PublicKey: webAuthN.PublicKey, diff --git a/internal/webauthn/converter_test.go b/internal/webauthn/converter_test.go new file mode 100644 index 0000000000..a8f2a3608b --- /dev/null +++ b/internal/webauthn/converter_test.go @@ -0,0 +1,153 @@ +package webauthn + +import ( + "context" + "testing" + + "github.com/go-webauthn/webauthn/webauthn" + "github.com/stretchr/testify/assert" + + "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/domain" +) + +func TestWebAuthNsToCredentials(t *testing.T) { + type args struct { + ctx context.Context + webAuthNs []*domain.WebAuthNToken + rpID string + } + tests := []struct { + name string + args args + want []webauthn.Credential + }{ + { + name: "unready credential", + args: args{ + ctx: context.Background(), + webAuthNs: []*domain.WebAuthNToken{ + { + KeyID: []byte("key1"), + PublicKey: []byte("publicKey1"), + AttestationType: "attestation1", + AAGUID: []byte("aaguid1"), + SignCount: 1, + State: domain.MFAStateNotReady, + }, + }, + rpID: "example.com", + }, + want: []webauthn.Credential{}, + }, + { + name: "not matching rpID", + args: args{ + ctx: context.Background(), + webAuthNs: []*domain.WebAuthNToken{ + { + KeyID: []byte("key1"), + PublicKey: []byte("publicKey1"), + AttestationType: "attestation1", + AAGUID: []byte("aaguid1"), + SignCount: 1, + State: domain.MFAStateReady, + RPID: "other.com", + }, + }, + rpID: "example.com", + }, + want: []webauthn.Credential{}, + }, + { + name: "matching rpID", + args: args{ + ctx: context.Background(), + webAuthNs: []*domain.WebAuthNToken{ + { + KeyID: []byte("key1"), + PublicKey: []byte("publicKey1"), + AttestationType: "attestation1", + AAGUID: []byte("aaguid1"), + SignCount: 1, + State: domain.MFAStateReady, + RPID: "example.com", + }, + }, + rpID: "example.com", + }, + want: []webauthn.Credential{ + { + ID: []byte("key1"), + PublicKey: []byte("publicKey1"), + AttestationType: "attestation1", + Authenticator: webauthn.Authenticator{ + AAGUID: []byte("aaguid1"), + SignCount: 1, + }, + }, + }, + }, + { + name: "no rpID, different host", + args: args{ + ctx: http.WithDomainContext(context.Background(), &http.DomainCtx{ + InstanceHost: "other.com:443", + PublicHost: "other.com:443", + Protocol: "https", + }), + webAuthNs: []*domain.WebAuthNToken{ + { + KeyID: []byte("key1"), + PublicKey: []byte("publicKey1"), + AttestationType: "attestation1", + AAGUID: []byte("aaguid1"), + SignCount: 1, + State: domain.MFAStateReady, + RPID: "", + }, + }, + rpID: "example.com", + }, + want: []webauthn.Credential{}, + }, + { + name: "no rpID, same host", + args: args{ + ctx: http.WithDomainContext(context.Background(), &http.DomainCtx{ + InstanceHost: "example.com:443", + PublicHost: "example.com:443", + Protocol: "https", + }), + webAuthNs: []*domain.WebAuthNToken{ + { + KeyID: []byte("key1"), + PublicKey: []byte("publicKey1"), + AttestationType: "attestation1", + AAGUID: []byte("aaguid1"), + SignCount: 1, + State: domain.MFAStateReady, + RPID: "", + }, + }, + rpID: "example.com", + }, + want: []webauthn.Credential{ + { + ID: []byte("key1"), + PublicKey: []byte("publicKey1"), + AttestationType: "attestation1", + Authenticator: webauthn.Authenticator{ + AAGUID: []byte("aaguid1"), + SignCount: 1, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, WebAuthNsToCredentials(tt.args.ctx, tt.args.webAuthNs, tt.args.rpID), "WebAuthNsToCredentials(%v, %v, %v)", tt.args.ctx, tt.args.webAuthNs, tt.args.rpID) + }) + } +} diff --git a/internal/webauthn/webauthn.go b/internal/webauthn/webauthn.go index 998c013a3c..10d6fc52bf 100644 --- a/internal/webauthn/webauthn.go +++ b/internal/webauthn/webauthn.go @@ -57,7 +57,7 @@ func (w *Config) BeginRegistration(ctx context.Context, user *domain.Human, acco if err != nil { return nil, err } - creds := WebAuthNsToCredentials(webAuthNs, rpID) + creds := WebAuthNsToCredentials(ctx, webAuthNs, rpID) existing := make([]protocol.CredentialDescriptor, len(creds)) for i, cred := range creds { existing[i] = protocol.CredentialDescriptor{ @@ -136,7 +136,7 @@ func (w *Config) BeginLogin(ctx context.Context, user *domain.Human, userVerific } assertion, sessionData, err := webAuthNServer.BeginLogin(&webUser{ Human: user, - credentials: WebAuthNsToCredentials(webAuthNs, rpID), + credentials: WebAuthNsToCredentials(ctx, webAuthNs, rpID), }, webauthn.WithUserVerification(UserVerificationFromDomain(userVerification))) if err != nil { logging.WithFields("error", tryExtractProtocolErrMsg(err)).Debug("webauthn login could not be started") @@ -163,7 +163,7 @@ func (w *Config) FinishLogin(ctx context.Context, user *domain.Human, webAuthN * } webUser := &webUser{ Human: user, - credentials: WebAuthNsToCredentials(webAuthNs, webAuthN.RPID), + credentials: WebAuthNsToCredentials(ctx, webAuthNs, webAuthN.RPID), } webAuthNServer, err := w.serverFromContext(ctx, webAuthN.RPID, assertionData.Response.CollectedClientData.Origin) if err != nil {