feat(api/v2): implement U2F session check (#6339)

This commit is contained in:
Tim Möhlmann
2023-08-11 18:36:18 +03:00
committed by GitHub
parent 4e0c3115fe
commit 86af67d1be
47 changed files with 1035 additions and 665 deletions

View File

@@ -358,7 +358,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate)),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
@@ -401,7 +401,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate)),
),
),
tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) {
@@ -444,7 +444,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate),
),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,
@@ -523,7 +523,7 @@ func TestCommands_LinkSessionToAuthRequest(t *testing.T) {
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate),
),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate,

View File

@@ -164,7 +164,7 @@ func TestCommands_AddOIDCSessionAccessToken(t *testing.T) {
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"),
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate),
),
eventFromEventPusher(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
@@ -365,7 +365,7 @@ func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) {
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"),
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate),
),
eventFromEventPusher(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,

View File

@@ -15,7 +15,6 @@ import (
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
@@ -138,10 +137,8 @@ func (s *SessionCommands) Exec(ctx context.Context) error {
return nil
}
func (s *SessionCommands) Start(ctx context.Context, domain string) {
s.eventCommands = append(s.eventCommands, session.NewAddedEvent(ctx, s.sessionWriteModel.aggregate, domain))
// set the domain so checks can use it
s.sessionWriteModel.Domain = domain
func (s *SessionCommands) Start(ctx context.Context) {
s.eventCommands = append(s.eventCommands, session.NewAddedEvent(ctx, s.sessionWriteModel.aggregate))
}
func (s *SessionCommands) UserChecked(ctx context.Context, userID string, checkedAt time.Time) error {
@@ -159,15 +156,23 @@ func (s *SessionCommands) IntentChecked(ctx context.Context, checkedAt time.Time
s.eventCommands = append(s.eventCommands, session.NewIntentCheckedEvent(ctx, s.sessionWriteModel.aggregate, checkedAt))
}
func (s *SessionCommands) PasskeyChallenged(ctx context.Context, challenge string, allowedCrentialIDs [][]byte, userVerification domain.UserVerificationRequirement) {
s.eventCommands = append(s.eventCommands, session.NewPasskeyChallengedEvent(ctx, s.sessionWriteModel.aggregate, challenge, allowedCrentialIDs, userVerification))
func (s *SessionCommands) WebAuthNChallenged(ctx context.Context, challenge string, allowedCrentialIDs [][]byte, userVerification domain.UserVerificationRequirement, rpid string) {
s.eventCommands = append(s.eventCommands, session.NewWebAuthNChallengedEvent(ctx, s.sessionWriteModel.aggregate, challenge, allowedCrentialIDs, userVerification, rpid))
}
func (s *SessionCommands) PasskeyChecked(ctx context.Context, checkedAt time.Time, tokenID string, signCount uint32) {
func (s *SessionCommands) WebAuthNChecked(ctx context.Context, checkedAt time.Time, tokenID string, signCount uint32, userVerified bool) {
s.eventCommands = append(s.eventCommands,
session.NewPasskeyCheckedEvent(ctx, s.sessionWriteModel.aggregate, checkedAt),
usr_repo.NewHumanPasswordlessSignCountChangedEvent(ctx, s.sessionWriteModel.aggregate, tokenID, signCount),
session.NewWebAuthNCheckedEvent(ctx, s.sessionWriteModel.aggregate, checkedAt, userVerified),
)
if s.sessionWriteModel.WebAuthNChallenge.UserVerification == domain.UserVerificationRequirementRequired {
s.eventCommands = append(s.eventCommands,
user.NewHumanPasswordlessSignCountChangedEvent(ctx, s.sessionWriteModel.aggregate, tokenID, signCount),
)
} else {
s.eventCommands = append(s.eventCommands,
user.NewHumanU2FSignCountChangedEvent(ctx, s.sessionWriteModel.aggregate, tokenID, signCount),
)
}
}
func (s *SessionCommands) SetToken(ctx context.Context, tokenID string) {
@@ -226,7 +231,7 @@ func (s *SessionCommands) commands(ctx context.Context) (string, []eventstore.Co
return token, s.eventCommands, nil
}
func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, sessionDomain string, metadata map[string][]byte) (set *SessionChanged, err error) {
func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, metadata map[string][]byte) (set *SessionChanged, err error) {
sessionID, err := c.idGenerator.Next()
if err != nil {
return nil, err
@@ -237,7 +242,7 @@ func (c *Commands) CreateSession(ctx context.Context, cmds []SessionCommand, ses
return nil, err
}
cmd := c.NewSessionCommands(cmds, sessionWriteModel)
cmd.Start(ctx, sessionDomain)
cmd.Start(ctx)
return c.updateSession(ctx, cmd, metadata)
}

View File

@@ -4,22 +4,18 @@ import (
"time"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/session"
)
type PasskeyChallengeModel struct {
type WebAuthNChallengeModel struct {
Challenge string
AllowedCrentialIDs [][]byte
UserVerification domain.UserVerificationRequirement
RPID string
}
func (p *PasskeyChallengeModel) WebAuthNLogin(human *domain.Human, credentialAssertionData []byte) (*domain.WebAuthNLogin, error) {
if p == nil {
return nil, caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Ioqu5", "Errors.Session.Passkey.NoChallenge")
}
func (p *WebAuthNChallengeModel) WebAuthNLogin(human *domain.Human, credentialAssertionData []byte) *domain.WebAuthNLogin {
return &domain.WebAuthNLogin{
ObjectRoot: human.ObjectRoot,
CredentialAssertionData: credentialAssertionData,
@@ -27,23 +23,23 @@ func (p *PasskeyChallengeModel) WebAuthNLogin(human *domain.Human, credentialAss
AllowedCredentialIDs: p.AllowedCrentialIDs,
UserVerification: p.UserVerification,
RPID: p.RPID,
}, nil
}
}
type SessionWriteModel struct {
eventstore.WriteModel
TokenID string
UserID string
UserCheckedAt time.Time
PasswordCheckedAt time.Time
IntentCheckedAt time.Time
PasskeyCheckedAt time.Time
Metadata map[string][]byte
Domain string
State domain.SessionState
TokenID string
UserID string
UserCheckedAt time.Time
PasswordCheckedAt time.Time
IntentCheckedAt time.Time
WebAuthNCheckedAt time.Time
WebAuthNUserVerified bool
Metadata map[string][]byte
State domain.SessionState
PasskeyChallenge *PasskeyChallengeModel
WebAuthNChallenge *WebAuthNChallengeModel
aggregate *eventstore.Aggregate
}
@@ -70,10 +66,10 @@ func (wm *SessionWriteModel) Reduce() error {
wm.reducePasswordChecked(e)
case *session.IntentCheckedEvent:
wm.reduceIntentChecked(e)
case *session.PasskeyChallengedEvent:
wm.reducePasskeyChallenged(e)
case *session.PasskeyCheckedEvent:
wm.reducePasskeyChecked(e)
case *session.WebAuthNChallengedEvent:
wm.reduceWebAuthNChallenged(e)
case *session.WebAuthNCheckedEvent:
wm.reduceWebAuthNChecked(e)
case *session.TokenSetEvent:
wm.reduceTokenSet(e)
case *session.TerminateEvent:
@@ -93,8 +89,8 @@ func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder {
session.UserCheckedType,
session.PasswordCheckedType,
session.IntentCheckedType,
session.PasskeyChallengedType,
session.PasskeyCheckedType,
session.WebAuthNChallengedType,
session.WebAuthNCheckedType,
session.TokenSetType,
session.MetadataSetType,
session.TerminateType,
@@ -108,7 +104,6 @@ func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder {
}
func (wm *SessionWriteModel) reduceAdded(e *session.AddedEvent) {
wm.Domain = e.Domain
wm.State = domain.SessionStateActive
}
@@ -125,18 +120,19 @@ func (wm *SessionWriteModel) reduceIntentChecked(e *session.IntentCheckedEvent)
wm.IntentCheckedAt = e.CheckedAt
}
func (wm *SessionWriteModel) reducePasskeyChallenged(e *session.PasskeyChallengedEvent) {
wm.PasskeyChallenge = &PasskeyChallengeModel{
func (wm *SessionWriteModel) reduceWebAuthNChallenged(e *session.WebAuthNChallengedEvent) {
wm.WebAuthNChallenge = &WebAuthNChallengeModel{
Challenge: e.Challenge,
AllowedCrentialIDs: e.AllowedCrentialIDs,
UserVerification: e.UserVerification,
RPID: wm.Domain,
RPID: e.RPID,
}
}
func (wm *SessionWriteModel) reducePasskeyChecked(e *session.PasskeyCheckedEvent) {
wm.PasskeyChallenge = nil
wm.PasskeyCheckedAt = e.CheckedAt
func (wm *SessionWriteModel) reduceWebAuthNChecked(e *session.WebAuthNCheckedEvent) {
wm.WebAuthNChallenge = nil
wm.WebAuthNCheckedAt = e.CheckedAt
wm.WebAuthNUserVerified = e.UserVerified
}
func (wm *SessionWriteModel) reduceTokenSet(e *session.TokenSetEvent) {
@@ -152,9 +148,9 @@ func (wm *SessionWriteModel) AuthenticationTime() time.Time {
var authTime time.Time
for _, check := range []time.Time{
wm.PasswordCheckedAt,
wm.PasskeyCheckedAt,
wm.WebAuthNCheckedAt,
wm.IntentCheckedAt,
// TODO: add U2F and OTP check https://github.com/zitadel/zitadel/issues/5477
// TODO: add OTP check https://github.com/zitadel/zitadel/issues/5477
// TODO: add OTP (sms and email) check https://github.com/zitadel/zitadel/issues/6224
} {
if check.After(authTime) {
@@ -170,8 +166,12 @@ func (wm *SessionWriteModel) AuthMethodTypes() []domain.UserAuthMethodType {
if !wm.PasswordCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypePassword)
}
if !wm.PasskeyCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypePasswordless)
if !wm.WebAuthNCheckedAt.IsZero() {
if wm.WebAuthNUserVerified {
types = append(types, domain.UserAuthMethodTypePasswordless)
} else {
types = append(types, domain.UserAuthMethodTypeU2F)
}
}
if !wm.IntentCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypeIDP)
@@ -181,9 +181,6 @@ func (wm *SessionWriteModel) AuthMethodTypes() []domain.UserAuthMethodType {
if !wm.TOTPCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypeTOTP)
}
if !wm.U2FCheckedAt.IsZero() {
types = append(types, domain.UserAuthMethodTypeU2F)
}
*/
// TODO: add checks with https://github.com/zitadel/zitadel/issues/6224
/*

View File

@@ -0,0 +1,75 @@
package command
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/domain"
)
func TestSessionWriteModel_AuthMethodTypes(t *testing.T) {
type fields struct {
PasswordCheckedAt time.Time
IntentCheckedAt time.Time
WebAuthNCheckedAt time.Time
WebAuthNUserVerified bool
}
tests := []struct {
name string
fields fields
want []domain.UserAuthMethodType
}{
{
name: "password",
fields: fields{
PasswordCheckedAt: testNow,
},
want: []domain.UserAuthMethodType{
domain.UserAuthMethodTypePassword,
},
},
{
name: "passwordless",
fields: fields{
WebAuthNCheckedAt: testNow,
WebAuthNUserVerified: true,
},
want: []domain.UserAuthMethodType{
domain.UserAuthMethodTypePasswordless,
},
},
{
name: "u2f",
fields: fields{
WebAuthNCheckedAt: testNow,
WebAuthNUserVerified: false,
},
want: []domain.UserAuthMethodType{
domain.UserAuthMethodTypeU2F,
},
},
{
name: "intent",
fields: fields{
IntentCheckedAt: testNow,
},
want: []domain.UserAuthMethodType{
domain.UserAuthMethodTypeIDP,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wm := &SessionWriteModel{
PasswordCheckedAt: tt.fields.PasswordCheckedAt,
IntentCheckedAt: tt.fields.IntentCheckedAt,
WebAuthNCheckedAt: tt.fields.WebAuthNCheckedAt,
WebAuthNUserVerified: tt.fields.WebAuthNUserVerified,
}
got := wm.AuthMethodTypes()
assert.Equal(t, got, tt.want)
})
}
}

View File

@@ -1,84 +0,0 @@
package command
import (
"context"
"encoding/json"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
)
type humanPasskeys struct {
human *domain.Human
tokens []*domain.WebAuthNToken
}
func (s *SessionCommands) getHumanPasskeys(ctx context.Context) (*humanPasskeys, error) {
humanWritemodel, err := s.gethumanWriteModel(ctx)
if err != nil {
return nil, err
}
tokenReadModel, err := s.getHumanPasswordlessTokenReadModel(ctx)
if err != nil {
return nil, err
}
return &humanPasskeys{
human: writeModelToHuman(humanWritemodel),
tokens: readModelToPasswordlessTokens(tokenReadModel),
}, nil
}
func (s *SessionCommands) getHumanPasswordlessTokenReadModel(ctx context.Context) (*HumanPasswordlessTokensReadModel, error) {
tokenReadModel := NewHumanPasswordlessTokensReadModel(s.sessionWriteModel.UserID, s.sessionWriteModel.ResourceOwner)
err := s.eventstore.FilterToQueryReducer(ctx, tokenReadModel)
if err != nil {
return nil, err
}
return tokenReadModel, nil
}
func (c *Commands) CreatePasskeyChallenge(userVerification domain.UserVerificationRequirement, dst json.Unmarshaler) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
humanPasskeys, err := cmd.getHumanPasskeys(ctx)
if err != nil {
return err
}
webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, humanPasskeys.human, userVerification, cmd.sessionWriteModel.Domain, humanPasskeys.tokens...)
if err != nil {
return err
}
if err = json.Unmarshal(webAuthNLogin.CredentialAssertionData, dst); err != nil {
return caos_errs.ThrowInternal(err, "COMMAND-Yah6A", "Errors.Internal")
}
cmd.PasskeyChallenged(ctx, webAuthNLogin.Challenge, webAuthNLogin.AllowedCredentialIDs, webAuthNLogin.UserVerification)
return nil
}
}
func (c *Commands) CheckPasskey(credentialAssertionData json.Marshaler) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
credentialAssertionData, err := json.Marshal(credentialAssertionData)
if err != nil {
return caos_errs.ThrowInvalidArgument(err, "COMMAND-ohG2o", "todo")
}
humanPasskeys, err := cmd.getHumanPasskeys(ctx)
if err != nil {
return err
}
webAuthN, err := cmd.sessionWriteModel.PasskeyChallenge.WebAuthNLogin(humanPasskeys.human, credentialAssertionData)
if err != nil {
return err
}
keyID, signCount, err := c.webauthnConfig.FinishLogin(ctx, humanPasskeys.human, webAuthN, credentialAssertionData, humanPasskeys.tokens...)
if err != nil && keyID == nil {
return err
}
_, token := domain.GetTokenByKeyID(humanPasskeys.tokens, keyID)
if token == nil {
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Aej7i", "Errors.User.WebAuthN.NotFound")
}
cmd.PasskeyChecked(ctx, cmd.now(), token.WebAuthNTokenID, signCount)
return nil
}
}

View File

@@ -1,131 +0,0 @@
package command
import (
"context"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/user"
)
func TestSessionCommands_getHumanPasskeys(t *testing.T) {
userAggr := &user.NewAggregate("user1", "org1").Aggregate
type fields struct {
eventstore *eventstore.Eventstore
sessionWriteModel *SessionWriteModel
}
type res struct {
want *humanPasskeys
err error
}
tests := []struct {
name string
fields fields
res res
}{
{
name: "missing UID",
fields: fields{
eventstore: &eventstore.Eventstore{},
sessionWriteModel: &SessionWriteModel{},
},
res: res{
want: nil,
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-eeR2e", "Errors.User.UserIDMissing"),
},
},
{
name: "passwordless filter error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
user.NewHumanAddedEvent(context.Background(),
userAggr,
"", "", "", "", "", language.Georgian,
domain.GenderDiverse, "", true,
),
),
),
expectFilterError(io.ErrClosedPipe),
),
sessionWriteModel: &SessionWriteModel{
UserID: "user1",
},
},
res: res{
want: nil,
err: io.ErrClosedPipe,
},
},
{
name: "ok",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
user.NewHumanAddedEvent(context.Background(),
userAggr,
"", "", "", "", "", language.Georgian,
domain.GenderDiverse, "", true,
),
),
),
expectFilter(eventFromEventPusher(
user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush(
context.Background(), &org.NewAggregate("org1").Aggregate, user.HumanPasswordlessTokenAddedType,
), "111", "challenge", "rpID"),
)),
),
sessionWriteModel: &SessionWriteModel{
UserID: "user1",
},
},
res: res{
want: &humanPasskeys{
human: &domain.Human{
ObjectRoot: models.ObjectRoot{
AggregateID: "user1",
ResourceOwner: "org1",
},
State: domain.UserStateActive,
Profile: &domain.Profile{
PreferredLanguage: language.Georgian,
Gender: domain.GenderDiverse,
},
Email: &domain.Email{},
},
tokens: []*domain.WebAuthNToken{{
ObjectRoot: models.ObjectRoot{
AggregateID: "org1",
},
WebAuthNTokenID: "111",
State: domain.MFAStateNotReady,
Challenge: "challenge",
RPID: "rpID",
}},
},
err: nil,
},
},
}
for _, tt := range tests {
s := &SessionCommands{
eventstore: tt.fields.eventstore,
sessionWriteModel: tt.fields.sessionWriteModel,
}
got, err := s.getHumanPasskeys(context.Background())
require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.want, got)
}
}

View File

@@ -146,7 +146,6 @@ func TestCommands_CreateSession(t *testing.T) {
type args struct {
ctx context.Context
checks []SessionCommand
domain string
metadata map[string][]byte
}
type res struct {
@@ -205,40 +204,7 @@ func TestCommands_CreateSession(t *testing.T) {
expectFilter(),
expectPush(
eventPusherToEvents(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, ""),
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID",
),
),
),
},
res{
want: &SessionChanged{
ObjectDetails: &domain.ObjectDetails{ResourceOwner: "org1"},
ID: "sessionID",
NewToken: "token",
},
},
},
{
"empty session with domain",
fields{
idGenerator: mock.NewIDGeneratorExpectIDs(t, "sessionID"),
tokenCreator: func(sessionID string) (string, string, error) {
return "tokenID",
"token",
nil
},
},
args{
ctx: authz.NewMockContext("", "org1", ""),
domain: "domain.tld",
},
[]expect{
expectFilter(),
expectPush(
eventPusherToEvents(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"),
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate),
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID",
),
@@ -262,7 +228,7 @@ func TestCommands_CreateSession(t *testing.T) {
idGenerator: tt.fields.idGenerator,
sessionTokenCreator: tt.fields.tokenCreator,
}
got, err := c.CreateSession(tt.args.ctx, tt.args.checks, tt.args.domain, tt.args.metadata)
got, err := c.CreateSession(tt.args.ctx, tt.args.checks, tt.args.metadata)
require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.want, got)
})
@@ -311,7 +277,7 @@ func TestCommands_UpdateSession(t *testing.T) {
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")),
@@ -336,7 +302,7 @@ func TestCommands_UpdateSession(t *testing.T) {
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")),
@@ -769,7 +735,7 @@ func TestCommands_TerminateSession(t *testing.T) {
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")),
@@ -794,7 +760,7 @@ func TestCommands_TerminateSession(t *testing.T) {
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID")),
@@ -823,7 +789,7 @@ func TestCommands_TerminateSession(t *testing.T) {
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID"),
@@ -854,7 +820,7 @@ func TestCommands_TerminateSession(t *testing.T) {
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")),
session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate)),
eventFromEventPusher(
session.NewTokenSetEvent(context.Background(), &session.NewAggregate("sessionID", "org1").Aggregate,
"tokenID"),

View File

@@ -0,0 +1,89 @@
package command
import (
"context"
"encoding/json"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
)
type humanWebAuthNTokens struct {
human *domain.Human
tokens []*domain.WebAuthNToken
}
func (s *SessionCommands) getHumanWebAuthNTokens(ctx context.Context, userVerification domain.UserVerificationRequirement) (*humanWebAuthNTokens, error) {
humanWritemodel, err := s.gethumanWriteModel(ctx)
if err != nil {
return nil, err
}
tokenReadModel, err := s.getHumanWebAuthNTokenReadModel(ctx, userVerification)
if err != nil {
return nil, err
}
return &humanWebAuthNTokens{
human: writeModelToHuman(humanWritemodel),
tokens: readModelToWebAuthNTokens(tokenReadModel),
}, nil
}
func (s *SessionCommands) getHumanWebAuthNTokenReadModel(ctx context.Context, userVerification domain.UserVerificationRequirement) (readModel HumanWebAuthNTokensReadModel, err error) {
readModel = NewHumanU2FTokensReadModel(s.sessionWriteModel.UserID, s.sessionWriteModel.ResourceOwner)
if userVerification == domain.UserVerificationRequirementRequired {
readModel = NewHumanPasswordlessTokensReadModel(s.sessionWriteModel.UserID, s.sessionWriteModel.ResourceOwner)
}
err = s.eventstore.FilterToQueryReducer(ctx, readModel)
if err != nil {
return nil, err
}
return readModel, nil
}
func (c *Commands) CreateWebAuthNChallenge(userVerification domain.UserVerificationRequirement, rpid string, dst json.Unmarshaler) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
humanPasskeys, err := cmd.getHumanWebAuthNTokens(ctx, userVerification)
if err != nil {
return err
}
webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, humanPasskeys.human, userVerification, rpid, humanPasskeys.tokens...)
if err != nil {
return err
}
if err = json.Unmarshal(webAuthNLogin.CredentialAssertionData, dst); err != nil {
return caos_errs.ThrowInternal(err, "COMMAND-Yah6A", "Errors.Internal")
}
cmd.WebAuthNChallenged(ctx, webAuthNLogin.Challenge, webAuthNLogin.AllowedCredentialIDs, webAuthNLogin.UserVerification, rpid)
return nil
}
}
func (c *Commands) CheckWebAuthN(credentialAssertionData json.Marshaler) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
credentialAssertionData, err := json.Marshal(credentialAssertionData)
if err != nil {
return caos_errs.ThrowInternal(err, "COMMAND-ohG2o", "Errors.Internal")
}
challenge := cmd.sessionWriteModel.WebAuthNChallenge
if challenge == nil {
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Ioqu5", "Errors.Session.WebAuthN.NoChallenge")
}
webAuthNTokens, err := cmd.getHumanWebAuthNTokens(ctx, challenge.UserVerification)
if err != nil {
return err
}
webAuthN := challenge.WebAuthNLogin(webAuthNTokens.human, credentialAssertionData)
credential, err := c.webauthnConfig.FinishLogin(ctx, webAuthNTokens.human, webAuthN, credentialAssertionData, webAuthNTokens.tokens...)
if err != nil && (credential == nil || credential.ID == nil) {
return err
}
_, token := domain.GetTokenByKeyID(webAuthNTokens.tokens, credential.ID)
if token == nil {
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Aej7i", "Errors.User.WebAuthN.NotFound")
}
cmd.WebAuthNChecked(ctx, cmd.now(), token.WebAuthNTokenID, credential.Authenticator.SignCount, credential.Flags.UserVerified)
return nil
}
}

View File

@@ -0,0 +1,250 @@
package command
import (
"context"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/user"
)
func TestSessionCommands_getHumanWebAuthNTokens(t *testing.T) {
userAggr := &user.NewAggregate("user1", "org1").Aggregate
type fields struct {
eventstore *eventstore.Eventstore
sessionWriteModel *SessionWriteModel
}
type args struct {
userVerification domain.UserVerificationRequirement
}
type res struct {
want *humanWebAuthNTokens
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "missing UID",
fields: fields{
eventstore: &eventstore.Eventstore{},
sessionWriteModel: &SessionWriteModel{},
},
args: args{
domain.UserVerificationRequirementDiscouraged,
},
res: res{
want: nil,
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-eeR2e", "Errors.User.UserIDMissing"),
},
},
{
name: "passwordless filter error",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
user.NewHumanAddedEvent(context.Background(),
userAggr,
"", "", "", "", "", language.Georgian,
domain.GenderDiverse, "", true,
),
),
),
expectFilterError(io.ErrClosedPipe),
),
sessionWriteModel: &SessionWriteModel{
UserID: "user1",
},
},
args: args{
domain.UserVerificationRequirementDiscouraged,
},
res: res{
want: nil,
err: io.ErrClosedPipe,
},
},
{
name: "ok, discouraged, u2f",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
user.NewHumanAddedEvent(context.Background(),
userAggr,
"", "", "", "", "", language.Georgian,
domain.GenderDiverse, "", true,
),
),
),
expectFilter(eventFromEventPusher(
user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush(
context.Background(), &org.NewAggregate("org1").Aggregate, user.HumanU2FTokenAddedType,
), "111", "challenge", "rpID"),
)),
),
sessionWriteModel: &SessionWriteModel{
UserID: "user1",
},
},
args: args{
domain.UserVerificationRequirementDiscouraged,
},
res: res{
want: &humanWebAuthNTokens{
human: &domain.Human{
ObjectRoot: models.ObjectRoot{
AggregateID: "user1",
ResourceOwner: "org1",
},
State: domain.UserStateActive,
Profile: &domain.Profile{
PreferredLanguage: language.Georgian,
Gender: domain.GenderDiverse,
},
Email: &domain.Email{},
},
tokens: []*domain.WebAuthNToken{{
ObjectRoot: models.ObjectRoot{
AggregateID: "org1",
},
WebAuthNTokenID: "111",
State: domain.MFAStateNotReady,
Challenge: "challenge",
RPID: "rpID",
}},
},
err: nil,
},
},
{
name: "ok, preferred, u2f",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
user.NewHumanAddedEvent(context.Background(),
userAggr,
"", "", "", "", "", language.Georgian,
domain.GenderDiverse, "", true,
),
),
),
expectFilter(eventFromEventPusher(
user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush(
context.Background(), &org.NewAggregate("org1").Aggregate, user.HumanU2FTokenAddedType,
), "111", "challenge", "rpID"),
)),
),
sessionWriteModel: &SessionWriteModel{
UserID: "user1",
},
},
args: args{
domain.UserVerificationRequirementPreferred,
},
res: res{
want: &humanWebAuthNTokens{
human: &domain.Human{
ObjectRoot: models.ObjectRoot{
AggregateID: "user1",
ResourceOwner: "org1",
},
State: domain.UserStateActive,
Profile: &domain.Profile{
PreferredLanguage: language.Georgian,
Gender: domain.GenderDiverse,
},
Email: &domain.Email{},
},
tokens: []*domain.WebAuthNToken{{
ObjectRoot: models.ObjectRoot{
AggregateID: "org1",
},
WebAuthNTokenID: "111",
State: domain.MFAStateNotReady,
Challenge: "challenge",
RPID: "rpID",
}},
},
err: nil,
},
},
{
name: "ok, required, u2f",
fields: fields{
eventstore: eventstoreExpect(t,
expectFilter(
eventFromEventPusher(
user.NewHumanAddedEvent(context.Background(),
userAggr,
"", "", "", "", "", language.Georgian,
domain.GenderDiverse, "", true,
),
),
),
expectFilter(eventFromEventPusher(
user.NewHumanWebAuthNAddedEvent(eventstore.NewBaseEventForPush(
context.Background(), &org.NewAggregate("org1").Aggregate, user.HumanPasswordlessTokenAddedType,
), "111", "challenge", "rpID"),
)),
),
sessionWriteModel: &SessionWriteModel{
UserID: "user1",
},
},
args: args{
domain.UserVerificationRequirementRequired,
},
res: res{
want: &humanWebAuthNTokens{
human: &domain.Human{
ObjectRoot: models.ObjectRoot{
AggregateID: "user1",
ResourceOwner: "org1",
},
State: domain.UserStateActive,
Profile: &domain.Profile{
PreferredLanguage: language.Georgian,
Gender: domain.GenderDiverse,
},
Email: &domain.Email{},
},
tokens: []*domain.WebAuthNToken{{
ObjectRoot: models.ObjectRoot{
AggregateID: "org1",
},
WebAuthNTokenID: "111",
State: domain.MFAStateNotReady,
Challenge: "challenge",
RPID: "rpID",
}},
},
err: nil,
},
},
}
for _, tt := range tests {
s := &SessionCommands{
eventstore: tt.fields.eventstore,
sessionWriteModel: tt.fields.sessionWriteModel,
}
got, err := s.getHumanWebAuthNTokens(context.Background(), tt.args.userVerification)
require.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.want, got)
}
}

View File

@@ -112,17 +112,9 @@ func personalTokenWriteModelToToken(wm *PersonalAccessTokenWriteModel, algorithm
}, base64.RawURLEncoding.EncodeToString(encrypted), nil
}
func readModelToU2FTokens(wm *HumanU2FTokensReadModel) []*domain.WebAuthNToken {
tokens := make([]*domain.WebAuthNToken, len(wm.WebAuthNTokens))
for i, token := range wm.WebAuthNTokens {
tokens[i] = writeModelToWebAuthN(token)
}
return tokens
}
func readModelToPasswordlessTokens(wm *HumanPasswordlessTokensReadModel) []*domain.WebAuthNToken {
tokens := make([]*domain.WebAuthNToken, len(wm.WebAuthNTokens))
for i, token := range wm.WebAuthNTokens {
func readModelToWebAuthNTokens(readModel HumanWebAuthNTokensReadModel) []*domain.WebAuthNToken {
tokens := make([]*domain.WebAuthNToken, len(readModel.GetWebAuthNTokens()))
for i, token := range readModel.GetWebAuthNTokens() {
tokens[i] = writeModelToWebAuthN(token)
}
return tokens

View File

@@ -24,7 +24,7 @@ func (c *Commands) getHumanU2FTokens(ctx context.Context, userID, resourceowner
if tokenReadModel.UserState == domain.UserStateDeleted {
return nil, caos_errs.ThrowNotFound(nil, "COMMAND-4M0ds", "Errors.User.NotFound")
}
return readModelToU2FTokens(tokenReadModel), nil
return readModelToWebAuthNTokens(tokenReadModel), nil
}
func (c *Commands) getHumanPasswordlessTokens(ctx context.Context, userID, resourceOwner string) ([]*domain.WebAuthNToken, error) {
@@ -36,7 +36,7 @@ func (c *Commands) getHumanPasswordlessTokens(ctx context.Context, userID, resou
if tokenReadModel.UserState == domain.UserStateDeleted {
return nil, caos_errs.ThrowNotFound(nil, "COMMAND-Mv9sd", "Errors.User.NotFound")
}
return readModelToPasswordlessTokens(tokenReadModel), nil
return readModelToWebAuthNTokens(tokenReadModel), nil
}
func (c *Commands) getHumanU2FLogin(ctx context.Context, userID, authReqID, resourceowner string) (*domain.WebAuthNLogin, error) {
@@ -454,12 +454,12 @@ func (c *Commands) finishWebAuthNLogin(ctx context.Context, userID, resourceOwne
if err != nil {
return nil, nil, 0, err
}
keyID, signCount, err := c.webauthnConfig.FinishLogin(ctx, human, webAuthN, credentialData, tokens...)
if err != nil && keyID == nil {
credential, err := c.webauthnConfig.FinishLogin(ctx, human, webAuthN, credentialData, tokens...)
if err != nil && (credential == nil || credential.ID == nil) {
return nil, nil, 0, err
}
_, token := domain.GetTokenByKeyID(tokens, keyID)
_, token := domain.GetTokenByKeyID(tokens, credential.ID)
if token == nil {
return nil, nil, 0, caos_errs.ThrowPreconditionFailed(nil, "COMMAND-3b7zs", "Errors.User.WebAuthN.NotFound")
}
@@ -470,7 +470,7 @@ func (c *Commands) finishWebAuthNLogin(ctx context.Context, userID, resourceOwne
}
userAgg := UserAggregateFromWriteModel(&writeModel.WriteModel)
return userAgg, token, signCount, nil
return userAgg, token, credential.Authenticator.SignCount, nil
}
func (c *Commands) HumanRemoveU2F(ctx context.Context, userID, webAuthNID, resourceOwner string) (*domain.ObjectDetails, error) {

View File

@@ -146,6 +146,12 @@ func (wm *HumanWebAuthNWriteModel) Query() *eventstore.SearchQueryBuilder {
Builder()
}
type HumanWebAuthNTokensReadModel interface {
eventstore.QueryReducer
GetWebAuthNTokens() []*HumanWebAuthNWriteModel
WebAuthNTokenByID(id string) (int, *HumanWebAuthNWriteModel)
}
type HumanU2FTokensReadModel struct {
eventstore.WriteModel
@@ -220,6 +226,10 @@ func (rm *HumanU2FTokensReadModel) Query() *eventstore.SearchQueryBuilder {
}
func (wm *HumanU2FTokensReadModel) GetWebAuthNTokens() []*HumanWebAuthNWriteModel {
return wm.WebAuthNTokens
}
func (wm *HumanU2FTokensReadModel) WebAuthNTokenByID(id string) (idx int, token *HumanWebAuthNWriteModel) {
for idx, token = range wm.WebAuthNTokens {
if token.WebauthNTokenID == id {
@@ -303,6 +313,10 @@ func (rm *HumanPasswordlessTokensReadModel) Query() *eventstore.SearchQueryBuild
}
func (wm *HumanPasswordlessTokensReadModel) GetWebAuthNTokens() []*HumanWebAuthNWriteModel {
return wm.WebAuthNTokens
}
func (wm *HumanPasswordlessTokensReadModel) WebAuthNTokenByID(id string) (idx int, token *HumanWebAuthNWriteModel) {
for idx, token = range wm.WebAuthNTokens {
if token.WebauthNTokenID == id {