fix: prevent error reason leakage in case of IgnoreUnknownUsernames (#8372)

# Which Problems Are Solved

ZITADEL administrators can enable a setting called "Ignoring unknown
usernames" which helps mitigate attacks that try to guess/enumerate
usernames. If enabled, ZITADEL will show the password prompt even if the
user doesn't exist and report "Username or Password invalid".
Due to a implementation change to prevent deadlocks calling the
database, the flag would not be correctly respected in all cases and an
attacker would gain information if an account exist within ZITADEL,
since the error message shows "object not found" instead of the generic
error message.

# How the Problems Are Solved

- Proper check of the error using an error function / type and
`errors.Is`

# Additional Changes

None.

# Additional Context

- raised in a support request

Co-authored-by: Silvan <silvan.reusser@gmail.com>

(cherry picked from commit a1d24353db)
This commit is contained in:
Livio Spring 2024-07-31 15:56:20 +02:00
parent 38da602ee1
commit 3c7d12834e
No known key found for this signature in database
GPG Key ID: 26BB1C2FA5952CF0
6 changed files with 209 additions and 22 deletions

View File

@ -2,6 +2,7 @@ package eventstore
import (
"context"
"errors"
"slices"
"strings"
"time"
@ -29,6 +30,12 @@ import (
const unknownUserID = "UNKNOWN"
var (
ErrUserNotFound = func(err error) error {
return zerrors.ThrowNotFound(err, "EVENT-hodc6", "Errors.User.NotFound")
}
)
type AuthRequestRepo struct {
Command *command.Commands
Query *query.Queries
@ -51,6 +58,7 @@ type AuthRequestRepo struct {
ProjectProvider projectProvider
ApplicationProvider applicationProvider
CustomTextProvider customTextProvider
PasswordChecker passwordChecker
IdGenerator id.Generator
}
@ -70,7 +78,7 @@ type userSessionViewProvider interface {
}
type userViewProvider interface {
UserByID(string, string) (*user_view_model.UserView, error)
UserByID(context.Context, string, string) (*user_view_model.UserView, error)
}
type loginPolicyViewProvider interface {
@ -120,6 +128,10 @@ type customTextProvider interface {
CustomTextListByTemplate(ctx context.Context, aggregateID string, text string, withOwnerRemoved bool) (texts *query.CustomTexts, err error)
}
type passwordChecker interface {
HumanCheckPassword(ctx context.Context, resourceOwner, userID, password string, authReq *domain.AuthRequest, lockoutPolicy *domain.LockoutPolicy) error
}
func (repo *AuthRequestRepo) Health(ctx context.Context) error {
return repo.AuthRequests.Health(ctx)
}
@ -336,6 +348,7 @@ func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, authReqID, user
request, err := repo.getAuthRequestEnsureUser(ctx, authReqID, userAgentID, userID)
if err != nil {
if isIgnoreUserNotFoundError(err, request) {
// use the same errorID as below (otherwise it would expose the error reason)
return zerrors.ThrowInvalidArgument(nil, "EVENT-SDe2f", "Errors.User.UsernameOrPassword.Invalid")
}
return err
@ -344,19 +357,20 @@ func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, authReqID, user
if err != nil {
return err
}
err = repo.Command.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info), lockoutPolicyToDomain(policy))
err = repo.PasswordChecker.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info), lockoutPolicyToDomain(policy))
if isIgnoreUserInvalidPasswordError(err, request) {
return zerrors.ThrowInvalidArgument(nil, "EVENT-Jsf32", "Errors.User.UsernameOrPassword.Invalid")
// use the same errorID as above (otherwise it would expose the error reason)
return zerrors.ThrowInvalidArgument(nil, "EVENT-SDe2f", "Errors.User.UsernameOrPassword.Invalid")
}
return err
}
func isIgnoreUserNotFoundError(err error, request *domain.AuthRequest) bool {
return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && zerrors.IsNotFound(err) && zerrors.Contains(err, "Errors.User.NotFound")
return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && errors.Is(err, ErrUserNotFound(nil))
}
func isIgnoreUserInvalidPasswordError(err error, request *domain.AuthRequest) bool {
return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && zerrors.IsErrorInvalidArgument(err) && zerrors.Contains(err, "Errors.User.Password.Invalid")
return request != nil && request.LoginPolicy != nil && request.LoginPolicy.IgnoreUnknownUsernames && errors.Is(err, command.ErrPasswordInvalid(nil))
}
func lockoutPolicyToDomain(policy *query.LockoutPolicy) *domain.LockoutPolicy {
@ -1590,7 +1604,7 @@ func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
user, viewErr := viewProvider.UserByID(userID, authz.GetInstance(ctx).InstanceID())
user, viewErr := viewProvider.UserByID(ctx, userID, authz.GetInstance(ctx).InstanceID())
if viewErr != nil && !zerrors.IsNotFound(viewErr) {
return nil, viewErr
} else if user == nil {
@ -1603,9 +1617,10 @@ func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider
}
if len(events) == 0 {
if viewErr != nil {
return nil, viewErr
// We already returned all errors apart from not found, but need to make sure that can be checked in case IgnoreUnknownUsernames option is active.
return nil, ErrUserNotFound(viewErr)
}
return user_view_model.UserToModel(user), viewErr
return user_view_model.UserToModel(user), nil
}
userCopy := *user
for _, event := range events {

View File

@ -10,9 +10,11 @@ import (
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view"
cache "github.com/zitadel/zitadel/internal/auth_request/repository"
"github.com/zitadel/zitadel/internal/auth_request/repository/mock"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
@ -103,7 +105,7 @@ func (m *mockViewUserSession) GetLatestUserSessionSequence(ctx context.Context,
type mockViewNoUser struct{}
func (m *mockViewNoUser) UserByID(string, string) (*user_view_model.UserView, error) {
func (m *mockViewNoUser) UserByID(context.Context, string, string) (*user_view_model.UserView, error) {
return nil, zerrors.ThrowNotFound(nil, "id", "user not found")
}
@ -189,7 +191,7 @@ func (m *mockLockoutPolicy) LockoutPolicyByOrg(context.Context, bool, string) (*
return m.policy, nil
}
func (m *mockViewUser) UserByID(string, string) (*user_view_model.UserView, error) {
func (m *mockViewUser) UserByID(context.Context, string, string) (*user_view_model.UserView, error) {
return &user_view_model.UserView{
State: int32(user_model.UserStateActive),
UserName: "UserName",
@ -289,6 +291,14 @@ func (m *mockIDPUserLinks) IDPUserLinks(ctx context.Context, queries *query.IDPU
return &query.IDPUserLinks{Links: m.idps}, nil
}
type mockPasswordChecker struct {
err error
}
func (m *mockPasswordChecker) HumanCheckPassword(ctx context.Context, resourceOwner, userID, password string, authReq *domain.AuthRequest, policy *domain.LockoutPolicy) error {
return m.err
}
func TestAuthRequestRepo_nextSteps(t *testing.T) {
type fields struct {
AuthRequests cache.AuthRequestCache
@ -2347,3 +2357,155 @@ func Test_userByID(t *testing.T) {
})
}
}
func TestAuthRequestRepo_VerifyPassword_IgnoreUnknownUsernames(t *testing.T) {
authRequest := func(userID string) *domain.AuthRequest {
a := &domain.AuthRequest{
ID: "authRequestID",
AgentID: "userAgentID",
UserID: userID,
LoginPolicy: &domain.LoginPolicy{
ObjectRoot: es_models.ObjectRoot{},
Default: true,
AllowUsernamePassword: true,
AllowRegister: true,
AllowExternalIDP: true,
IDPProviders: []*domain.IDPProvider{
{
ObjectRoot: es_models.ObjectRoot{},
Type: domain.IdentityProviderTypeSystem,
IDPConfigID: "idpConfig1",
Name: "IdP",
IDPType: domain.IDPTypeOIDC,
IDPState: domain.IDPConfigStateActive,
},
},
IgnoreUnknownUsernames: true,
},
AllowedExternalIDPs: []*domain.IDPProvider{
{
ObjectRoot: es_models.ObjectRoot{},
Type: domain.IdentityProviderTypeSystem,
IDPConfigID: "idpConfig1",
Name: "IdP",
IDPType: domain.IDPTypeOIDC,
IDPState: domain.IDPConfigStateActive,
},
},
LabelPolicy: &domain.LabelPolicy{
ObjectRoot: es_models.ObjectRoot{},
State: domain.LabelPolicyStateActive,
Default: true,
},
PrivacyPolicy: &domain.PrivacyPolicy{
ObjectRoot: es_models.ObjectRoot{},
State: domain.PolicyStateActive,
Default: true,
},
LockoutPolicy: &domain.LockoutPolicy{
Default: true,
},
DefaultTranslations: []*domain.CustomText{{}},
OrgTranslations: []*domain.CustomText{{}},
SAMLRequestID: "",
}
a.SetPolicyOrgID("instance1")
return a
}
type fields struct {
AuthRequests func(*testing.T, string) cache.AuthRequestCache
UserViewProvider userViewProvider
UserEventProvider userEventProvider
OrgViewProvider orgViewProvider
PasswordChecker passwordChecker
LockoutPolicyViewProvider lockoutPolicyViewProvider
}
type args struct {
ctx context.Context
authReqID string
userID string
resourceOwner string
password string
userAgentID string
info *domain.BrowserInfo
}
tests := []struct {
name string
fields fields
args args
}{
{
name: "no user",
fields: fields{
AuthRequests: func(tt *testing.T, userID string) cache.AuthRequestCache {
m := mock.NewMockAuthRequestCache(gomock.NewController(tt))
a := authRequest(userID)
m.EXPECT().GetAuthRequestByID(gomock.Any(), "authRequestID").Return(a, nil)
m.EXPECT().CacheAuthRequest(gomock.Any(), a)
return m
},
UserViewProvider: &mockViewNoUser{},
UserEventProvider: &mockEventUser{},
},
args: args{
ctx: authz.NewMockContext("instance1", "", ""),
authReqID: "authRequestID",
userID: unknownUserID,
resourceOwner: "org1",
password: "password",
userAgentID: "userAgentID",
info: &domain.BrowserInfo{
UserAgent: "useragent",
},
},
},
{
name: "invalid password",
fields: fields{
AuthRequests: func(tt *testing.T, userID string) cache.AuthRequestCache {
m := mock.NewMockAuthRequestCache(gomock.NewController(tt))
a := authRequest(userID)
m.EXPECT().GetAuthRequestByID(gomock.Any(), "authRequestID").Return(a, nil)
m.EXPECT().CacheAuthRequest(gomock.Any(), a)
return m
},
UserViewProvider: &mockViewUser{},
UserEventProvider: &mockEventUser{},
OrgViewProvider: &mockViewOrg{State: domain.OrgStateActive},
PasswordChecker: &mockPasswordChecker{
err: command.ErrPasswordInvalid(nil),
},
LockoutPolicyViewProvider: &mockLockoutPolicy{
policy: &query.LockoutPolicy{
ShowFailures: true,
},
},
},
args: args{
ctx: authz.NewMockContext("instance1", "", ""),
authReqID: "authRequestID",
userID: "user1",
resourceOwner: "org1",
password: "password",
userAgentID: "userAgentID",
info: &domain.BrowserInfo{
UserAgent: "useragent",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
repo := &AuthRequestRepo{
AuthRequests: tt.fields.AuthRequests(t, tt.args.userID),
UserViewProvider: tt.fields.UserViewProvider,
UserEventProvider: tt.fields.UserEventProvider,
OrgViewProvider: tt.fields.OrgViewProvider,
PasswordChecker: tt.fields.PasswordChecker,
LockoutPolicyViewProvider: tt.fields.LockoutPolicyViewProvider,
}
err := repo.VerifyPassword(tt.args.ctx, tt.args.authReqID, tt.args.userID, tt.args.resourceOwner, tt.args.password, tt.args.userAgentID, tt.args.info)
assert.ErrorIs(t, err, zerrors.ThrowInvalidArgument(nil, "EVENT-SDe2f", "Errors.User.UsernameOrPassword.Invalid"))
})
}
}

View File

@ -77,6 +77,7 @@ func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, c
ProjectProvider: queryView,
ApplicationProvider: queries,
CustomTextProvider: queries,
PasswordChecker: command,
IdGenerator: id.SonyFlakeGenerator(),
},
eventstore.TokenRepo{

View File

@ -16,8 +16,8 @@ const (
userTable = "auth.users3"
)
func (v *View) UserByID(userID, instanceID string) (*model.UserView, error) {
return view.UserByID(v.Db, userTable, userID, instanceID)
func (v *View) UserByID(ctx context.Context, userID, instanceID string) (*model.UserView, error) {
return view.UserByID(ctx, v.Db, userID, instanceID)
}
func (v *View) UserByLoginName(ctx context.Context, loginName, instanceID string) (*model.UserView, error) {
@ -27,7 +27,7 @@ func (v *View) UserByLoginName(ctx context.Context, loginName, instanceID string
}
//nolint: contextcheck // no lint was added because refactor would change too much code
return view.UserByID(v.Db, userTable, queriedUser.ID, instanceID)
return view.UserByID(ctx, v.Db, queriedUser.ID, instanceID)
}
func (v *View) UserByLoginNameAndResourceOwner(ctx context.Context, loginName, resourceOwner, instanceID string) (*model.UserView, error) {
@ -37,7 +37,7 @@ func (v *View) UserByLoginNameAndResourceOwner(ctx context.Context, loginName, r
}
//nolint: contextcheck // no lint was added because refactor would change too much code
user, err := view.UserByID(v.Db, userTable, queriedUser.ID, instanceID)
user, err := view.UserByID(ctx, v.Db, queriedUser.ID, instanceID)
if err != nil {
return nil, err
}
@ -103,7 +103,7 @@ func (v *View) userByID(ctx context.Context, instanceID string, queries ...query
OnError(err).
Errorf("could not get current sequence for userByID")
user, err := view.UserByID(v.Db, userTable, queriedUser.ID, instanceID)
user, err := view.UserByID(ctx, v.Db, queriedUser.ID, instanceID)
if err != nil && !zerrors.IsNotFound(err) {
return nil, err
}

View File

@ -16,6 +16,15 @@ import (
"github.com/zitadel/zitadel/internal/zerrors"
)
var (
ErrPasswordInvalid = func(err error) error {
return zerrors.ThrowInvalidArgument(err, "COMMAND-3M0fs", "Errors.User.Password.Invalid")
}
ErrPasswordUnchanged = func(err error) error {
return zerrors.ThrowPreconditionFailed(err, "COMMAND-Aesh5", "Errors.User.Password.NotChanged")
}
)
func (c *Commands) SetPassword(ctx context.Context, orgID, userID, password string, oneTime bool) (objectDetails *domain.ObjectDetails, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
@ -383,10 +392,10 @@ func convertPasswapErr(err error) error {
return nil
}
if errors.Is(err, passwap.ErrPasswordMismatch) {
return zerrors.ThrowInvalidArgument(err, "COMMAND-3M0fs", "Errors.User.Password.Invalid")
return ErrPasswordInvalid(err)
}
if errors.Is(err, passwap.ErrPasswordNoChange) {
return zerrors.ThrowPreconditionFailed(err, "COMMAND-Aesh5", "Errors.User.Password.NotChanged")
return ErrPasswordUnchanged(err)
}
return zerrors.ThrowInternal(err, "COMMAND-CahN2", "Errors.Internal")
}

View File

@ -16,12 +16,12 @@ import (
//go:embed user_by_id.sql
var userByIDQuery string
func UserByID(db *gorm.DB, table, userID, instanceID string) (*model.UserView, error) {
func UserByID(ctx context.Context, db *gorm.DB, userID, instanceID string) (*model.UserView, error) {
user := new(model.UserView)
query := db.Raw(userByIDQuery, instanceID, userID)
tx := query.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true})
tx := query.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
defer func() {
if err := tx.Commit().Error; err != nil {
logging.OnError(err).Info("commit failed")
@ -35,8 +35,8 @@ func UserByID(db *gorm.DB, table, userID, instanceID string) (*model.UserView, e
return user, nil
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, zerrors.ThrowNotFound(err, "VIEW-hodc6", "object not found")
return nil, zerrors.ThrowNotFound(err, "VIEW-hodc6", "Errors.User.NotFound")
}
logging.WithFields("table ", table).WithError(err).Warn("get from cache error")
return nil, zerrors.ThrowInternal(err, "VIEW-qJBg9", "cache error")
logging.WithError(err).Warn("unable to get user by id")
return nil, zerrors.ThrowInternal(err, "VIEW-qJBg9", "unable to get user by id")
}