diff --git a/internal/api/grpc/server/middleware/access_interceptor.go b/internal/api/grpc/server/middleware/access_interceptor.go index 100264c3f5..f95c3225ed 100644 --- a/internal/api/grpc/server/middleware/access_interceptor.go +++ b/internal/api/grpc/server/middleware/access_interceptor.go @@ -20,7 +20,6 @@ func AccessStorageInterceptor(svc *logstore.Service[*record.AccessLog]) grpc.Una if !svc.Enabled() { return handler(ctx, req) } - reqMd, _ := metadata.FromIncomingContext(ctx) resp, handlerErr := handler(ctx, req) diff --git a/internal/api/grpc/server/middleware/execution_interceptor.go b/internal/api/grpc/server/middleware/execution_interceptor.go index 3288f28ad8..053386caae 100644 --- a/internal/api/grpc/server/middleware/execution_interceptor.go +++ b/internal/api/grpc/server/middleware/execution_interceptor.go @@ -34,7 +34,7 @@ func ExecutionHandler(queries *query.Queries) grpc.UnaryServerInterceptor { func executeTargetsForRequest(ctx context.Context, targets []execution.Target, fullMethod string, req interface{}) (_ interface{}, err error) { ctx, span := tracing.NewSpan(ctx) - defer span.EndWithError(err) + defer func() { span.EndWithError(err) }() // if no targets are found, return without any calls if len(targets) == 0 { @@ -56,7 +56,7 @@ func executeTargetsForRequest(ctx context.Context, targets []execution.Target, f func executeTargetsForResponse(ctx context.Context, targets []execution.Target, fullMethod string, req, resp interface{}) (_ interface{}, err error) { ctx, span := tracing.NewSpan(ctx) - defer span.EndWithError(err) + defer func() { span.EndWithError(err) }() // if no targets are found, return without any calls if len(targets) == 0 { diff --git a/internal/api/grpc/session/v2/session.go b/internal/api/grpc/session/v2/session.go index 7562d64350..08f19368ef 100644 --- a/internal/api/grpc/session/v2/session.go +++ b/internal/api/grpc/session/v2/session.go @@ -255,7 +255,7 @@ type userSearchByID struct { } func (u userSearchByID) search(ctx context.Context, q *query.Queries) (*query.User, error) { - return q.GetUserByID(ctx, true, u.id) + return q.GetUserByID(ctx, false, u.id) } type userSearchByLoginName struct { diff --git a/internal/api/oidc/auth_request.go b/internal/api/oidc/auth_request.go index a113392df8..f750b2a3ea 100644 --- a/internal/api/oidc/auth_request.go +++ b/internal/api/oidc/auth_request.go @@ -150,7 +150,7 @@ func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string) if err != nil { return nil, err } - appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, true) + appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false) if err != nil { return nil, err } @@ -546,11 +546,7 @@ func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authoriz code: code, state: authReq.GetState(), } - callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) - if err != nil { - return "", err - } - return callback, err + return op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) } func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest) (string, error) { diff --git a/internal/auth/repository/eventsourcing/eventstore/org.go b/internal/auth/repository/eventsourcing/eventstore/org.go index 938f0d27cd..78c69d63c9 100644 --- a/internal/auth/repository/eventsourcing/eventstore/org.go +++ b/internal/auth/repository/eventsourcing/eventstore/org.go @@ -23,7 +23,7 @@ type OrgRepository struct { } func (repo *OrgRepository) GetMyPasswordComplexityPolicy(ctx context.Context) (*iam_model.PasswordComplexityPolicyView, error) { - policy, err := repo.Query.PasswordComplexityPolicyByOrg(ctx, true, authz.GetCtxData(ctx).OrgID, false) + policy, err := repo.Query.PasswordComplexityPolicyByOrg(ctx, false, authz.GetCtxData(ctx).OrgID, false) if err != nil { return nil, err } diff --git a/internal/eventstore/subscription.go b/internal/eventstore/subscription.go index c76c81df19..076d16ad52 100644 --- a/internal/eventstore/subscription.go +++ b/internal/eventstore/subscription.go @@ -1,6 +1,7 @@ package eventstore import ( + "slices" "sync" "github.com/zitadel/logging" @@ -8,7 +9,7 @@ import ( var ( subscriptions = map[AggregateType][]*Subscription{} - subsMutext sync.Mutex + subsMutex sync.RWMutex ) type Subscription struct { @@ -27,8 +28,8 @@ func SubscribeAggregates(eventQueue chan Event, aggregates ...AggregateType) *Su types: types, } - subsMutext.Lock() - defer subsMutext.Unlock() + subsMutex.Lock() + defer subsMutex.Unlock() for _, aggregate := range aggregates { subscriptions[aggregate] = append(subscriptions[aggregate], sub) @@ -45,8 +46,8 @@ func SubscribeEventTypes(eventQueue chan Event, types map[AggregateType][]EventT types: types, } - subsMutext.Lock() - defer subsMutext.Unlock() + subsMutex.Lock() + defer subsMutex.Unlock() for aggregate := range types { subscriptions[aggregate] = append(subscriptions[aggregate], sub) @@ -56,8 +57,8 @@ func SubscribeEventTypes(eventQueue chan Event, types map[AggregateType][]EventT } func (es *Eventstore) notify(events []Event) { - subsMutext.Lock() - defer subsMutext.Unlock() + subsMutex.RLock() + defer subsMutex.RUnlock() for _, event := range events { subs, ok := subscriptions[event.Aggregate().Type] if !ok { @@ -71,14 +72,11 @@ func (es *Eventstore) notify(events []Event) { continue } //subscription for certain events - for _, eventType := range eventTypes { - if event.Type() == eventType { - select { - case sub.Events <- event: - default: - logging.Debug("unable to push event") - } - break + if slices.Contains(eventTypes, event.Type()) { + select { + case sub.Events <- event: + default: + logging.Debug("unable to push event") } } } @@ -86,8 +84,8 @@ func (es *Eventstore) notify(events []Event) { } func (s *Subscription) Unsubscribe() { - subsMutext.Lock() - defer subsMutext.Unlock() + subsMutex.Lock() + defer subsMutex.Unlock() for aggregate := range s.types { subs, ok := subscriptions[aggregate] if !ok { diff --git a/internal/eventstore/v3/field.go b/internal/eventstore/v3/field.go index b399e7f5e8..372c224c6c 100644 --- a/internal/eventstore/v3/field.go +++ b/internal/eventstore/v3/field.go @@ -47,7 +47,7 @@ func (es *Eventstore) FillFields(ctx context.Context, events ...eventstore.FillF // Search implements the [eventstore.Search] method func (es *Eventstore) Search(ctx context.Context, conditions ...map[eventstore.FieldType]any) (result []*eventstore.SearchResult, err error) { ctx, span := tracing.NewSpan(ctx) - defer span.EndWithError(err) + defer func() { span.EndWithError(err) }() var builder strings.Builder args := buildSearchStatement(ctx, &builder, conditions...) diff --git a/internal/execution/execution.go b/internal/execution/execution.go index 116f377e17..575c86ecc4 100644 --- a/internal/execution/execution.go +++ b/internal/execution/execution.go @@ -42,7 +42,7 @@ func CallTargets( info ContextInfo, ) (_ interface{}, err error) { ctx, span := tracing.NewSpan(ctx) - defer span.EndWithError(err) + defer func() { span.EndWithError(err) }() for _, target := range targets { // call the type of target @@ -72,7 +72,7 @@ func CallTarget( info ContextInfoRequest, ) (res []byte, err error) { ctx, span := tracing.NewSpan(ctx) - defer span.EndWithError(err) + defer func() { span.EndWithError(err) }() switch target.GetTargetType() { // get request, ignore response and return request and error for handling in list of targets diff --git a/internal/integration/client.go b/internal/integration/client.go index abc774f452..47458cf4cd 100644 --- a/internal/integration/client.go +++ b/internal/integration/client.go @@ -3,6 +3,7 @@ package integration import ( "context" "fmt" + "sync" "testing" "time" @@ -157,6 +158,7 @@ func (i *Instance) CreateHumanUser(ctx context.Context) *user_v2.AddHumanUserRes }, }) logging.OnError(err).Panic("create human user") + i.TriggerUserByID(ctx, resp.GetUserId()) return resp } @@ -181,6 +183,7 @@ func (i *Instance) CreateHumanUserNoPhone(ctx context.Context) *user_v2.AddHuman }, }) logging.OnError(err).Panic("create human user") + i.TriggerUserByID(ctx, resp.GetUserId()) return resp } @@ -212,9 +215,26 @@ func (i *Instance) CreateHumanUserWithTOTP(ctx context.Context, secret string) * TotpSecret: gu.Ptr(secret), }) logging.OnError(err).Panic("create human user") + i.TriggerUserByID(ctx, resp.GetUserId()) return resp } +// TriggerUserByID makes sure the user projection gets triggered after creation. +func (i *Instance) TriggerUserByID(ctx context.Context, users ...string) { + var wg sync.WaitGroup + wg.Add(len(users)) + for _, user := range users { + go func(user string) { + defer wg.Done() + _, err := i.Client.UserV2.GetUserByID(ctx, &user_v2.GetUserByIDRequest{ + UserId: user, + }) + logging.OnError(err).Warn("get user by ID for trigger failed") + }(user) + } + wg.Wait() +} + func (i *Instance) CreateOrganization(ctx context.Context, name, adminEmail string) *org.AddOrganizationResponse { resp, err := i.Client.OrgV2.AddOrganization(ctx, &org.AddOrganizationRequest{ Name: name, @@ -238,6 +258,13 @@ func (i *Instance) CreateOrganization(ctx context.Context, name, adminEmail stri }, }) logging.OnError(err).Panic("create org") + + users := make([]string, len(resp.GetCreatedAdmins())) + for i, admin := range resp.GetCreatedAdmins() { + users[i] = admin.GetUserId() + } + i.TriggerUserByID(ctx, users...) + return resp } @@ -302,6 +329,7 @@ func (i *Instance) CreateHumanUserVerified(ctx context.Context, org, email, phon }, }) logging.OnError(err).Panic("create human user") + i.TriggerUserByID(ctx, resp.GetUserId()) return resp } @@ -313,6 +341,7 @@ func (i *Instance) CreateMachineUser(ctx context.Context) *mgmt.AddMachineUserRe AccessTokenType: user_pb.AccessTokenType_ACCESS_TOKEN_TYPE_BEARER, }) logging.OnError(err).Panic("create human user") + i.TriggerUserByID(ctx, resp.GetUserId()) return resp } diff --git a/internal/query/user.go b/internal/query/user.go index 3ee9a48463..4ea167d004 100644 --- a/internal/query/user.go +++ b/internal/query/user.go @@ -428,34 +428,6 @@ func (q *Queries) GetUserByLoginName(ctx context.Context, shouldTriggered bool, return user, err } -// Deprecated: use either GetUserByID or GetUserByLoginName -func (q *Queries) GetUser(ctx context.Context, shouldTriggerBulk bool, queries ...SearchQuery) (user *User, err error) { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() - - if shouldTriggerBulk { - triggerUserProjections(ctx) - } - - query, scan := prepareUserQuery(ctx, q.client) - for _, q := range queries { - query = q.toQuery(query) - } - eq := sq.Eq{ - UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), - } - stmt, args, err := query.Where(eq).ToSql() - if err != nil { - return nil, zerrors.ThrowInternal(err, "QUERY-Dnhr2", "Errors.Query.SQLStatment") - } - - err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { - user, err = scan(row) - return err - }, stmt, args...) - return user, err -} - func (q *Queries) GetHumanProfile(ctx context.Context, userID string, queries ...SearchQuery) (profile *Profile, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() diff --git a/internal/query/user_auth_method.go b/internal/query/user_auth_method.go index 0687545aef..ab6c464bad 100644 --- a/internal/query/user_auth_method.go +++ b/internal/query/user_auth_method.go @@ -3,6 +3,7 @@ package query import ( "context" "database/sql" + _ "embed" "errors" "slices" "time" @@ -212,6 +213,9 @@ type UserAuthMethodRequirements struct { ForceMFALocalOnly bool } +//go:embed user_auth_method_types_required.sql +var listUserAuthMethodTypesStmt string + func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string) (requirements *UserAuthMethodRequirements, err error) { ctxData := authz.GetCtxData(ctx) if ctxData.UserID != userID { @@ -222,20 +226,33 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, q.client) - eq := sq.Eq{ - UserIDCol.identifier(): userID, - UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), - } - stmt, args, err := query.Where(eq).ToSql() - if err != nil { - return nil, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest") - } - - err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { - requirements, err = scan(row) - return err - }, stmt, args...) + err = q.client.QueryRowContext(ctx, + func(row *sql.Row) error { + var userType sql.NullInt32 + var forceMFA sql.NullBool + var forceMFALocalOnly sql.NullBool + err := row.Scan( + &userType, + &forceMFA, + &forceMFALocalOnly, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal") + } + return zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal") + } + requirements = &UserAuthMethodRequirements{ + UserType: domain.UserType(userType.Int32), + ForceMFA: forceMFA.Bool, + ForceMFALocalOnly: forceMFALocalOnly.Bool, + } + return nil + }, + listUserAuthMethodTypesStmt, + userID, + authz.GetInstance(ctx).InstanceID(), + ) if err != nil { return nil, zerrors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal") } @@ -461,45 +478,6 @@ func prepareUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase, ac } } -func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { - loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery() - if err != nil { - return sq.SelectBuilder{}, nil - } - return sq.Select( - UserTypeCol.identifier(), - forceMFAForce.identifier(), - forceMFAForceLocalOnly.identifier()). - From(userTable.identifier()). - LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " + - "(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " + - forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()). - OrderBy(forceMFAIsDefault.identifier()). - Limit(1). - PlaceholderFormat(sq.Dollar), - func(row *sql.Row) (*UserAuthMethodRequirements, error) { - var userType sql.NullInt32 - var forceMFA sql.NullBool - var forceMFALocalOnly sql.NullBool - err := row.Scan( - &userType, - &forceMFA, - &forceMFALocalOnly, - ) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal") - } - return nil, zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal") - } - return &UserAuthMethodRequirements{ - UserType: domain.UserType(userType.Int32), - ForceMFA: forceMFA.Bool, - ForceMFALocalOnly: forceMFALocalOnly.Bool, - }, nil - } -} - func prepareAuthMethodsIDPsQuery() (string, error) { idpsQuery, _, err := sq.Select( userIDPsCountUserID.identifier(), @@ -536,16 +514,3 @@ func prepareAuthMethodQuery(activeOnly bool, includeWithoutDomain bool, queryDom return q.ToSql() } - -func prepareAuthMethodsForceMFAQuery() (string, error) { - loginPolicyQuery, _, err := sq.Select( - forceMFAForce.identifier(), - forceMFAForceLocalOnly.identifier(), - forceMFAInstanceID.identifier(), - forceMFAOrgID.identifier(), - forceMFAIsDefault.identifier(), - ). - From(forceMFATable.identifier()). - ToSql() - return loginPolicyQuery, err -} diff --git a/internal/query/user_auth_method_test.go b/internal/query/user_auth_method_test.go index 041e4f8e9e..47c50c4505 100644 --- a/internal/query/user_auth_method_test.go +++ b/internal/query/user_auth_method_test.go @@ -14,7 +14,6 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/domain" - "github.com/zitadel/zitadel/internal/zerrors" ) func TestUser_authMethodsCheckPermission(t *testing.T) { @@ -664,106 +663,6 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, object: (*AuthMethodTypes)(nil), }, - { - name: "prepareUserAuthMethodTypesRequiredQuery no result", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { - builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { - return scan(row) - } - }, - want: want{ - sqlExpectations: mockQueriesScanErr( - regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt), - nil, - nil, - ), - err: func(err error) (error, bool) { - if !zerrors.IsNotFound(err) { - return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false - } - return nil, true - }, - }, - object: (*UserAuthMethodRequirements)(nil), - }, - { - name: "prepareUserAuthMethodTypesRequiredQuery one second factor", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { - builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { - return scan(row) - } - }, - want: want{ - sqlExpectations: mockQueries( - regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt), - prepareAuthMethodTypesRequiredCols, - [][]driver.Value{ - { - domain.UserTypeHuman, - true, - true, - }, - }, - ), - }, - object: &UserAuthMethodRequirements{ - UserType: domain.UserTypeHuman, - ForceMFA: true, - ForceMFALocalOnly: true, - }, - }, - { - name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { - builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { - return scan(row) - } - }, - want: want{ - sqlExpectations: mockQueries( - regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt), - prepareAuthMethodTypesRequiredCols, - [][]driver.Value{ - { - domain.UserTypeHuman, - true, - true, - }, - }, - ), - }, - - object: &UserAuthMethodRequirements{ - UserType: domain.UserTypeHuman, - ForceMFA: true, - ForceMFALocalOnly: true, - }, - }, - { - name: "prepareUserAuthMethodTypesRequiredQuery sql err", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { - builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { - return scan(row) - } - }, - want: want{ - sqlExpectations: mockQueryErr( - regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt), - sql.ErrConnDone, - ), - err: func(err error) (error, bool) { - if !errors.Is(err, sql.ErrConnDone) { - return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false - } - return nil, true - }, - }, - object: nil, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/query/user_auth_method_types_required.sql b/internal/query/user_auth_method_types_required.sql new file mode 100644 index 0000000000..d10420f0eb --- /dev/null +++ b/internal/query/user_auth_method_types_required.sql @@ -0,0 +1,17 @@ +SELECT + projections.users14.type + , auth_methods_force_mfa.force_mfa + , auth_methods_force_mfa.force_mfa_local_only +FROM + projections.users14 +LEFT JOIN + projections.login_policies5 AS auth_methods_force_mfa +ON + auth_methods_force_mfa.instance_id = projections.users14.instance_id + AND auth_methods_force_mfa.aggregate_id = ANY(ARRAY[projections.users14.instance_id, projections.users14.resource_owner]) +WHERE + projections.users14.id = $1 + AND projections.users14.instance_id = $2 +ORDER BY + auth_methods_force_mfa.is_default +LIMIT 1; \ No newline at end of file