diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index 7f90aca5c9..0fa789baf1 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -307,6 +307,7 @@ Login: MaxAge: 12h # ZITADEL_LOGIN_CACHE_MAXAGE # 168h is 7 days, one week SharedMaxAge: 168h # ZITADEL_LOGIN_CACHE_SHAREDMAXAGE + DefaultOTPEmailURLV2: "/otp/verify?loginName={{.LoginName}}&code={{.Code}}" # ZITADEL_LOGIN_CACHE_DEFAULTOTPEMAILURLV2 Console: ShortCache: diff --git a/cmd/start/start.go b/cmd/start/start.go index 1191e19b21..245928d919 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -222,7 +222,25 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error actionsLogstoreSvc := logstore.New(queries, usageReporter, actionsExecutionDBEmitter, actionsExecutionStdoutEmitter) actions.SetLogstoreService(actionsLogstoreSvc) - notification.Start(ctx, config.Projections.Customizations["notifications"], config.Projections.Customizations["notificationsquotas"], config.Projections.Customizations["telemetry"], *config.Telemetry, config.ExternalDomain, config.ExternalPort, config.ExternalSecure, commands, queries, eventstoreClient, assets.AssetAPIFromDomain(config.ExternalSecure, config.ExternalPort), config.SystemDefaults.Notifications.FileSystemPath, keys.User, keys.SMTP, keys.SMS) + notification.Start( + ctx, + config.Projections.Customizations["notifications"], + config.Projections.Customizations["notificationsquotas"], + config.Projections.Customizations["telemetry"], + *config.Telemetry, + config.ExternalDomain, + config.ExternalPort, + config.ExternalSecure, + commands, + queries, + eventstoreClient, + assets.AssetAPIFromDomain(config.ExternalSecure, config.ExternalPort), + config.Login.DefaultOTPEmailURLV2, + config.SystemDefaults.Notifications.FileSystemPath, + keys.User, + keys.SMTP, + keys.SMS, + ) router := mux.NewRouter() tlsConfig, err := config.TLS.Config() diff --git a/internal/api/grpc/session/v2/session.go b/internal/api/grpc/session/v2/session.go index b1a443a0d5..fea5cd62b1 100644 --- a/internal/api/grpc/session/v2/session.go +++ b/internal/api/grpc/session/v2/session.go @@ -45,7 +45,10 @@ func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRe if err != nil { return nil, err } - challengeResponse, cmds := s.challengesToCommand(req.GetChallenges(), checks) + challengeResponse, cmds, err := s.challengesToCommand(req.GetChallenges(), checks) + if err != nil { + return nil, err + } set, err := s.command.CreateSession(ctx, cmds, metadata) if err != nil { @@ -64,7 +67,10 @@ func (s *Server) SetSession(ctx context.Context, req *session.SetSessionRequest) if err != nil { return nil, err } - challengeResponse, cmds := s.challengesToCommand(req.GetChallenges(), checks) + challengeResponse, cmds, err := s.challengesToCommand(req.GetChallenges(), checks) + if err != nil { + return nil, err + } set, err := s.command.UpdateSession(ctx, req.GetSessionId(), req.GetSessionToken(), cmds, req.GetMetadata()) if err != nil { @@ -121,6 +127,8 @@ func factorsToPb(s *query.Session) *session.Factors { WebAuthN: webAuthNFactorToPb(s.WebAuthNFactor), Intent: intentFactorToPb(s.IntentFactor), Totp: totpFactorToPb(s.TOTPFactor), + OtpSms: otpFactorToPb(s.OTPSMSFactor), + OtpEmail: otpFactorToPb(s.OTPEmailFactor), } } @@ -161,6 +169,15 @@ func totpFactorToPb(factor query.SessionTOTPFactor) *session.TOTPFactor { } } +func otpFactorToPb(factor query.SessionOTPFactor) *session.OTPFactor { + if factor.OTPCheckedAt.IsZero() { + return nil + } + return &session.OTPFactor{ + VerifiedAt: timestamppb.New(factor.OTPCheckedAt), + } +} + func userFactorToPb(factor query.SessionUserFactor) *session.UserFactor { if factor.UserID == "" || factor.UserCheckedAt.IsZero() { return nil @@ -240,7 +257,7 @@ func (s *Server) checksToCommand(ctx context.Context, checks *session.Checks) ([ if err != nil { return nil, err } - sessionChecks := make([]command.SessionCommand, 0, 3) + sessionChecks := make([]command.SessionCommand, 0, 7) if checkUser != nil { user, err := checkUser.search(ctx, s.query) if err != nil { @@ -260,12 +277,18 @@ func (s *Server) checksToCommand(ctx context.Context, checks *session.Checks) ([ if totp := checks.GetTotp(); totp != nil { sessionChecks = append(sessionChecks, command.CheckTOTP(totp.GetTotp())) } + if otp := checks.GetOtpSms(); otp != nil { + sessionChecks = append(sessionChecks, command.CheckOTPSMS(otp.GetOtp())) + } + if otp := checks.GetOtpEmail(); otp != nil { + sessionChecks = append(sessionChecks, command.CheckOTPEmail(otp.GetOtp())) + } return sessionChecks, nil } -func (s *Server) challengesToCommand(challenges *session.RequestChallenges, cmds []command.SessionCommand) (*session.Challenges, []command.SessionCommand) { +func (s *Server) challengesToCommand(challenges *session.RequestChallenges, cmds []command.SessionCommand) (*session.Challenges, []command.SessionCommand, error) { if challenges == nil { - return nil, cmds + return nil, cmds, nil } resp := new(session.Challenges) if req := challenges.GetWebAuthN(); req != nil { @@ -273,7 +296,20 @@ func (s *Server) challengesToCommand(challenges *session.RequestChallenges, cmds resp.WebAuthN = challenge cmds = append(cmds, cmd) } - return resp, cmds + if req := challenges.GetOtpSms(); req != nil { + challenge, cmd := s.createOTPSMSChallengeCommand(req) + resp.OtpSms = challenge + cmds = append(cmds, cmd) + } + if req := challenges.GetOtpEmail(); req != nil { + challenge, cmd, err := s.createOTPEmailChallengeCommand(req) + if err != nil { + return nil, nil, err + } + resp.OtpEmail = challenge + cmds = append(cmds, cmd) + } + return resp, cmds, nil } func (s *Server) createWebAuthNChallengeCommand(req *session.RequestChallenges_WebAuthN) (*session.Challenges_WebAuthN, command.SessionCommand) { @@ -299,6 +335,34 @@ func userVerificationRequirementToDomain(req session.UserVerificationRequirement } } +func (s *Server) createOTPSMSChallengeCommand(req *session.RequestChallenges_OTPSMS) (*string, command.SessionCommand) { + if req.GetReturnCode() { + challenge := new(string) + return challenge, s.command.CreateOTPSMSChallengeReturnCode(challenge) + } + + return nil, s.command.CreateOTPSMSChallenge() + +} + +func (s *Server) createOTPEmailChallengeCommand(req *session.RequestChallenges_OTPEmail) (*string, command.SessionCommand, error) { + switch t := req.GetDeliveryType().(type) { + case *session.RequestChallenges_OTPEmail_SendCode_: + cmd, err := s.command.CreateOTPEmailChallengeURLTemplate(t.SendCode.GetUrlTemplate()) + if err != nil { + return nil, nil, err + } + return nil, cmd, nil + case *session.RequestChallenges_OTPEmail_ReturnCode_: + challenge := new(string) + return challenge, s.command.CreateOTPEmailChallengeReturnCode(challenge), nil + case nil: + return nil, s.command.CreateOTPEmailChallenge(), nil + default: + return nil, nil, caos_errs.ThrowUnimplementedf(nil, "SESSION-k3ng0", "delivery_type oneOf %T in OTPEmailChallenge not implemented", t) + } +} + func userCheck(user *session.CheckUser) (userSearch, error) { if user == nil { return nil, nil diff --git a/internal/api/grpc/session/v2/session_integration_test.go b/internal/api/grpc/session/v2/session_integration_test.go index 6c322bc281..2e6b3de46f 100644 --- a/internal/api/grpc/session/v2/session_integration_test.go +++ b/internal/api/grpc/session/v2/session_integration_test.go @@ -39,6 +39,14 @@ func TestMain(m *testing.M) { CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx User = Tester.CreateHumanUser(CTX) + Tester.Client.UserV2.VerifyEmail(CTX, &user.VerifyEmailRequest{ + UserId: User.GetUserId(), + VerificationCode: User.GetEmailCode(), + }) + Tester.Client.UserV2.VerifyPhone(CTX, &user.VerifyPhoneRequest{ + UserId: User.GetUserId(), + VerificationCode: User.GetPhoneCode(), + }) Tester.SetUserPassword(CTX, User.GetUserId(), integration.UserPassword) Tester.RegisterUserPasskey(CTX, User.GetUserId()) return m.Run() @@ -75,6 +83,8 @@ const ( wantWebAuthNFactorUserVerified wantTOTPFactor wantIntentFactor + wantOTPSMSFactor + wantOTPEmailFactor ) func verifyFactors(t testing.TB, factors *session.Factors, window time.Duration, want []wantFactor) { @@ -107,6 +117,14 @@ func verifyFactors(t testing.TB, factors *session.Factors, window time.Duration, pf := factors.GetIntent() assert.NotNil(t, pf) assert.WithinRange(t, pf.GetVerifiedAt().AsTime(), time.Now().Add(-window), time.Now().Add(window)) + case wantOTPSMSFactor: + pf := factors.GetOtpSms() + assert.NotNil(t, pf) + assert.WithinRange(t, pf.GetVerifiedAt().AsTime(), time.Now().Add(-window), time.Now().Add(window)) + case wantOTPEmailFactor: + pf := factors.GetOtpEmail() + assert.NotNil(t, pf) + assert.WithinRange(t, pf.GetVerifiedAt().AsTime(), time.Now().Add(-window), time.Now().Add(window)) } } } @@ -362,6 +380,20 @@ func registerTOTP(ctx context.Context, t *testing.T, userID string) (secret stri return secret } +func registerOTPSMS(ctx context.Context, t *testing.T, userID string) { + _, err := Tester.Client.UserV2.AddOTPSMS(ctx, &user.AddOTPSMSRequest{ + UserId: userID, + }) + require.NoError(t, err) +} + +func registerOTPEmail(ctx context.Context, t *testing.T, userID string) { + _, err := Tester.Client.UserV2.AddOTPEmail(ctx, &user.AddOTPEmailRequest{ + UserId: userID, + }) + require.NoError(t, err) +} + func TestServer_SetSession_flow(t *testing.T) { // create new, empty session createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) @@ -421,6 +453,8 @@ func TestServer_SetSession_flow(t *testing.T) { userAuthCtx := Tester.WithAuthorizationToken(CTX, sessionToken) Tester.RegisterUserU2F(userAuthCtx, User.GetUserId()) totpSecret := registerTOTP(userAuthCtx, t, User.GetUserId()) + registerOTPSMS(userAuthCtx, t, User.GetUserId()) + registerOTPEmail(userAuthCtx, t, User.GetUserId()) t.Run("check webauthn, user not verified (U2F)", func(t *testing.T) { @@ -478,6 +512,66 @@ func TestServer_SetSession_flow(t *testing.T) { sessionToken = resp.GetSessionToken() verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor, wantWebAuthNFactor, wantTOTPFactor) }) + + t.Run("check OTP SMS", func(t *testing.T) { + resp, err := Client.SetSession(CTX, &session.SetSessionRequest{ + SessionId: createResp.GetSessionId(), + SessionToken: sessionToken, + Challenges: &session.RequestChallenges{ + OtpSms: &session.RequestChallenges_OTPSMS{ReturnCode: true}, + }, + }) + require.NoError(t, err) + verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil) + sessionToken = resp.GetSessionToken() + + otp := resp.GetChallenges().GetOtpSms() + require.NotEmpty(t, otp) + + resp, err = Client.SetSession(CTX, &session.SetSessionRequest{ + SessionId: createResp.GetSessionId(), + SessionToken: sessionToken, + Checks: &session.Checks{ + OtpSms: &session.CheckOTP{ + Otp: otp, + }, + }, + }) + require.NoError(t, err) + sessionToken = resp.GetSessionToken() + verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor, wantWebAuthNFactor, wantOTPSMSFactor) + }) + + t.Run("check OTP Email", func(t *testing.T) { + resp, err := Client.SetSession(CTX, &session.SetSessionRequest{ + SessionId: createResp.GetSessionId(), + SessionToken: sessionToken, + Challenges: &session.RequestChallenges{ + OtpEmail: &session.RequestChallenges_OTPEmail{ + DeliveryType: &session.RequestChallenges_OTPEmail_ReturnCode_{}, + }, + }, + }) + require.NoError(t, err) + verifyCurrentSession(t, createResp.GetSessionId(), resp.GetSessionToken(), resp.GetDetails().GetSequence(), time.Minute, nil) + sessionToken = resp.GetSessionToken() + + otp := resp.GetChallenges().GetOtpEmail() + require.NotEmpty(t, otp) + + resp, err = Client.SetSession(CTX, &session.SetSessionRequest{ + SessionId: createResp.GetSessionId(), + SessionToken: sessionToken, + Checks: &session.Checks{ + OtpEmail: &session.CheckOTP{ + Otp: otp, + }, + }, + }) + require.NoError(t, err) + sessionToken = resp.GetSessionToken() + verifyCurrentSession(t, createResp.GetSessionId(), sessionToken, resp.GetDetails().GetSequence(), time.Minute, nil, wantUserFactor, wantWebAuthNFactor, wantOTPEmailFactor) + }) } func Test_ZITADEL_API_missing_authentication(t *testing.T) { diff --git a/internal/api/ui/login/login.go b/internal/api/ui/login/login.go index 1a76ad1780..64a463f77f 100644 --- a/internal/api/ui/login/login.go +++ b/internal/api/ui/login/login.go @@ -47,6 +47,9 @@ type Config struct { CSRFCookieName string Cache middleware.CacheConfig AssetCache middleware.CacheConfig + + // LoginV2 + DefaultOTPEmailURLV2 string } const ( diff --git a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go index dcdd62017b..861fd06529 100644 --- a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go +++ b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go @@ -206,15 +206,12 @@ func authMethodsFromSession(session *query.Session) []domain.UserAuthMethodType types = append(types, domain.UserAuthMethodTypeTOTP) } */ - // TODO: add checks with https://github.com/zitadel/zitadel/issues/6224 - /* - if !session.TOTPFactor.OTPSMSCheckedAt.IsZero() { - types = append(types, domain.UserAuthMethodTypeOTPSMS) - } - if !session.TOTPFactor.OTPEmailCheckedAt.IsZero() { - types = append(types, domain.UserAuthMethodTypeOTPEmail) - } - */ + if !session.OTPSMSFactor.OTPCheckedAt.IsZero() { + types = append(types, domain.UserAuthMethodTypeOTPSMS) + } + if !session.OTPEmailFactor.OTPCheckedAt.IsZero() { + types = append(types, domain.UserAuthMethodTypeOTPEmail) + } return types } diff --git a/internal/command/command.go b/internal/command/command.go index 5149652739..aa8c1f3e77 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -34,8 +34,9 @@ import ( type Commands struct { httpClient *http.Client - checkPermission domain.PermissionCheck - newCode cryptoCodeFunc + checkPermission domain.PermissionCheck + newCode cryptoCodeFunc + newCodeWithDefault cryptoCodeWithDefaultFunc eventstore *eventstore.Eventstore static static.Storage @@ -122,6 +123,7 @@ func StartCommands( httpClient: httpClient, checkPermission: permissionCheck, newCode: newCryptoCode, + newCodeWithDefault: newCryptoCodeWithDefaultConfig, sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg), sessionTokenVerifier: sessionTokenVerifier, defaultAccessTokenLifetime: defaultAccessTokenLifetime, diff --git a/internal/command/crypto.go b/internal/command/crypto.go index af9407e0a3..0f7fe11ce5 100644 --- a/internal/command/crypto.go +++ b/internal/command/crypto.go @@ -12,6 +12,10 @@ import ( type cryptoCodeFunc func(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto) (*CryptoCode, error) +type cryptoCodeWithDefaultFunc func(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto, defaultConfig *crypto.GeneratorConfig) (*CryptoCode, error) + +var emptyConfig = &crypto.GeneratorConfig{} + type CryptoCode struct { Crypted *crypto.CryptoValue Plain string @@ -19,7 +23,11 @@ type CryptoCode struct { } func newCryptoCode(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto) (*CryptoCode, error) { - gen, config, err := secretGenerator(ctx, filter, typ, alg) + return newCryptoCodeWithDefaultConfig(ctx, filter, typ, alg, emptyConfig) +} + +func newCryptoCodeWithDefaultConfig(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto, defaultConfig *crypto.GeneratorConfig) (*CryptoCode, error) { + gen, config, err := secretGenerator(ctx, filter, typ, alg, defaultConfig) if err != nil { return nil, err } @@ -35,15 +43,15 @@ func newCryptoCode(ctx context.Context, filter preparation.FilterToQueryReducer, } func verifyCryptoCode(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto, creation time.Time, expiry time.Duration, crypted *crypto.CryptoValue, plain string) error { - gen, _, err := secretGenerator(ctx, filter, typ, alg) + gen, _, err := secretGenerator(ctx, filter, typ, alg, emptyConfig) if err != nil { return err } return crypto.VerifyCode(creation, expiry, crypted, plain, gen) } -func secretGenerator(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto) (crypto.Generator, *crypto.GeneratorConfig, error) { - config, err := secretGeneratorConfig(ctx, filter, typ) +func secretGenerator(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, alg crypto.Crypto, defaultConfig *crypto.GeneratorConfig) (crypto.Generator, *crypto.GeneratorConfig, error) { + config, err := secretGeneratorConfigWithDefault(ctx, filter, typ, defaultConfig) if err != nil { return nil, nil, err } @@ -58,26 +66,10 @@ func secretGenerator(ctx context.Context, filter preparation.FilterToQueryReduce } func secretGeneratorConfig(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType) (*crypto.GeneratorConfig, error) { - wm := NewInstanceSecretGeneratorConfigWriteModel(ctx, typ) - events, err := filter(ctx, wm.Query()) - if err != nil { - return nil, err - } - wm.AppendEvents(events...) - if err := wm.Reduce(); err != nil { - return nil, err - } - return &crypto.GeneratorConfig{ - Length: wm.Length, - Expiry: wm.Expiry, - IncludeLowerLetters: wm.IncludeLowerLetters, - IncludeUpperLetters: wm.IncludeUpperLetters, - IncludeDigits: wm.IncludeDigits, - IncludeSymbols: wm.IncludeSymbols, - }, nil + return secretGeneratorConfigWithDefault(ctx, filter, typ, emptyConfig) } -func secretGeneratorConfigWithDefault(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, defaultGenerator *crypto.GeneratorConfig) (*crypto.GeneratorConfig, error) { +func secretGeneratorConfigWithDefault(ctx context.Context, filter preparation.FilterToQueryReducer, typ domain.SecretGeneratorType, defaultConfig *crypto.GeneratorConfig) (*crypto.GeneratorConfig, error) { wm := NewInstanceSecretGeneratorConfigWriteModel(ctx, typ) events, err := filter(ctx, wm.Query()) if err != nil { @@ -88,7 +80,7 @@ func secretGeneratorConfigWithDefault(ctx context.Context, filter preparation.Fi return nil, err } if wm.State != domain.SecretGeneratorStateActive { - return defaultGenerator, nil + return defaultConfig, nil } return &crypto.GeneratorConfig{ Length: wm.Length, diff --git a/internal/command/crypto_test.go b/internal/command/crypto_test.go index 66c2c63c5c..33dca9ec37 100644 --- a/internal/command/crypto_test.go +++ b/internal/command/crypto_test.go @@ -33,6 +33,21 @@ func mockCode(code string, exp time.Duration) cryptoCodeFunc { } } +func mockCodeWithDefault(code string, exp time.Duration) cryptoCodeWithDefaultFunc { + return func(ctx context.Context, filter preparation.FilterToQueryReducer, _ domain.SecretGeneratorType, alg crypto.Crypto, _ *crypto.GeneratorConfig) (*CryptoCode, error) { + return &CryptoCode{ + Crypted: &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte(code), + }, + Plain: code, + Expiry: exp, + }, nil + } +} + var ( testGeneratorConfig = crypto.GeneratorConfig{ Length: 12, @@ -175,8 +190,9 @@ func Test_verifyCryptoCode(t *testing.T) { func Test_secretGenerator(t *testing.T) { type args struct { - typ domain.SecretGeneratorType - alg crypto.Crypto + typ domain.SecretGeneratorType + alg crypto.Crypto + defaultConfig *crypto.GeneratorConfig } tests := []struct { name string @@ -190,8 +206,9 @@ func Test_secretGenerator(t *testing.T) { name: "filter config error", eventsore: eventstoreExpect(t, expectFilterError(io.ErrClosedPipe)), args: args{ - typ: domain.SecretGeneratorTypeVerifyEmailCode, - alg: crypto.CreateMockHashAlg(gomock.NewController(t)), + typ: domain.SecretGeneratorTypeVerifyEmailCode, + alg: crypto.CreateMockHashAlg(gomock.NewController(t)), + defaultConfig: emptyConfig, }, wantErr: io.ErrClosedPipe, }, @@ -201,8 +218,9 @@ func Test_secretGenerator(t *testing.T) { eventFromEventPusher(testSecretGeneratorAddedEvent(domain.SecretGeneratorTypeVerifyEmailCode)), )), args: args{ - typ: domain.SecretGeneratorTypeVerifyEmailCode, - alg: crypto.CreateMockHashAlg(gomock.NewController(t)), + typ: domain.SecretGeneratorTypeVerifyEmailCode, + alg: crypto.CreateMockHashAlg(gomock.NewController(t)), + defaultConfig: emptyConfig, }, want: crypto.NewHashGenerator(testGeneratorConfig, crypto.CreateMockHashAlg(gomock.NewController(t))), wantConf: &testGeneratorConfig, @@ -213,8 +231,31 @@ func Test_secretGenerator(t *testing.T) { eventFromEventPusher(testSecretGeneratorAddedEvent(domain.SecretGeneratorTypeVerifyEmailCode)), )), args: args{ - typ: domain.SecretGeneratorTypeVerifyEmailCode, - alg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + typ: domain.SecretGeneratorTypeVerifyEmailCode, + alg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + defaultConfig: emptyConfig, + }, + want: crypto.NewEncryptionGenerator(testGeneratorConfig, crypto.CreateMockEncryptionAlg(gomock.NewController(t))), + wantConf: &testGeneratorConfig, + }, + { + name: "hash generator with default config", + eventsore: eventstoreExpect(t, expectFilter()), + args: args{ + typ: domain.SecretGeneratorTypeVerifyEmailCode, + alg: crypto.CreateMockHashAlg(gomock.NewController(t)), + defaultConfig: &testGeneratorConfig, + }, + want: crypto.NewHashGenerator(testGeneratorConfig, crypto.CreateMockHashAlg(gomock.NewController(t))), + wantConf: &testGeneratorConfig, + }, + { + name: "encryption generator with default config", + eventsore: eventstoreExpect(t, expectFilter()), + args: args{ + typ: domain.SecretGeneratorTypeVerifyEmailCode, + alg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + defaultConfig: &testGeneratorConfig, }, want: crypto.NewEncryptionGenerator(testGeneratorConfig, crypto.CreateMockEncryptionAlg(gomock.NewController(t))), wantConf: &testGeneratorConfig, @@ -225,15 +266,16 @@ func Test_secretGenerator(t *testing.T) { eventFromEventPusher(testSecretGeneratorAddedEvent(domain.SecretGeneratorTypeVerifyEmailCode)), )), args: args{ - typ: domain.SecretGeneratorTypeVerifyEmailCode, - alg: nil, + typ: domain.SecretGeneratorTypeVerifyEmailCode, + alg: nil, + defaultConfig: emptyConfig, }, wantErr: errors.ThrowInternalf(nil, "COMMA-RreV6", "Errors.Internal unsupported crypto algorithm type %T", nil), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, gotConf, err := secretGenerator(context.Background(), tt.eventsore.Filter, tt.args.typ, tt.args.alg) + got, gotConf, err := secretGenerator(context.Background(), tt.eventsore.Filter, tt.args.typ, tt.args.alg, tt.args.defaultConfig) require.ErrorIs(t, err, tt.wantErr) assert.IsType(t, tt.want, got) assert.Equal(t, tt.wantConf, gotConf) diff --git a/internal/command/session.go b/internal/command/session.go index fead616c2f..a58c5d2d3a 100644 --- a/internal/command/session.go +++ b/internal/command/session.go @@ -33,6 +33,8 @@ type SessionCommands struct { hasher *crypto.PasswordHasher intentAlg crypto.EncryptionAlgorithm totpAlg crypto.EncryptionAlgorithm + otpAlg crypto.EncryptionAlgorithm + createCode cryptoCodeWithDefaultFunc createToken func(sessionID string) (id string, token string, err error) now func() time.Time } @@ -45,6 +47,8 @@ func (c *Commands) NewSessionCommands(cmds []SessionCommand, session *SessionWri hasher: c.userPasswordHasher, intentAlg: c.idpConfigEncryption, totpAlg: c.multifactors.OTP.CryptoMFA, + otpAlg: c.userEncryption, + createCode: c.newCodeWithDefault, createToken: c.sessionTokenCreator, now: time.Now, } @@ -204,6 +208,22 @@ func (s *SessionCommands) TOTPChecked(ctx context.Context, checkedAt time.Time) s.eventCommands = append(s.eventCommands, session.NewTOTPCheckedEvent(ctx, s.sessionWriteModel.aggregate, checkedAt)) } +func (s *SessionCommands) OTPSMSChallenged(ctx context.Context, code *crypto.CryptoValue, expiry time.Duration, returnCode bool) { + s.eventCommands = append(s.eventCommands, session.NewOTPSMSChallengedEvent(ctx, s.sessionWriteModel.aggregate, code, expiry, returnCode)) +} + +func (s *SessionCommands) OTPSMSChecked(ctx context.Context, checkedAt time.Time) { + s.eventCommands = append(s.eventCommands, session.NewOTPSMSCheckedEvent(ctx, s.sessionWriteModel.aggregate, checkedAt)) +} + +func (s *SessionCommands) OTPEmailChallenged(ctx context.Context, code *crypto.CryptoValue, expiry time.Duration, returnCode bool, urlTmpl string) { + s.eventCommands = append(s.eventCommands, session.NewOTPEmailChallengedEvent(ctx, s.sessionWriteModel.aggregate, code, expiry, returnCode, urlTmpl)) +} + +func (s *SessionCommands) OTPEmailChecked(ctx context.Context, checkedAt time.Time) { + s.eventCommands = append(s.eventCommands, session.NewOTPEmailCheckedEvent(ctx, s.sessionWriteModel.aggregate, checkedAt)) +} + func (s *SessionCommands) SetToken(ctx context.Context, tokenID string) { s.eventCommands = append(s.eventCommands, session.NewTokenSetEvent(ctx, s.sessionWriteModel.aggregate, tokenID)) } diff --git a/internal/command/session_model.go b/internal/command/session_model.go index 373e5b96b4..0ae23bac7c 100644 --- a/internal/command/session_model.go +++ b/internal/command/session_model.go @@ -3,6 +3,7 @@ package command import ( "time" + "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/repository/session" @@ -15,6 +16,12 @@ type WebAuthNChallengeModel struct { RPID string } +type OTPCode struct { + Code *crypto.CryptoValue + Expiry time.Duration + CreationDate time.Time +} + func (p *WebAuthNChallengeModel) WebAuthNLogin(human *domain.Human, credentialAssertionData []byte) *domain.WebAuthNLogin { return &domain.WebAuthNLogin{ ObjectRoot: human.ObjectRoot, @@ -36,11 +43,15 @@ type SessionWriteModel struct { IntentCheckedAt time.Time WebAuthNCheckedAt time.Time TOTPCheckedAt time.Time + OTPSMSCheckedAt time.Time + OTPEmailCheckedAt time.Time WebAuthNUserVerified bool Metadata map[string][]byte State domain.SessionState - WebAuthNChallenge *WebAuthNChallengeModel + WebAuthNChallenge *WebAuthNChallengeModel + OTPSMSCodeChallenge *OTPCode + OTPEmailCodeChallenge *OTPCode aggregate *eventstore.Aggregate } @@ -73,6 +84,14 @@ func (wm *SessionWriteModel) Reduce() error { wm.reduceWebAuthNChecked(e) case *session.TOTPCheckedEvent: wm.reduceTOTPChecked(e) + case *session.OTPSMSChallengedEvent: + wm.reduceOTPSMSChallenged(e) + case *session.OTPSMSCheckedEvent: + wm.reduceOTPSMSChecked(e) + case *session.OTPEmailChallengedEvent: + wm.reduceOTPEmailChallenged(e) + case *session.OTPEmailCheckedEvent: + wm.reduceOTPEmailChecked(e) case *session.TokenSetEvent: wm.reduceTokenSet(e) case *session.TerminateEvent: @@ -95,6 +114,10 @@ func (wm *SessionWriteModel) Query() *eventstore.SearchQueryBuilder { session.WebAuthNChallengedType, session.WebAuthNCheckedType, session.TOTPCheckedType, + session.OTPSMSChallengedType, + session.OTPSMSCheckedType, + session.OTPEmailChallengedType, + session.OTPEmailCheckedType, session.TokenSetType, session.MetadataSetType, session.TerminateType, @@ -143,6 +166,32 @@ func (wm *SessionWriteModel) reduceTOTPChecked(e *session.TOTPCheckedEvent) { wm.TOTPCheckedAt = e.CheckedAt } +func (wm *SessionWriteModel) reduceOTPSMSChallenged(e *session.OTPSMSChallengedEvent) { + wm.OTPSMSCodeChallenge = &OTPCode{ + Code: e.Code, + Expiry: e.Expiry, + CreationDate: e.CreationDate(), + } +} + +func (wm *SessionWriteModel) reduceOTPSMSChecked(e *session.OTPSMSCheckedEvent) { + wm.OTPSMSCodeChallenge = nil + wm.OTPSMSCheckedAt = e.CheckedAt +} + +func (wm *SessionWriteModel) reduceOTPEmailChallenged(e *session.OTPEmailChallengedEvent) { + wm.OTPEmailCodeChallenge = &OTPCode{ + Code: e.Code, + Expiry: e.Expiry, + CreationDate: e.CreationDate(), + } +} + +func (wm *SessionWriteModel) reduceOTPEmailChecked(e *session.OTPEmailCheckedEvent) { + wm.OTPEmailCodeChallenge = nil + wm.OTPEmailCheckedAt = e.CheckedAt +} + func (wm *SessionWriteModel) reduceTokenSet(e *session.TokenSetEvent) { wm.TokenID = e.TokenID } @@ -159,7 +208,8 @@ func (wm *SessionWriteModel) AuthenticationTime() time.Time { wm.WebAuthNCheckedAt, wm.TOTPCheckedAt, wm.IntentCheckedAt, - // TODO: add OTP (sms and email) check https://github.com/zitadel/zitadel/issues/6224 + wm.OTPSMSCheckedAt, + wm.OTPEmailCheckedAt, } { if check.After(authTime) { authTime = check @@ -187,14 +237,11 @@ func (wm *SessionWriteModel) AuthMethodTypes() []domain.UserAuthMethodType { if !wm.TOTPCheckedAt.IsZero() { types = append(types, domain.UserAuthMethodTypeTOTP) } - // TODO: add checks with https://github.com/zitadel/zitadel/issues/6224 - /* - if !wm.TOTPFactor.OTPSMSCheckedAt.IsZero() { - types = append(types, domain.UserAuthMethodTypeOTPSMS) - } - if !wm.TOTPFactor.OTPEmailCheckedAt.IsZero() { - types = append(types, domain.UserAuthMethodTypeOTPEmail) - } - */ + if !wm.OTPSMSCheckedAt.IsZero() { + types = append(types, domain.UserAuthMethodTypeOTPSMS) + } + if !wm.OTPEmailCheckedAt.IsZero() { + types = append(types, domain.UserAuthMethodTypeOTPEmail) + } return types } diff --git a/internal/command/session_otp.go b/internal/command/session_otp.go new file mode 100644 index 0000000000..eecf47f90b --- /dev/null +++ b/internal/command/session_otp.go @@ -0,0 +1,148 @@ +package command + +import ( + "context" + "io" + + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" + caos_errs "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/repository/session" +) + +func (c *Commands) CreateOTPSMSChallengeReturnCode(dst *string) SessionCommand { + return c.createOTPSMSChallenge(true, dst) +} + +func (c *Commands) CreateOTPSMSChallenge() SessionCommand { + return c.createOTPSMSChallenge(false, nil) +} + +func (c *Commands) createOTPSMSChallenge(returnCode bool, dst *string) SessionCommand { + return func(ctx context.Context, cmd *SessionCommands) error { + if cmd.sessionWriteModel.UserID == "" { + return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing") + } + writeModel := NewHumanOTPSMSWriteModel(cmd.sessionWriteModel.UserID, "") + if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil { + return err + } + if !writeModel.OTPAdded() { + return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady") + } + code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPSMS, cmd.otpAlg, c.defaultSecretGenerators.OTPSMS) + if err != nil { + return err + } + if returnCode { + *dst = code.Plain + } + cmd.OTPSMSChallenged(ctx, code.Crypted, code.Expiry, returnCode) + return nil + } +} + +func (c *Commands) OTPSMSSent(ctx context.Context, sessionID, resourceOwner string) error { + sessionWriteModel := NewSessionWriteModel(sessionID, resourceOwner) + err := c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel) + if err != nil { + return err + } + if sessionWriteModel.OTPSMSCodeChallenge == nil { + return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-G3t31", "Errors.User.Code.NotFound") + } + return c.pushAppendAndReduce(ctx, sessionWriteModel, + session.NewOTPSMSSentEvent(ctx, &session.NewAggregate(sessionID, sessionWriteModel.ResourceOwner).Aggregate), + ) +} + +func (c *Commands) CreateOTPEmailChallengeURLTemplate(urlTmpl string) (SessionCommand, error) { + if err := domain.RenderOTPEmailURLTemplate(io.Discard, urlTmpl, "code", "userID", "loginName", "displayName", language.English); err != nil { + return nil, err + } + return c.createOTPEmailChallenge(false, urlTmpl, nil), nil +} + +func (c *Commands) CreateOTPEmailChallengeReturnCode(dst *string) SessionCommand { + return c.createOTPEmailChallenge(true, "", dst) +} + +func (c *Commands) CreateOTPEmailChallenge() SessionCommand { + return c.createOTPEmailChallenge(false, "", nil) +} + +func (c *Commands) createOTPEmailChallenge(returnCode bool, urlTmpl string, dst *string) SessionCommand { + return func(ctx context.Context, cmd *SessionCommands) error { + if cmd.sessionWriteModel.UserID == "" { + return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing") + } + writeModel := NewHumanOTPEmailWriteModel(cmd.sessionWriteModel.UserID, "") + if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil { + return err + } + if !writeModel.OTPAdded() { + return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady") + } + code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPEmail, cmd.otpAlg, c.defaultSecretGenerators.OTPEmail) + if err != nil { + return err + } + if returnCode { + *dst = code.Plain + } + cmd.OTPEmailChallenged(ctx, code.Crypted, code.Expiry, returnCode, urlTmpl) + return nil + } +} + +func (c *Commands) OTPEmailSent(ctx context.Context, sessionID, resourceOwner string) error { + sessionWriteModel := NewSessionWriteModel(sessionID, resourceOwner) + err := c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel) + if err != nil { + return err + } + if sessionWriteModel.OTPEmailCodeChallenge == nil { + return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SLr02", "Errors.User.Code.NotFound") + } + return c.pushAppendAndReduce(ctx, sessionWriteModel, + session.NewOTPEmailSentEvent(ctx, &session.NewAggregate(sessionID, sessionWriteModel.ResourceOwner).Aggregate), + ) +} + +func CheckOTPSMS(code string) SessionCommand { + return func(ctx context.Context, cmd *SessionCommands) (err error) { + if cmd.sessionWriteModel.UserID == "" { + return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-VDrh3", "Errors.User.UserIDMissing") + } + challenge := cmd.sessionWriteModel.OTPSMSCodeChallenge + if challenge == nil { + return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SF3tv", "Errors.User.Code.NotFound") + } + err = crypto.VerifyCodeWithAlgorithm(challenge.CreationDate, challenge.Expiry, challenge.Code, code, cmd.otpAlg) + if err != nil { + return err + } + cmd.OTPSMSChecked(ctx, cmd.now()) + return nil + } +} + +func CheckOTPEmail(code string) SessionCommand { + return func(ctx context.Context, cmd *SessionCommands) (err error) { + if cmd.sessionWriteModel.UserID == "" { + return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-ejo2w", "Errors.User.UserIDMissing") + } + challenge := cmd.sessionWriteModel.OTPEmailCodeChallenge + if challenge == nil { + return caos_errs.ThrowPreconditionFailed(nil, "COMMAND-zF3g3", "Errors.User.Code.NotFound") + } + err = crypto.VerifyCodeWithAlgorithm(challenge.CreationDate, challenge.Expiry, challenge.Code, code, cmd.otpAlg) + if err != nil { + return err + } + cmd.OTPEmailChecked(ctx, cmd.now()) + return nil + } +} diff --git a/internal/command/session_otp_test.go b/internal/command/session_otp_test.go new file mode 100644 index 0000000000..9cf957945f --- /dev/null +++ b/internal/command/session_otp_test.go @@ -0,0 +1,951 @@ +package command + +import ( + "context" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" + caos_errs "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/repository/session" + "github.com/zitadel/zitadel/internal/repository/user" +) + +func TestCommands_CreateOTPSMSChallengeReturnCode(t *testing.T) { + type fields struct { + userID string + eventstore func(*testing.T) *eventstore.Eventstore + createCode cryptoCodeWithDefaultFunc + } + type res struct { + err error + returnCode string + commands []eventstore.Command + } + tests := []struct { + name string + fields fields + res res + }{ + { + name: "userID missing, precondition error", + fields: fields{ + userID: "", + eventstore: expectEventstore(), + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing"), + }, + }, + { + name: "otp not ready, precondition error", + fields: fields{ + userID: "userID", + eventstore: expectEventstore( + expectFilter(), + ), + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady"), + }, + }, + { + name: "generate code", + fields: fields{ + userID: "userID", + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org").Aggregate), + ), + ), + ), + createCode: mockCodeWithDefault("1234567", 5*time.Minute), + }, + res: res{ + returnCode: "1234567", + commands: []eventstore.Command{ + session.NewOTPSMSChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("1234567"), + }, + 5*time.Minute, + true, + ), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + // config will not be actively used for the test (is only for default), + // but not providing it would result in a nil pointer + defaultSecretGenerators: &SecretGenerators{ + OTPSMS: emptyConfig, + }, + } + var dst string + cmd := c.CreateOTPSMSChallengeReturnCode(&dst) + + sessionModel := &SessionWriteModel{ + UserID: tt.fields.userID, + UserCheckedAt: testNow, + State: domain.SessionStateActive, + aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate, + } + cmds := &SessionCommands{ + sessionCommands: []SessionCommand{cmd}, + sessionWriteModel: sessionModel, + eventstore: tt.fields.eventstore(t), + createCode: tt.fields.createCode, + now: time.Now, + } + + err := cmd(context.Background(), cmds) + assert.ErrorIs(t, err, tt.res.err) + assert.Equal(t, tt.res.returnCode, dst) + assert.Equal(t, tt.res.commands, cmds.eventCommands) + }) + } +} + +func TestCommands_CreateOTPSMSChallenge(t *testing.T) { + type fields struct { + userID string + eventstore func(*testing.T) *eventstore.Eventstore + createCode cryptoCodeWithDefaultFunc + } + type res struct { + err error + commands []eventstore.Command + } + tests := []struct { + name string + fields fields + res res + }{ + { + name: "userID missing, precondition error", + fields: fields{ + userID: "", + eventstore: expectEventstore(), + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing"), + }, + }, + { + name: "otp not ready, precondition error", + fields: fields{ + userID: "userID", + eventstore: expectEventstore( + expectFilter(), + ), + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady"), + }, + }, + { + name: "generate code", + fields: fields{ + userID: "userID", + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org").Aggregate), + ), + ), + ), + createCode: mockCodeWithDefault("1234567", 5*time.Minute), + }, + res: res{ + commands: []eventstore.Command{ + session.NewOTPSMSChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("1234567"), + }, + 5*time.Minute, + false, + ), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + // config will not be actively used for the test (is only for default), + // but not providing it would result in a nil pointer + defaultSecretGenerators: &SecretGenerators{ + OTPSMS: emptyConfig, + }, + } + + cmd := c.CreateOTPSMSChallenge() + + sessionModel := &SessionWriteModel{ + UserID: tt.fields.userID, + UserCheckedAt: testNow, + State: domain.SessionStateActive, + aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate, + } + cmds := &SessionCommands{ + sessionCommands: []SessionCommand{cmd}, + sessionWriteModel: sessionModel, + eventstore: tt.fields.eventstore(t), + createCode: tt.fields.createCode, + now: time.Now, + } + + err := cmd(context.Background(), cmds) + assert.ErrorIs(t, err, tt.res.err) + assert.Equal(t, tt.res.commands, cmds.eventCommands) + }) + } +} + +func TestCommands_OTPSMSSent(t *testing.T) { + type fields struct { + eventstore func(*testing.T) *eventstore.Eventstore + } + type args struct { + ctx context.Context + sessionID string + resourceOwner string + } + tests := []struct { + name string + fields fields + args args + wantErr error + }{ + { + name: "not challenged, precondition error", + fields: fields{ + eventstore: expectEventstore( + expectFilter(), + ), + }, + args: args{ + ctx: context.Background(), + sessionID: "sessionID", + resourceOwner: "instanceID", + }, + wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-G3t31", "Errors.User.Code.NotFound"), + }, + { + name: "challenged and sent", + fields: fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + session.NewOTPSMSChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("1234567"), + }, + 5*time.Minute, + false, + ), + ), + ), + expectPush( + eventPusherToEvents( + session.NewOTPSMSSentEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate), + ), + ), + ), + }, + args: args{ + ctx: context.Background(), + sessionID: "sessionID", + resourceOwner: "instanceID", + }, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore(t), + } + err := c.OTPSMSSent(tt.args.ctx, tt.args.sessionID, tt.args.resourceOwner) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCommands_CreateOTPEmailChallengeURLTemplate(t *testing.T) { + type fields struct { + userID string + eventstore func(*testing.T) *eventstore.Eventstore + createCode cryptoCodeWithDefaultFunc + } + type args struct { + urlTmpl string + } + type res struct { + templateError error + err error + commands []eventstore.Command + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + name: "invalid template, precondition error", + args: args{ + urlTmpl: "https://example.com/mfa/email?userID={{.UserID}}&code={{.InvalidField}}", + }, + fields: fields{ + eventstore: expectEventstore(), + }, + res: res{ + templateError: caos_errs.ThrowInvalidArgument(nil, "DOMAIN-ieYa7", "Errors.User.InvalidURLTemplate"), + }, + }, + { + name: "userID missing, precondition error", + args: args{ + urlTmpl: "https://example.com/mfa/email?userID={{.UserID}}&code={{.Code}}&lang={{.PreferredLanguage}}", + }, + fields: fields{ + eventstore: expectEventstore(), + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing"), + }, + }, + { + name: "otp not ready, precondition error", + args: args{ + urlTmpl: "https://example.com/mfa/email?userID={{.UserID}}&code={{.Code}}&lang={{.PreferredLanguage}}", + }, + fields: fields{ + userID: "userID", + eventstore: expectEventstore( + expectFilter(), + ), + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady"), + }, + }, + { + name: "generate code", + args: args{ + urlTmpl: "https://example.com/mfa/email?userID={{.UserID}}&code={{.Code}}&lang={{.PreferredLanguage}}", + }, + fields: fields{ + userID: "userID", + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org").Aggregate), + ), + ), + ), + createCode: mockCodeWithDefault("1234567", 5*time.Minute), + }, + res: res{ + commands: []eventstore.Command{ + session.NewOTPEmailChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("1234567"), + }, + 5*time.Minute, + false, + "https://example.com/mfa/email?userID={{.UserID}}&code={{.Code}}&lang={{.PreferredLanguage}}", + ), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + // config will not be actively used for the test (is only for default), + // but not providing it would result in a nil pointer + defaultSecretGenerators: &SecretGenerators{ + OTPEmail: emptyConfig, + }, + } + + cmd, err := c.CreateOTPEmailChallengeURLTemplate(tt.args.urlTmpl) + assert.ErrorIs(t, err, tt.res.templateError) + if tt.res.templateError != nil { + return + } + + sessionModel := &SessionWriteModel{ + UserID: tt.fields.userID, + UserCheckedAt: testNow, + State: domain.SessionStateActive, + aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate, + } + cmds := &SessionCommands{ + sessionCommands: []SessionCommand{cmd}, + sessionWriteModel: sessionModel, + eventstore: tt.fields.eventstore(t), + createCode: tt.fields.createCode, + now: time.Now, + } + + err = cmd(context.Background(), cmds) + assert.ErrorIs(t, err, tt.res.err) + assert.Equal(t, tt.res.commands, cmds.eventCommands) + }) + } +} + +func TestCommands_CreateOTPEmailChallengeReturnCode(t *testing.T) { + type fields struct { + userID string + eventstore func(*testing.T) *eventstore.Eventstore + createCode cryptoCodeWithDefaultFunc + } + type res struct { + err error + returnCode string + commands []eventstore.Command + } + tests := []struct { + name string + fields fields + res res + }{ + { + name: "userID missing, precondition error", + fields: fields{ + eventstore: expectEventstore(), + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing"), + }, + }, + { + name: "otp not ready, precondition error", + fields: fields{ + userID: "userID", + eventstore: expectEventstore( + expectFilter(), + ), + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady"), + }, + }, + { + name: "generate code", + fields: fields{ + userID: "userID", + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org").Aggregate), + ), + ), + ), + createCode: mockCodeWithDefault("1234567", 5*time.Minute), + }, + res: res{ + returnCode: "1234567", + commands: []eventstore.Command{ + session.NewOTPEmailChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("1234567"), + }, + 5*time.Minute, + true, + "", + ), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + // config will not be actively used for the test (is only for default), + // but not providing it would result in a nil pointer + defaultSecretGenerators: &SecretGenerators{ + OTPEmail: emptyConfig, + }, + } + var dst string + cmd := c.CreateOTPEmailChallengeReturnCode(&dst) + + sessionModel := &SessionWriteModel{ + UserID: tt.fields.userID, + UserCheckedAt: testNow, + State: domain.SessionStateActive, + aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate, + } + cmds := &SessionCommands{ + sessionCommands: []SessionCommand{cmd}, + sessionWriteModel: sessionModel, + eventstore: tt.fields.eventstore(t), + createCode: tt.fields.createCode, + now: time.Now, + } + + err := cmd(context.Background(), cmds) + assert.ErrorIs(t, err, tt.res.err) + assert.Equal(t, tt.res.returnCode, dst) + assert.Equal(t, tt.res.commands, cmds.eventCommands) + }) + } +} + +func TestCommands_CreateOTPEmailChallenge(t *testing.T) { + type fields struct { + userID string + eventstore func(*testing.T) *eventstore.Eventstore + createCode cryptoCodeWithDefaultFunc + } + type res struct { + err error + commands []eventstore.Command + } + tests := []struct { + name string + fields fields + res res + }{ + { + name: "userID missing, precondition error", + fields: fields{ + eventstore: expectEventstore(), + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing"), + }, + }, + { + name: "otp not ready, precondition error", + fields: fields{ + userID: "userID", + eventstore: expectEventstore( + expectFilter(), + ), + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady"), + }, + }, + { + name: "generate code", + fields: fields{ + userID: "userID", + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org").Aggregate), + ), + ), + ), + createCode: mockCodeWithDefault("1234567", 5*time.Minute), + }, + res: res{ + commands: []eventstore.Command{ + session.NewOTPEmailChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("1234567"), + }, + 5*time.Minute, + false, + "", + ), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + // config will not be actively used for the test (is only for default), + // but not providing it would result in a nil pointer + defaultSecretGenerators: &SecretGenerators{ + OTPEmail: emptyConfig, + }, + } + + cmd := c.CreateOTPEmailChallenge() + + sessionModel := &SessionWriteModel{ + UserID: tt.fields.userID, + UserCheckedAt: testNow, + State: domain.SessionStateActive, + aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate, + } + cmds := &SessionCommands{ + sessionCommands: []SessionCommand{cmd}, + sessionWriteModel: sessionModel, + eventstore: tt.fields.eventstore(t), + createCode: tt.fields.createCode, + now: time.Now, + } + + err := cmd(context.Background(), cmds) + assert.ErrorIs(t, err, tt.res.err) + assert.Equal(t, tt.res.commands, cmds.eventCommands) + }) + } +} + +func TestCommands_OTPEmailSent(t *testing.T) { + type fields struct { + eventstore func(*testing.T) *eventstore.Eventstore + } + type args struct { + ctx context.Context + sessionID string + resourceOwner string + } + tests := []struct { + name string + fields fields + args args + wantErr error + }{ + { + name: "not challenged, precondition error", + fields: fields{ + eventstore: expectEventstore( + expectFilter(), + ), + }, + args: args{ + ctx: context.Background(), + sessionID: "sessionID", + resourceOwner: "instanceID", + }, + wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SLr02", "Errors.User.Code.NotFound"), + }, + { + name: "challenged and sent", + fields: fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + session.NewOTPEmailChallengedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("1234567"), + }, + 5*time.Minute, + false, + "", + ), + ), + ), + expectPush( + eventPusherToEvents( + session.NewOTPEmailSentEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate), + ), + ), + ), + }, + args: args{ + ctx: context.Background(), + sessionID: "sessionID", + resourceOwner: "instanceID", + }, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore(t), + } + err := c.OTPEmailSent(tt.args.ctx, tt.args.sessionID, tt.args.resourceOwner) + assert.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestCheckOTPSMS(t *testing.T) { + type fields struct { + eventstore func(*testing.T) *eventstore.Eventstore + userID string + otpCodeChallenge *OTPCode + otpAlg crypto.EncryptionAlgorithm + } + type args struct { + code string + } + type res struct { + err error + commands []eventstore.Command + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + name: "missing userID", + fields: fields{ + eventstore: expectEventstore(), + userID: "", + }, + args: args{ + code: "code", + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-VDrh3", "Errors.User.UserIDMissing"), + }, + }, + { + name: "missing challenge", + fields: fields{ + eventstore: expectEventstore(), + userID: "userID", + otpCodeChallenge: nil, + }, + args: args{ + code: "code", + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SF3tv", "Errors.User.Code.NotFound"), + }, + }, + { + name: "invalid code", + fields: fields{ + eventstore: expectEventstore(), + userID: "userID", + otpCodeChallenge: &OTPCode{ + Code: &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("code"), + }, + Expiry: 5 * time.Minute, + CreationDate: testNow.Add(-10 * time.Minute), + }, + otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args: args{ + code: "code", + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"), + }, + }, + { + name: "check ok", + fields: fields{ + eventstore: expectEventstore(), + userID: "userID", + otpCodeChallenge: &OTPCode{ + Code: &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("code"), + }, + Expiry: 5 * time.Minute, + CreationDate: testNow, + }, + otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args: args{ + code: "code", + }, + res: res{ + commands: []eventstore.Command{ + session.NewOTPSMSCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + testNow, + ), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := CheckOTPSMS(tt.args.code) + + sessionModel := &SessionWriteModel{ + UserID: tt.fields.userID, + UserCheckedAt: testNow, + State: domain.SessionStateActive, + OTPSMSCodeChallenge: tt.fields.otpCodeChallenge, + aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate, + } + cmds := &SessionCommands{ + sessionCommands: []SessionCommand{cmd}, + sessionWriteModel: sessionModel, + eventstore: tt.fields.eventstore(t), + otpAlg: tt.fields.otpAlg, + now: func() time.Time { + return testNow + }, + } + + err := cmd(context.Background(), cmds) + assert.ErrorIs(t, err, tt.res.err) + assert.Equal(t, tt.res.commands, cmds.eventCommands) + }) + } +} + +func TestCheckOTPEmail(t *testing.T) { + type fields struct { + eventstore func(*testing.T) *eventstore.Eventstore + userID string + otpCodeChallenge *OTPCode + otpAlg crypto.EncryptionAlgorithm + } + type args struct { + code string + } + type res struct { + err error + commands []eventstore.Command + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + name: "missing userID", + fields: fields{ + eventstore: expectEventstore(), + userID: "", + }, + args: args{ + code: "code", + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-ejo2w", "Errors.User.UserIDMissing"), + }, + }, + { + name: "missing challenge", + fields: fields{ + eventstore: expectEventstore(), + userID: "userID", + otpCodeChallenge: nil, + }, + args: args{ + code: "code", + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-zF3g3", "Errors.User.Code.NotFound"), + }, + }, + { + name: "invalid code", + fields: fields{ + eventstore: expectEventstore(), + userID: "userID", + otpCodeChallenge: &OTPCode{ + Code: &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("code"), + }, + Expiry: 5 * time.Minute, + CreationDate: testNow.Add(-10 * time.Minute), + }, + otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args: args{ + code: "code", + }, + res: res{ + err: caos_errs.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"), + }, + }, + { + name: "check ok", + fields: fields{ + eventstore: expectEventstore(), + userID: "userID", + otpCodeChallenge: &OTPCode{ + Code: &crypto.CryptoValue{ + CryptoType: crypto.TypeEncryption, + Algorithm: "enc", + KeyID: "id", + Crypted: []byte("code"), + }, + Expiry: 5 * time.Minute, + CreationDate: testNow, + }, + otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args: args{ + code: "code", + }, + res: res{ + commands: []eventstore.Command{ + session.NewOTPEmailCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + testNow, + ), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := CheckOTPEmail(tt.args.code) + + sessionModel := &SessionWriteModel{ + UserID: tt.fields.userID, + UserCheckedAt: testNow, + State: domain.SessionStateActive, + OTPEmailCodeChallenge: tt.fields.otpCodeChallenge, + aggregate: &session.NewAggregate("sessionID", "instanceID").Aggregate, + } + cmds := &SessionCommands{ + sessionCommands: []SessionCommand{cmd}, + sessionWriteModel: sessionModel, + eventstore: tt.fields.eventstore(t), + otpAlg: tt.fields.otpAlg, + now: func() time.Time { + return testNow + }, + } + + err := cmd(context.Background(), cmds) + assert.ErrorIs(t, err, tt.res.err) + assert.Equal(t, tt.res.commands, cmds.eventCommands) + }) + } +} diff --git a/internal/command/user_human_otp.go b/internal/command/user_human_otp.go index 170dde17a7..d8aac85e5b 100644 --- a/internal/command/user_human_otp.go +++ b/internal/command/user_human_otp.go @@ -310,7 +310,6 @@ func (c *Commands) HumanCheckOTPSMS(ctx context.Context, userID, code, resourceO resourceOwner, authRequest, writeModel, - domain.SecretGeneratorTypeOTPSMS, succeededEvent, failedEvent, ) @@ -431,7 +430,6 @@ func (c *Commands) HumanCheckOTPEmail(ctx context.Context, userID, code, resourc resourceOwner, authRequest, writeModel, - domain.SecretGeneratorTypeOTPEmail, succeededEvent, failedEvent, ) @@ -497,7 +495,6 @@ func (c *Commands) humanCheckOTP( userID, code, resourceOwner string, authRequest *domain.AuthRequest, writeModelByID func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error), - secretGeneratorType domain.SecretGeneratorType, checkSucceededEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command, checkFailedEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command, ) error { diff --git a/internal/domain/session.go b/internal/domain/session.go index 84a74a8f63..56dda951da 100644 --- a/internal/domain/session.go +++ b/internal/domain/session.go @@ -1,5 +1,11 @@ package domain +import ( + "io" + + "golang.org/x/text/language" +) + type SessionState int32 const ( @@ -7,3 +13,23 @@ const ( SessionStateActive SessionStateTerminated ) + +type OTPEmailURLData struct { + Code string + UserID string + LoginName string + DisplayName string + PreferredLanguage language.Tag +} + +// RenderOTPEmailURLTemplate parses and renders tmpl. +// code, userID, (preferred) loginName, displayName and preferredLanguage are passed into the [OTPEmailURLData]. +func RenderOTPEmailURLTemplate(w io.Writer, tmpl, code, userID, loginName, displayName string, preferredLanguage language.Tag) error { + return renderURLTemplate(w, tmpl, &OTPEmailURLData{ + Code: code, + UserID: userID, + LoginName: loginName, + DisplayName: displayName, + PreferredLanguage: preferredLanguage, + }) +} diff --git a/internal/notification/handlers/user_notifier.go b/internal/notification/handlers/user_notifier.go index d5ad60b464..7b1d407662 100644 --- a/internal/notification/handlers/user_notifier.go +++ b/internal/notification/handlers/user_notifier.go @@ -2,9 +2,11 @@ package handlers import ( "context" + "strings" "time" "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/ui/login" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" @@ -13,7 +15,9 @@ import ( "github.com/zitadel/zitadel/internal/eventstore/handler" "github.com/zitadel/zitadel/internal/eventstore/handler/crdb" "github.com/zitadel/zitadel/internal/notification/types" + "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query/projection" + "github.com/zitadel/zitadel/internal/repository/session" "github.com/zitadel/zitadel/internal/repository/user" ) @@ -26,6 +30,7 @@ type userNotifier struct { commands *command.Commands queries *NotificationQueries assetsPrefix func(context.Context) string + otpEmailTmpl string metricSuccessfulDeliveriesEmail, metricFailedDeliveriesEmail, metricSuccessfulDeliveriesSMS, @@ -38,6 +43,7 @@ func NewUserNotifier( commands *command.Commands, queries *NotificationQueries, assetsPrefix func(context.Context) string, + otpEmailTmpl string, metricSuccessfulDeliveriesEmail, metricFailedDeliveriesEmail, metricSuccessfulDeliveriesSMS, @@ -50,6 +56,7 @@ func NewUserNotifier( p.commands = commands p.queries = queries p.assetsPrefix = assetsPrefix + p.otpEmailTmpl = otpEmailTmpl p.metricSuccessfulDeliveriesEmail = metricSuccessfulDeliveriesEmail p.metricFailedDeliveriesEmail = metricFailedDeliveriesEmail p.metricSuccessfulDeliveriesSMS = metricSuccessfulDeliveriesSMS @@ -117,6 +124,19 @@ func (u *userNotifier) reducers() []handler.AggregateReducer { }, }, }, + { + Aggregate: session.AggregateType, + EventRedusers: []handler.EventReducer{ + { + Event: session.OTPSMSChallengedType, + Reduce: u.reduceSessionOTPSMSChallenged, + }, + { + Event: session.OTPEmailChallengedType, + Reduce: u.reduceSessionOTPEmailChallenged, + }, + }, + }, } } @@ -346,25 +366,70 @@ func (u *userNotifier) reduceOTPSMSCodeAdded(event eventstore.Event) (*handler.S if !ok { return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-ASF3g", "reduce.wrong.event.type %s", user.HumanOTPSMSCodeAddedType) } + return u.reduceOTPSMS( + e, + e.Code, + e.Expiry, + e.Aggregate().ID, + e.Aggregate().ResourceOwner, + u.commands.HumanOTPSMSCodeSent, + user.HumanOTPSMSCodeAddedType, + user.HumanOTPSMSCodeSentType, + ) +} + +func (u *userNotifier) reduceSessionOTPSMSChallenged(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*session.OTPSMSChallengedEvent) + if !ok { + return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-Sk32L", "reduce.wrong.event.type %s", session.OTPSMSChallengedType) + } + if e.CodeReturned { + return crdb.NewNoOpStatement(e), nil + } ctx := HandlerContext(event.Aggregate()) - alreadyHandled, err := u.checkIfCodeAlreadyHandledOrExpired(ctx, event, e.Expiry, nil, - user.HumanOTPSMSCodeAddedType, user.HumanOTPSMSCodeSentType) + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + if err != nil { + return nil, err + } + return u.reduceOTPSMS( + e, + e.Code, + e.Expiry, + s.UserFactor.UserID, + s.UserFactor.ResourceOwner, + u.commands.OTPSMSSent, + session.OTPSMSChallengedType, + session.OTPSMSSentType, + ) +} + +func (u *userNotifier) reduceOTPSMS( + event eventstore.Event, + code *crypto.CryptoValue, + expiry time.Duration, + userID, + resourceOwner string, + sentCommand func(ctx context.Context, userID string, resourceOwner string) (err error), + eventTypes ...eventstore.EventType, +) (*handler.Statement, error) { + ctx := HandlerContext(event.Aggregate()) + alreadyHandled, err := u.checkIfCodeAlreadyHandledOrExpired(ctx, event, expiry, nil, eventTypes...) if err != nil { return nil, err } if alreadyHandled { - return crdb.NewNoOpStatement(e), nil + return crdb.NewNoOpStatement(event), nil } - code, err := crypto.DecryptString(e.Code, u.queries.UserDataCrypto) + plainCode, err := crypto.DecryptString(code, u.queries.UserDataCrypto) if err != nil { return nil, err } - colors, err := u.queries.ActiveLabelPolicyByOrg(ctx, e.Aggregate().ResourceOwner, false) + colors, err := u.queries.ActiveLabelPolicyByOrg(ctx, resourceOwner, false) if err != nil { return nil, err } - notifyUser, err := u.queries.GetNotifyUserByID(ctx, true, e.Aggregate().ID, false) + notifyUser, err := u.queries.GetNotifyUserByID(ctx, true, userID, false) if err != nil { return nil, err } @@ -386,19 +451,19 @@ func (u *userNotifier) reduceOTPSMSCodeAdded(event eventstore.Event) (*handler.S u.queries.GetLogProvider, colors, u.assetsPrefix(ctx), - e, + event, u.metricSuccessfulDeliveriesSMS, u.metricFailedDeliveriesSMS, ) - err = notify.SendOTPSMSCode(authz.GetInstance(ctx).RequestedDomain(), origin, code, e.Expiry) + err = notify.SendOTPSMSCode(authz.GetInstance(ctx).RequestedDomain(), origin, plainCode, expiry) if err != nil { return nil, err } - err = u.commands.HumanOTPSMSCodeSent(ctx, e.Aggregate().ID, e.Aggregate().ResourceOwner) + err = sentCommand(ctx, userID, resourceOwner) if err != nil { return nil, err } - return crdb.NewNoOpStatement(e), nil + return crdb.NewNoOpStatement(event), nil } func (u *userNotifier) reduceOTPEmailCodeAdded(event eventstore.Event) (*handler.Statement, error) { @@ -406,34 +471,100 @@ func (u *userNotifier) reduceOTPEmailCodeAdded(event eventstore.Event) (*handler if !ok { return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-JL3hw", "reduce.wrong.event.type %s", user.HumanOTPEmailCodeAddedType) } + var authRequestID string + if e.AuthRequestInfo != nil { + authRequestID = e.AuthRequestInfo.ID + } + url := func(code, origin string, _ *query.NotifyUser) (string, error) { + return login.OTPLink(origin, authRequestID, code, domain.MFATypeOTPEmail), nil + } + return u.reduceOTPEmail( + e, + e.Code, + e.Expiry, + e.Aggregate().ID, + e.Aggregate().ResourceOwner, + url, + u.commands.HumanOTPEmailCodeSent, + user.HumanOTPEmailCodeAddedType, + user.HumanOTPEmailCodeSentType, + ) +} + +func (u *userNotifier) reduceSessionOTPEmailChallenged(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*session.OTPEmailChallengedEvent) + if !ok { + return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-zbsgt", "reduce.wrong.event.type %s", session.OTPEmailChallengedType) + } + if e.ReturnCode { + return crdb.NewNoOpStatement(e), nil + } ctx := HandlerContext(event.Aggregate()) - alreadyHandled, err := u.checkIfCodeAlreadyHandledOrExpired(ctx, event, e.Expiry, nil, - user.HumanOTPEmailCodeAddedType, user.HumanOTPEmailCodeSentType) + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + if err != nil { + return nil, err + } + url := func(code, origin string, user *query.NotifyUser) (string, error) { + var buf strings.Builder + urlTmpl := origin + u.otpEmailTmpl + if e.URLTmpl != "" { + urlTmpl = e.URLTmpl + } + if err := domain.RenderOTPEmailURLTemplate(&buf, urlTmpl, code, user.ID, user.PreferredLoginName, user.DisplayName, user.PreferredLanguage); err != nil { + return "", err + } + return buf.String(), nil + } + return u.reduceOTPEmail( + e, + e.Code, + e.Expiry, + s.UserFactor.UserID, + s.UserFactor.ResourceOwner, + url, + u.commands.OTPEmailSent, + user.HumanOTPEmailCodeAddedType, + user.HumanOTPEmailCodeSentType, + ) +} + +func (u *userNotifier) reduceOTPEmail( + event eventstore.Event, + code *crypto.CryptoValue, + expiry time.Duration, + userID, + resourceOwner string, + urlTmpl func(code, origin string, user *query.NotifyUser) (string, error), + sentCommand func(ctx context.Context, userID string, resourceOwner string) (err error), + eventTypes ...eventstore.EventType, +) (*handler.Statement, error) { + ctx := HandlerContext(event.Aggregate()) + alreadyHandled, err := u.checkIfCodeAlreadyHandledOrExpired(ctx, event, expiry, nil, eventTypes...) if err != nil { return nil, err } if alreadyHandled { - return crdb.NewNoOpStatement(e), nil + return crdb.NewNoOpStatement(event), nil } - code, err := crypto.DecryptString(e.Code, u.queries.UserDataCrypto) + plainCode, err := crypto.DecryptString(code, u.queries.UserDataCrypto) if err != nil { return nil, err } - colors, err := u.queries.ActiveLabelPolicyByOrg(ctx, e.Aggregate().ResourceOwner, false) + colors, err := u.queries.ActiveLabelPolicyByOrg(ctx, resourceOwner, false) if err != nil { return nil, err } - template, err := u.queries.MailTemplateByOrg(ctx, e.Aggregate().ResourceOwner, false) + template, err := u.queries.MailTemplateByOrg(ctx, resourceOwner, false) if err != nil { return nil, err } - notifyUser, err := u.queries.GetNotifyUserByID(ctx, true, e.Aggregate().ID, false) + notifyUser, err := u.queries.GetNotifyUserByID(ctx, true, userID, false) if err != nil { return nil, err } - translator, err := u.queries.GetTranslatorWithOrgTexts(ctx, notifyUser.ResourceOwner, domain.VerifyEmailOTPMessageType) + translator, err := u.queries.GetTranslatorWithOrgTexts(ctx, resourceOwner, domain.VerifyEmailOTPMessageType) if err != nil { return nil, err } @@ -442,9 +573,9 @@ func (u *userNotifier) reduceOTPEmailCodeAdded(event eventstore.Event) (*handler if err != nil { return nil, err } - var authRequestID string - if e.AuthRequestInfo != nil { - authRequestID = e.AuthRequestInfo.ID + url, err := urlTmpl(plainCode, origin, notifyUser) + if err != nil { + return nil, err } notify := types.SendEmail( ctx, @@ -456,19 +587,19 @@ func (u *userNotifier) reduceOTPEmailCodeAdded(event eventstore.Event) (*handler u.queries.GetLogProvider, colors, u.assetsPrefix(ctx), - e, + event, u.metricSuccessfulDeliveriesEmail, u.metricFailedDeliveriesEmail, ) - err = notify.SendOTPEmailCode(notifyUser, authz.GetInstance(ctx).RequestedDomain(), origin, code, authRequestID, e.Expiry) + err = notify.SendOTPEmailCode(notifyUser, url, authz.GetInstance(ctx).RequestedDomain(), origin, plainCode, expiry) if err != nil { return nil, err } - err = u.commands.HumanOTPEmailCodeSent(ctx, e.Aggregate().ID, e.Aggregate().ResourceOwner) + err = sentCommand(ctx, event.Aggregate().ID, event.Aggregate().ResourceOwner) if err != nil { return nil, err } - return crdb.NewNoOpStatement(e), nil + return crdb.NewNoOpStatement(event), nil } func (u *userNotifier) reduceDomainClaimed(event eventstore.Event) (*handler.Statement, error) { diff --git a/internal/notification/projections.go b/internal/notification/projections.go index 0a6f6f659c..380775254f 100644 --- a/internal/notification/projections.go +++ b/internal/notification/projections.go @@ -27,9 +27,7 @@ const ( func Start( ctx context.Context, - userHandlerCustomConfig projection.CustomConfig, - quotaHandlerCustomConfig projection.CustomConfig, - telemetryHandlerCustomConfig projection.CustomConfig, + userHandlerCustomConfig, quotaHandlerCustomConfig, telemetryHandlerCustomConfig projection.CustomConfig, telemetryCfg handlers.TelemetryPusherConfig, externalDomain string, externalPort uint16, @@ -38,10 +36,9 @@ func Start( queries *query.Queries, es *eventstore.Eventstore, assetsPrefix func(context.Context) string, + otpEmailTmpl string, fileSystemPath string, - userEncryption, - smtpEncryption, - smsEncryption crypto.EncryptionAlgorithm, + userEncryption, smtpEncryption, smsEncryption crypto.EncryptionAlgorithm, ) { statikFS, err := statik_fs.NewWithNamespace("notification") logging.OnError(err).Panic("unable to start listener") @@ -64,6 +61,7 @@ func Start( commands, q, assetsPrefix, + otpEmailTmpl, metricSuccessfulDeliveriesEmail, metricFailedDeliveriesEmail, metricSuccessfulDeliveriesSMS, diff --git a/internal/notification/types/otp.go b/internal/notification/types/otp.go index aea3a5c124..2079ecce01 100644 --- a/internal/notification/types/otp.go +++ b/internal/notification/types/otp.go @@ -3,7 +3,6 @@ package types import ( "time" - "github.com/zitadel/zitadel/internal/api/ui/login" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" ) @@ -13,8 +12,7 @@ func (notify Notify) SendOTPSMSCode(requestedDomain, origin, code string, expiry return notify("", args, domain.VerifySMSOTPMessageType, false) } -func (notify Notify) SendOTPEmailCode(user *query.NotifyUser, requestedDomain, origin, code, authRequestID string, expiry time.Duration) error { - url := login.OTPLink(origin, authRequestID, code, domain.MFATypeOTPEmail) +func (notify Notify) SendOTPEmailCode(user *query.NotifyUser, url, requestedDomain, origin, code string, expiry time.Duration) error { args := otpArgs(code, origin, requestedDomain, expiry) return notify(url, args, domain.VerifyEmailOTPMessageType, false) } diff --git a/internal/query/projection/session.go b/internal/query/projection/session.go index 654e804270..7305d6a87c 100644 --- a/internal/query/projection/session.go +++ b/internal/query/projection/session.go @@ -14,7 +14,7 @@ import ( ) const ( - SessionsProjectionTable = "projections.sessions4" + SessionsProjectionTable = "projections.sessions5" SessionColumnID = "id" SessionColumnCreationDate = "creation_date" @@ -31,6 +31,8 @@ const ( SessionColumnWebAuthNCheckedAt = "webauthn_checked_at" SessionColumnWebAuthNUserVerified = "webauthn_user_verified" SessionColumnTOTPCheckedAt = "totp_checked_at" + SessionColumnOTPSMSCheckedAt = "otp_sms_checked_at" + SessionColumnOTPEmailCheckedAt = "otp_email_checked_at" SessionColumnMetadata = "metadata" SessionColumnTokenID = "token_id" ) @@ -60,6 +62,8 @@ func newSessionProjection(ctx context.Context, config crdb.StatementHandlerConfi crdb.NewColumn(SessionColumnWebAuthNCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()), crdb.NewColumn(SessionColumnWebAuthNUserVerified, crdb.ColumnTypeBool, crdb.Nullable()), crdb.NewColumn(SessionColumnTOTPCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()), + crdb.NewColumn(SessionColumnOTPSMSCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()), + crdb.NewColumn(SessionColumnOTPEmailCheckedAt, crdb.ColumnTypeTimestamp, crdb.Nullable()), crdb.NewColumn(SessionColumnMetadata, crdb.ColumnTypeJSONB, crdb.Nullable()), crdb.NewColumn(SessionColumnTokenID, crdb.ColumnTypeText, crdb.Nullable()), }, @@ -99,6 +103,14 @@ func (p *sessionProjection) reducers() []handler.AggregateReducer { Event: session.TOTPCheckedType, Reduce: p.reduceTOTPChecked, }, + { + Event: session.OTPSMSCheckedType, + Reduce: p.reduceOTPSMSChecked, + }, + { + Event: session.OTPEmailCheckedType, + Reduce: p.reduceOTPEmailChecked, + }, { Event: session.TokenSetType, Reduce: p.reduceTokenSet, @@ -255,6 +267,46 @@ func (p *sessionProjection) reduceTOTPChecked(event eventstore.Event) (*handler. ), nil } +func (p *sessionProjection) reduceOTPSMSChecked(event eventstore.Event) (*handler.Statement, error) { + e, err := assertEvent[*session.OTPSMSCheckedEvent](event) + if err != nil { + return nil, err + } + + return crdb.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(SessionColumnChangeDate, e.CreationDate()), + handler.NewCol(SessionColumnSequence, e.Sequence()), + handler.NewCol(SessionColumnOTPSMSCheckedAt, e.CheckedAt), + }, + []handler.Condition{ + handler.NewCond(SessionColumnID, e.Aggregate().ID), + handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} + +func (p *sessionProjection) reduceOTPEmailChecked(event eventstore.Event) (*handler.Statement, error) { + e, err := assertEvent[*session.OTPEmailCheckedEvent](event) + if err != nil { + return nil, err + } + + return crdb.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(SessionColumnChangeDate, e.CreationDate()), + handler.NewCol(SessionColumnSequence, e.Sequence()), + handler.NewCol(SessionColumnOTPEmailCheckedAt, e.CheckedAt), + }, + []handler.Condition{ + handler.NewCond(SessionColumnID, e.Aggregate().ID), + handler.NewCond(SessionColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} + func (p *sessionProjection) reduceTokenSet(event eventstore.Event) (*handler.Statement, error) { e, ok := event.(*session.TokenSetEvent) if !ok { diff --git a/internal/query/projection/session_test.go b/internal/query/projection/session_test.go index c22310d620..8ac52b7484 100644 --- a/internal/query/projection/session_test.go +++ b/internal/query/projection/session_test.go @@ -43,7 +43,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "INSERT INTO projections.sessions4 (id, instance_id, creation_date, change_date, resource_owner, state, sequence, creator) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", + expectedStmt: "INSERT INTO projections.sessions5 (id, instance_id, creation_date, change_date, resource_owner, state, sequence, creator) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", expectedArgs: []interface{}{ "agg-id", "instance-id", @@ -79,7 +79,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions4 SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)", + expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, user_id, user_checked_at) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -112,7 +112,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions4 SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, password_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -145,7 +145,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions4 SET (change_date, sequence, webauthn_checked_at, webauthn_user_verified) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)", + expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, webauthn_checked_at, webauthn_user_verified) = ($1, $2, $3, $4) WHERE (id = $5) AND (instance_id = $6)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -178,7 +178,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions4 SET (change_date, sequence, intent_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, intent_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -210,7 +210,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions4 SET (change_date, sequence, totp_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, totp_checked_at) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -242,7 +242,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions4 SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, token_id) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -276,7 +276,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions4 SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", + expectedStmt: "UPDATE projections.sessions5 SET (change_date, sequence, metadata) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)", expectedArgs: []interface{}{ anyArg{}, anyArg{}, @@ -308,7 +308,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "DELETE FROM projections.sessions4 WHERE (id = $1) AND (instance_id = $2)", + expectedStmt: "DELETE FROM projections.sessions5 WHERE (id = $1) AND (instance_id = $2)", expectedArgs: []interface{}{ "agg-id", "instance-id", @@ -335,7 +335,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "DELETE FROM projections.sessions4 WHERE (instance_id = $1)", + expectedStmt: "DELETE FROM projections.sessions5 WHERE (instance_id = $1)", expectedArgs: []interface{}{ "agg-id", }, @@ -366,7 +366,7 @@ func TestSessionProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.sessions4 SET password_checked_at = $1 WHERE (user_id = $2) AND (password_checked_at < $3)", + expectedStmt: "UPDATE projections.sessions5 SET password_checked_at = $1 WHERE (user_id = $2) AND (password_checked_at < $3)", expectedArgs: []interface{}{ nil, "agg-id", diff --git a/internal/query/session.go b/internal/query/session.go index 2a1672a3fa..c098d3d110 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -35,6 +35,8 @@ type Session struct { IntentFactor SessionIntentFactor WebAuthNFactor SessionWebAuthNFactor TOTPFactor SessionTOTPFactor + OTPSMSFactor SessionOTPFactor + OTPEmailFactor SessionOTPFactor Metadata map[string][]byte } @@ -63,6 +65,10 @@ type SessionTOTPFactor struct { TOTPCheckedAt time.Time } +type SessionOTPFactor struct { + OTPCheckedAt time.Time +} + type SessionsSearchQueries struct { SearchRequest Queries []SearchQuery @@ -141,6 +147,14 @@ var ( name: projection.SessionColumnTOTPCheckedAt, table: sessionsTable, } + SessionColumnOTPSMSCheckedAt = Column{ + name: projection.SessionColumnOTPSMSCheckedAt, + table: sessionsTable, + } + SessionColumnOTPEmailCheckedAt = Column{ + name: projection.SessionColumnOTPEmailCheckedAt, + table: sessionsTable, + } SessionColumnMetadata = Column{ name: projection.SessionColumnMetadata, table: sessionsTable, @@ -243,6 +257,8 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil SessionColumnWebAuthNCheckedAt.identifier(), SessionColumnWebAuthNUserVerified.identifier(), SessionColumnTOTPCheckedAt.identifier(), + SessionColumnOTPSMSCheckedAt.identifier(), + SessionColumnOTPEmailCheckedAt.identifier(), SessionColumnMetadata.identifier(), SessionColumnToken.identifier(), ).From(sessionsTable.identifier()). @@ -263,6 +279,8 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil webAuthNCheckedAt sql.NullTime webAuthNUserPresent sql.NullBool totpCheckedAt sql.NullTime + otpSMSCheckedAt sql.NullTime + otpEmailCheckedAt sql.NullTime metadata database.Map[[]byte] token sql.NullString ) @@ -285,6 +303,8 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil &webAuthNCheckedAt, &webAuthNUserPresent, &totpCheckedAt, + &otpSMSCheckedAt, + &otpEmailCheckedAt, &metadata, &token, ) @@ -306,6 +326,8 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil session.WebAuthNFactor.WebAuthNCheckedAt = webAuthNCheckedAt.Time session.WebAuthNFactor.UserVerified = webAuthNUserPresent.Bool session.TOTPFactor.TOTPCheckedAt = totpCheckedAt.Time + session.OTPSMSFactor.OTPCheckedAt = otpSMSCheckedAt.Time + session.OTPEmailFactor.OTPCheckedAt = otpEmailCheckedAt.Time session.Metadata = metadata return session, token.String, nil @@ -331,6 +353,8 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui SessionColumnWebAuthNCheckedAt.identifier(), SessionColumnWebAuthNUserVerified.identifier(), SessionColumnTOTPCheckedAt.identifier(), + SessionColumnOTPSMSCheckedAt.identifier(), + SessionColumnOTPEmailCheckedAt.identifier(), SessionColumnMetadata.identifier(), countColumn.identifier(), ).From(sessionsTable.identifier()). @@ -354,6 +378,8 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui webAuthNCheckedAt sql.NullTime webAuthNUserPresent sql.NullBool totpCheckedAt sql.NullTime + otpSMSCheckedAt sql.NullTime + otpEmailCheckedAt sql.NullTime metadata database.Map[[]byte] ) @@ -375,6 +401,8 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui &webAuthNCheckedAt, &webAuthNUserPresent, &totpCheckedAt, + &otpSMSCheckedAt, + &otpEmailCheckedAt, &metadata, &sessions.Count, ) @@ -392,6 +420,8 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui session.WebAuthNFactor.WebAuthNCheckedAt = webAuthNCheckedAt.Time session.WebAuthNFactor.UserVerified = webAuthNUserPresent.Bool session.TOTPFactor.TOTPCheckedAt = totpCheckedAt.Time + session.OTPSMSFactor.OTPCheckedAt = otpSMSCheckedAt.Time + session.OTPEmailFactor.OTPCheckedAt = otpEmailCheckedAt.Time session.Metadata = metadata sessions.Sessions = append(sessions.Sessions, session) diff --git a/internal/query/sessions_test.go b/internal/query/sessions_test.go index c66868a9d6..fa5209bdd3 100644 --- a/internal/query/sessions_test.go +++ b/internal/query/sessions_test.go @@ -17,53 +17,57 @@ import ( ) var ( - expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions4.id,` + - ` projections.sessions4.creation_date,` + - ` projections.sessions4.change_date,` + - ` projections.sessions4.sequence,` + - ` projections.sessions4.state,` + - ` projections.sessions4.resource_owner,` + - ` projections.sessions4.creator,` + - ` projections.sessions4.user_id,` + - ` projections.sessions4.user_checked_at,` + + expectedSessionQuery = regexp.QuoteMeta(`SELECT projections.sessions5.id,` + + ` projections.sessions5.creation_date,` + + ` projections.sessions5.change_date,` + + ` projections.sessions5.sequence,` + + ` projections.sessions5.state,` + + ` projections.sessions5.resource_owner,` + + ` projections.sessions5.creator,` + + ` projections.sessions5.user_id,` + + ` projections.sessions5.user_checked_at,` + ` projections.login_names2.login_name,` + ` projections.users8_humans.display_name,` + ` projections.users8.resource_owner,` + - ` projections.sessions4.password_checked_at,` + - ` projections.sessions4.intent_checked_at,` + - ` projections.sessions4.webauthn_checked_at,` + - ` projections.sessions4.webauthn_user_verified,` + - ` projections.sessions4.totp_checked_at,` + - ` projections.sessions4.metadata,` + - ` projections.sessions4.token_id` + - ` FROM projections.sessions4` + - ` LEFT JOIN projections.login_names2 ON projections.sessions4.user_id = projections.login_names2.user_id AND projections.sessions4.instance_id = projections.login_names2.instance_id` + - ` LEFT JOIN projections.users8_humans ON projections.sessions4.user_id = projections.users8_humans.user_id AND projections.sessions4.instance_id = projections.users8_humans.instance_id` + - ` LEFT JOIN projections.users8 ON projections.sessions4.user_id = projections.users8.id AND projections.sessions4.instance_id = projections.users8.instance_id` + + ` projections.sessions5.password_checked_at,` + + ` projections.sessions5.intent_checked_at,` + + ` projections.sessions5.webauthn_checked_at,` + + ` projections.sessions5.webauthn_user_verified,` + + ` projections.sessions5.totp_checked_at,` + + ` projections.sessions5.otp_sms_checked_at,` + + ` projections.sessions5.otp_email_checked_at,` + + ` projections.sessions5.metadata,` + + ` projections.sessions5.token_id` + + ` FROM projections.sessions5` + + ` LEFT JOIN projections.login_names2 ON projections.sessions5.user_id = projections.login_names2.user_id AND projections.sessions5.instance_id = projections.login_names2.instance_id` + + ` LEFT JOIN projections.users8_humans ON projections.sessions5.user_id = projections.users8_humans.user_id AND projections.sessions5.instance_id = projections.users8_humans.instance_id` + + ` LEFT JOIN projections.users8 ON projections.sessions5.user_id = projections.users8.id AND projections.sessions5.instance_id = projections.users8.instance_id` + ` AS OF SYSTEM TIME '-1 ms'`) - expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions4.id,` + - ` projections.sessions4.creation_date,` + - ` projections.sessions4.change_date,` + - ` projections.sessions4.sequence,` + - ` projections.sessions4.state,` + - ` projections.sessions4.resource_owner,` + - ` projections.sessions4.creator,` + - ` projections.sessions4.user_id,` + - ` projections.sessions4.user_checked_at,` + + expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions5.id,` + + ` projections.sessions5.creation_date,` + + ` projections.sessions5.change_date,` + + ` projections.sessions5.sequence,` + + ` projections.sessions5.state,` + + ` projections.sessions5.resource_owner,` + + ` projections.sessions5.creator,` + + ` projections.sessions5.user_id,` + + ` projections.sessions5.user_checked_at,` + ` projections.login_names2.login_name,` + ` projections.users8_humans.display_name,` + ` projections.users8.resource_owner,` + - ` projections.sessions4.password_checked_at,` + - ` projections.sessions4.intent_checked_at,` + - ` projections.sessions4.webauthn_checked_at,` + - ` projections.sessions4.webauthn_user_verified,` + - ` projections.sessions4.totp_checked_at,` + - ` projections.sessions4.metadata,` + + ` projections.sessions5.password_checked_at,` + + ` projections.sessions5.intent_checked_at,` + + ` projections.sessions5.webauthn_checked_at,` + + ` projections.sessions5.webauthn_user_verified,` + + ` projections.sessions5.totp_checked_at,` + + ` projections.sessions5.otp_sms_checked_at,` + + ` projections.sessions5.otp_email_checked_at,` + + ` projections.sessions5.metadata,` + ` COUNT(*) OVER ()` + - ` FROM projections.sessions4` + - ` LEFT JOIN projections.login_names2 ON projections.sessions4.user_id = projections.login_names2.user_id AND projections.sessions4.instance_id = projections.login_names2.instance_id` + - ` LEFT JOIN projections.users8_humans ON projections.sessions4.user_id = projections.users8_humans.user_id AND projections.sessions4.instance_id = projections.users8_humans.instance_id` + - ` LEFT JOIN projections.users8 ON projections.sessions4.user_id = projections.users8.id AND projections.sessions4.instance_id = projections.users8.instance_id` + + ` FROM projections.sessions5` + + ` LEFT JOIN projections.login_names2 ON projections.sessions5.user_id = projections.login_names2.user_id AND projections.sessions5.instance_id = projections.login_names2.instance_id` + + ` LEFT JOIN projections.users8_humans ON projections.sessions5.user_id = projections.users8_humans.user_id AND projections.sessions5.instance_id = projections.users8_humans.instance_id` + + ` LEFT JOIN projections.users8 ON projections.sessions5.user_id = projections.users8.id AND projections.sessions5.instance_id = projections.users8.instance_id` + ` AS OF SYSTEM TIME '-1 ms'`) sessionCols = []string{ @@ -84,6 +88,8 @@ var ( "webauthn_checked_at", "webauthn_user_verified", "totp_checked_at", + "otp_sms_checked_at", + "otp_email_checked_at", "metadata", "token", } @@ -106,6 +112,8 @@ var ( "webauthn_checked_at", "webauthn_user_verified", "totp_checked_at", + "otp_sms_checked_at", + "otp_email_checked_at", "metadata", "count", } @@ -160,6 +168,8 @@ func Test_SessionsPrepare(t *testing.T) { testNow, true, testNow, + testNow, + testNow, []byte(`{"key": "dmFsdWU="}`), }, }, @@ -198,6 +208,12 @@ func Test_SessionsPrepare(t *testing.T) { TOTPFactor: SessionTOTPFactor{ TOTPCheckedAt: testNow, }, + OTPSMSFactor: SessionOTPFactor{ + OTPCheckedAt: testNow, + }, + OTPEmailFactor: SessionOTPFactor{ + OTPCheckedAt: testNow, + }, Metadata: map[string][]byte{ "key": []byte("value"), }, @@ -231,6 +247,8 @@ func Test_SessionsPrepare(t *testing.T) { testNow, true, testNow, + testNow, + testNow, []byte(`{"key": "dmFsdWU="}`), }, { @@ -251,6 +269,8 @@ func Test_SessionsPrepare(t *testing.T) { testNow, false, testNow, + testNow, + testNow, []byte(`{"key": "dmFsdWU="}`), }, }, @@ -289,6 +309,12 @@ func Test_SessionsPrepare(t *testing.T) { TOTPFactor: SessionTOTPFactor{ TOTPCheckedAt: testNow, }, + OTPSMSFactor: SessionOTPFactor{ + OTPCheckedAt: testNow, + }, + OTPEmailFactor: SessionOTPFactor{ + OTPCheckedAt: testNow, + }, Metadata: map[string][]byte{ "key": []byte("value"), }, @@ -321,6 +347,12 @@ func Test_SessionsPrepare(t *testing.T) { TOTPFactor: SessionTOTPFactor{ TOTPCheckedAt: testNow, }, + OTPSMSFactor: SessionOTPFactor{ + OTPCheckedAt: testNow, + }, + OTPEmailFactor: SessionOTPFactor{ + OTPCheckedAt: testNow, + }, Metadata: map[string][]byte{ "key": []byte("value"), }, @@ -407,6 +439,8 @@ func Test_SessionPrepare(t *testing.T) { testNow, true, testNow, + testNow, + testNow, []byte(`{"key": "dmFsdWU="}`), "tokenID", }, @@ -440,6 +474,12 @@ func Test_SessionPrepare(t *testing.T) { TOTPFactor: SessionTOTPFactor{ TOTPCheckedAt: testNow, }, + OTPSMSFactor: SessionOTPFactor{ + OTPCheckedAt: testNow, + }, + OTPEmailFactor: SessionOTPFactor{ + OTPCheckedAt: testNow, + }, Metadata: map[string][]byte{ "key": []byte("value"), }, diff --git a/internal/repository/session/eventstore.go b/internal/repository/session/eventstore.go index efa52b6582..2923e5239e 100644 --- a/internal/repository/session/eventstore.go +++ b/internal/repository/session/eventstore.go @@ -10,6 +10,12 @@ func RegisterEventMappers(es *eventstore.Eventstore) { RegisterFilterEventMapper(AggregateType, WebAuthNChallengedType, eventstore.GenericEventMapper[WebAuthNChallengedEvent]). RegisterFilterEventMapper(AggregateType, WebAuthNCheckedType, eventstore.GenericEventMapper[WebAuthNCheckedEvent]). RegisterFilterEventMapper(AggregateType, TOTPCheckedType, eventstore.GenericEventMapper[TOTPCheckedEvent]). + RegisterFilterEventMapper(AggregateType, OTPSMSChallengedType, eventstore.GenericEventMapper[OTPSMSChallengedEvent]). + RegisterFilterEventMapper(AggregateType, OTPSMSSentType, eventstore.GenericEventMapper[OTPSMSSentEvent]). + RegisterFilterEventMapper(AggregateType, OTPSMSCheckedType, eventstore.GenericEventMapper[OTPSMSCheckedEvent]). + RegisterFilterEventMapper(AggregateType, OTPEmailChallengedType, eventstore.GenericEventMapper[OTPEmailChallengedEvent]). + RegisterFilterEventMapper(AggregateType, OTPEmailSentType, eventstore.GenericEventMapper[OTPEmailSentEvent]). + RegisterFilterEventMapper(AggregateType, OTPEmailCheckedType, eventstore.GenericEventMapper[OTPEmailCheckedEvent]). RegisterFilterEventMapper(AggregateType, TokenSetType, TokenSetEventMapper). RegisterFilterEventMapper(AggregateType, MetadataSetType, MetadataSetEventMapper). RegisterFilterEventMapper(AggregateType, TerminateType, TerminateEventMapper) diff --git a/internal/repository/session/session.go b/internal/repository/session/session.go index 556cd033c7..76f4984d1d 100644 --- a/internal/repository/session/session.go +++ b/internal/repository/session/session.go @@ -5,6 +5,7 @@ import ( "encoding/json" "time" + "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore" @@ -20,6 +21,12 @@ const ( WebAuthNChallengedType = sessionEventPrefix + "webAuthN.challenged" WebAuthNCheckedType = sessionEventPrefix + "webAuthN.checked" TOTPCheckedType = sessionEventPrefix + "totp.checked" + OTPSMSChallengedType = sessionEventPrefix + "otp.sms.challenged" + OTPSMSSentType = sessionEventPrefix + "otp.sms.sent" + OTPSMSCheckedType = sessionEventPrefix + "otp.sms.checked" + OTPEmailChallengedType = sessionEventPrefix + "otp.email.challenged" + OTPEmailSentType = sessionEventPrefix + "otp.email.sent" + OTPEmailCheckedType = sessionEventPrefix + "otp.email.checked" TokenSetType = sessionEventPrefix + "token.set" MetadataSetType = sessionEventPrefix + "metadata.set" TerminateType = sessionEventPrefix + "terminated" @@ -298,6 +305,211 @@ func NewTOTPCheckedEvent( } } +type OTPSMSChallengedEvent struct { + eventstore.BaseEvent `json:"-"` + + Code *crypto.CryptoValue `json:"code"` + Expiry time.Duration `json:"expiry"` + CodeReturned bool `json:"codeReturned,omitempty"` +} + +func (e *OTPSMSChallengedEvent) Data() interface{} { + return e +} + +func (e *OTPSMSChallengedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func (e *OTPSMSChallengedEvent) SetBaseEvent(base *eventstore.BaseEvent) { + e.BaseEvent = *base +} + +func NewOTPSMSChallengedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + code *crypto.CryptoValue, + expiry time.Duration, + codeReturned bool, +) *OTPSMSChallengedEvent { + return &OTPSMSChallengedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + OTPSMSChallengedType, + ), + Code: code, + Expiry: expiry, + CodeReturned: codeReturned, + } +} + +type OTPSMSSentEvent struct { + eventstore.BaseEvent `json:"-"` +} + +func (e *OTPSMSSentEvent) Data() interface{} { + return e +} + +func (e *OTPSMSSentEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func (e *OTPSMSSentEvent) SetBaseEvent(base *eventstore.BaseEvent) { + e.BaseEvent = *base +} + +func NewOTPSMSSentEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, +) *OTPSMSSentEvent { + return &OTPSMSSentEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + OTPSMSSentType, + ), + } +} + +type OTPSMSCheckedEvent struct { + eventstore.BaseEvent `json:"-"` + + CheckedAt time.Time `json:"checkedAt"` +} + +func (e *OTPSMSCheckedEvent) Data() interface{} { + return e +} + +func (e *OTPSMSCheckedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func (e *OTPSMSCheckedEvent) SetBaseEvent(base *eventstore.BaseEvent) { + e.BaseEvent = *base +} + +func NewOTPSMSCheckedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + checkedAt time.Time, +) *OTPSMSCheckedEvent { + return &OTPSMSCheckedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + OTPSMSCheckedType, + ), + CheckedAt: checkedAt, + } +} + +type OTPEmailChallengedEvent struct { + eventstore.BaseEvent `json:"-"` + + Code *crypto.CryptoValue `json:"code"` + Expiry time.Duration `json:"expiry"` + ReturnCode bool `json:"returnCode,omitempty"` + URLTmpl string `json:"urlTmpl,omitempty"` +} + +func (e *OTPEmailChallengedEvent) Data() interface{} { + return e +} + +func (e *OTPEmailChallengedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func (e *OTPEmailChallengedEvent) SetBaseEvent(base *eventstore.BaseEvent) { + e.BaseEvent = *base +} + +func NewOTPEmailChallengedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + code *crypto.CryptoValue, + expiry time.Duration, + returnCode bool, + urlTmpl string, +) *OTPEmailChallengedEvent { + return &OTPEmailChallengedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + OTPEmailChallengedType, + ), + Code: code, + Expiry: expiry, + ReturnCode: returnCode, + URLTmpl: urlTmpl, + } +} + +type OTPEmailSentEvent struct { + eventstore.BaseEvent `json:"-"` +} + +func (e *OTPEmailSentEvent) Data() interface{} { + return e +} + +func (e *OTPEmailSentEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func (e *OTPEmailSentEvent) SetBaseEvent(base *eventstore.BaseEvent) { + e.BaseEvent = *base +} + +func NewOTPEmailSentEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, +) *OTPEmailSentEvent { + return &OTPEmailSentEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + OTPEmailSentType, + ), + } +} + +type OTPEmailCheckedEvent struct { + eventstore.BaseEvent `json:"-"` + + CheckedAt time.Time `json:"checkedAt"` +} + +func (e *OTPEmailCheckedEvent) Data() interface{} { + return e +} + +func (e *OTPEmailCheckedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func (e *OTPEmailCheckedEvent) SetBaseEvent(base *eventstore.BaseEvent) { + e.BaseEvent = *base +} + +func NewOTPEmailCheckedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + checkedAt time.Time, +) *OTPEmailCheckedEvent { + return &OTPEmailCheckedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + OTPEmailCheckedType, + ), + CheckedAt: checkedAt, + } +} + type TokenSetEvent struct { eventstore.BaseEvent `json:"-"` diff --git a/proto/zitadel/session/v2alpha/challenge.proto b/proto/zitadel/session/v2alpha/challenge.proto index ed1ef6e647..b8c6b0c089 100644 --- a/proto/zitadel/session/v2alpha/challenge.proto +++ b/proto/zitadel/session/v2alpha/challenge.proto @@ -37,8 +37,33 @@ message RequestChallenges { } ]; } + message OTPSMS { + bool return_code = 1; + } + message OTPEmail { + message SendCode { + optional string url_template = 1 [ + (validate.rules).string = {min_len: 1, max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + min_length: 1; + max_length: 200; + example: "\"https://example.com/otp/verify?userID={{.UserID}}&code={{.Code}}\""; + description: "\"Optionally set a url_template, which will be used in the mail sent by ZITADEL to guide the user to your verification page. If no template is set, the default ZITADEL url will be used.\"" + } + ]; + } + message ReturnCode {} + + // if no delivery_type is specified, an email is sent with the default url + oneof delivery_type { + SendCode send_code = 2; + ReturnCode return_code = 3; + } + } optional WebAuthN web_auth_n = 1; + optional OTPSMS otp_sms = 2; + optional OTPEmail otp_email = 3; } message Challenges { @@ -52,4 +77,6 @@ message Challenges { } optional WebAuthN web_auth_n = 1; + optional string otp_sms = 2; + optional string otp_email = 3; } diff --git a/proto/zitadel/session/v2alpha/session.proto b/proto/zitadel/session/v2alpha/session.proto index 44f337c0d6..5c0bcb3115 100644 --- a/proto/zitadel/session/v2alpha/session.proto +++ b/proto/zitadel/session/v2alpha/session.proto @@ -47,6 +47,8 @@ message Factors { WebAuthNFactor web_auth_n = 3; IntentFactor intent = 4; TOTPFactor totp = 5; + OTPFactor otp_sms = 6; + OTPFactor otp_email = 7; } message UserFactor { @@ -110,6 +112,14 @@ message TOTPFactor { ]; } +message OTPFactor { + google.protobuf.Timestamp verified_at = 1 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "\"time when the One-Time Password was last checked\""; + } + ]; +} + message SearchQuery { oneof query { option (validate.required) = true; diff --git a/proto/zitadel/session/v2alpha/session_service.proto b/proto/zitadel/session/v2alpha/session_service.proto index 9a4d017a3b..533b07e999 100644 --- a/proto/zitadel/session/v2alpha/session_service.proto +++ b/proto/zitadel/session/v2alpha/session_service.proto @@ -380,6 +380,16 @@ message Checks { description: "\"Checks the Time-based One-Time Password and updates the session on success. Requires that the user is already checked, either in the previous or the same request.\""; } ]; + optional CheckOTP otp_sms = 6 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "\"Checks the One-Time Password sent over SMS and updates the session on success. Requires that the user is already checked, either in the previous or the same request.\""; + } + ]; + optional CheckOTP otp_email = 7 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "\"Checks the One-Time Password sent over Email and updates the session on success. Requires that the user is already checked, either in the previous or the same request.\""; + } + ]; } message CheckUser { @@ -456,4 +466,14 @@ message CheckTOTP { example: "\"323764\""; } ]; +} + +message CheckOTP { + string otp = 1 [ + (validate.rules).string = {min_len: 1}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + min_length: 1; + example: "\"3237642\""; + } + ]; } \ No newline at end of file