zitadel/internal/webauthn/webauthn.go
2023-09-15 13:43:38 +00:00

221 lines
7.6 KiB
Go

package webauthn
import (
"bytes"
"context"
"encoding/json"
"errors"
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
)
type Config struct {
DisplayName string
ExternalSecure bool
}
type webUser struct {
*domain.Human
accountName string
credentials []webauthn.Credential
}
func (u *webUser) WebAuthnID() []byte {
return []byte(u.AggregateID)
}
func (u *webUser) WebAuthnName() string {
if u.accountName != "" {
return u.accountName
}
return u.GetUsername()
}
func (u *webUser) WebAuthnDisplayName() string {
if u.DisplayName != "" {
return u.DisplayName
}
return u.GetUsername()
}
func (u *webUser) WebAuthnIcon() string {
return ""
}
func (u *webUser) WebAuthnCredentials() []webauthn.Credential {
return u.credentials
}
func (w *Config) BeginRegistration(ctx context.Context, user *domain.Human, accountName string, authType domain.AuthenticatorAttachment, userVerification domain.UserVerificationRequirement, rpID string, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNToken, error) {
webAuthNServer, err := w.serverFromContext(ctx, rpID, "")
if err != nil {
return nil, err
}
creds := WebAuthNsToCredentials(webAuthNs, rpID)
existing := make([]protocol.CredentialDescriptor, len(creds))
for i, cred := range creds {
existing[i] = protocol.CredentialDescriptor{
Type: protocol.PublicKeyCredentialType,
CredentialID: cred.ID,
}
}
credentialOptions, sessionData, err := webAuthNServer.BeginRegistration(
&webUser{
Human: user,
accountName: accountName,
credentials: creds,
},
webauthn.WithAuthenticatorSelection(protocol.AuthenticatorSelection{
UserVerification: UserVerificationFromDomain(userVerification),
AuthenticatorAttachment: AuthenticatorAttachmentFromDomain(authType),
}),
webauthn.WithConveyancePreference(protocol.PreferNoAttestation),
webauthn.WithExclusions(existing),
)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "WEBAU-bM8sd", "Errors.User.WebAuthN.BeginRegisterFailed")
}
cred, err := json.Marshal(credentialOptions)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "WEBAU-D7cus", "Errors.User.WebAuthN.MarshalError")
}
return &domain.WebAuthNToken{
Challenge: sessionData.Challenge,
CredentialCreationData: cred,
AllowedCredentialIDs: sessionData.AllowedCredentialIDs,
UserVerification: UserVerificationToDomain(sessionData.UserVerification),
RPID: webAuthNServer.Config.RPID,
}, nil
}
func (w *Config) FinishRegistration(ctx context.Context, user *domain.Human, webAuthN *domain.WebAuthNToken, tokenName string, credData []byte, isLoginUI bool) (*domain.WebAuthNToken, error) {
if webAuthN == nil {
return nil, caos_errs.ThrowInternal(nil, "WEBAU-5M9so", "Errors.User.WebAuthN.NotFound")
}
credentialData, err := protocol.ParseCredentialCreationResponseBody(bytes.NewReader(credData))
if err != nil {
logging.WithFields("error", tryExtractProtocolErrMsg(err)).Debug("webauthn credential could not be parsed")
return nil, caos_errs.ThrowInternal(err, "WEBAU-sEr8c", "Errors.User.WebAuthN.ErrorOnParseCredential")
}
sessionData := WebAuthNToSessionData(webAuthN)
webAuthNServer, err := w.serverFromContext(ctx, webAuthN.RPID, credentialData.Response.CollectedClientData.Origin)
if err != nil {
return nil, err
}
credential, err := webAuthNServer.CreateCredential(
&webUser{
Human: user,
},
sessionData,
credentialData)
if err != nil {
logging.WithFields("error", tryExtractProtocolErrMsg(err)).Debug("webauthn credential could not be created")
return nil, caos_errs.ThrowInternal(err, "WEBAU-3Vb9s", "Errors.User.WebAuthN.CreateCredentialFailed")
}
webAuthN.KeyID = credential.ID
webAuthN.PublicKey = credential.PublicKey
webAuthN.AttestationType = credential.AttestationType
webAuthN.AAGUID = credential.Authenticator.AAGUID
webAuthN.SignCount = credential.Authenticator.SignCount
webAuthN.WebAuthNTokenName = tokenName
webAuthN.RPID = webAuthNServer.Config.RPID
return webAuthN, nil
}
func (w *Config) BeginLogin(ctx context.Context, user *domain.Human, userVerification domain.UserVerificationRequirement, rpID string, webAuthNs ...*domain.WebAuthNToken) (*domain.WebAuthNLogin, error) {
webAuthNServer, err := w.serverFromContext(ctx, rpID, "")
if err != nil {
return nil, err
}
assertion, sessionData, err := webAuthNServer.BeginLogin(&webUser{
Human: user,
credentials: WebAuthNsToCredentials(webAuthNs, rpID),
}, webauthn.WithUserVerification(UserVerificationFromDomain(userVerification)))
if err != nil {
logging.WithFields("error", tryExtractProtocolErrMsg(err)).Debug("webauthn login could not be started")
return nil, caos_errs.ThrowInternal(err, "WEBAU-4G8sw", "Errors.User.WebAuthN.BeginLoginFailed")
}
cred, err := json.Marshal(assertion)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "WEBAU-2M0s9", "Errors.User.WebAuthN.MarshalError")
}
return &domain.WebAuthNLogin{
Challenge: sessionData.Challenge,
CredentialAssertionData: cred,
AllowedCredentialIDs: sessionData.AllowedCredentialIDs,
UserVerification: userVerification,
RPID: webAuthNServer.Config.RPID,
}, nil
}
func (w *Config) FinishLogin(ctx context.Context, user *domain.Human, webAuthN *domain.WebAuthNLogin, credData []byte, webAuthNs ...*domain.WebAuthNToken) (*webauthn.Credential, error) {
assertionData, err := protocol.ParseCredentialRequestResponseBody(bytes.NewReader(credData))
if err != nil {
logging.WithFields("error", tryExtractProtocolErrMsg(err)).Debug("webauthn assertion could not be parsed")
return nil, caos_errs.ThrowInternal(err, "WEBAU-ADgv4", "Errors.User.WebAuthN.ValidateLoginFailed")
}
webUser := &webUser{
Human: user,
credentials: WebAuthNsToCredentials(webAuthNs, webAuthN.RPID),
}
webAuthNServer, err := w.serverFromContext(ctx, webAuthN.RPID, assertionData.Response.CollectedClientData.Origin)
if err != nil {
return nil, err
}
credential, err := webAuthNServer.ValidateLogin(webUser, WebAuthNLoginToSessionData(webAuthN), assertionData)
if err != nil {
logging.WithFields("error", tryExtractProtocolErrMsg(err)).Debug("webauthn assertion failed")
return nil, caos_errs.ThrowInternal(err, "WEBAU-3M9si", "Errors.User.WebAuthN.ValidateLoginFailed")
}
if credential.Authenticator.CloneWarning {
return credential, caos_errs.ThrowInternal(nil, "WEBAU-4M90s", "Errors.User.WebAuthN.CloneWarning")
}
return credential, nil
}
func (w *Config) serverFromContext(ctx context.Context, id, origin string) (*webauthn.WebAuthn, error) {
config := w.config(id, origin)
if id == "" {
config = w.configFromContext(ctx)
}
webAuthn, err := webauthn.New(config)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "WEBAU-UX9ta", "Errors.User.WebAuthN.ServerConfig")
}
return webAuthn, nil
}
func (w *Config) configFromContext(ctx context.Context) *webauthn.Config {
instance := authz.GetInstance(ctx)
return &webauthn.Config{
RPDisplayName: w.DisplayName,
RPID: instance.RequestedDomain(),
RPOrigins: []string{http.BuildOrigin(instance.RequestedHost(), w.ExternalSecure)},
}
}
func (w *Config) config(id, origin string) *webauthn.Config {
return &webauthn.Config{
RPDisplayName: w.DisplayName,
RPID: id,
RPOrigins: []string{origin},
}
}
func tryExtractProtocolErrMsg(err error) string {
var e *protocol.Error
if errors.As(err, &e) {
return e.Details + ": " + e.DevInfo
}
return e.Error()
}