diff --git a/internal/api/grpc/admin/org.go b/internal/api/grpc/admin/org.go index 93e6936d42..90b99ca208 100644 --- a/internal/api/grpc/admin/org.go +++ b/internal/api/grpc/admin/org.go @@ -9,7 +9,6 @@ import ( http_utils "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/domain" - "github.com/zitadel/zitadel/internal/query" admin_pb "github.com/zitadel/zitadel/pkg/grpc/admin" ) @@ -104,17 +103,5 @@ func (s *Server) SetUpOrg(ctx context.Context, req *admin_pb.SetUpOrgRequest) (* } func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain string) ([]string, error) { - loginName, err := query.NewUserPreferredLoginNameSearchQuery("@"+orgDomain, query.TextEndsWithIgnoreCase) - if err != nil { - return nil, err - } - users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil) - if err != nil { - return nil, err - } - userIDs := make([]string, len(users.Users)) - for i, user := range users.Users { - userIDs[i] = user.ID - } - return userIDs, nil + return s.query.SearchClaimedUserIDsOfOrgDomain(ctx, orgDomain, "") } diff --git a/internal/api/grpc/management/org.go b/internal/api/grpc/management/org.go index a6a934160a..57caeda3ce 100644 --- a/internal/api/grpc/management/org.go +++ b/internal/api/grpc/management/org.go @@ -316,28 +316,7 @@ func (s *Server) RemoveOrgMember(ctx context.Context, req *mgmt_pb.RemoveOrgMemb } func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain, orgID string) ([]string, error) { - queries := make([]query.SearchQuery, 0, 2) - loginName, err := query.NewUserPreferredLoginNameSearchQuery("@"+orgDomain, query.TextEndsWithIgnoreCase) - if err != nil { - return nil, err - } - queries = append(queries, loginName) - if orgID != "" { - owner, err := query.NewUserResourceOwnerSearchQuery(orgID, query.TextNotEquals) - if err != nil { - return nil, err - } - queries = append(queries, owner) - } - users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, orgID, nil) - if err != nil { - return nil, err - } - userIDs := make([]string, len(users.Users)) - for i, user := range users.Users { - userIDs[i] = user.ID - } - return userIDs, nil + return s.query.SearchClaimedUserIDsOfOrgDomain(ctx, orgDomain, orgID) } func (s *Server) ListOrgMetadata(ctx context.Context, req *mgmt_pb.ListOrgMetadataRequest) (*mgmt_pb.ListOrgMetadataResponse, error) { diff --git a/internal/api/ui/login/login.go b/internal/api/ui/login/login.go index 5fa97ddd56..5ff27c14fc 100644 --- a/internal/api/ui/login/login.go +++ b/internal/api/ui/login/login.go @@ -178,19 +178,7 @@ func (l *Login) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgName string if err != nil { return nil, err } - loginName, err := query.NewUserPreferredLoginNameSearchQuery("@"+orgDomain, query.TextEndsWithIgnoreCase) - if err != nil { - return nil, err - } - users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil) - if err != nil { - return nil, err - } - userIDs := make([]string, len(users.Users)) - for i, user := range users.Users { - userIDs[i] = user.ID - } - return userIDs, nil + return l.query.SearchClaimedUserIDsOfOrgDomain(ctx, orgDomain, "") } func setContext(ctx context.Context, resourceOwner string) context.Context { diff --git a/internal/query/user.go b/internal/query/user.go index 56d3d130f1..c0fc3de97c 100644 --- a/internal/query/user.go +++ b/internal/query/user.go @@ -685,6 +685,35 @@ func (q *Queries) IsUserUnique(ctx context.Context, username, email, resourceOwn return isUnique, err } +//go:embed user_claimed_user_ids.sql +var userClaimedUserIDOfOrgDomain string + +func (q *Queries) SearchClaimedUserIDsOfOrgDomain(ctx context.Context, domain, orgID string) (userIDs []string, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + err = q.client.QueryContext(ctx, + func(rows *sql.Rows) error { + userIDs = make([]string, 0) + for rows.Next() { + var userID string + err := rows.Scan(&userID) + if err != nil { + return err + } + userIDs = append(userIDs, userID) + } + return nil + }, + userClaimedUserIDOfOrgDomain, + authz.GetInstance(ctx).InstanceID(), + "%@"+domain, + orgID, + ) + + return userIDs, err +} + func (q *UserSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder { query = q.SearchRequest.toQuery(query) for _, q := range q.Queries { diff --git a/internal/query/user_claimed_user_ids.sql b/internal/query/user_claimed_user_ids.sql new file mode 100644 index 0000000000..5d4639be46 --- /dev/null +++ b/internal/query/user_claimed_user_ids.sql @@ -0,0 +1,13 @@ +SELECT u.id +FROM projections.login_names3_users u + LEFT JOIN projections.login_names3_policies p_custom + ON u.instance_id = p_custom.instance_id + AND p_custom.instance_id = $1 + AND p_custom.resource_owner = u.resource_owner + JOIN projections.login_names3_policies p_default + ON u.instance_id = p_default.instance_id + AND p_default.instance_id = $1 AND p_default.is_default IS TRUE +WHERE u.instance_id = $1 + AND COALESCE(p_custom.must_be_domain, p_default.must_be_domain) = false + AND u.user_name_lower like $2 + AND u.resource_owner <> $3; \ No newline at end of file