fix: provide domain in session, passkey and u2f (#6097)

This fix provides a possibility to pass a domain on the session, which
will be used (as rpID) to create a passkey / u2f assertion and
attestation. This is useful in cases where the login UI is served under
a different domain / origin than the ZITADEL API.
This commit is contained in:
Livio Spring
2023-06-27 14:36:07 +02:00
committed by GitHub
parent d0cda1b479
commit bd5defa96a
32 changed files with 287 additions and 123 deletions

View File

@@ -7,10 +7,10 @@ import (
"github.com/zitadel/zitadel/internal/domain"
)
func WebAuthNsToCredentials(webAuthNs []*domain.WebAuthNToken) []webauthn.Credential {
func WebAuthNsToCredentials(webAuthNs []*domain.WebAuthNToken, rpID string) []webauthn.Credential {
creds := make([]webauthn.Credential, 0)
for _, webAuthN := range webAuthNs {
if webAuthN.State == domain.MFAStateReady {
if webAuthN.State == domain.MFAStateReady && webAuthN.RPID == rpID {
creds = append(creds, webauthn.Credential{
ID: webAuthN.KeyID,
PublicKey: webAuthN.PublicKey,

View File

@@ -52,12 +52,12 @@ func (u *webUser) WebAuthnCredentials() []webauthn.Credential {
return u.credentials
}
func (w *Config) BeginRegistration(ctx context.Context, user *domain.Human, accountName string, authType domain.AuthenticatorAttachment, userVerification domain.UserVerificationRequirement, isLoginUI bool, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNToken, error) {
webAuthNServer, err := w.serverFromContext(ctx)
func (w *Config) BeginRegistration(ctx context.Context, user *domain.Human, accountName string, authType domain.AuthenticatorAttachment, userVerification domain.UserVerificationRequirement, rpID string, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNToken, error) {
webAuthNServer, err := w.serverFromContext(ctx, rpID, "")
if err != nil {
return nil, err
}
creds := WebAuthNsToCredentials(webAuthNs)
creds := WebAuthNsToCredentials(webAuthNs, rpID)
existing := make([]protocol.CredentialDescriptor, len(creds))
for i, cred := range creds {
existing[i] = protocol.CredentialDescriptor{
@@ -90,6 +90,7 @@ func (w *Config) BeginRegistration(ctx context.Context, user *domain.Human, acco
CredentialCreationData: cred,
AllowedCredentialIDs: sessionData.AllowedCredentialIDs,
UserVerification: UserVerificationToDomain(sessionData.UserVerification),
RPID: webAuthNServer.Config.RPID,
}, nil
}
@@ -104,7 +105,7 @@ func (w *Config) FinishRegistration(ctx context.Context, user *domain.Human, web
return nil, caos_errs.ThrowInternal(err, "WEBAU-sEr8c", "Errors.User.WebAuthN.ErrorOnParseCredential")
}
sessionData := WebAuthNToSessionData(webAuthN)
webAuthNServer, err := w.serverFromContext(ctx)
webAuthNServer, err := w.serverFromContext(ctx, webAuthN.RPID, credentialData.Response.CollectedClientData.Origin)
if err != nil {
return nil, err
}
@@ -124,17 +125,18 @@ func (w *Config) FinishRegistration(ctx context.Context, user *domain.Human, web
webAuthN.AAGUID = credential.Authenticator.AAGUID
webAuthN.SignCount = credential.Authenticator.SignCount
webAuthN.WebAuthNTokenName = tokenName
webAuthN.RPID = webAuthNServer.Config.RPID
return webAuthN, nil
}
func (w *Config) BeginLogin(ctx context.Context, user *domain.Human, userVerification domain.UserVerificationRequirement, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNLogin, error) {
webAuthNServer, err := w.serverFromContext(ctx)
func (w *Config) BeginLogin(ctx context.Context, user *domain.Human, userVerification domain.UserVerificationRequirement, rpID string, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNLogin, error) {
webAuthNServer, err := w.serverFromContext(ctx, rpID, "")
if err != nil {
return nil, err
}
assertion, sessionData, err := webAuthNServer.BeginLogin(&webUser{
Human: user,
credentials: WebAuthNsToCredentials(webAuthNs),
credentials: WebAuthNsToCredentials(webAuthNs, rpID),
}, webauthn.WithUserVerification(UserVerificationFromDomain(userVerification)))
if err != nil {
return nil, caos_errs.ThrowInternal(err, "WEBAU-4G8sw", "Errors.User.WebAuthN.BeginLoginFailed")
@@ -148,6 +150,7 @@ func (w *Config) BeginLogin(ctx context.Context, user *domain.Human, userVerific
CredentialAssertionData: cred,
AllowedCredentialIDs: sessionData.AllowedCredentialIDs,
UserVerification: userVerification,
RPID: webAuthNServer.Config.RPID,
}, nil
}
@@ -158,9 +161,9 @@ func (w *Config) FinishLogin(ctx context.Context, user *domain.Human, webAuthN *
}
webUser := &webUser{
Human: user,
credentials: WebAuthNsToCredentials(webAuthNs),
credentials: WebAuthNsToCredentials(webAuthNs, webAuthN.RPID),
}
webAuthNServer, err := w.serverFromContext(ctx)
webAuthNServer, err := w.serverFromContext(ctx, webAuthN.RPID, assertionData.Response.CollectedClientData.Origin)
if err != nil {
return nil, 0, err
}
@@ -175,15 +178,31 @@ func (w *Config) FinishLogin(ctx context.Context, user *domain.Human, webAuthN *
return credential.ID, credential.Authenticator.SignCount, nil
}
func (w *Config) serverFromContext(ctx context.Context) (*webauthn.WebAuthn, error) {
instance := authz.GetInstance(ctx)
webAuthn, err := webauthn.New(&webauthn.Config{
RPDisplayName: w.DisplayName,
RPID: instance.RequestedDomain(),
RPOrigins: []string{http.BuildOrigin(instance.RequestedHost(), w.ExternalSecure)},
})
func (w *Config) serverFromContext(ctx context.Context, id, origin string) (*webauthn.WebAuthn, error) {
config := w.config(id, origin)
if id == "" {
config = w.configFromContext(ctx)
}
webAuthn, err := webauthn.New(config)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "WEBAU-UX9ta", "Errors.User.WebAuthN.ServerConfig")
}
return webAuthn, nil
}
func (w *Config) configFromContext(ctx context.Context) *webauthn.Config {
instance := authz.GetInstance(ctx)
return &webauthn.Config{
RPDisplayName: w.DisplayName,
RPID: instance.RequestedDomain(),
RPOrigins: []string{http.BuildOrigin(instance.RequestedHost(), w.ExternalSecure)},
}
}
func (w *Config) config(id, origin string) *webauthn.Config {
return &webauthn.Config{
RPDisplayName: w.DisplayName,
RPID: id,
RPOrigins: []string{origin},
}
}

View File

@@ -14,7 +14,9 @@ import (
func TestConfig_serverFromContext(t *testing.T) {
type args struct {
ctx context.Context
ctx context.Context
id string
origin string
}
tests := []struct {
name string
@@ -24,12 +26,12 @@ func TestConfig_serverFromContext(t *testing.T) {
}{
{
name: "webauthn error",
args: args{context.Background()},
args: args{context.Background(), "", ""},
wantErr: caos_errs.ThrowInternal(nil, "WEBAU-UX9ta", "Errors.User.WebAuthN.ServerConfig"),
},
{
name: "success",
args: args{authz.WithRequestedDomain(context.Background(), "example.com")},
name: "success from ctx",
args: args{authz.WithRequestedDomain(context.Background(), "example.com"), "", ""},
want: &webauthn.WebAuthn{
Config: &webauthn.Config{
RPDisplayName: "DisplayName",
@@ -38,6 +40,17 @@ func TestConfig_serverFromContext(t *testing.T) {
},
},
},
{
name: "success from id",
args: args{authz.WithRequestedDomain(context.Background(), "example.com"), "external.com", "https://external.com"},
want: &webauthn.WebAuthn{
Config: &webauthn.Config{
RPDisplayName: "DisplayName",
RPID: "external.com",
RPOrigins: []string{"https://external.com"},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -45,7 +58,7 @@ func TestConfig_serverFromContext(t *testing.T) {
DisplayName: "DisplayName",
ExternalSecure: true,
}
got, err := w.serverFromContext(tt.args.ctx)
got, err := w.serverFromContext(tt.args.ctx, tt.args.id, tt.args.origin)
require.ErrorIs(t, err, tt.wantErr)
if tt.want != nil {
require.NotNil(t, got)