diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index 79d986dc5e2..1b059320676 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -2,6 +2,7 @@ package eventstore import ( "context" + "errors" "slices" "strings" "time" @@ -29,6 +30,12 @@ import ( const unknownUserID = "UNKNOWN" +var ( + ErrUserNotFound = func(err error) error { + return zerrors.ThrowNotFound(err, "EVENT-hodc6", "Errors.User.NotFound") + } +) + type AuthRequestRepo struct { Command *command.Commands Query *query.Queries @@ -52,6 +59,7 @@ type AuthRequestRepo struct { ProjectProvider projectProvider ApplicationProvider applicationProvider CustomTextProvider customTextProvider + PasswordChecker passwordChecker IdGenerator id.Generator } @@ -71,7 +79,7 @@ type userSessionViewProvider interface { } type userViewProvider interface { - UserByID(string, string) (*user_view_model.UserView, error) + UserByID(context.Context, string, string) (*user_view_model.UserView, error) } type loginPolicyViewProvider interface { @@ -125,6 +133,10 @@ type customTextProvider interface { CustomTextListByTemplate(ctx context.Context, aggregateID string, text string, withOwnerRemoved bool) (texts *query.CustomTexts, err error) } +type passwordChecker interface { + HumanCheckPassword(ctx context.Context, resourceOwner, userID, password string, authReq *domain.AuthRequest) error +} + func (repo *AuthRequestRepo) Health(ctx context.Context) error { return repo.AuthRequests.Health(ctx) } @@ -341,23 +353,25 @@ func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, authReqID, user request, err := repo.getAuthRequestEnsureUser(ctx, authReqID, userAgentID, userID) if err != nil { if isIgnoreUserNotFoundError(err, request) { + // use the same errorID as below (otherwise it would expose the error reason) return zerrors.ThrowInvalidArgument(nil, "EVENT-SDe2f", "Errors.User.UsernameOrPassword.Invalid") } return err } - err = repo.Command.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info)) + err = repo.PasswordChecker.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info)) if isIgnoreUserInvalidPasswordError(err, request) { - return zerrors.ThrowInvalidArgument(nil, "EVENT-Jsf32", "Errors.User.UsernameOrPassword.Invalid") + // use the same errorID as above (otherwise it would expose the error reason) + return zerrors.ThrowInvalidArgument(nil, "EVENT-SDe2f", "Errors.User.UsernameOrPassword.Invalid") } return err } func isIgnoreUserNotFoundError(err error, request *domain.AuthRequest) bool { - return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && zerrors.IsNotFound(err) && zerrors.Contains(err, "Errors.User.NotFound") + return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && errors.Is(err, ErrUserNotFound(nil)) } func isIgnoreUserInvalidPasswordError(err error, request *domain.AuthRequest) bool { - return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && zerrors.IsErrorInvalidArgument(err) && zerrors.Contains(err, "Errors.User.Password.Invalid") + return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && errors.Is(err, command.ErrPasswordInvalid(nil)) } func lockoutPolicyToDomain(policy *query.LockoutPolicy) *domain.LockoutPolicy { @@ -1629,7 +1643,7 @@ func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - user, viewErr := viewProvider.UserByID(userID, authz.GetInstance(ctx).InstanceID()) + user, viewErr := viewProvider.UserByID(ctx, userID, authz.GetInstance(ctx).InstanceID()) if viewErr != nil && !zerrors.IsNotFound(viewErr) { return nil, viewErr } else if user == nil { @@ -1642,9 +1656,10 @@ func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider } if len(events) == 0 { if viewErr != nil { - return nil, viewErr + // We already returned all errors apart from not found, but need to make sure that can be checked in case IgnoreUnknownUsernames option is active. + return nil, ErrUserNotFound(viewErr) } - return user_view_model.UserToModel(user), viewErr + return user_view_model.UserToModel(user), nil } userCopy := *user for _, event := range events { diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go index e64c7fed91a..10528a8a25d 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go @@ -10,9 +10,11 @@ import ( "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view" cache "github.com/zitadel/zitadel/internal/auth_request/repository" "github.com/zitadel/zitadel/internal/auth_request/repository/mock" + "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" @@ -103,7 +105,7 @@ func (m *mockViewUserSession) GetLatestUserSessionSequence(ctx context.Context, type mockViewNoUser struct{} -func (m *mockViewNoUser) UserByID(string, string) (*user_view_model.UserView, error) { +func (m *mockViewNoUser) UserByID(context.Context, string, string) (*user_view_model.UserView, error) { return nil, zerrors.ThrowNotFound(nil, "id", "user not found") } @@ -198,7 +200,7 @@ func (m *mockPasswordAgePolicy) PasswordAgePolicyByOrg(context.Context, bool, st return m.policy, nil } -func (m *mockViewUser) UserByID(string, string) (*user_view_model.UserView, error) { +func (m *mockViewUser) UserByID(context.Context, string, string) (*user_view_model.UserView, error) { return &user_view_model.UserView{ State: int32(user_model.UserStateActive), UserName: "UserName", @@ -298,6 +300,14 @@ func (m *mockIDPUserLinks) IDPUserLinks(ctx context.Context, queries *query.IDPU return &query.IDPUserLinks{Links: m.idps}, nil } +type mockPasswordChecker struct { + err error +} + +func (m *mockPasswordChecker) HumanCheckPassword(ctx context.Context, resourceOwner, userID, password string, authReq *domain.AuthRequest) error { + return m.err +} + func TestAuthRequestRepo_nextSteps(t *testing.T) { type fields struct { AuthRequests cache.AuthRequestCache @@ -2403,3 +2413,153 @@ func Test_userByID(t *testing.T) { }) } } + +func TestAuthRequestRepo_VerifyPassword_IgnoreUnknownUsernames(t *testing.T) { + authRequest := func(userID string) *domain.AuthRequest { + a := &domain.AuthRequest{ + ID: "authRequestID", + AgentID: "userAgentID", + UserID: userID, + LoginPolicy: &domain.LoginPolicy{ + ObjectRoot: es_models.ObjectRoot{}, + Default: true, + AllowUsernamePassword: true, + AllowRegister: true, + AllowExternalIDP: true, + IDPProviders: []*domain.IDPProvider{ + { + ObjectRoot: es_models.ObjectRoot{}, + Type: domain.IdentityProviderTypeSystem, + IDPConfigID: "idpConfig1", + Name: "IdP", + IDPType: domain.IDPTypeOIDC, + IDPState: domain.IDPConfigStateActive, + }, + }, + IgnoreUnknownUsernames: true, + }, + AllowedExternalIDPs: []*domain.IDPProvider{ + { + ObjectRoot: es_models.ObjectRoot{}, + Type: domain.IdentityProviderTypeSystem, + IDPConfigID: "idpConfig1", + Name: "IdP", + IDPType: domain.IDPTypeOIDC, + IDPState: domain.IDPConfigStateActive, + }, + }, + LabelPolicy: &domain.LabelPolicy{ + ObjectRoot: es_models.ObjectRoot{}, + State: domain.LabelPolicyStateActive, + Default: true, + }, + PrivacyPolicy: &domain.PrivacyPolicy{ + ObjectRoot: es_models.ObjectRoot{}, + State: domain.PolicyStateActive, + Default: true, + }, + LockoutPolicy: &domain.LockoutPolicy{ + Default: true, + }, + PasswordAgePolicy: &domain.PasswordAgePolicy{ + ObjectRoot: es_models.ObjectRoot{}, + MaxAgeDays: 0, + ExpireWarnDays: 0, + }, + DefaultTranslations: []*domain.CustomText{{}}, + OrgTranslations: []*domain.CustomText{{}}, + SAMLRequestID: "", + } + a.SetPolicyOrgID("instance1") + return a + } + type fields struct { + AuthRequests func(*testing.T, string) cache.AuthRequestCache + UserViewProvider userViewProvider + UserEventProvider userEventProvider + OrgViewProvider orgViewProvider + PasswordChecker passwordChecker + } + type args struct { + ctx context.Context + authReqID string + userID string + resourceOwner string + password string + userAgentID string + info *domain.BrowserInfo + } + tests := []struct { + name string + fields fields + args args + }{ + { + name: "no user", + fields: fields{ + AuthRequests: func(tt *testing.T, userID string) cache.AuthRequestCache { + m := mock.NewMockAuthRequestCache(gomock.NewController(tt)) + a := authRequest(userID) + m.EXPECT().GetAuthRequestByID(gomock.Any(), "authRequestID").Return(a, nil) + m.EXPECT().CacheAuthRequest(gomock.Any(), a) + return m + }, + UserViewProvider: &mockViewNoUser{}, + UserEventProvider: &mockEventUser{}, + }, + args: args{ + ctx: authz.NewMockContext("instance1", "", ""), + authReqID: "authRequestID", + userID: unknownUserID, + resourceOwner: "org1", + password: "password", + userAgentID: "userAgentID", + info: &domain.BrowserInfo{ + UserAgent: "useragent", + }, + }, + }, + { + name: "invalid password", + fields: fields{ + AuthRequests: func(tt *testing.T, userID string) cache.AuthRequestCache { + m := mock.NewMockAuthRequestCache(gomock.NewController(tt)) + a := authRequest(userID) + m.EXPECT().GetAuthRequestByID(gomock.Any(), "authRequestID").Return(a, nil) + m.EXPECT().CacheAuthRequest(gomock.Any(), a) + return m + }, + UserViewProvider: &mockViewUser{}, + UserEventProvider: &mockEventUser{}, + OrgViewProvider: &mockViewOrg{State: domain.OrgStateActive}, + PasswordChecker: &mockPasswordChecker{ + err: command.ErrPasswordInvalid(nil), + }, + }, + args: args{ + ctx: authz.NewMockContext("instance1", "", ""), + authReqID: "authRequestID", + userID: "user1", + resourceOwner: "org1", + password: "password", + userAgentID: "userAgentID", + info: &domain.BrowserInfo{ + UserAgent: "useragent", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &AuthRequestRepo{ + AuthRequests: tt.fields.AuthRequests(t, tt.args.userID), + UserViewProvider: tt.fields.UserViewProvider, + UserEventProvider: tt.fields.UserEventProvider, + OrgViewProvider: tt.fields.OrgViewProvider, + PasswordChecker: tt.fields.PasswordChecker, + } + err := repo.VerifyPassword(tt.args.ctx, tt.args.authReqID, tt.args.userID, tt.args.resourceOwner, tt.args.password, tt.args.userAgentID, tt.args.info) + assert.ErrorIs(t, err, zerrors.ThrowInvalidArgument(nil, "EVENT-SDe2f", "Errors.User.UsernameOrPassword.Invalid")) + }) + } +} diff --git a/internal/auth/repository/eventsourcing/repository.go b/internal/auth/repository/eventsourcing/repository.go index a9ec6aa0582..0d5e3609cc0 100644 --- a/internal/auth/repository/eventsourcing/repository.go +++ b/internal/auth/repository/eventsourcing/repository.go @@ -78,6 +78,7 @@ func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, c ProjectProvider: queryView, ApplicationProvider: queries, CustomTextProvider: queries, + PasswordChecker: command, IdGenerator: id.SonyFlakeGenerator(), }, eventstore.TokenRepo{ diff --git a/internal/auth/repository/eventsourcing/view/user.go b/internal/auth/repository/eventsourcing/view/user.go index e75846471d6..812c36e62dd 100644 --- a/internal/auth/repository/eventsourcing/view/user.go +++ b/internal/auth/repository/eventsourcing/view/user.go @@ -16,8 +16,8 @@ const ( userTable = "auth.users3" ) -func (v *View) UserByID(userID, instanceID string) (*model.UserView, error) { - return view.UserByID(v.Db, userTable, userID, instanceID) +func (v *View) UserByID(ctx context.Context, userID, instanceID string) (*model.UserView, error) { + return view.UserByID(ctx, v.Db, userID, instanceID) } func (v *View) UserByLoginName(ctx context.Context, loginName, instanceID string) (*model.UserView, error) { @@ -27,7 +27,7 @@ func (v *View) UserByLoginName(ctx context.Context, loginName, instanceID string } //nolint: contextcheck // no lint was added because refactor would change too much code - return view.UserByID(v.Db, userTable, queriedUser.ID, instanceID) + return view.UserByID(ctx, v.Db, queriedUser.ID, instanceID) } func (v *View) UserByLoginNameAndResourceOwner(ctx context.Context, loginName, resourceOwner, instanceID string) (*model.UserView, error) { @@ -37,7 +37,7 @@ func (v *View) UserByLoginNameAndResourceOwner(ctx context.Context, loginName, r } //nolint: contextcheck // no lint was added because refactor would change too much code - user, err := view.UserByID(v.Db, userTable, queriedUser.ID, instanceID) + user, err := view.UserByID(ctx, v.Db, queriedUser.ID, instanceID) if err != nil { return nil, err } @@ -103,7 +103,7 @@ func (v *View) userByID(ctx context.Context, instanceID string, queries ...query OnError(err). Errorf("could not get current sequence for userByID") - user, err := view.UserByID(v.Db, userTable, queriedUser.ID, instanceID) + user, err := view.UserByID(ctx, v.Db, queriedUser.ID, instanceID) if err != nil && !zerrors.IsNotFound(err) { return nil, err } diff --git a/internal/command/user_human_password.go b/internal/command/user_human_password.go index 839977b384c..4fb7a32099b 100644 --- a/internal/command/user_human_password.go +++ b/internal/command/user_human_password.go @@ -16,6 +16,15 @@ import ( "github.com/zitadel/zitadel/internal/zerrors" ) +var ( + ErrPasswordInvalid = func(err error) error { + return zerrors.ThrowInvalidArgument(err, "COMMAND-3M0fs", "Errors.User.Password.Invalid") + } + ErrPasswordUnchanged = func(err error) error { + return zerrors.ThrowPreconditionFailed(err, "COMMAND-Aesh5", "Errors.User.Password.NotChanged") + } +) + func (c *Commands) SetPassword(ctx context.Context, orgID, userID, password string, oneTime bool) (objectDetails *domain.ObjectDetails, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -393,10 +402,10 @@ func convertPasswapErr(err error) error { return nil } if errors.Is(err, passwap.ErrPasswordMismatch) { - return zerrors.ThrowInvalidArgument(err, "COMMAND-3M0fs", "Errors.User.Password.Invalid") + return ErrPasswordInvalid(err) } if errors.Is(err, passwap.ErrPasswordNoChange) { - return zerrors.ThrowPreconditionFailed(err, "COMMAND-Aesh5", "Errors.User.Password.NotChanged") + return ErrPasswordUnchanged(err) } return zerrors.ThrowInternal(err, "COMMAND-CahN2", "Errors.Internal") } diff --git a/internal/user/repository/view/user_view.go b/internal/user/repository/view/user_view.go index 98b23f46612..c09ef157c98 100644 --- a/internal/user/repository/view/user_view.go +++ b/internal/user/repository/view/user_view.go @@ -16,12 +16,12 @@ import ( //go:embed user_by_id.sql var userByIDQuery string -func UserByID(db *gorm.DB, table, userID, instanceID string) (*model.UserView, error) { +func UserByID(ctx context.Context, db *gorm.DB, userID, instanceID string) (*model.UserView, error) { user := new(model.UserView) query := db.Raw(userByIDQuery, instanceID, userID) - tx := query.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + tx := query.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) defer func() { if err := tx.Commit().Error; err != nil { logging.OnError(err).Info("commit failed") @@ -35,8 +35,8 @@ func UserByID(db *gorm.DB, table, userID, instanceID string) (*model.UserView, e return user, nil } if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, zerrors.ThrowNotFound(err, "VIEW-hodc6", "object not found") + return nil, zerrors.ThrowNotFound(err, "VIEW-hodc6", "Errors.User.NotFound") } - logging.WithFields("table ", table).WithError(err).Warn("get from cache error") - return nil, zerrors.ThrowInternal(err, "VIEW-qJBg9", "cache error") + logging.WithError(err).Warn("unable to get user by id") + return nil, zerrors.ThrowInternal(err, "VIEW-qJBg9", "unable to get user by id") }