package query import ( "errors" "fmt" "reflect" "time" sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" ) type SearchResponse struct { Count uint64 *State } type SearchRequest struct { Offset uint64 Limit uint64 SortingColumn Column Asc bool } func (req *SearchRequest) toQuery(query sq.SelectBuilder) sq.SelectBuilder { if req.Offset > 0 { query = query.Offset(req.Offset) } if req.Limit > 0 { query = query.Limit(req.Limit) } if !req.SortingColumn.isZero() { clause := req.SortingColumn.orderBy() if !req.Asc { clause += " DESC" } query = query.OrderByClause(clause) } return query } type SearchQuery interface { toQuery(sq.SelectBuilder) sq.SelectBuilder comp() sq.Sqlizer Col() Column } type NotNullQuery struct { Column Column } func NewNotNullQuery(col Column) (*NotNullQuery, error) { if col.isZero() { return nil, ErrMissingColumn } return &NotNullQuery{ Column: col, }, nil } func (q *NotNullQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *NotNullQuery) comp() sq.Sqlizer { return sq.NotEq{q.Column.identifier(): nil} } func (q *NotNullQuery) Col() Column { return q.Column } type IsNullQuery struct { Column Column } func NewIsNullQuery(col Column) (*IsNullQuery, error) { if col.isZero() { return nil, ErrMissingColumn } return &IsNullQuery{ Column: col, }, nil } func (q *IsNullQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *IsNullQuery) comp() sq.Sqlizer { return sq.Eq{q.Column.identifier(): nil} } func (q *IsNullQuery) Col() Column { return q.Column } type OrQuery struct { queries []SearchQuery } func NewOrQuery(queries ...SearchQuery) (*OrQuery, error) { if len(queries) == 0 { return nil, ErrMissingColumn } return &OrQuery{queries: queries}, nil } func (q *OrQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *OrQuery) comp() sq.Sqlizer { or := make(sq.Or, len(q.queries)) for i, query := range q.queries { or[i] = query.comp() } return or } type AndQuery struct { queries []SearchQuery } func (q *AndQuery) Col() Column { return Column{} } func NewAndQuery(queries ...SearchQuery) (*AndQuery, error) { if len(queries) == 0 { return nil, ErrMissingColumn } return &AndQuery{queries: queries}, nil } func (q *AndQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *AndQuery) comp() sq.Sqlizer { and := make(sq.And, len(q.queries)) for i, query := range q.queries { and[i] = query.comp() } return and } type NotQuery struct { query SearchQuery } func (q *NotQuery) Col() Column { return q.query.Col() } func NewNotQuery(query SearchQuery) (*NotQuery, error) { if query == nil { return nil, ErrMissingColumn } return &NotQuery{query: query}, nil } func (q *NotQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (notQ NotQuery) ToSql() (sql string, args []interface{}, err error) { querySql, queryArgs, queryErr := notQ.query.comp().ToSql() // Handle the error from the query's ToSql() function. if queryErr != nil { return "", queryArgs, queryErr } // Construct the SQL statement. sql = fmt.Sprintf("NOT (%s)", querySql) return sql, queryArgs, nil } func (q *NotQuery) comp() sq.Sqlizer { return q } func (q *OrQuery) Col() Column { return Column{} } type ColumnComparisonQuery struct { Column1 Column Compare ColumnComparison Column2 Column } func NewColumnComparisonQuery(col1 Column, col2 Column, compare ColumnComparison) (*ColumnComparisonQuery, error) { if compare < 0 || compare >= columnCompareMax { return nil, ErrInvalidCompare } if col1.isZero() { return nil, ErrMissingColumn } if col2.isZero() { return nil, ErrMissingColumn } return &ColumnComparisonQuery{ Column1: col1, Column2: col2, Compare: compare, }, nil } func (q *ColumnComparisonQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *ColumnComparisonQuery) Col() Column { return Column{} } func (s *ColumnComparisonQuery) comp() sq.Sqlizer { switch s.Compare { case ColumnEquals: return sq.Expr(s.Column1.identifier() + " = " + s.Column2.identifier()) case ColumnNotEquals: return sq.Expr(s.Column1.identifier() + " != " + s.Column2.identifier()) } return nil } type ColumnComparison int const ( ColumnEquals ColumnComparison = iota ColumnNotEquals columnCompareMax ) type InTextQuery struct { Column Column Values []string } func (q *InTextQuery) Col() Column { return q.Column } func NewInTextQuery(col Column, values []string) (*InTextQuery, error) { if len(values) == 0 { return nil, ErrEmptyValues } if col.isZero() { return nil, ErrMissingColumn } return &InTextQuery{ Column: col, Values: values, }, nil } type textQuery struct { Column Column Text string Compare TextComparison } var ( ErrNothingSelected = errors.New("nothing selected") ErrInvalidCompare = errors.New("invalid compare") ErrMissingColumn = errors.New("missing column") ErrInvalidNumber = errors.New("value is no number") ErrEmptyValues = errors.New("values array must not be empty") ) func NewTextQuery(col Column, value string, compare TextComparison) (*textQuery, error) { if compare < 0 || compare >= textCompareMax { return nil, ErrInvalidCompare } if col.isZero() { return nil, ErrMissingColumn } // handle the comparisons which use (i)like and therefore need to escape potential wildcards in the value switch compare { case TextEqualsIgnoreCase, TextStartsWith, TextStartsWithIgnoreCase, TextEndsWith, TextEndsWithIgnoreCase, TextContains, TextContainsIgnoreCase: value = database.EscapeLikeWildcards(value) case TextEquals, TextListContains, TextNotEquals, textCompareMax: // do nothing } return &textQuery{ Column: col, Text: value, Compare: compare, }, nil } func (q *textQuery) Col() Column { return q.Column } func (q *InTextQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *InTextQuery) comp() sq.Sqlizer { // This translates to an IN query return sq.Eq{q.Column.identifier(): q.Values} } func (q *textQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *textQuery) comp() sq.Sqlizer { switch q.Compare { case TextEquals: return sq.Eq{q.Column.identifier(): q.Text} case TextNotEquals: return sq.NotEq{q.Column.identifier(): q.Text} case TextEqualsIgnoreCase: return sq.ILike{q.Column.identifier(): q.Text} case TextStartsWith: return sq.Like{q.Column.identifier(): q.Text + "%"} case TextStartsWithIgnoreCase: return sq.ILike{q.Column.identifier(): q.Text + "%"} case TextEndsWith: return sq.Like{q.Column.identifier(): "%" + q.Text} case TextEndsWithIgnoreCase: return sq.ILike{q.Column.identifier(): "%" + q.Text} case TextContains: return sq.Like{q.Column.identifier(): "%" + q.Text + "%"} case TextContainsIgnoreCase: return sq.ILike{q.Column.identifier(): "%" + q.Text + "%"} case TextListContains: return &listContains{col: q.Column, args: []interface{}{q.Text}} case textCompareMax: return nil } return nil } type TextComparison int const ( TextEquals TextComparison = iota TextEqualsIgnoreCase TextStartsWith TextStartsWithIgnoreCase TextEndsWith TextEndsWithIgnoreCase TextContains TextContainsIgnoreCase TextListContains TextNotEquals textCompareMax ) type NumberQuery struct { Column Column Number interface{} Compare NumberComparison } func NewNumberQuery(c Column, value interface{}, compare NumberComparison) (*NumberQuery, error) { if compare < 0 || compare >= numberCompareMax { return nil, ErrInvalidCompare } if c.isZero() { return nil, ErrMissingColumn } switch reflect.TypeOf(value).Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: // everything fine default: return nil, ErrInvalidNumber } return &NumberQuery{ Column: c, Number: value, Compare: compare, }, nil } func (q *NumberQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *NumberQuery) Col() Column { return q.Column } func (q *NumberQuery) comp() sq.Sqlizer { switch q.Compare { case NumberEquals: return sq.Eq{q.Column.identifier(): q.Number} case NumberNotEquals: return sq.NotEq{q.Column.identifier(): q.Number} case NumberLess: return sq.Lt{q.Column.identifier(): q.Number} case NumberGreater: return sq.Gt{q.Column.identifier(): q.Number} case NumberListContains: return &listContains{col: q.Column, args: []interface{}{q.Number}} case numberCompareMax: return nil } return nil } type NumberComparison int const ( NumberEquals NumberComparison = iota NumberNotEquals NumberLess NumberGreater NumberListContains numberCompareMax ) // Deprecated: Use NumberComparison, will be removed as soon as all calls are changed to query func NumberComparisonFromMethod(m domain.SearchMethod) NumberComparison { switch m { case domain.SearchMethodEquals: return NumberEquals case domain.SearchMethodNotEquals: return NumberNotEquals case domain.SearchMethodGreaterThan: return NumberGreater case domain.SearchMethodLessThan: return NumberLess case domain.SearchMethodListContains: return NumberListContains default: return numberCompareMax } } type SubSelect struct { Column Column Queries []SearchQuery } func NewSubSelect(c Column, queries []SearchQuery) (*SubSelect, error) { if len(queries) == 0 { return nil, ErrNothingSelected } if c.isZero() { return nil, ErrMissingColumn } return &SubSelect{ Column: c, Queries: queries, }, nil } func (q *SubSelect) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *SubSelect) comp() sq.Sqlizer { selectQuery := sq.Select(q.Column.identifier()).From(q.Column.table.identifier()) for _, query := range q.Queries { selectQuery = query.toQuery(selectQuery) } return selectQuery } type ListQuery struct { Column Column Data interface{} Compare ListComparison } func NewListQuery(column Column, value interface{}, compare ListComparison) (*ListQuery, error) { if compare < 0 || compare >= listCompareMax { return nil, ErrInvalidCompare } if column.isZero() { return nil, ErrMissingColumn } return &ListQuery{ Column: column, Data: value, Compare: compare, }, nil } func (q *ListQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *ListQuery) comp() sq.Sqlizer { if q.Compare != ListIn { return nil } if subSelect, ok := q.Data.(*SubSelect); ok { subSelect, args, err := subSelect.comp().ToSql() if err != nil { return nil } return sq.Expr(q.Column.identifier()+" IN ( "+subSelect+" )", args...) } return sq.Eq{q.Column.identifier(): q.Data} } func (q *ListQuery) Col() Column { return q.Column } type ListComparison int const ( ListIn ListComparison = iota listCompareMax ) func ListComparisonFromMethod(m domain.SearchMethod) ListComparison { switch m { case domain.SearchMethodEquals: return ListIn default: return listCompareMax } } type or struct { queries []SearchQuery } func Or(queries ...SearchQuery) *or { return &or{ queries: queries, } } func (q *or) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *or) comp() sq.Sqlizer { queries := make([]sq.Sqlizer, 0) for _, query := range q.queries { queries = append(queries, query.comp()) } return sq.Or(queries) } func (q *or) Col() Column { return Column{} } type BoolQuery struct { Column Column Value bool } func NewBoolQuery(c Column, value bool) (*BoolQuery, error) { return &BoolQuery{ Column: c, Value: value, }, nil } func (q *BoolQuery) Col() Column { return q.Column } func (q *BoolQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *BoolQuery) comp() sq.Sqlizer { return sq.Eq{q.Column.identifier(): q.Value} } type TimestampComparison int const ( TimestampEquals TimestampComparison = iota TimestampGreater TimestampGreaterOrEquals TimestampLess TimestampLessOrEquals ) type TimestampQuery struct { Column Column Compare TimestampComparison Value time.Time } func NewTimestampQuery(c Column, value time.Time, compare TimestampComparison) (*TimestampQuery, error) { return &TimestampQuery{ Column: c, Compare: compare, Value: value, }, nil } func (q *TimestampQuery) Col() Column { return q.Column } func (q *TimestampQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query.Where(q.comp()) } func (q *TimestampQuery) comp() sq.Sqlizer { switch q.Compare { case TimestampEquals: return sq.Eq{q.Column.identifier(): q.Value} case TimestampGreater: return sq.Gt{q.Column.identifier(): q.Value} case TimestampGreaterOrEquals: return sq.GtOrEq{q.Column.identifier(): q.Value} case TimestampLess: return sq.Lt{q.Column.identifier(): q.Value} case TimestampLessOrEquals: return sq.LtOrEq{q.Column.identifier(): q.Value} } return nil } var ( // countColumn represents the default counter for search responses countColumn = Column{ name: "COUNT(*) OVER ()", } // uniqueColumn shows if there are any results uniqueColumn = Column{ name: "COUNT(*) = 0", } ) type table struct { name string alias string instanceIDCol string } func (t table) setAlias(a string) table { t.alias = a return t } func (t table) identifier() string { if t.alias == "" { return t.name } return t.name + " AS " + t.alias } func (t table) isZero() bool { return t.name == "" } func (t table) InstanceIDIdentifier() string { if t.alias != "" { return t.alias + "." + t.instanceIDCol } return t.name + "." + t.instanceIDCol } type Column struct { name string table table isOrderByLower bool } func (c Column) identifier() string { if c.table.alias != "" { return c.table.alias + "." + c.name } if c.table.name != "" { return c.table.name + "." + c.name } return c.name } func (c Column) orderBy() string { if !c.isOrderByLower { return c.identifier() } return "LOWER(" + c.identifier() + ")" } func (c Column) setTable(t table) Column { c.table = t return c } func (c Column) isZero() bool { return c.table.isZero() || c.name == "" } func join(join, from Column) string { if join.identifier() == join.table.InstanceIDIdentifier() { return join.table.identifier() + " ON " + from.identifier() + " = " + join.identifier() } return join.table.identifier() + " ON " + from.identifier() + " = " + join.identifier() + " AND " + from.table.InstanceIDIdentifier() + " = " + join.table.InstanceIDIdentifier() } type listContains struct { col Column args interface{} } func (q *listContains) ToSql() (string, []interface{}, error) { return q.col.identifier() + " @> ? ", []interface{}{q.args}, nil }