mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:27:42 +00:00
feat(api/v2): implement TOTP session check (#6362)
* feat(api/v2): implement TOTP session check * add integration test * correct typo in projection test * fix event type typos --------- Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
@@ -26,11 +26,13 @@ type SessionCommands struct {
|
||||
sessionWriteModel *SessionWriteModel
|
||||
passwordWriteModel *HumanPasswordWriteModel
|
||||
intentWriteModel *IDPIntentWriteModel
|
||||
totpWriteModel *HumanTOTPWriteModel
|
||||
eventstore *eventstore.Eventstore
|
||||
eventCommands []eventstore.Command
|
||||
|
||||
hasher *crypto.PasswordHasher
|
||||
intentAlg crypto.EncryptionAlgorithm
|
||||
totpAlg crypto.EncryptionAlgorithm
|
||||
createToken func(sessionID string) (id string, token string, err error)
|
||||
now func() time.Time
|
||||
}
|
||||
@@ -42,6 +44,7 @@ func (c *Commands) NewSessionCommands(cmds []SessionCommand, session *SessionWri
|
||||
eventstore: c.eventstore,
|
||||
hasher: c.userPasswordHasher,
|
||||
intentAlg: c.idpConfigEncryption,
|
||||
totpAlg: c.multifactors.OTP.CryptoMFA,
|
||||
createToken: c.sessionTokenCreator,
|
||||
now: time.Now,
|
||||
}
|
||||
@@ -127,6 +130,28 @@ func CheckIntent(intentID, token string) SessionCommand {
|
||||
}
|
||||
}
|
||||
|
||||
func CheckTOTP(code string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) (err error) {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Neil7", "Errors.User.UserIDMissing")
|
||||
}
|
||||
cmd.totpWriteModel = NewHumanTOTPWriteModel(cmd.sessionWriteModel.UserID, "")
|
||||
err = cmd.eventstore.FilterToQueryReducer(ctx, cmd.totpWriteModel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cmd.totpWriteModel.State != domain.MFAStateReady {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-eej1U", "Errors.User.MFA.OTP.NotReady")
|
||||
}
|
||||
err = domain.VerifyTOTP(code, cmd.totpWriteModel.Secret, cmd.totpAlg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.TOTPChecked(ctx, cmd.now())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Exec will execute the commands specified and returns an error on the first occurrence
|
||||
func (s *SessionCommands) Exec(ctx context.Context) error {
|
||||
for _, cmd := range s.sessionCommands {
|
||||
@@ -175,6 +200,10 @@ func (s *SessionCommands) WebAuthNChecked(ctx context.Context, checkedAt time.Ti
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionCommands) TOTPChecked(ctx context.Context, checkedAt time.Time) {
|
||||
s.eventCommands = append(s.eventCommands, session.NewTOTPCheckedEvent(ctx, s.sessionWriteModel.aggregate, checkedAt))
|
||||
}
|
||||
|
||||
func (s *SessionCommands) SetToken(ctx context.Context, tokenID string) {
|
||||
s.eventCommands = append(s.eventCommands, session.NewTokenSetEvent(ctx, s.sessionWriteModel.aggregate, tokenID))
|
||||
}
|
||||
|
@@ -35,6 +35,7 @@ type SessionWriteModel struct {
|
||||
PasswordCheckedAt time.Time
|
||||
IntentCheckedAt time.Time
|
||||
WebAuthNCheckedAt time.Time
|
||||
TOTPCheckedAt time.Time
|
||||
WebAuthNUserVerified bool
|
||||
Metadata map[string][]byte
|
||||
State domain.SessionState
|
||||
@@ -70,6 +71,8 @@ func (wm *SessionWriteModel) Reduce() error {
|
||||
wm.reduceWebAuthNChallenged(e)
|
||||
case *session.WebAuthNCheckedEvent:
|
||||
wm.reduceWebAuthNChecked(e)
|
||||
case *session.TOTPCheckedEvent:
|
||||
wm.reduceTOTPChecked(e)
|
||||
case *session.TokenSetEvent:
|
||||
wm.reduceTokenSet(e)
|
||||
case *session.TerminateEvent:
|
||||
@@ -91,6 +94,7 @@ func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
session.IntentCheckedType,
|
||||
session.WebAuthNChallengedType,
|
||||
session.WebAuthNCheckedType,
|
||||
session.TOTPCheckedType,
|
||||
session.TokenSetType,
|
||||
session.MetadataSetType,
|
||||
session.TerminateType,
|
||||
@@ -135,6 +139,10 @@ func (wm *SessionWriteModel) reduceWebAuthNChecked(e *session.WebAuthNCheckedEve
|
||||
wm.WebAuthNUserVerified = e.UserVerified
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) reduceTOTPChecked(e *session.TOTPCheckedEvent) {
|
||||
wm.TOTPCheckedAt = e.CheckedAt
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) reduceTokenSet(e *session.TokenSetEvent) {
|
||||
wm.TokenID = e.TokenID
|
||||
}
|
||||
@@ -149,8 +157,8 @@ func (wm *SessionWriteModel) AuthenticationTime() time.Time {
|
||||
for _, check := range []time.Time{
|
||||
wm.PasswordCheckedAt,
|
||||
wm.WebAuthNCheckedAt,
|
||||
wm.TOTPCheckedAt,
|
||||
wm.IntentCheckedAt,
|
||||
// TODO: add OTP check https://github.com/zitadel/zitadel/issues/5477
|
||||
// TODO: add OTP (sms and email) check https://github.com/zitadel/zitadel/issues/6224
|
||||
} {
|
||||
if check.After(authTime) {
|
||||
@@ -176,12 +184,9 @@ func (wm *SessionWriteModel) AuthMethodTypes() []domain.UserAuthMethodType {
|
||||
if !wm.IntentCheckedAt.IsZero() {
|
||||
types = append(types, domain.UserAuthMethodTypeIDP)
|
||||
}
|
||||
// TODO: add checks with https://github.com/zitadel/zitadel/issues/5477
|
||||
/*
|
||||
if !wm.TOTPCheckedAt.IsZero() {
|
||||
types = append(types, domain.UserAuthMethodTypeTOTP)
|
||||
}
|
||||
*/
|
||||
if !wm.TOTPCheckedAt.IsZero() {
|
||||
types = append(types, domain.UserAuthMethodTypeTOTP)
|
||||
}
|
||||
// TODO: add checks with https://github.com/zitadel/zitadel/issues/6224
|
||||
/*
|
||||
if !wm.TOTPFactor.OTPSMSCheckedAt.IsZero() {
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/language"
|
||||
@@ -695,6 +696,138 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckTOTP(t *testing.T) {
|
||||
ctx := authz.NewMockContext("", "org1", "user1")
|
||||
|
||||
cryptoAlg := crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
key, secret, err := domain.NewTOTPKey("example.com", "user1", cryptoAlg)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessAgg := &session.NewAggregate("session1", "org1").Aggregate
|
||||
userAgg := &user.NewAggregate("user1", "org1").Aggregate
|
||||
|
||||
code, err := totp.GenerateCode(key.Secret(), testNow)
|
||||
require.NoError(t, err)
|
||||
|
||||
type fields struct {
|
||||
sessionWriteModel *SessionWriteModel
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
fields fields
|
||||
wantEventCommands []eventstore.Command
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "missing userID",
|
||||
code: code,
|
||||
fields: fields{
|
||||
sessionWriteModel: &SessionWriteModel{
|
||||
aggregate: sessAgg,
|
||||
},
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Neil7", "Errors.User.UserIDMissing"),
|
||||
},
|
||||
{
|
||||
name: "filter error",
|
||||
code: code,
|
||||
fields: fields{
|
||||
sessionWriteModel: &SessionWriteModel{
|
||||
UserID: "user1",
|
||||
UserCheckedAt: testNow,
|
||||
aggregate: sessAgg,
|
||||
},
|
||||
eventstore: expectEventstore(
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
},
|
||||
wantErr: io.ErrClosedPipe,
|
||||
},
|
||||
{
|
||||
name: "otp not ready error",
|
||||
code: code,
|
||||
fields: fields{
|
||||
sessionWriteModel: &SessionWriteModel{
|
||||
UserID: "user1",
|
||||
UserCheckedAt: testNow,
|
||||
aggregate: sessAgg,
|
||||
},
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPAddedEvent(ctx, userAgg, secret),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-eej1U", "Errors.User.MFA.OTP.NotReady"),
|
||||
},
|
||||
{
|
||||
name: "otp verify error",
|
||||
code: "foobar",
|
||||
fields: fields{
|
||||
sessionWriteModel: &SessionWriteModel{
|
||||
UserID: "user1",
|
||||
UserCheckedAt: testNow,
|
||||
aggregate: sessAgg,
|
||||
},
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPAddedEvent(ctx, userAgg, secret),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
wantErr: caos_errs.ThrowInvalidArgument(nil, "EVENT-8isk2", "Errors.User.MFA.OTP.InvalidCode"),
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
code: code,
|
||||
fields: fields{
|
||||
sessionWriteModel: &SessionWriteModel{
|
||||
UserID: "user1",
|
||||
UserCheckedAt: testNow,
|
||||
aggregate: sessAgg,
|
||||
},
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPAddedEvent(ctx, userAgg, secret),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
wantEventCommands: []eventstore.Command{
|
||||
session.NewTOTPCheckedEvent(ctx, sessAgg, testNow),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd := &SessionCommands{
|
||||
sessionWriteModel: tt.fields.sessionWriteModel,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
totpAlg: cryptoAlg,
|
||||
now: func() time.Time { return testNow },
|
||||
}
|
||||
err := CheckTOTP(tt.code)(ctx, cmd)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantEventCommands, cmd.eventCommands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_TerminateSession(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
|
Reference in New Issue
Block a user