mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:27:42 +00:00
feat(api): add otp (sms and email) checks in session api (#6422)
* feat: add otp (sms and email) checks in session api * implement sending * fix tests * add tests * add integration tests * fix merge main and add tests * put default OTP Email url into config --------- Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
@@ -34,8 +34,9 @@ import (
|
||||
type Commands struct {
|
||||
httpClient *http.Client
|
||||
|
||||
checkPermission domain.PermissionCheck
|
||||
newCode cryptoCodeFunc
|
||||
checkPermission domain.PermissionCheck
|
||||
newCode cryptoCodeFunc
|
||||
newCodeWithDefault cryptoCodeWithDefaultFunc
|
||||
|
||||
eventstore *eventstore.Eventstore
|
||||
static static.Storage
|
||||
@@ -122,6 +123,7 @@ func StartCommands(
|
||||
httpClient: httpClient,
|
||||
checkPermission: permissionCheck,
|
||||
newCode: newCryptoCode,
|
||||
newCodeWithDefault: newCryptoCodeWithDefaultConfig,
|
||||
sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg),
|
||||
sessionTokenVerifier: sessionTokenVerifier,
|
||||
defaultAccessTokenLifetime: defaultAccessTokenLifetime,
|
||||
|
@@ -12,6 +12,10 @@ import (
|
||||
|
||||
type cryptoCodeFunc func(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto) (*CryptoCode, error)
|
||||
|
||||
type cryptoCodeWithDefaultFunc func(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto, defaultConfig *crypto.GeneratorConfig) (*CryptoCode, error)
|
||||
|
||||
var emptyConfig = &crypto.GeneratorConfig{}
|
||||
|
||||
type CryptoCode struct {
|
||||
Crypted *crypto.CryptoValue
|
||||
Plain string
|
||||
@@ -19,7 +23,11 @@ type CryptoCode struct {
|
||||
}
|
||||
|
||||
func newCryptoCode(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto) (*CryptoCode, error) {
|
||||
gen, config, err := secretGenerator(ctx, filter, typ, alg)
|
||||
return newCryptoCodeWithDefaultConfig(ctx, filter, typ, alg, emptyConfig)
|
||||
}
|
||||
|
||||
func newCryptoCodeWithDefaultConfig(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto, defaultConfig *crypto.GeneratorConfig) (*CryptoCode, error) {
|
||||
gen, config, err := secretGenerator(ctx, filter, typ, alg, defaultConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -35,15 +43,15 @@ func newCryptoCode(ctx context.Context, filter preparation.FilterToQueryReducer,
|
||||
}
|
||||
|
||||
func verifyCryptoCode(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto, creation time.Time, expiry time.Duration, crypted *crypto.CryptoValue, plain string) error {
|
||||
gen, _, err := secretGenerator(ctx, filter, typ, alg)
|
||||
gen, _, err := secretGenerator(ctx, filter, typ, alg, emptyConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return crypto.VerifyCode(creation, expiry, crypted, plain, gen)
|
||||
}
|
||||
|
||||
func secretGenerator(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto) (crypto.Generator, *crypto.GeneratorConfig, error) {
|
||||
config, err := secretGeneratorConfig(ctx, filter, typ)
|
||||
func secretGenerator(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto, defaultConfig *crypto.GeneratorConfig) (crypto.Generator, *crypto.GeneratorConfig, error) {
|
||||
config, err := secretGeneratorConfigWithDefault(ctx, filter, typ, defaultConfig)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -58,26 +66,10 @@ func secretGenerator(ctx context.Context, filter preparation.FilterToQueryReduce
|
||||
}
|
||||
|
||||
func secretGeneratorConfig(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType) (*crypto.GeneratorConfig, error) {
|
||||
wm := NewInstanceSecretGeneratorConfigWriteModel(ctx, typ)
|
||||
events, err := filter(ctx, wm.Query())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wm.AppendEvents(events...)
|
||||
if err := wm.Reduce(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &crypto.GeneratorConfig{
|
||||
Length: wm.Length,
|
||||
Expiry: wm.Expiry,
|
||||
IncludeLowerLetters: wm.IncludeLowerLetters,
|
||||
IncludeUpperLetters: wm.IncludeUpperLetters,
|
||||
IncludeDigits: wm.IncludeDigits,
|
||||
IncludeSymbols: wm.IncludeSymbols,
|
||||
}, nil
|
||||
return secretGeneratorConfigWithDefault(ctx, filter, typ, emptyConfig)
|
||||
}
|
||||
|
||||
func secretGeneratorConfigWithDefault(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, defaultGenerator *crypto.GeneratorConfig) (*crypto.GeneratorConfig, error) {
|
||||
func secretGeneratorConfigWithDefault(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, defaultConfig *crypto.GeneratorConfig) (*crypto.GeneratorConfig, error) {
|
||||
wm := NewInstanceSecretGeneratorConfigWriteModel(ctx, typ)
|
||||
events, err := filter(ctx, wm.Query())
|
||||
if err != nil {
|
||||
@@ -88,7 +80,7 @@ func secretGeneratorConfigWithDefault(ctx context.Context, filter preparation.Fi
|
||||
return nil, err
|
||||
}
|
||||
if wm.State != domain.SecretGeneratorStateActive {
|
||||
return defaultGenerator, nil
|
||||
return defaultConfig, nil
|
||||
}
|
||||
return &crypto.GeneratorConfig{
|
||||
Length: wm.Length,
|
||||
|
@@ -33,6 +33,21 @@ func mockCode(code string, exp time.Duration) cryptoCodeFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func mockCodeWithDefault(code string, exp time.Duration) cryptoCodeWithDefaultFunc {
|
||||
return func(ctx context.Context, filter preparation.FilterToQueryReducer, _ domain.SecretGeneratorType, alg crypto.Crypto, _ *crypto.GeneratorConfig) (*CryptoCode, error) {
|
||||
return &CryptoCode{
|
||||
Crypted: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte(code),
|
||||
},
|
||||
Plain: code,
|
||||
Expiry: exp,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
testGeneratorConfig = crypto.GeneratorConfig{
|
||||
Length: 12,
|
||||
@@ -175,8 +190,9 @@ func Test_verifyCryptoCode(t *testing.T) {
|
||||
|
||||
func Test_secretGenerator(t *testing.T) {
|
||||
type args struct {
|
||||
typ domain.SecretGeneratorType
|
||||
alg crypto.Crypto
|
||||
typ domain.SecretGeneratorType
|
||||
alg crypto.Crypto
|
||||
defaultConfig *crypto.GeneratorConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -190,8 +206,9 @@ func Test_secretGenerator(t *testing.T) {
|
||||
name: "filter config error",
|
||||
eventsore: eventstoreExpect(t, expectFilterError(io.ErrClosedPipe)),
|
||||
args: args{
|
||||
typ: domain.SecretGeneratorTypeVerifyEmailCode,
|
||||
alg: crypto.CreateMockHashAlg(gomock.NewController(t)),
|
||||
typ: domain.SecretGeneratorTypeVerifyEmailCode,
|
||||
alg: crypto.CreateMockHashAlg(gomock.NewController(t)),
|
||||
defaultConfig: emptyConfig,
|
||||
},
|
||||
wantErr: io.ErrClosedPipe,
|
||||
},
|
||||
@@ -201,8 +218,9 @@ func Test_secretGenerator(t *testing.T) {
|
||||
eventFromEventPusher(testSecretGeneratorAddedEvent(domain.SecretGeneratorTypeVerifyEmailCode)),
|
||||
)),
|
||||
args: args{
|
||||
typ: domain.SecretGeneratorTypeVerifyEmailCode,
|
||||
alg: crypto.CreateMockHashAlg(gomock.NewController(t)),
|
||||
typ: domain.SecretGeneratorTypeVerifyEmailCode,
|
||||
alg: crypto.CreateMockHashAlg(gomock.NewController(t)),
|
||||
defaultConfig: emptyConfig,
|
||||
},
|
||||
want: crypto.NewHashGenerator(testGeneratorConfig, crypto.CreateMockHashAlg(gomock.NewController(t))),
|
||||
wantConf: &testGeneratorConfig,
|
||||
@@ -213,8 +231,31 @@ func Test_secretGenerator(t *testing.T) {
|
||||
eventFromEventPusher(testSecretGeneratorAddedEvent(domain.SecretGeneratorTypeVerifyEmailCode)),
|
||||
)),
|
||||
args: args{
|
||||
typ: domain.SecretGeneratorTypeVerifyEmailCode,
|
||||
alg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
typ: domain.SecretGeneratorTypeVerifyEmailCode,
|
||||
alg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
defaultConfig: emptyConfig,
|
||||
},
|
||||
want: crypto.NewEncryptionGenerator(testGeneratorConfig, crypto.CreateMockEncryptionAlg(gomock.NewController(t))),
|
||||
wantConf: &testGeneratorConfig,
|
||||
},
|
||||
{
|
||||
name: "hash generator with default config",
|
||||
eventsore: eventstoreExpect(t, expectFilter()),
|
||||
args: args{
|
||||
typ: domain.SecretGeneratorTypeVerifyEmailCode,
|
||||
alg: crypto.CreateMockHashAlg(gomock.NewController(t)),
|
||||
defaultConfig: &testGeneratorConfig,
|
||||
},
|
||||
want: crypto.NewHashGenerator(testGeneratorConfig, crypto.CreateMockHashAlg(gomock.NewController(t))),
|
||||
wantConf: &testGeneratorConfig,
|
||||
},
|
||||
{
|
||||
name: "encryption generator with default config",
|
||||
eventsore: eventstoreExpect(t, expectFilter()),
|
||||
args: args{
|
||||
typ: domain.SecretGeneratorTypeVerifyEmailCode,
|
||||
alg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
defaultConfig: &testGeneratorConfig,
|
||||
},
|
||||
want: crypto.NewEncryptionGenerator(testGeneratorConfig, crypto.CreateMockEncryptionAlg(gomock.NewController(t))),
|
||||
wantConf: &testGeneratorConfig,
|
||||
@@ -225,15 +266,16 @@ func Test_secretGenerator(t *testing.T) {
|
||||
eventFromEventPusher(testSecretGeneratorAddedEvent(domain.SecretGeneratorTypeVerifyEmailCode)),
|
||||
)),
|
||||
args: args{
|
||||
typ: domain.SecretGeneratorTypeVerifyEmailCode,
|
||||
alg: nil,
|
||||
typ: domain.SecretGeneratorTypeVerifyEmailCode,
|
||||
alg: nil,
|
||||
defaultConfig: emptyConfig,
|
||||
},
|
||||
wantErr: errors.ThrowInternalf(nil, "COMMA-RreV6", "Errors.Internal unsupported crypto algorithm type %T", nil),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, gotConf, err := secretGenerator(context.Background(), tt.eventsore.Filter, tt.args.typ, tt.args.alg)
|
||||
got, gotConf, err := secretGenerator(context.Background(), tt.eventsore.Filter, tt.args.typ, tt.args.alg, tt.args.defaultConfig)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.IsType(t, tt.want, got)
|
||||
assert.Equal(t, tt.wantConf, gotConf)
|
||||
|
@@ -33,6 +33,8 @@ type SessionCommands struct {
|
||||
hasher *crypto.PasswordHasher
|
||||
intentAlg crypto.EncryptionAlgorithm
|
||||
totpAlg crypto.EncryptionAlgorithm
|
||||
otpAlg crypto.EncryptionAlgorithm
|
||||
createCode cryptoCodeWithDefaultFunc
|
||||
createToken func(sessionID string) (id string, token string, err error)
|
||||
now func() time.Time
|
||||
}
|
||||
@@ -45,6 +47,8 @@ func (c *Commands) NewSessionCommands(cmds []SessionCommand, session *SessionWri
|
||||
hasher: c.userPasswordHasher,
|
||||
intentAlg: c.idpConfigEncryption,
|
||||
totpAlg: c.multifactors.OTP.CryptoMFA,
|
||||
otpAlg: c.userEncryption,
|
||||
createCode: c.newCodeWithDefault,
|
||||
createToken: c.sessionTokenCreator,
|
||||
now: time.Now,
|
||||
}
|
||||
@@ -204,6 +208,22 @@ func (s *SessionCommands) TOTPChecked(ctx context.Context, checkedAt time.Time)
|
||||
s.eventCommands = append(s.eventCommands, session.NewTOTPCheckedEvent(ctx, s.sessionWriteModel.aggregate, checkedAt))
|
||||
}
|
||||
|
||||
func (s *SessionCommands) OTPSMSChallenged(ctx context.Context, code *crypto.CryptoValue, expiry time.Duration, returnCode bool) {
|
||||
s.eventCommands = append(s.eventCommands, session.NewOTPSMSChallengedEvent(ctx, s.sessionWriteModel.aggregate, code, expiry, returnCode))
|
||||
}
|
||||
|
||||
func (s *SessionCommands) OTPSMSChecked(ctx context.Context, checkedAt time.Time) {
|
||||
s.eventCommands = append(s.eventCommands, session.NewOTPSMSCheckedEvent(ctx, s.sessionWriteModel.aggregate, checkedAt))
|
||||
}
|
||||
|
||||
func (s *SessionCommands) OTPEmailChallenged(ctx context.Context, code *crypto.CryptoValue, expiry time.Duration, returnCode bool, urlTmpl string) {
|
||||
s.eventCommands = append(s.eventCommands, session.NewOTPEmailChallengedEvent(ctx, s.sessionWriteModel.aggregate, code, expiry, returnCode, urlTmpl))
|
||||
}
|
||||
|
||||
func (s *SessionCommands) OTPEmailChecked(ctx context.Context, checkedAt time.Time) {
|
||||
s.eventCommands = append(s.eventCommands, session.NewOTPEmailCheckedEvent(ctx, s.sessionWriteModel.aggregate, checkedAt))
|
||||
}
|
||||
|
||||
func (s *SessionCommands) SetToken(ctx context.Context, tokenID string) {
|
||||
s.eventCommands = append(s.eventCommands, session.NewTokenSetEvent(ctx, s.sessionWriteModel.aggregate, tokenID))
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@ package command
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
@@ -15,6 +16,12 @@ type WebAuthNChallengeModel struct {
|
||||
RPID string
|
||||
}
|
||||
|
||||
type OTPCode struct {
|
||||
Code *crypto.CryptoValue
|
||||
Expiry time.Duration
|
||||
CreationDate time.Time
|
||||
}
|
||||
|
||||
func (p *WebAuthNChallengeModel) WebAuthNLogin(human *domain.Human, credentialAssertionData []byte) *domain.WebAuthNLogin {
|
||||
return &domain.WebAuthNLogin{
|
||||
ObjectRoot: human.ObjectRoot,
|
||||
@@ -36,11 +43,15 @@ type SessionWriteModel struct {
|
||||
IntentCheckedAt time.Time
|
||||
WebAuthNCheckedAt time.Time
|
||||
TOTPCheckedAt time.Time
|
||||
OTPSMSCheckedAt time.Time
|
||||
OTPEmailCheckedAt time.Time
|
||||
WebAuthNUserVerified bool
|
||||
Metadata map[string][]byte
|
||||
State domain.SessionState
|
||||
|
||||
WebAuthNChallenge *WebAuthNChallengeModel
|
||||
WebAuthNChallenge *WebAuthNChallengeModel
|
||||
OTPSMSCodeChallenge *OTPCode
|
||||
OTPEmailCodeChallenge *OTPCode
|
||||
|
||||
aggregate *eventstore.Aggregate
|
||||
}
|
||||
@@ -73,6 +84,14 @@ func (wm *SessionWriteModel) Reduce() error {
|
||||
wm.reduceWebAuthNChecked(e)
|
||||
case *session.TOTPCheckedEvent:
|
||||
wm.reduceTOTPChecked(e)
|
||||
case *session.OTPSMSChallengedEvent:
|
||||
wm.reduceOTPSMSChallenged(e)
|
||||
case *session.OTPSMSCheckedEvent:
|
||||
wm.reduceOTPSMSChecked(e)
|
||||
case *session.OTPEmailChallengedEvent:
|
||||
wm.reduceOTPEmailChallenged(e)
|
||||
case *session.OTPEmailCheckedEvent:
|
||||
wm.reduceOTPEmailChecked(e)
|
||||
case *session.TokenSetEvent:
|
||||
wm.reduceTokenSet(e)
|
||||
case *session.TerminateEvent:
|
||||
@@ -95,6 +114,10 @@ func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
session.WebAuthNChallengedType,
|
||||
session.WebAuthNCheckedType,
|
||||
session.TOTPCheckedType,
|
||||
session.OTPSMSChallengedType,
|
||||
session.OTPSMSCheckedType,
|
||||
session.OTPEmailChallengedType,
|
||||
session.OTPEmailCheckedType,
|
||||
session.TokenSetType,
|
||||
session.MetadataSetType,
|
||||
session.TerminateType,
|
||||
@@ -143,6 +166,32 @@ func (wm *SessionWriteModel) reduceTOTPChecked(e *session.TOTPCheckedEvent) {
|
||||
wm.TOTPCheckedAt = e.CheckedAt
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) reduceOTPSMSChallenged(e *session.OTPSMSChallengedEvent) {
|
||||
wm.OTPSMSCodeChallenge = &OTPCode{
|
||||
Code: e.Code,
|
||||
Expiry: e.Expiry,
|
||||
CreationDate: e.CreationDate(),
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) reduceOTPSMSChecked(e *session.OTPSMSCheckedEvent) {
|
||||
wm.OTPSMSCodeChallenge = nil
|
||||
wm.OTPSMSCheckedAt = e.CheckedAt
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) reduceOTPEmailChallenged(e *session.OTPEmailChallengedEvent) {
|
||||
wm.OTPEmailCodeChallenge = &OTPCode{
|
||||
Code: e.Code,
|
||||
Expiry: e.Expiry,
|
||||
CreationDate: e.CreationDate(),
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) reduceOTPEmailChecked(e *session.OTPEmailCheckedEvent) {
|
||||
wm.OTPEmailCodeChallenge = nil
|
||||
wm.OTPEmailCheckedAt = e.CheckedAt
|
||||
}
|
||||
|
||||
func (wm *SessionWriteModel) reduceTokenSet(e *session.TokenSetEvent) {
|
||||
wm.TokenID = e.TokenID
|
||||
}
|
||||
@@ -159,7 +208,8 @@ func (wm *SessionWriteModel) AuthenticationTime() time.Time {
|
||||
wm.WebAuthNCheckedAt,
|
||||
wm.TOTPCheckedAt,
|
||||
wm.IntentCheckedAt,
|
||||
// TODO: add OTP (sms and email) check https://github.com/zitadel/zitadel/issues/6224
|
||||
wm.OTPSMSCheckedAt,
|
||||
wm.OTPEmailCheckedAt,
|
||||
} {
|
||||
if check.After(authTime) {
|
||||
authTime = check
|
||||
@@ -187,14 +237,11 @@ func (wm *SessionWriteModel) AuthMethodTypes() []domain.UserAuthMethodType {
|
||||
if !wm.TOTPCheckedAt.IsZero() {
|
||||
types = append(types, domain.UserAuthMethodTypeTOTP)
|
||||
}
|
||||
// TODO: add checks with https://github.com/zitadel/zitadel/issues/6224
|
||||
/*
|
||||
if !wm.TOTPFactor.OTPSMSCheckedAt.IsZero() {
|
||||
types = append(types, domain.UserAuthMethodTypeOTPSMS)
|
||||
}
|
||||
if !wm.TOTPFactor.OTPEmailCheckedAt.IsZero() {
|
||||
types = append(types, domain.UserAuthMethodTypeOTPEmail)
|
||||
}
|
||||
*/
|
||||
if !wm.OTPSMSCheckedAt.IsZero() {
|
||||
types = append(types, domain.UserAuthMethodTypeOTPSMS)
|
||||
}
|
||||
if !wm.OTPEmailCheckedAt.IsZero() {
|
||||
types = append(types, domain.UserAuthMethodTypeOTPEmail)
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
148
internal/command/session_otp.go
Normal file
148
internal/command/session_otp.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
)
|
||||
|
||||
func (c *Commands) CreateOTPSMSChallengeReturnCode(dst *string) SessionCommand {
|
||||
return c.createOTPSMSChallenge(true, dst)
|
||||
}
|
||||
|
||||
func (c *Commands) CreateOTPSMSChallenge() SessionCommand {
|
||||
return c.createOTPSMSChallenge(false, nil)
|
||||
}
|
||||
|
||||
func (c *Commands) createOTPSMSChallenge(returnCode bool, dst *string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) error {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing")
|
||||
}
|
||||
writeModel := NewHumanOTPSMSWriteModel(cmd.sessionWriteModel.UserID, "")
|
||||
if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil {
|
||||
return err
|
||||
}
|
||||
if !writeModel.OTPAdded() {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady")
|
||||
}
|
||||
code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPSMS, cmd.otpAlg, c.defaultSecretGenerators.OTPSMS)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if returnCode {
|
||||
*dst = code.Plain
|
||||
}
|
||||
cmd.OTPSMSChallenged(ctx, code.Crypted, code.Expiry, returnCode)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Commands) OTPSMSSent(ctx context.Context, sessionID, resourceOwner string) error {
|
||||
sessionWriteModel := NewSessionWriteModel(sessionID, resourceOwner)
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sessionWriteModel.OTPSMSCodeChallenge == nil {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-G3t31", "Errors.User.Code.NotFound")
|
||||
}
|
||||
return c.pushAppendAndReduce(ctx, sessionWriteModel,
|
||||
session.NewOTPSMSSentEvent(ctx, &session.NewAggregate(sessionID, sessionWriteModel.ResourceOwner).Aggregate),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *Commands) CreateOTPEmailChallengeURLTemplate(urlTmpl string) (SessionCommand, error) {
|
||||
if err := domain.RenderOTPEmailURLTemplate(io.Discard, urlTmpl, "code", "userID", "loginName", "displayName", language.English); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.createOTPEmailChallenge(false, urlTmpl, nil), nil
|
||||
}
|
||||
|
||||
func (c *Commands) CreateOTPEmailChallengeReturnCode(dst *string) SessionCommand {
|
||||
return c.createOTPEmailChallenge(true, "", dst)
|
||||
}
|
||||
|
||||
func (c *Commands) CreateOTPEmailChallenge() SessionCommand {
|
||||
return c.createOTPEmailChallenge(false, "", nil)
|
||||
}
|
||||
|
||||
func (c *Commands) createOTPEmailChallenge(returnCode bool, urlTmpl string, dst *string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) error {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing")
|
||||
}
|
||||
writeModel := NewHumanOTPEmailWriteModel(cmd.sessionWriteModel.UserID, "")
|
||||
if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil {
|
||||
return err
|
||||
}
|
||||
if !writeModel.OTPAdded() {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady")
|
||||
}
|
||||
code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPEmail, cmd.otpAlg, c.defaultSecretGenerators.OTPEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if returnCode {
|
||||
*dst = code.Plain
|
||||
}
|
||||
cmd.OTPEmailChallenged(ctx, code.Crypted, code.Expiry, returnCode, urlTmpl)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Commands) OTPEmailSent(ctx context.Context, sessionID, resourceOwner string) error {
|
||||
sessionWriteModel := NewSessionWriteModel(sessionID, resourceOwner)
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sessionWriteModel.OTPEmailCodeChallenge == nil {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SLr02", "Errors.User.Code.NotFound")
|
||||
}
|
||||
return c.pushAppendAndReduce(ctx, sessionWriteModel,
|
||||
session.NewOTPEmailSentEvent(ctx, &session.NewAggregate(sessionID, sessionWriteModel.ResourceOwner).Aggregate),
|
||||
)
|
||||
}
|
||||
|
||||
func CheckOTPSMS(code string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) (err error) {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-VDrh3", "Errors.User.UserIDMissing")
|
||||
}
|
||||
challenge := cmd.sessionWriteModel.OTPSMSCodeChallenge
|
||||
if challenge == nil {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SF3tv", "Errors.User.Code.NotFound")
|
||||
}
|
||||
err = crypto.VerifyCodeWithAlgorithm(challenge.CreationDate, challenge.Expiry, challenge.Code, code, cmd.otpAlg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.OTPSMSChecked(ctx, cmd.now())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func CheckOTPEmail(code string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) (err error) {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-ejo2w", "Errors.User.UserIDMissing")
|
||||
}
|
||||
challenge := cmd.sessionWriteModel.OTPEmailCodeChallenge
|
||||
if challenge == nil {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-zF3g3", "Errors.User.Code.NotFound")
|
||||
}
|
||||
err = crypto.VerifyCodeWithAlgorithm(challenge.CreationDate, challenge.Expiry, challenge.Code, code, cmd.otpAlg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.OTPEmailChecked(ctx, cmd.now())
|
||||
return nil
|
||||
}
|
||||
}
|
951
internal/command/session_otp_test.go
Normal file
951
internal/command/session_otp_test.go
Normal file
@@ -0,0 +1,951 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
)
|
||||
|
||||
func TestCommands_CreateOTPSMSChallengeReturnCode(t *testing.T) {
|
||||
type fields struct {
|
||||
userID string
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
createCode cryptoCodeWithDefaultFunc
|
||||
}
|
||||
type res struct {
|
||||
err error
|
||||
returnCode string
|
||||
commands []eventstore.Command
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "userID missing, precondition error",
|
||||
fields: fields{
|
||||
userID: "",
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "otp not ready, precondition error",
|
||||
fields: fields{
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "generate code",
|
||||
fields: fields{
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org").Aggregate),
|
||||
),
|
||||
),
|
||||
),
|
||||
createCode: mockCodeWithDefault("1234567", 5*time.Minute),
|
||||
},
|
||||
res: res{
|
||||
returnCode: "1234567",
|
||||
commands: []eventstore.Command{
|
||||
session.NewOTPSMSChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("1234567"),
|
||||
},
|
||||
5*time.Minute,
|
||||
true,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
// config will not be actively used for the test (is only for default),
|
||||
// but not providing it would result in a nil pointer
|
||||
defaultSecretGenerators: &SecretGenerators{
|
||||
OTPSMS: emptyConfig,
|
||||
},
|
||||
}
|
||||
var dst string
|
||||
cmd := c.CreateOTPSMSChallengeReturnCode(&dst)
|
||||
|
||||
sessionModel := &SessionWriteModel{
|
||||
UserID: tt.fields.userID,
|
||||
UserCheckedAt: testNow,
|
||||
State: domain.SessionStateActive,
|
||||
aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
}
|
||||
cmds := &SessionCommands{
|
||||
sessionCommands: []SessionCommand{cmd},
|
||||
sessionWriteModel: sessionModel,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
createCode: tt.fields.createCode,
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Equal(t, tt.res.returnCode, dst)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_CreateOTPSMSChallenge(t *testing.T) {
|
||||
type fields struct {
|
||||
userID string
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
createCode cryptoCodeWithDefaultFunc
|
||||
}
|
||||
type res struct {
|
||||
err error
|
||||
commands []eventstore.Command
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "userID missing, precondition error",
|
||||
fields: fields{
|
||||
userID: "",
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "otp not ready, precondition error",
|
||||
fields: fields{
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "generate code",
|
||||
fields: fields{
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org").Aggregate),
|
||||
),
|
||||
),
|
||||
),
|
||||
createCode: mockCodeWithDefault("1234567", 5*time.Minute),
|
||||
},
|
||||
res: res{
|
||||
commands: []eventstore.Command{
|
||||
session.NewOTPSMSChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("1234567"),
|
||||
},
|
||||
5*time.Minute,
|
||||
false,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
// config will not be actively used for the test (is only for default),
|
||||
// but not providing it would result in a nil pointer
|
||||
defaultSecretGenerators: &SecretGenerators{
|
||||
OTPSMS: emptyConfig,
|
||||
},
|
||||
}
|
||||
|
||||
cmd := c.CreateOTPSMSChallenge()
|
||||
|
||||
sessionModel := &SessionWriteModel{
|
||||
UserID: tt.fields.userID,
|
||||
UserCheckedAt: testNow,
|
||||
State: domain.SessionStateActive,
|
||||
aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
}
|
||||
cmds := &SessionCommands{
|
||||
sessionCommands: []SessionCommand{cmd},
|
||||
sessionWriteModel: sessionModel,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
createCode: tt.fields.createCode,
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_OTPSMSSent(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
sessionID string
|
||||
resourceOwner string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "not challenged, precondition error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
sessionID: "sessionID",
|
||||
resourceOwner: "instanceID",
|
||||
},
|
||||
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-G3t31", "Errors.User.Code.NotFound"),
|
||||
},
|
||||
{
|
||||
name: "challenged and sent",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewOTPSMSChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("1234567"),
|
||||
},
|
||||
5*time.Minute,
|
||||
false,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
eventPusherToEvents(
|
||||
session.NewOTPSMSSentEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
sessionID: "sessionID",
|
||||
resourceOwner: "instanceID",
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
err := c.OTPSMSSent(tt.args.ctx, tt.args.sessionID, tt.args.resourceOwner)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_CreateOTPEmailChallengeURLTemplate(t *testing.T) {
|
||||
type fields struct {
|
||||
userID string
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
createCode cryptoCodeWithDefaultFunc
|
||||
}
|
||||
type args struct {
|
||||
urlTmpl string
|
||||
}
|
||||
type res struct {
|
||||
templateError error
|
||||
err error
|
||||
commands []eventstore.Command
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "invalid template, precondition error",
|
||||
args: args{
|
||||
urlTmpl: "https://example.com/mfa/email?userID={{.UserID}}&code={{.InvalidField}}",
|
||||
},
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
res: res{
|
||||
templateError: caos_errs.ThrowInvalidArgument(nil, "DOMAIN-ieYa7", "Errors.User.InvalidURLTemplate"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "userID missing, precondition error",
|
||||
args: args{
|
||||
urlTmpl: "https://example.com/mfa/email?userID={{.UserID}}&code={{.Code}}&lang={{.PreferredLanguage}}",
|
||||
},
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "otp not ready, precondition error",
|
||||
args: args{
|
||||
urlTmpl: "https://example.com/mfa/email?userID={{.UserID}}&code={{.Code}}&lang={{.PreferredLanguage}}",
|
||||
},
|
||||
fields: fields{
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "generate code",
|
||||
args: args{
|
||||
urlTmpl: "https://example.com/mfa/email?userID={{.UserID}}&code={{.Code}}&lang={{.PreferredLanguage}}",
|
||||
},
|
||||
fields: fields{
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org").Aggregate),
|
||||
),
|
||||
),
|
||||
),
|
||||
createCode: mockCodeWithDefault("1234567", 5*time.Minute),
|
||||
},
|
||||
res: res{
|
||||
commands: []eventstore.Command{
|
||||
session.NewOTPEmailChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("1234567"),
|
||||
},
|
||||
5*time.Minute,
|
||||
false,
|
||||
"https://example.com/mfa/email?userID={{.UserID}}&code={{.Code}}&lang={{.PreferredLanguage}}",
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
// config will not be actively used for the test (is only for default),
|
||||
// but not providing it would result in a nil pointer
|
||||
defaultSecretGenerators: &SecretGenerators{
|
||||
OTPEmail: emptyConfig,
|
||||
},
|
||||
}
|
||||
|
||||
cmd, err := c.CreateOTPEmailChallengeURLTemplate(tt.args.urlTmpl)
|
||||
assert.ErrorIs(t, err, tt.res.templateError)
|
||||
if tt.res.templateError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
sessionModel := &SessionWriteModel{
|
||||
UserID: tt.fields.userID,
|
||||
UserCheckedAt: testNow,
|
||||
State: domain.SessionStateActive,
|
||||
aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
}
|
||||
cmds := &SessionCommands{
|
||||
sessionCommands: []SessionCommand{cmd},
|
||||
sessionWriteModel: sessionModel,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
createCode: tt.fields.createCode,
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
err = cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_CreateOTPEmailChallengeReturnCode(t *testing.T) {
|
||||
type fields struct {
|
||||
userID string
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
createCode cryptoCodeWithDefaultFunc
|
||||
}
|
||||
type res struct {
|
||||
err error
|
||||
returnCode string
|
||||
commands []eventstore.Command
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "userID missing, precondition error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "otp not ready, precondition error",
|
||||
fields: fields{
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "generate code",
|
||||
fields: fields{
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org").Aggregate),
|
||||
),
|
||||
),
|
||||
),
|
||||
createCode: mockCodeWithDefault("1234567", 5*time.Minute),
|
||||
},
|
||||
res: res{
|
||||
returnCode: "1234567",
|
||||
commands: []eventstore.Command{
|
||||
session.NewOTPEmailChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("1234567"),
|
||||
},
|
||||
5*time.Minute,
|
||||
true,
|
||||
"",
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
// config will not be actively used for the test (is only for default),
|
||||
// but not providing it would result in a nil pointer
|
||||
defaultSecretGenerators: &SecretGenerators{
|
||||
OTPEmail: emptyConfig,
|
||||
},
|
||||
}
|
||||
var dst string
|
||||
cmd := c.CreateOTPEmailChallengeReturnCode(&dst)
|
||||
|
||||
sessionModel := &SessionWriteModel{
|
||||
UserID: tt.fields.userID,
|
||||
UserCheckedAt: testNow,
|
||||
State: domain.SessionStateActive,
|
||||
aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
}
|
||||
cmds := &SessionCommands{
|
||||
sessionCommands: []SessionCommand{cmd},
|
||||
sessionWriteModel: sessionModel,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
createCode: tt.fields.createCode,
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Equal(t, tt.res.returnCode, dst)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_CreateOTPEmailChallenge(t *testing.T) {
|
||||
type fields struct {
|
||||
userID string
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
createCode cryptoCodeWithDefaultFunc
|
||||
}
|
||||
type res struct {
|
||||
err error
|
||||
commands []eventstore.Command
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "userID missing, precondition error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "otp not ready, precondition error",
|
||||
fields: fields{
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "generate code",
|
||||
fields: fields{
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org").Aggregate),
|
||||
),
|
||||
),
|
||||
),
|
||||
createCode: mockCodeWithDefault("1234567", 5*time.Minute),
|
||||
},
|
||||
res: res{
|
||||
commands: []eventstore.Command{
|
||||
session.NewOTPEmailChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("1234567"),
|
||||
},
|
||||
5*time.Minute,
|
||||
false,
|
||||
"",
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
// config will not be actively used for the test (is only for default),
|
||||
// but not providing it would result in a nil pointer
|
||||
defaultSecretGenerators: &SecretGenerators{
|
||||
OTPEmail: emptyConfig,
|
||||
},
|
||||
}
|
||||
|
||||
cmd := c.CreateOTPEmailChallenge()
|
||||
|
||||
sessionModel := &SessionWriteModel{
|
||||
UserID: tt.fields.userID,
|
||||
UserCheckedAt: testNow,
|
||||
State: domain.SessionStateActive,
|
||||
aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
}
|
||||
cmds := &SessionCommands{
|
||||
sessionCommands: []SessionCommand{cmd},
|
||||
sessionWriteModel: sessionModel,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
createCode: tt.fields.createCode,
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_OTPEmailSent(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
sessionID string
|
||||
resourceOwner string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "not challenged, precondition error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
sessionID: "sessionID",
|
||||
resourceOwner: "instanceID",
|
||||
},
|
||||
wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SLr02", "Errors.User.Code.NotFound"),
|
||||
},
|
||||
{
|
||||
name: "challenged and sent",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewOTPEmailChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("1234567"),
|
||||
},
|
||||
5*time.Minute,
|
||||
false,
|
||||
"",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
eventPusherToEvents(
|
||||
session.NewOTPEmailSentEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
sessionID: "sessionID",
|
||||
resourceOwner: "instanceID",
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
err := c.OTPEmailSent(tt.args.ctx, tt.args.sessionID, tt.args.resourceOwner)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckOTPSMS(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
userID string
|
||||
otpCodeChallenge *OTPCode
|
||||
otpAlg crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
code string
|
||||
}
|
||||
type res struct {
|
||||
err error
|
||||
commands []eventstore.Command
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "missing userID",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "",
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-VDrh3", "Errors.User.UserIDMissing"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing challenge",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: nil,
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SF3tv", "Errors.User.Code.NotFound"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid code",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
Expiry: 5 * time.Minute,
|
||||
CreationDate: testNow.Add(-10 * time.Minute),
|
||||
},
|
||||
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "check ok",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
Expiry: 5 * time.Minute,
|
||||
CreationDate: testNow,
|
||||
},
|
||||
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
commands: []eventstore.Command{
|
||||
session.NewOTPSMSCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
testNow,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd := CheckOTPSMS(tt.args.code)
|
||||
|
||||
sessionModel := &SessionWriteModel{
|
||||
UserID: tt.fields.userID,
|
||||
UserCheckedAt: testNow,
|
||||
State: domain.SessionStateActive,
|
||||
OTPSMSCodeChallenge: tt.fields.otpCodeChallenge,
|
||||
aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
}
|
||||
cmds := &SessionCommands{
|
||||
sessionCommands: []SessionCommand{cmd},
|
||||
sessionWriteModel: sessionModel,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
otpAlg: tt.fields.otpAlg,
|
||||
now: func() time.Time {
|
||||
return testNow
|
||||
},
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckOTPEmail(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
userID string
|
||||
otpCodeChallenge *OTPCode
|
||||
otpAlg crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
code string
|
||||
}
|
||||
type res struct {
|
||||
err error
|
||||
commands []eventstore.Command
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "missing userID",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "",
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-ejo2w", "Errors.User.UserIDMissing"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing challenge",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: nil,
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-zF3g3", "Errors.User.Code.NotFound"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid code",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
Expiry: 5 * time.Minute,
|
||||
CreationDate: testNow.Add(-10 * time.Minute),
|
||||
},
|
||||
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: caos_errs.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "check ok",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
Expiry: 5 * time.Minute,
|
||||
CreationDate: testNow,
|
||||
},
|
||||
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
commands: []eventstore.Command{
|
||||
session.NewOTPEmailCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
testNow,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd := CheckOTPEmail(tt.args.code)
|
||||
|
||||
sessionModel := &SessionWriteModel{
|
||||
UserID: tt.fields.userID,
|
||||
UserCheckedAt: testNow,
|
||||
State: domain.SessionStateActive,
|
||||
OTPEmailCodeChallenge: tt.fields.otpCodeChallenge,
|
||||
aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
}
|
||||
cmds := &SessionCommands{
|
||||
sessionCommands: []SessionCommand{cmd},
|
||||
sessionWriteModel: sessionModel,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
otpAlg: tt.fields.otpAlg,
|
||||
now: func() time.Time {
|
||||
return testNow
|
||||
},
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
}
|
@@ -310,7 +310,6 @@ func (c *Commands) HumanCheckOTPSMS(ctx context.Context, userID, code, resourceO
|
||||
resourceOwner,
|
||||
authRequest,
|
||||
writeModel,
|
||||
domain.SecretGeneratorTypeOTPSMS,
|
||||
succeededEvent,
|
||||
failedEvent,
|
||||
)
|
||||
@@ -431,7 +430,6 @@ func (c *Commands) HumanCheckOTPEmail(ctx context.Context, userID, code, resourc
|
||||
resourceOwner,
|
||||
authRequest,
|
||||
writeModel,
|
||||
domain.SecretGeneratorTypeOTPEmail,
|
||||
succeededEvent,
|
||||
failedEvent,
|
||||
)
|
||||
@@ -497,7 +495,6 @@ func (c *Commands) humanCheckOTP(
|
||||
userID, code, resourceOwner string,
|
||||
authRequest *domain.AuthRequest,
|
||||
writeModelByID func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error),
|
||||
secretGeneratorType domain.SecretGeneratorType,
|
||||
checkSucceededEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command,
|
||||
checkFailedEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command,
|
||||
) error {
|
||||
|
Reference in New Issue
Block a user