diff --git a/docs/docs/apis/proto/user.md b/docs/docs/apis/proto/user.md index 6143aab915..b3fd8e9bd3 100644 --- a/docs/docs/apis/proto/user.md +++ b/docs/docs/apis/proto/user.md @@ -113,6 +113,18 @@ title: zitadel/user.proto +### LoginNameQuery + + + +| Field | Type | Description | Validation | +| ----- | ---- | ----------- | ----------- | +| login_name | string | - | string.max_len: 200
| +| method | zitadel.v1.TextQueryMethod | - | enum.defined_only: true
| + + + + ### Machine @@ -288,6 +300,7 @@ this query is always equals | [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) query.email_query | EmailQuery | - | | | [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) query.state_query | StateQuery | - | | | [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) query.type_query | TypeQuery | - | | +| [**oneof**](https://developers.google.com/protocol-buffers/docs/proto3#oneof) query.login_name_query | LoginNameQuery | - | | diff --git a/internal/api/grpc/server/middleware/instance_interceptor.go b/internal/api/grpc/server/middleware/instance_interceptor.go index 3fd1727c04..799382437e 100644 --- a/internal/api/grpc/server/middleware/instance_interceptor.go +++ b/internal/api/grpc/server/middleware/instance_interceptor.go @@ -85,9 +85,9 @@ func hostFromContext(ctx context.Context, headerName string) (string, error) { return host[0], nil } -//isAllowedToSendHTTP1Header check if the gRPC call was sent to `localhost` -//this is only possible when calling the server directly running on localhost -//or through the gRPC gateway +// isAllowedToSendHTTP1Header check if the gRPC call was sent to `localhost` +// this is only possible when calling the server directly running on localhost +// or through the gRPC gateway func isAllowedToSendHTTP1Header(md metadata.MD) bool { authority, ok := md[":authority"] return ok && len(authority) == 1 && strings.Split(authority[0], ":")[0] == "localhost" diff --git a/internal/api/grpc/user/query.go b/internal/api/grpc/user/query.go index 190b8a8159..81fc738e90 100644 --- a/internal/api/grpc/user/query.go +++ b/internal/api/grpc/user/query.go @@ -36,6 +36,8 @@ func UserQueryToQuery(query *user_pb.SearchQuery) (query.SearchQuery, error) { return StateQueryToQuery(q.StateQuery) case *user_pb.SearchQuery_TypeQuery: return TypeQueryToQuery(q.TypeQuery) + case *user_pb.SearchQuery_LoginNameQuery: + return LoginNameQueryToQuery(q.LoginNameQuery) case *user_pb.SearchQuery_ResourceOwner: return ResourceOwnerQueryToQuery(q.ResourceOwner) default: @@ -75,6 +77,10 @@ func TypeQueryToQuery(q *user_pb.TypeQuery) (query.SearchQuery, error) { return query.NewUserTypeSearchQuery(int32(q.Type)) } +func LoginNameQueryToQuery(q *user_pb.LoginNameQuery) (query.SearchQuery, error) { + return query.NewUserLoginNameExistsQuery(q.LoginName, object.TextMethodToQuery(q.Method)) +} + func ResourceOwnerQueryToQuery(q *user_pb.ResourceOwnerQuery) (query.SearchQuery, error) { return query.NewUserResourceOwnerSearchQuery(q.OrgID, query.TextEquals) } diff --git a/internal/api/ui/console/console.go b/internal/api/ui/console/console.go index 5b299f166b..49a2df0e19 100644 --- a/internal/api/ui/console/console.go +++ b/internal/api/ui/console/console.go @@ -65,8 +65,8 @@ func (i *spaHandler) Open(name string) (http.File, error) { return &file{File: f}, nil } -//file wraps the http.File and fs.FileInfo interfaces -//to return the build.Date() as ModTime() of the file +// file wraps the http.File and fs.FileInfo interfaces +// to return the build.Date() as ModTime() of the file type file struct { http.File fs.FileInfo diff --git a/internal/database/type.go b/internal/database/type.go index 2377aa3f10..5757529c8e 100644 --- a/internal/database/type.go +++ b/internal/database/type.go @@ -20,7 +20,7 @@ func (s *StringArray) Scan(src any) error { return nil } -// Value implements the `database/sql/driver.Valuer`` interface. +// Value implements the `database/sql/driver.Valuer` interface. func (s StringArray) Value() (driver.Value, error) { if len(s) == 0 { return nil, nil @@ -57,7 +57,7 @@ func (s *EnumArray[F]) Scan(src any) error { return nil } -// Value implements the `database/sql/driver.Valuer`` interface. +// Value implements the `database/sql/driver.Valuer` interface. func (s EnumArray[F]) Value() (driver.Value, error) { if len(s) == 0 { return nil, nil diff --git a/internal/eventstore/repository/search_query.go b/internal/eventstore/repository/search_query.go index 08a04bd7ce..6270fdc62b 100644 --- a/internal/eventstore/repository/search_query.go +++ b/internal/eventstore/repository/search_query.go @@ -6,7 +6,7 @@ import ( "github.com/zitadel/zitadel/internal/errors" ) -//SearchQuery defines the which and how data are queried +// SearchQuery defines the which and how data are queried type SearchQuery struct { Columns Columns Limit uint64 @@ -15,7 +15,7 @@ type SearchQuery struct { Tx *sql.Tx } -//Columns defines which fields of the event are needed for the query +// Columns defines which fields of the event are needed for the query type Columns int32 const ( @@ -36,14 +36,14 @@ func (c Columns) Validate() error { return nil } -//Filter represents all fields needed to compare a field of an event with a value +// Filter represents all fields needed to compare a field of an event with a value type Filter struct { Field Field Value interface{} Operation Operation } -//Operation defines how fields are compared +// Operation defines how fields are compared type Operation int32 const ( @@ -63,7 +63,7 @@ const ( operationCount ) -//Field is the representation of a field from the event +// Field is the representation of a field from the event type Field int32 const ( @@ -91,7 +91,7 @@ const ( fieldCount ) -//NewFilter is used in tests. Use searchQuery.*Filter() instead +// NewFilter is used in tests. Use searchQuery.*Filter() instead func NewFilter(field Field, value interface{}, operation Operation) *Filter { return &Filter{ Field: field, @@ -100,7 +100,7 @@ func NewFilter(field Field, value interface{}, operation Operation) *Filter { } } -//Validate checks if the fields of the filter have valid values +// Validate checks if the fields of the filter have valid values func (f *Filter) Validate() error { if f == nil { return errors.ThrowPreconditionFailed(nil, "REPO-z6KcG", "filter is nil") diff --git a/internal/query/search_query.go b/internal/query/search_query.go index 4239e64a04..50529a6634 100644 --- a/internal/query/search_query.go +++ b/internal/query/search_query.go @@ -89,6 +89,52 @@ func (q *orQuery) comp() sq.Sqlizer { return or } +type ColumnComparisonQuery struct { + Column1 Column + Compare ColumnComparison + Column2 Column +} + +func NewColumnComparisonQuery(col1 Column, col2 Column, compare ColumnComparison) (*ColumnComparisonQuery, error) { + if compare < 0 || compare >= columnCompareMax { + return nil, ErrInvalidCompare + } + if col1.isZero() { + return nil, ErrMissingColumn + } + if col2.isZero() { + return nil, ErrMissingColumn + } + return &ColumnComparisonQuery{ + Column1: col1, + Column2: col2, + Compare: compare, + }, nil +} + +func (q *ColumnComparisonQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { + return query.Where(q.comp()) +} + +func (s *ColumnComparisonQuery) comp() sq.Sqlizer { + switch s.Compare { + case ColumnEquals: + return sq.Expr(s.Column1.identifier() + " = " + s.Column2.identifier()) + case ColumnNotEquals: + return sq.Expr(s.Column1.identifier() + " != " + s.Column2.identifier()) + } + return nil +} + +type ColumnComparison int + +const ( + ColumnEquals ColumnComparison = iota + ColumnNotEquals + + columnCompareMax +) + type TextQuery struct { Column Column Text string @@ -96,9 +142,10 @@ type TextQuery struct { } var ( - ErrInvalidCompare = errors.New("invalid compare") - ErrMissingColumn = errors.New("missing column") - ErrInvalidNumber = errors.New("value is no number") + ErrNothingSelected = errors.New("nothing selected") + ErrInvalidCompare = errors.New("invalid compare") + ErrMissingColumn = errors.New("missing column") + ErrInvalidNumber = errors.New("value is no number") ) func NewTextQuery(col Column, value string, compare TextComparison) (*TextQuery, error) { @@ -262,13 +309,44 @@ func NumberComparisonFromMethod(m domain.SearchMethod) NumberComparison { } } +type SubSelect struct { + Column Column + Queries []SearchQuery +} + +func NewSubSelect(c Column, queries []SearchQuery) (*SubSelect, error) { + if len(queries) == 0 { + return nil, ErrNothingSelected + } + if c.isZero() { + return nil, ErrMissingColumn + } + + return &SubSelect{ + Column: c, + Queries: queries, + }, nil +} + +func (q *SubSelect) toQuery(query sq.SelectBuilder) sq.SelectBuilder { + return query.Where(q.comp()) +} + +func (q *SubSelect) comp() sq.Sqlizer { + selectQuery := sq.Select(q.Column.identifier()).From(q.Column.table.identifier()) + for _, query := range q.Queries { + selectQuery = query.toQuery(selectQuery) + } + return selectQuery +} + type ListQuery struct { Column Column - List []interface{} + Data interface{} Compare ListComparison } -func NewListQuery(column Column, value []interface{}, compare ListComparison) (*ListQuery, error) { +func NewListQuery(column Column, value interface{}, compare ListComparison) (*ListQuery, error) { if compare < 0 || compare >= listCompareMax { return nil, ErrInvalidCompare } @@ -277,7 +355,7 @@ func NewListQuery(column Column, value []interface{}, compare ListComparison) (* } return &ListQuery{ Column: column, - List: value, + Data: value, Compare: compare, }, nil } @@ -289,7 +367,14 @@ func (q *ListQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { func (s *ListQuery) comp() sq.Sqlizer { switch s.Compare { case ListIn: - return sq.Eq{s.Column.identifier(): s.List} + if subSelect, ok := s.Data.(*SubSelect); ok { + subSelect, args, err := subSelect.comp().ToSql() + if err != nil { + return nil + } + return sq.Expr(s.Column.identifier()+" IN ( "+subSelect+" )", args...) + } + return sq.Eq{s.Column.identifier(): s.Data} } return nil } diff --git a/internal/query/search_query_test.go b/internal/query/search_query_test.go index 77b9066ecc..db74c7c048 100644 --- a/internal/query/search_query_test.go +++ b/internal/query/search_query_test.go @@ -13,13 +13,29 @@ import ( var ( testTable = table{ name: "test_table", - alias: "test_table", instanceIDCol: "instance_id", } + testTableAlias = table{ + name: "test_table", + alias: "test_alias", + instanceIDCol: "instance_id", + } + testTable2 = table{ + name: "test_table2", + alias: "test_table2", + } testCol = Column{ name: "test_col", table: testTable, } + testColAlias = Column{ + name: "test_col", + table: testTableAlias, + } + testCol2 = Column{ + name: "test_col2", + table: testTable2, + } testLowerCol = Column{ name: "test_lower_col", table: testTable, @@ -140,6 +156,620 @@ func TestSearchRequest_ToQuery(t *testing.T) { } } +func TestNewSubSelect(t *testing.T) { + type args struct { + column Column + queries []SearchQuery + } + tests := []struct { + name string + args args + want *SubSelect + wantErr func(error) bool + }{ + { + name: "no query nil", + args: args{ + column: testCol, + queries: nil, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrNothingSelected) + }, + }, + { + name: "no query zero", + args: args{ + column: testCol, + queries: []SearchQuery{}, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrNothingSelected) + }, + }, + { + name: "no column 1", + args: args{ + column: Column{}, + queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}}, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrMissingColumn) + }, + }, + { + name: "no column name 1", + args: args{ + column: testNoCol, + queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}}, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrMissingColumn) + }, + }, + { + name: "correct 1", + args: args{ + column: testCol, + queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}}, + }, + want: &SubSelect{ + Column: testCol, + Queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}}, + }, + }, + { + name: "correct 3", + args: args{ + column: testCol, + queries: []SearchQuery{&TextQuery{testCol, "horst1", TextEquals}, &TextQuery{testCol, "horst2", TextEquals}, &TextQuery{testCol, "horst3", TextEquals}}, + }, + want: &SubSelect{ + Column: testCol, + Queries: []SearchQuery{&TextQuery{testCol, "horst1", TextEquals}, &TextQuery{testCol, "horst2", TextEquals}, &TextQuery{testCol, "horst3", TextEquals}}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewSubSelect(tt.args.column, tt.args.queries) + if err != nil && tt.wantErr == nil { + t.Errorf("NewTextQuery() no error expected got %v", err) + return + } else if tt.wantErr != nil && !tt.wantErr(err) { + t.Errorf("NewTextQuery() unexpeted error = %v", err) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewTextQuery() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSubSelect_comp(t *testing.T) { + type fields struct { + Column Column + Queries []SearchQuery + } + type want struct { + query interface{} + isNil bool + } + tests := []struct { + name string + fields fields + want want + }{ + { + name: "no queries", + fields: fields{ + Column: testCol, + Queries: []SearchQuery{}, + }, + want: want{ + query: sq.Select("test_table.test_col").From("test_table"), + }, + }, + { + name: "queries 1", + fields: fields{ + Column: testCol, + Queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}}, + }, + want: want{ + query: sq.Select("test_table.test_col").From("test_table").Where(sq.Eq{"test_table.test_col": interface{}("horst")}), + }, + }, + { + name: "queries 1 with alias", + fields: fields{ + Column: testColAlias, + Queries: []SearchQuery{&TextQuery{testColAlias, "horst", TextEquals}}, + }, + want: want{ + query: sq.Select("test_alias.test_col").From("test_table AS test_alias").Where(sq.Eq{"test_alias.test_col": interface{}("horst")}), + }, + }, + { + name: "queries 3", + fields: fields{ + Column: testCol, + Queries: []SearchQuery{&TextQuery{testCol, "horst1", TextEquals}, &TextQuery{testCol, "horst2", TextEquals}, &TextQuery{testCol, "horst3", TextEquals}}, + }, + want: want{ + query: sq.Select("test_table.test_col").From("test_table").From("test_table").Where(sq.Eq{"test_table.test_col": "horst1"}).From("test_table").Where(sq.Eq{"test_table.test_col": "horst2"}).From("test_table").Where(sq.Eq{"test_table.test_col": "horst3"}), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &SubSelect{ + Column: tt.fields.Column, + Queries: tt.fields.Queries, + } + query := s.comp() + if query == nil && tt.want.isNil { + return + } else if tt.want.isNil && query != nil { + t.Error("query should not be nil") + } + + if !reflect.DeepEqual(query, tt.want.query) { + t.Errorf("wrong query: want: %v, (%T), got: %v, (%T)", tt.want.query, tt.want.query, query, query) + } + }) + } +} + +func TestNewColumnComparisonQuery(t *testing.T) { + type args struct { + column Column + columnCompare Column + compare ColumnComparison + } + tests := []struct { + name string + args args + want *ColumnComparisonQuery + wantErr func(error) bool + }{ + { + name: "too low compare", + args: args{ + column: testCol, + columnCompare: testCol2, + compare: -1, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrInvalidCompare) + }, + }, + { + name: "too high compare", + args: args{ + column: testCol, + columnCompare: testCol2, + compare: columnCompareMax, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrInvalidCompare) + }, + }, + { + name: "no column 1", + args: args{ + column: Column{}, + columnCompare: testCol2, + compare: ColumnEquals, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrMissingColumn) + }, + }, + { + name: "no column 2", + args: args{ + column: testCol, + columnCompare: Column{}, + compare: ColumnEquals, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrMissingColumn) + }, + }, + { + name: "no column name 1", + args: args{ + column: testNoCol, + columnCompare: testCol2, + compare: ColumnEquals, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrMissingColumn) + }, + }, + { + name: "no column name 2", + args: args{ + column: testCol, + columnCompare: testNoCol, + compare: ColumnEquals, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrMissingColumn) + }, + }, + { + name: "correct", + args: args{ + column: testCol, + columnCompare: testCol2, + compare: ColumnEquals, + }, + want: &ColumnComparisonQuery{ + Column1: testCol, + Column2: testCol2, + Compare: ColumnEquals, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewColumnComparisonQuery(tt.args.column, tt.args.columnCompare, tt.args.compare) + if err != nil && tt.wantErr == nil { + t.Errorf("NewTextQuery() no error expected got %v", err) + return + } else if tt.wantErr != nil && !tt.wantErr(err) { + t.Errorf("NewTextQuery() unexpeted error = %v", err) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewTextQuery() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestColumnComparisonQuery_comp(t *testing.T) { + type fields struct { + Column Column + ColumnCompare Column + Compare ColumnComparison + } + type want struct { + query interface{} + isNil bool + } + tests := []struct { + name string + fields fields + want want + }{ + { + name: "equals", + fields: fields{ + Column: testCol, + ColumnCompare: testCol2, + Compare: ColumnEquals, + }, + want: want{ + query: sq.Expr("test_table.test_col = test_table2.test_col2"), + }, + }, + { + name: "not equals", + fields: fields{ + Column: testCol, + ColumnCompare: testCol2, + Compare: ColumnNotEquals, + }, + want: want{ + query: sq.Expr("test_table.test_col != test_table2.test_col2"), + }, + }, + { + name: "too high comparison", + fields: fields{ + Column: testCol, + ColumnCompare: testCol2, + Compare: columnCompareMax, + }, + want: want{ + isNil: true, + }, + }, + { + name: "too low comparison", + fields: fields{ + Column: testCol, + ColumnCompare: testCol2, + Compare: -1, + }, + want: want{ + isNil: true, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &ColumnComparisonQuery{ + Column1: tt.fields.Column, + Column2: tt.fields.ColumnCompare, + Compare: tt.fields.Compare, + } + query := s.comp() + if query == nil && tt.want.isNil { + return + } else if tt.want.isNil && query != nil { + t.Error("query should not be nil") + } + + if !reflect.DeepEqual(query, tt.want.query) { + t.Errorf("wrong query: want: %v, (%T), got: %v, (%T)", tt.want.query, tt.want.query, query, query) + } + }) + } +} + +func TestNewListQuery(t *testing.T) { + type args struct { + column Column + data interface{} + compare ListComparison + } + tests := []struct { + name string + args args + want *ListQuery + wantErr func(error) bool + }{ + { + name: "too low compare", + args: args{ + column: testCol, + data: []interface{}{"hurst"}, + compare: -1, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrInvalidCompare) + }, + }, + { + name: "too high compare", + args: args{ + column: testCol, + data: []interface{}{"hurst"}, + compare: listCompareMax, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrInvalidCompare) + }, + }, + { + name: "no column", + args: args{ + column: Column{}, + data: []interface{}{"hurst"}, + compare: ListIn, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrMissingColumn) + }, + }, + { + name: "no column name", + args: args{ + column: testNoCol, + data: []interface{}{"hurst"}, + compare: ListIn, + }, + wantErr: func(err error) bool { + return errors.Is(err, ErrMissingColumn) + }, + }, + { + name: "correct slice", + args: args{ + column: testCol, + data: []interface{}{"hurst"}, + compare: ListIn, + }, + want: &ListQuery{ + Column: testCol, + Data: []interface{}{"hurst"}, + Compare: ListIn, + }, + }, + { + name: "correct", + args: args{ + column: testCol, + data: &SubSelect{Column: testCol, Queries: []SearchQuery{&TextQuery{testCol, "horst1", TextEquals}}}, + compare: ListIn, + }, + want: &ListQuery{ + Column: testCol, + Data: &SubSelect{Column: testCol, Queries: []SearchQuery{&TextQuery{testCol, "horst1", TextEquals}}}, + Compare: ListIn, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewListQuery(tt.args.column, tt.args.data, tt.args.compare) + if err != nil && tt.wantErr == nil { + t.Errorf("NewTextQuery() no error expected got %v", err) + return + } else if tt.wantErr != nil && !tt.wantErr(err) { + t.Errorf("NewTextQuery() unexpeted error = %v", err) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewTextQuery() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestListQuery_comp(t *testing.T) { + type fields struct { + Column Column + Data interface{} + Compare ListComparison + } + type want struct { + query interface{} + isNil bool + } + tests := []struct { + name string + fields fields + want want + }{ + { + name: "in list one element", + fields: fields{ + Column: testCol, + Data: []interface{}{"hurst"}, + Compare: ListIn, + }, + want: want{ + query: sq.Eq{"test_table.test_col": []interface{}{"hurst"}}, + }, + }, + { + name: "in list three elements", + fields: fields{ + Column: testCol, + Data: []interface{}{"hurst1", "hurst2", "hurst3"}, + Compare: ListIn, + }, + want: want{ + query: sq.Eq{"test_table.test_col": []interface{}{"hurst1", "hurst2", "hurst3"}}, + }, + }, + { + name: "in string list one element", + fields: fields{ + Column: testCol, + Data: []string{"hurst"}, + Compare: ListIn, + }, + want: want{ + query: sq.Eq{"test_table.test_col": []string{"hurst"}}, + }, + }, + { + name: "in string list three elements", + fields: fields{ + Column: testCol, + Data: []string{"hurst1", "hurst2", "hurst3"}, + Compare: ListIn, + }, + want: want{ + query: sq.Eq{"test_table.test_col": []string{"hurst1", "hurst2", "hurst3"}}, + }, + }, + { + name: "in int list one element", + fields: fields{ + Column: testCol, + Data: []int{1}, + Compare: ListIn, + }, + want: want{ + query: sq.Eq{"test_table.test_col": []int{1}}, + }, + }, + { + name: "in int list three elements", + fields: fields{ + Column: testCol, + Data: []int{1, 2, 3}, + Compare: ListIn, + }, + want: want{ + query: sq.Eq{"test_table.test_col": []int{1, 2, 3}}, + }, + }, + { + name: "in subquery text", + fields: fields{ + Column: testCol, + Data: &SubSelect{Column: testCol, Queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}}}, + Compare: ListIn, + }, + want: want{ + query: sq.Expr("test_table.test_col IN ( SELECT test_table.test_col FROM test_table WHERE test_table.test_col = ? )", "horst"), + }, + }, + { + name: "in subquery number", + fields: fields{ + Column: testCol, + Data: &SubSelect{Column: testCol, Queries: []SearchQuery{&NumberQuery{testCol, 1, NumberEquals}}}, + Compare: ListIn, + }, + want: want{ + query: sq.Expr("test_table.test_col IN ( SELECT test_table.test_col FROM test_table WHERE test_table.test_col = ? )", 1), + }, + }, + { + name: "in subquery column", + fields: fields{ + Column: testCol, + Data: &SubSelect{Column: testCol, Queries: []SearchQuery{&ColumnComparisonQuery{testCol, ColumnEquals, testCol2}}}, + Compare: ListIn, + }, + want: want{ + query: sq.Expr("test_table.test_col IN ( SELECT test_table.test_col FROM test_table WHERE test_table.test_col = test_table2.test_col2 )"), + }, + }, + { + name: "too high comparison", + fields: fields{ + Column: testCol, + Data: []interface{}{"hurst"}, + Compare: listCompareMax, + }, + want: want{ + isNil: true, + }, + }, + { + name: "too low comparison", + fields: fields{ + Column: testCol, + Data: []interface{}{"hurst"}, + Compare: -1, + }, + want: want{ + isNil: true, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &ListQuery{ + Column: tt.fields.Column, + Data: tt.fields.Data, + Compare: tt.fields.Compare, + } + query := s.comp() + if query == nil && tt.want.isNil { + return + } else if tt.want.isNil && query != nil { + t.Error("query should not be nil") + } + + if !reflect.DeepEqual(query, tt.want.query) { + t.Errorf("wrong query: want: %v, (%T), got: %v, (%T)", tt.want.query, tt.want.query, query, query) + } + }) + } +} + func TestNewTextQuery(t *testing.T) { type args struct { column Column diff --git a/internal/query/user.go b/internal/query/user.go index eecb603897..6a9ab7f39a 100644 --- a/internal/query/user.go +++ b/internal/query/user.go @@ -569,6 +569,34 @@ func NewUserLoginNamesSearchQuery(value string) (SearchQuery, error) { return NewTextQuery(userLoginNamesListCol, value, TextListContains) } +func NewUserLoginNameExistsQuery(value string, comparison TextComparison) (SearchQuery, error) { + //linking queries for the subselect + instanceQuery, err := NewColumnComparisonQuery(LoginNameInstanceIDCol, UserInstanceIDCol, ColumnEquals) + if err != nil { + return nil, err + } + userIDQuery, err := NewColumnComparisonQuery(LoginNameUserIDCol, UserIDCol, ColumnEquals) + if err != nil { + return nil, err + } + //text query to select data from the linked sub select + loginNameQuery, err := NewTextQuery(LoginNameNameCol, value, comparison) + if err != nil { + return nil, err + } + //full definition of the sub select + subSelect, err := NewSubSelect(LoginNameUserIDCol, []SearchQuery{instanceQuery, userIDQuery, loginNameQuery}) + if err != nil { + return nil, err + } + // "WHERE * IN (*)" query with subquery as list-data provider + return NewListQuery( + UserIDCol, + subSelect, + ListIn, + ) +} + func prepareUserQuery(instanceID string) (sq.SelectBuilder, func(*sql.Row) (*User, error)) { loginNamesQuery, loginNamesArgs, err := sq.Select( userLoginNamesUserIDCol.identifier(), diff --git a/proto/zitadel/user.proto b/proto/zitadel/user.proto index be10315e54..258bf44f1e 100644 --- a/proto/zitadel/user.proto +++ b/proto/zitadel/user.proto @@ -168,6 +168,7 @@ message SearchQuery { EmailQuery email_query = 6; StateQuery state_query = 7; TypeQuery type_query = 8; + LoginNameQuery login_name_query = 9; } } @@ -262,6 +263,22 @@ message EmailQuery { ]; } +message LoginNameQuery { + string login_name = 1 [ + (validate.rules).string = {max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + max_length: 200; + example: "\"gigi@zitadel.cloud\""; + } + ]; + zitadel.v1.TextQueryMethod method = 2 [ + (validate.rules).enum.defined_only = true, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "defines which text equality method is used"; + } + ]; +} + //UserStateQuery is always equals message StateQuery { UserState state = 1 [