package repository import ( "context" "database/sql" "fmt" "github.com/jinzhu/gorm" "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/zerrors" ) type SearchRequest interface { GetLimit() uint64 GetOffset() uint64 GetSortingColumn() ColumnKey GetAsc() bool GetQueries() []SearchQuery } type SearchQuery interface { GetKey() ColumnKey GetMethod() domain.SearchMethod GetValue() interface{} } type ColumnKey interface { ToColumnName() string } func PrepareSearchQuery(table string, request SearchRequest) func(db *gorm.DB, res interface{}) (uint64, error) { return func(db *gorm.DB, res interface{}) (uint64, error) { var count uint64 = 0 query := db.Table(table) if column := request.GetSortingColumn(); column != nil { order := "DESC" if request.GetAsc() { order = "ASC" } query = query.Order(fmt.Sprintf("%s %s", column.ToColumnName(), order)) } for _, q := range request.GetQueries() { var err error query, err = SetQuery(query, q.GetKey(), q.GetValue(), q.GetMethod()) if err != nil { return count, zerrors.ThrowInvalidArgument(err, "VIEW-KaGue", "query is invalid") } } query = query.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) defer func() { if err := query.Commit().Error; err != nil { logging.OnError(err).Info("commit failed") } query.RollbackUnlessCommitted() }() query = query.Count(&count) if res == nil { return count, nil } if request.GetLimit() != 0 { query = query.Limit(request.GetLimit()) } query = query.Offset(request.GetOffset()) err := query.Find(res).Error if err != nil { return count, zerrors.ThrowInternal(err, "VIEW-muSDK", "unable to find result") } return count, nil } } func SetQuery(query *gorm.DB, key ColumnKey, value interface{}, method domain.SearchMethod) (*gorm.DB, error) { column := key.ToColumnName() if column == "" { return nil, zerrors.ThrowInvalidArgument(nil, "VIEW-7dz3w", "Column name missing") } switch method { case domain.SearchMethodEquals: query = query.Where(""+column+" = ?", value) case domain.SearchMethodEqualsIgnoreCase: valueText, ok := value.(string) if !ok { return nil, zerrors.ThrowInvalidArgument(nil, "VIEW-idu8e", "Equal ignore case only possible for strings") } query = query.Where("LOWER("+column+") = LOWER(?)", valueText) case domain.SearchMethodStartsWith: valueText, ok := value.(string) if !ok { return nil, zerrors.ThrowInvalidArgument(nil, "VIEW-SLj7s", "Starts with only possible for strings") } valueText = database.EscapeLikeWildcards(valueText) query = query.Where(column+" LIKE ?", valueText+"%") case domain.SearchMethodStartsWithIgnoreCase: valueText, ok := value.(string) if !ok { return nil, zerrors.ThrowInvalidArgument(nil, "VIEW-eidus", "Starts with ignore case only possible for strings") } valueText = database.EscapeLikeWildcards(valueText) query = query.Where("LOWER("+column+") LIKE LOWER(?)", valueText+"%") case domain.SearchMethodEndsWith: valueText, ok := value.(string) if !ok { return nil, zerrors.ThrowInvalidArgument(nil, "VIEW-Hswd3", "Ends with only possible for strings") } valueText = database.EscapeLikeWildcards(valueText) query = query.Where(column+" LIKE ?", "%"+valueText) case domain.SearchMethodEndsWithIgnoreCase: valueText, ok := value.(string) if !ok { return nil, zerrors.ThrowInvalidArgument(nil, "VIEW-dAG31", "Ends with ignore case only possible for strings") } valueText = database.EscapeLikeWildcards(valueText) query = query.Where("LOWER("+column+") LIKE LOWER(?)", "%"+valueText) case domain.SearchMethodContains: valueText, ok := value.(string) if !ok { return nil, zerrors.ThrowInvalidArgument(nil, "VIEW-3ids", "Contains with only possible for strings") } valueText = database.EscapeLikeWildcards(valueText) query = query.Where(column+" LIKE ?", "%"+valueText+"%") case domain.SearchMethodContainsIgnoreCase: valueText, ok := value.(string) if !ok { return nil, zerrors.ThrowInvalidArgument(nil, "VIEW-eid73", "Contains with ignore case only possible for strings") } valueText = database.EscapeLikeWildcards(valueText) query = query.Where("LOWER("+column+") LIKE LOWER(?)", "%"+valueText+"%") case domain.SearchMethodNotEquals: query = query.Where(""+column+" <> ?", value) case domain.SearchMethodGreaterThan: query = query.Where(column+" > ?", value) case domain.SearchMethodLessThan: query = query.Where(column+" < ?", value) case domain.SearchMethodIsOneOf: query = query.Where(column+" IN (?)", value) case domain.SearchMethodListContains: valueText, ok := value.(string) if !ok { return nil, zerrors.ThrowInvalidArgument(nil, "VIEW-Psois", "list contains only possible for strings") } query = query.Where("? <@ "+column, database.TextArray[string]{valueText}) default: return nil, nil } return query, nil }