fix(query): escape wildcards in text search (#7131) (#7135)

* fix(query): escape like wildcards

* test: search query wildcards

* add do nothing
This commit is contained in:
Silvan
2024-01-02 16:27:36 +01:00
committed by GitHub
parent 9892fd92b6
commit 8bc56f6fe7
6 changed files with 736 additions and 51 deletions

View File

@@ -8,6 +8,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
)
@@ -255,7 +256,7 @@ func NewInTextQuery(col Column, values []string) (*InTextQuery, error) {
}, nil
}
type TextQuery struct {
type textQuery struct {
Column Column
Text string
Compare TextComparison
@@ -269,21 +270,38 @@ var (
ErrEmptyValues = errors.New("values array must not be empty")
)
func NewTextQuery(col Column, value string, compare TextComparison) (*TextQuery, error) {
func NewTextQuery(col Column, value string, compare TextComparison) (*textQuery, error) {
if compare < 0 || compare >= textCompareMax {
return nil, ErrInvalidCompare
}
if col.isZero() {
return nil, ErrMissingColumn
}
return &TextQuery{
// handle the comparisons which use (i)like and therefore need to escape potential wildcards in the value
switch compare {
case TextEqualsIgnoreCase,
TextStartsWith,
TextStartsWithIgnoreCase,
TextEndsWith,
TextEndsWithIgnoreCase,
TextContains,
TextContainsIgnoreCase:
value = database.EscapeLikeWildcards(value)
case TextEquals,
TextListContains,
TextNotEquals,
textCompareMax:
// do nothing
}
return &textQuery{
Column: col,
Text: value,
Compare: compare,
}, nil
}
func (q *TextQuery) Col() Column {
func (q *textQuery) Col() Column {
return q.Column
}
@@ -296,11 +314,11 @@ func (q *InTextQuery) comp() sq.Sqlizer {
return sq.Eq{q.Column.identifier(): q.Values}
}
func (q *TextQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
func (q *textQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
return query.Where(q.comp())
}
func (q *TextQuery) comp() sq.Sqlizer {
func (q *textQuery) comp() sq.Sqlizer {
switch q.Compare {
case TextEquals:
return sq.Eq{q.Column.identifier(): q.Text}
@@ -346,32 +364,6 @@ const (
textCompareMax
)
// Deprecated: Use TextComparison, will be removed as soon as all calls are changed to query
func TextComparisonFromMethod(m domain.SearchMethod) TextComparison {
switch m {
case domain.SearchMethodEquals:
return TextEquals
case domain.SearchMethodEqualsIgnoreCase:
return TextEqualsIgnoreCase
case domain.SearchMethodStartsWith:
return TextStartsWith
case domain.SearchMethodStartsWithIgnoreCase:
return TextStartsWithIgnoreCase
case domain.SearchMethodContains:
return TextContains
case domain.SearchMethodContainsIgnoreCase:
return TextContainsIgnoreCase
case domain.SearchMethodEndsWith:
return TextEndsWith
case domain.SearchMethodEndsWithIgnoreCase:
return TextEndsWithIgnoreCase
case domain.SearchMethodListContains:
return TextListContains
default:
return textCompareMax
}
}
type NumberQuery struct {
Column Column
Number interface{}

View File

@@ -191,7 +191,7 @@ func TestNewSubSelect(t *testing.T) {
name: "no column 1",
args: args{
column: Column{},
queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}},
queries: []SearchQuery{&textQuery{testCol, "horst", TextEquals}},
},
wantErr: func(err error) bool {
return errors.Is(err, ErrMissingColumn)
@@ -201,7 +201,7 @@ func TestNewSubSelect(t *testing.T) {
name: "no column name 1",
args: args{
column: testNoCol,
queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}},
queries: []SearchQuery{&textQuery{testCol, "horst", TextEquals}},
},
wantErr: func(err error) bool {
return errors.Is(err, ErrMissingColumn)
@@ -211,22 +211,22 @@ func TestNewSubSelect(t *testing.T) {
name: "correct 1",
args: args{
column: testCol,
queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}},
queries: []SearchQuery{&textQuery{testCol, "horst", TextEquals}},
},
want: &SubSelect{
Column: testCol,
Queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}},
Queries: []SearchQuery{&textQuery{testCol, "horst", TextEquals}},
},
},
{
name: "correct 3",
args: args{
column: testCol,
queries: []SearchQuery{&TextQuery{testCol, "horst1", TextEquals}, &TextQuery{testCol, "horst2", TextEquals}, &TextQuery{testCol, "horst3", TextEquals}},
queries: []SearchQuery{&textQuery{testCol, "horst1", TextEquals}, &textQuery{testCol, "horst2", TextEquals}, &textQuery{testCol, "horst3", TextEquals}},
},
want: &SubSelect{
Column: testCol,
Queries: []SearchQuery{&TextQuery{testCol, "horst1", TextEquals}, &TextQuery{testCol, "horst2", TextEquals}, &TextQuery{testCol, "horst3", TextEquals}},
Queries: []SearchQuery{&textQuery{testCol, "horst1", TextEquals}, &textQuery{testCol, "horst2", TextEquals}, &textQuery{testCol, "horst3", TextEquals}},
},
},
}
@@ -275,7 +275,7 @@ func TestSubSelect_comp(t *testing.T) {
name: "queries 1",
fields: fields{
Column: testCol,
Queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}},
Queries: []SearchQuery{&textQuery{testCol, "horst", TextEquals}},
},
want: want{
query: sq.Select("test_table.test_col").From("test_table").Where(sq.Eq{"test_table.test_col": interface{}("horst")}),
@@ -285,7 +285,7 @@ func TestSubSelect_comp(t *testing.T) {
name: "queries 1 with alias",
fields: fields{
Column: testColAlias,
Queries: []SearchQuery{&TextQuery{testColAlias, "horst", TextEquals}},
Queries: []SearchQuery{&textQuery{testColAlias, "horst", TextEquals}},
},
want: want{
query: sq.Select("test_alias.test_col").From("test_table AS test_alias").Where(sq.Eq{"test_alias.test_col": interface{}("horst")}),
@@ -295,7 +295,7 @@ func TestSubSelect_comp(t *testing.T) {
name: "queries 3",
fields: fields{
Column: testCol,
Queries: []SearchQuery{&TextQuery{testCol, "horst1", TextEquals}, &TextQuery{testCol, "horst2", TextEquals}, &TextQuery{testCol, "horst3", TextEquals}},
Queries: []SearchQuery{&textQuery{testCol, "horst1", TextEquals}, &textQuery{testCol, "horst2", TextEquals}, &textQuery{testCol, "horst3", TextEquals}},
},
want: want{
query: sq.Select("test_table.test_col").From("test_table").From("test_table").Where(sq.Eq{"test_table.test_col": "horst1"}).From("test_table").Where(sq.Eq{"test_table.test_col": "horst2"}).From("test_table").Where(sq.Eq{"test_table.test_col": "horst3"}),
@@ -585,12 +585,12 @@ func TestNewListQuery(t *testing.T) {
name: "correct",
args: args{
column: testCol,
data: &SubSelect{Column: testCol, Queries: []SearchQuery{&TextQuery{testCol, "horst1", TextEquals}}},
data: &SubSelect{Column: testCol, Queries: []SearchQuery{&textQuery{testCol, "horst1", TextEquals}}},
compare: ListIn,
},
want: &ListQuery{
Column: testCol,
Data: &SubSelect{Column: testCol, Queries: []SearchQuery{&TextQuery{testCol, "horst1", TextEquals}}},
Data: &SubSelect{Column: testCol, Queries: []SearchQuery{&textQuery{testCol, "horst1", TextEquals}}},
Compare: ListIn,
},
},
@@ -697,7 +697,7 @@ func TestListQuery_comp(t *testing.T) {
name: "in subquery text",
fields: fields{
Column: testCol,
Data: &SubSelect{Column: testCol, Queries: []SearchQuery{&TextQuery{testCol, "horst", TextEquals}}},
Data: &SubSelect{Column: testCol, Queries: []SearchQuery{&textQuery{testCol, "horst", TextEquals}}},
Compare: ListIn,
},
want: want{
@@ -779,7 +779,7 @@ func TestNewTextQuery(t *testing.T) {
tests := []struct {
name string
args args
want *TextQuery
want *textQuery
wantErr func(error) bool
}{
{
@@ -827,18 +827,317 @@ func TestNewTextQuery(t *testing.T) {
},
},
{
name: "correct",
name: "equals",
args: args{
column: testCol,
value: "hurst",
compare: TextEquals,
},
want: &TextQuery{
want: &textQuery{
Column: testCol,
Text: "hurst",
Compare: TextEquals,
},
},
{
name: "equals ignore case",
args: args{
column: testCol,
value: "hurst",
compare: TextEqualsIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "hurst",
Compare: TextEqualsIgnoreCase,
},
},
{
name: "equals ignore case % wildcard",
args: args{
column: testCol,
value: "hu%rst",
compare: TextEqualsIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "hu\\%rst",
Compare: TextEqualsIgnoreCase,
},
},
{
name: "equals ignore case _ wildcard",
args: args{
column: testCol,
value: "hu_rst",
compare: TextEqualsIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "hu\\_rst",
Compare: TextEqualsIgnoreCase,
},
},
{
name: "equals ignore case _, % wildcards",
args: args{
column: testCol,
value: "h_urst%",
compare: TextEqualsIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "h\\_urst\\%",
Compare: TextEqualsIgnoreCase,
},
},
{
name: "not equal",
args: args{
column: testCol,
value: "hurst",
compare: TextNotEquals,
},
want: &textQuery{
Column: testCol,
Text: "hurst",
Compare: TextNotEquals,
},
},
{
name: "starts with",
args: args{
column: testCol,
value: "hurst",
compare: TextStartsWith,
},
want: &textQuery{
Column: testCol,
Text: "hurst",
Compare: TextStartsWith,
},
},
{
name: "starts with _ wildcard",
args: args{
column: testCol,
value: "_hurst",
compare: TextStartsWith,
},
want: &textQuery{
Column: testCol,
Text: "\\_hurst",
Compare: TextStartsWith,
},
},
{
name: "starts with % wildcard",
args: args{
column: testCol,
value: "hurst%",
compare: TextStartsWith,
},
want: &textQuery{
Column: testCol,
Text: "hurst\\%",
Compare: TextStartsWith,
},
},
{
name: "starts with %, % wildcard",
args: args{
column: testCol,
value: "hu%%rst",
compare: TextStartsWith,
},
want: &textQuery{
Column: testCol,
Text: "hu\\%\\%rst",
Compare: TextStartsWith,
},
},
{
name: "starts with ignore case",
args: args{
column: testCol,
value: "hurst",
compare: TextStartsWithIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "hurst",
Compare: TextStartsWithIgnoreCase,
},
},
{
name: "starts with ignore case _ wildcard",
args: args{
column: testCol,
value: "hur_st",
compare: TextStartsWithIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "hur\\_st",
Compare: TextStartsWithIgnoreCase,
},
},
{
name: "starts with ignore case % wildcard",
args: args{
column: testCol,
value: "hurst%",
compare: TextStartsWithIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "hurst\\%",
Compare: TextStartsWithIgnoreCase,
},
},
{
name: "starts with ignore case _, _ wildcard",
args: args{
column: testCol,
value: "h_r_t",
compare: TextStartsWithIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "h\\_r\\_t",
Compare: TextStartsWithIgnoreCase,
},
},
{
name: "ends with",
args: args{
column: testCol,
value: "hurst",
compare: TextEndsWith,
},
want: &textQuery{
Column: testCol,
Text: "hurst",
Compare: TextEndsWith,
},
},
{
name: "ends with % wildcard",
args: args{
column: testCol,
value: "%hurst",
compare: TextEndsWith,
},
want: &textQuery{
Column: testCol,
Text: "\\%hurst",
Compare: TextEndsWith,
},
},
{
name: "ends with _ wildcard",
args: args{
column: testCol,
value: "hurst_",
compare: TextEndsWith,
},
want: &textQuery{
Column: testCol,
Text: "hurst\\_",
Compare: TextEndsWith,
},
},
{
name: "ends with _, % wildcard",
args: args{
column: testCol,
value: "hurst_%",
compare: TextEndsWith,
},
want: &textQuery{
Column: testCol,
Text: "hurst\\_\\%",
Compare: TextEndsWith,
},
},
{
name: "ends with ignore case",
args: args{
column: testCol,
value: "hurst",
compare: TextEndsWithIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "hurst",
Compare: TextEndsWithIgnoreCase,
},
},
{
name: "ends with ignore case _, %, _ wildcards",
args: args{
column: testCol,
value: "h_r_t%",
compare: TextEndsWithIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "h\\_r\\_t\\%",
Compare: TextEndsWithIgnoreCase,
},
},
{
name: "contains",
args: args{
column: testCol,
value: "hurst",
compare: TextContains,
},
want: &textQuery{
Column: testCol,
Text: "hurst",
Compare: TextContains,
},
},
{
name: "contains % wildcard",
args: args{
column: testCol,
value: "%",
compare: TextContains,
},
want: &textQuery{
Column: testCol,
Text: "\\%",
Compare: TextContains,
},
},
{
name: "contains ignore csae",
args: args{
column: testCol,
value: "hurst",
compare: TextContainsIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "hurst",
Compare: TextContainsIgnoreCase,
},
},
{
name: "contains ignore csae _ wildcard",
args: args{
column: testCol,
value: "hurs_",
compare: TextContainsIgnoreCase,
},
want: &textQuery{
Column: testCol,
Text: "hurs\\_",
Compare: TextContainsIgnoreCase,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -894,6 +1193,17 @@ func TestTextQuery_comp(t *testing.T) {
query: sq.ILike{"test_table.test_col": "Hurst"},
},
},
{
name: "equals ignore case wildcard",
fields: fields{
Column: testCol,
Text: "Hu%%rst",
Compare: TextEqualsIgnoreCase,
},
want: want{
query: sq.ILike{"test_table.test_col": "Hu\\%\\%rst"},
},
},
{
name: "starts with",
fields: fields{
@@ -905,6 +1215,17 @@ func TestTextQuery_comp(t *testing.T) {
query: sq.Like{"test_table.test_col": "Hurst%"},
},
},
{
name: "starts with wildcards",
fields: fields{
Column: testCol,
Text: "_Hurst%",
Compare: TextStartsWith,
},
want: want{
query: sq.Like{"test_table.test_col": "\\_Hurst\\%%"},
},
},
{
name: "starts with ignore case",
fields: fields{
@@ -916,6 +1237,17 @@ func TestTextQuery_comp(t *testing.T) {
query: sq.ILike{"test_table.test_col": "Hurst%"},
},
},
{
name: "starts with ignore case wildcards",
fields: fields{
Column: testCol,
Text: "Hurst%",
Compare: TextStartsWithIgnoreCase,
},
want: want{
query: sq.ILike{"test_table.test_col": "Hurst\\%%"},
},
},
{
name: "ends with",
fields: fields{
@@ -927,6 +1259,17 @@ func TestTextQuery_comp(t *testing.T) {
query: sq.Like{"test_table.test_col": "%Hurst"},
},
},
{
name: "ends with wildcards",
fields: fields{
Column: testCol,
Text: "Hurst%",
Compare: TextEndsWith,
},
want: want{
query: sq.Like{"test_table.test_col": "%Hurst\\%"},
},
},
{
name: "ends with ignore case",
fields: fields{
@@ -938,6 +1281,17 @@ func TestTextQuery_comp(t *testing.T) {
query: sq.ILike{"test_table.test_col": "%Hurst"},
},
},
{
name: "ends with ignore case wildcards",
fields: fields{
Column: testCol,
Text: "%Hurst",
Compare: TextEndsWithIgnoreCase,
},
want: want{
query: sq.ILike{"test_table.test_col": "%\\%Hurst"},
},
},
{
name: "contains",
fields: fields{
@@ -949,6 +1303,17 @@ func TestTextQuery_comp(t *testing.T) {
query: sq.Like{"test_table.test_col": "%Hurst%"},
},
},
{
name: "contains wildcards",
fields: fields{
Column: testCol,
Text: "Hu%rst%",
Compare: TextContains,
},
want: want{
query: sq.Like{"test_table.test_col": "%Hu\\%rst\\%%"},
},
},
{
name: "containts ignore case",
fields: fields{
@@ -960,6 +1325,17 @@ func TestTextQuery_comp(t *testing.T) {
query: sq.ILike{"test_table.test_col": "%Hurst%"},
},
},
{
name: "contains ignore case wildcards",
fields: fields{
Column: testCol,
Text: "%Hurst%",
Compare: TextContainsIgnoreCase,
},
want: want{
query: sq.ILike{"test_table.test_col": "%\\%Hurst\\%%"},
},
},
{
name: "list containts",
fields: fields{
@@ -999,10 +1375,10 @@ func TestTextQuery_comp(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &TextQuery{
Column: tt.fields.Column,
Text: tt.fields.Text,
Compare: tt.fields.Compare,
s, _ := NewTextQuery(tt.fields.Column, tt.fields.Text, tt.fields.Compare)
if s == nil {
// used to check correct behavior of comp
s = &textQuery{Column: tt.fields.Column, Text: tt.fields.Text, Compare: tt.fields.Compare}
}
query := s.comp()
if query == nil && tt.want.isNil {