diff --git a/internal/v2/database/filter.go b/internal/v2/database/filter.go new file mode 100644 index 0000000000..3653abee79 --- /dev/null +++ b/internal/v2/database/filter.go @@ -0,0 +1,33 @@ +package database + +type Condition interface { + Write(stmt *Statement, columnName string) +} + +type Filter[C compare, V value] struct { + comp C + value V +} + +func (f Filter[C, V]) Write(stmt *Statement, columnName string) { + prepareWrite(stmt, columnName, f.comp) + stmt.WriteArg(f.value) +} + +func prepareWrite[C compare](stmt *Statement, columnName string, comp C) { + stmt.WriteString(columnName) + stmt.WriteRune(' ') + stmt.WriteString(comp.String()) + stmt.WriteRune(' ') +} + +type compare interface { + numberCompare | textCompare | listCompare + String() string +} + +type value interface { + number | text + // TODO: condition must know if it's args are named parameters or not + // number | text | placeholder +} diff --git a/internal/v2/database/list_filter.go b/internal/v2/database/list_filter.go new file mode 100644 index 0000000000..834a49ae0b --- /dev/null +++ b/internal/v2/database/list_filter.go @@ -0,0 +1,57 @@ +package database + +import "github.com/zitadel/logging" + +type ListFilter[V value] struct { + comp listCompare + list []V +} + +func NewListEquals[V value](list ...V) *ListFilter[V] { + return newListFilter[V](listEqual, list) +} + +func NewListContains[V value](list ...V) *ListFilter[V] { + return newListFilter[V](listContain, list) +} + +func NewListNotContains[V value](list ...V) *ListFilter[V] { + return newListFilter[V](listNotContain, list) +} + +func newListFilter[V value](comp listCompare, list []V) *ListFilter[V] { + return &ListFilter[V]{ + comp: comp, + list: list, + } +} + +func (f ListFilter[V]) Write(stmt *Statement, columnName string) { + if len(f.list) == 0 { + logging.WithFields("column", columnName).Debug("skip list filter because no entries defined") + return + } + if f.comp == listNotContain { + stmt.WriteString("NOT(") + } + stmt.WriteString(columnName) + stmt.WriteString(" = ") + if f.comp != listEqual { + stmt.WriteString("ANY(") + } + stmt.WriteArg(f.list) + if f.comp != listEqual { + stmt.WriteString(")") + } + if f.comp == listNotContain { + stmt.WriteRune(')') + } +} + +type listCompare uint8 + +const ( + listEqual listCompare = iota + listContain + listNotContain +) diff --git a/internal/v2/database/list_filter_test.go b/internal/v2/database/list_filter_test.go new file mode 100644 index 0000000000..9f5ffaad60 --- /dev/null +++ b/internal/v2/database/list_filter_test.go @@ -0,0 +1,122 @@ +package database + +import ( + "reflect" + "testing" +) + +func TestNewListConstructors(t *testing.T) { + type args struct { + constructor func(t ...string) *ListFilter[string] + t []string + } + tests := []struct { + name string + args args + want *ListFilter[string] + }{ + { + name: "NewListEquals", + args: args{ + constructor: NewListEquals[string], + t: []string{"as", "df"}, + }, + want: &ListFilter[string]{ + comp: listEqual, + list: []string{"as", "df"}, + }, + }, + { + name: "NewListContains", + args: args{ + constructor: NewListContains[string], + t: []string{"as", "df"}, + }, + want: &ListFilter[string]{ + comp: listContain, + list: []string{"as", "df"}, + }, + }, + { + name: "NewListNotContains", + args: args{ + constructor: NewListNotContains[string], + t: []string{"as", "df"}, + }, + want: &ListFilter[string]{ + comp: listNotContain, + list: []string{"as", "df"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.args.constructor(tt.args.t...); !reflect.DeepEqual(got, tt.want) { + t.Errorf("number constructor = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewListConditionWrite(t *testing.T) { + type args struct { + constructor func(t ...string) *ListFilter[string] + t []string + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "ListEquals", + args: args{ + constructor: NewListEquals[string], + t: []string{"as", "df"}, + }, + want: wantQuery{ + query: "test = $1", + args: []any{[]string{"as", "df"}}, + }, + }, + { + name: "ListContains", + args: args{ + constructor: NewListContains[string], + t: []string{"as", "df"}, + }, + want: wantQuery{ + query: "test = ANY($1)", + args: []any{[]string{"as", "df"}}, + }, + }, + { + name: "ListNotContains", + args: args{ + constructor: NewListNotContains[string], + t: []string{"as", "df"}, + }, + want: wantQuery{ + query: "NOT(test = ANY($1))", + args: []any{[]string{"as", "df"}}, + }, + }, + { + name: "empty list", + args: args{ + constructor: NewListNotContains[string], + }, + want: wantQuery{ + query: "", + args: nil, + }, + }, + } + for _, tt := range tests { + var stmt Statement + t.Run(tt.name, func(t *testing.T) { + tt.args.constructor(tt.args.t...).Write(&stmt, "test") + assertQuery(t, &stmt, tt.want) + }) + } +} diff --git a/internal/v2/database/mock/sql_mock.go b/internal/v2/database/mock/sql_mock.go new file mode 100644 index 0000000000..c693671d6f --- /dev/null +++ b/internal/v2/database/mock/sql_mock.go @@ -0,0 +1,139 @@ +package mock + +import ( + "database/sql" + "database/sql/driver" + "reflect" + "testing" + + "github.com/DATA-DOG/go-sqlmock" +) + +type SQLMock struct { + DB *sql.DB + mock sqlmock.Sqlmock +} + +type Expectation func(m sqlmock.Sqlmock) + +func NewSQLMock(t *testing.T, expectations ...Expectation) *SQLMock { + db, mock, err := sqlmock.New( + sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual), + sqlmock.ValueConverterOption(new(TypeConverter)), + ) + if err != nil { + t.Fatal("create mock failed", err) + } + + for _, expectation := range expectations { + expectation(mock) + } + + return &SQLMock{ + DB: db, + mock: mock, + } +} + +func (m *SQLMock) Assert(t *testing.T) { + t.Helper() + + if err := m.mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations not met: %v", err) + } + + m.DB.Close() +} + +func ExpectBegin(err error) Expectation { + return func(m sqlmock.Sqlmock) { + e := m.ExpectBegin() + if err != nil { + e.WillReturnError(err) + } + } +} + +func ExpectCommit(err error) Expectation { + return func(m sqlmock.Sqlmock) { + e := m.ExpectCommit() + if err != nil { + e.WillReturnError(err) + } + } +} + +type ExecOpt func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec + +func WithExecArgs(args ...driver.Value) ExecOpt { + return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec { + return e.WithArgs(args...) + } +} + +func WithExecErr(err error) ExecOpt { + return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec { + return e.WillReturnError(err) + } +} + +func WithExecNoRowsAffected() ExecOpt { + return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec { + return e.WillReturnResult(driver.ResultNoRows) + } +} + +func WithExecRowsAffected(affected driver.RowsAffected) ExecOpt { + return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec { + return e.WillReturnResult(affected) + } +} + +func ExpectExec(stmt string, opts ...ExecOpt) Expectation { + return func(m sqlmock.Sqlmock) { + e := m.ExpectExec(stmt) + for _, opt := range opts { + e = opt(e) + } + } +} + +type QueryOpt func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery + +func WithQueryArgs(args ...driver.Value) QueryOpt { + return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery { + return e.WithArgs(args...) + } +} + +func WithQueryErr(err error) QueryOpt { + return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery { + return e.WillReturnError(err) + } +} + +func WithQueryResult(columns []string, rows [][]driver.Value) QueryOpt { + return func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery { + mockedRows := m.NewRows(columns) + for _, row := range rows { + mockedRows = mockedRows.AddRow(row...) + } + return e.WillReturnRows(mockedRows) + } +} + +func ExpectQuery(stmt string, opts ...QueryOpt) Expectation { + return func(m sqlmock.Sqlmock) { + e := m.ExpectQuery(stmt) + for _, opt := range opts { + e = opt(m, e) + } + } +} + +type AnyType[T interface{}] struct{} + +// Match satisfies sqlmock.Argument interface +func (a AnyType[T]) Match(v driver.Value) bool { + return reflect.TypeOf(new(T)).Elem().Kind().String() == reflect.TypeOf(v).Kind().String() +} diff --git a/internal/v2/database/mock/type_converter.go b/internal/v2/database/mock/type_converter.go new file mode 100644 index 0000000000..f27fc3456f --- /dev/null +++ b/internal/v2/database/mock/type_converter.go @@ -0,0 +1,78 @@ +package mock + +import ( + "database/sql/driver" + "encoding/hex" + "encoding/json" + "reflect" + "strconv" + "strings" +) + +var _ driver.ValueConverter = (*TypeConverter)(nil) + +type TypeConverter struct{} + +// ConvertValue converts a value to a driver Value. +func (s TypeConverter) ConvertValue(v any) (driver.Value, error) { + if driver.IsValue(v) { + return v, nil + } + value := reflect.ValueOf(v) + + if rawMessage, ok := v.(json.RawMessage); ok { + return convertBytes(rawMessage), nil + } + + if value.Kind() == reflect.Slice { + //nolint: exhaustive + // only defined types + switch value.Type().Elem().Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: + return convertSigned(value), nil + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + return convertUnsigned(value), nil + case reflect.String: + return convertText(value), nil + } + } + return v, nil +} + +// converts a text array to valid pgx v5 representation +func convertSigned(array reflect.Value) string { + slice := make([]string, array.Len()) + for i := 0; i < array.Len(); i++ { + slice[i] = strconv.FormatInt(array.Index(i).Int(), 10) + } + + return "{" + strings.Join(slice, ",") + "}" +} + +// converts a text array to valid pgx v5 representation +func convertUnsigned(array reflect.Value) string { + slice := make([]string, array.Len()) + for i := 0; i < array.Len(); i++ { + slice[i] = strconv.FormatUint(array.Index(i).Uint(), 10) + } + + return "{" + strings.Join(slice, ",") + "}" +} + +// converts a text array to valid pgx v5 representation +func convertText(array reflect.Value) string { + slice := make([]string, array.Len()) + for i := 0; i < array.Len(); i++ { + slice[i] = array.Index(i).String() + } + + return "{" + strings.Join(slice, ",") + "}" +} + +func convertBytes(array []byte) string { + var builder strings.Builder + builder.Grow(hex.EncodedLen(len(array)) + 4) + builder.WriteString(`\x`) + builder.Write(hex.AppendEncode(nil, array)) + return builder.String() +} diff --git a/internal/v2/database/number_filter.go b/internal/v2/database/number_filter.go new file mode 100644 index 0000000000..ce263ceeee --- /dev/null +++ b/internal/v2/database/number_filter.go @@ -0,0 +1,100 @@ +package database + +import ( + "time" + + "github.com/zitadel/logging" + "golang.org/x/exp/constraints" +) + +type NumberFilter[N number] struct { + Filter[numberCompare, N] +} + +func NewNumberEquals[N number](n N) *NumberFilter[N] { + return newNumberFilter(numberEqual, n) +} + +func NewNumberAtLeast[N number](n N) *NumberFilter[N] { + return newNumberFilter(numberAtLeast, n) +} + +func NewNumberAtMost[N number](n N) *NumberFilter[N] { + return newNumberFilter(numberAtMost, n) +} + +func NewNumberGreater[N number](n N) *NumberFilter[N] { + return newNumberFilter(numberGreater, n) +} + +func NewNumberLess[N number](n N) *NumberFilter[N] { + return newNumberFilter(numberLess, n) +} + +func NewNumberUnequal[N number](n N) *NumberFilter[N] { + return newNumberFilter(numberUnequal, n) +} + +func newNumberFilter[N number](comp numberCompare, n N) *NumberFilter[N] { + return &NumberFilter[N]{ + Filter: Filter[numberCompare, N]{ + comp: comp, + value: n, + }, + } +} + +// NumberBetweenFilter combines [AtLeast] and [AtMost] comparisons +type NumberBetweenFilter[N number] struct { + min, max N +} + +func NewNumberBetween[N number](min, max N) *NumberBetweenFilter[N] { + return &NumberBetweenFilter[N]{ + min: min, + max: max, + } +} + +func (f NumberBetweenFilter[N]) Write(stmt *Statement, columnName string) { + NewNumberAtLeast[N](f.min).Write(stmt, columnName) + stmt.WriteString(" AND ") + NewNumberAtMost[N](f.max).Write(stmt, columnName) +} + +type numberCompare uint8 + +const ( + numberEqual numberCompare = iota + numberAtLeast + numberAtMost + numberGreater + numberLess + numberUnequal +) + +func (c numberCompare) String() string { + switch c { + case numberEqual: + return "=" + case numberAtLeast: + return ">=" + case numberAtMost: + return "<=" + case numberGreater: + return ">" + case numberLess: + return "<" + case numberUnequal: + return "<>" + default: + logging.WithFields("compare", c).Panic("comparison type not implemented") + return "" + } +} + +type number interface { + constraints.Integer | constraints.Float | time.Time + // TODO: condition must know if it's args are named parameters or not + // constraints.Integer | constraints.Float | time.Time | placeholder +} diff --git a/internal/v2/database/number_filter_test.go b/internal/v2/database/number_filter_test.go new file mode 100644 index 0000000000..5f934b88e1 --- /dev/null +++ b/internal/v2/database/number_filter_test.go @@ -0,0 +1,216 @@ +package database + +import ( + "reflect" + "testing" +) + +func TestNewNumberConstructors(t *testing.T) { + type args struct { + constructor func(t int8) *NumberFilter[int8] + t int8 + } + tests := []struct { + name string + args args + want *NumberFilter[int8] + }{ + { + name: "NewNumberEqual", + args: args{ + constructor: NewNumberEquals[int8], + t: 10, + }, + want: &NumberFilter[int8]{ + Filter: Filter[numberCompare, int8]{ + comp: numberEqual, + value: 10, + }, + }, + }, + { + name: "NewNumberAtLeast", + args: args{ + constructor: NewNumberAtLeast[int8], + t: 10, + }, + want: &NumberFilter[int8]{ + Filter: Filter[numberCompare, int8]{ + comp: numberAtLeast, + value: 10, + }, + }, + }, + { + name: "NewNumberAtMost", + args: args{ + constructor: NewNumberAtMost[int8], + t: 10, + }, + want: &NumberFilter[int8]{ + Filter: Filter[numberCompare, int8]{ + comp: numberAtMost, + value: 10, + }, + }, + }, + { + name: "NewNumberGreater", + args: args{ + constructor: NewNumberGreater[int8], + t: 10, + }, + want: &NumberFilter[int8]{ + Filter: Filter[numberCompare, int8]{ + comp: numberGreater, + value: 10, + }, + }, + }, + { + name: "NewNumberLess", + args: args{ + constructor: NewNumberLess[int8], + t: 10, + }, + want: &NumberFilter[int8]{ + Filter: Filter[numberCompare, int8]{ + comp: numberLess, + value: 10, + }, + }, + }, + { + name: "NewNumberUnequal", + args: args{ + constructor: NewNumberUnequal[int8], + t: 10, + }, + want: &NumberFilter[int8]{ + Filter: Filter[numberCompare, int8]{ + comp: numberUnequal, + value: 10, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.args.constructor(tt.args.t); !reflect.DeepEqual(got, tt.want) { + t.Errorf("number constructor = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewNumberConditionWrite(t *testing.T) { + type args struct { + constructor func(t int8) *NumberFilter[int8] + t int8 + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "NewNumberEqual", + args: args{ + constructor: NewNumberEquals[int8], + t: 10, + }, + want: wantQuery{ + query: "test = $1", + args: []any{int8(10)}, + }, + }, + { + name: "NewNumberAtLeast", + args: args{ + constructor: NewNumberAtLeast[int8], + t: 10, + }, + want: wantQuery{ + query: "test >= $1", + args: []any{int8(10)}, + }, + }, + { + name: "NewNumberAtMost", + args: args{ + constructor: NewNumberAtMost[int8], + t: 10, + }, + want: wantQuery{ + query: "test <= $1", + args: []any{int8(10)}, + }, + }, + { + name: "NewNumberGreater", + args: args{ + constructor: NewNumberGreater[int8], + t: 10, + }, + want: wantQuery{ + query: "test > $1", + args: []any{int8(10)}, + }, + }, + { + name: "NewNumberLess", + args: args{ + constructor: NewNumberLess[int8], + t: 10, + }, + want: wantQuery{ + query: "test < $1", + args: []any{int8(10)}, + }, + }, + { + name: "NewNumberUnequal", + args: args{ + constructor: NewNumberUnequal[int8], + t: 10, + }, + want: wantQuery{ + query: "test <> $1", + args: []any{int8(10)}, + }, + }, + } + for _, tt := range tests { + var stmt Statement + t.Run(tt.name, func(t *testing.T) { + tt.args.constructor(tt.args.t).Write(&stmt, "test") + assertQuery(t, &stmt, tt.want) + }) + } +} + +func TestNumberBetween(t *testing.T) { + filter := NewNumberBetween[int8](10, 20) + + if !reflect.DeepEqual(filter, &NumberBetweenFilter[int8]{min: 10, max: 20}) { + t.Errorf("unexpected filter: %v", filter) + } + + var stmt Statement + filter.Write(&stmt, "test") + if stmt.String() != "test >= $1 AND test <= $2" { + t.Errorf("unexpected query: got: %q", stmt.String()) + } + + if len(stmt.Args()) != 2 { + t.Errorf("unexpected length of args: got %d", len(stmt.Args())) + return + } + + if !reflect.DeepEqual(int8(10), stmt.Args()[0]) { + t.Errorf("unexpected arg at position 0: want: 10, got: %v", stmt.Args()[0]) + } + if !reflect.DeepEqual(int8(20), stmt.Args()[1]) { + t.Errorf("unexpected arg at position 1: want: 20, got: %v", stmt.Args()[1]) + } +} diff --git a/internal/v2/database/pagination.go b/internal/v2/database/pagination.go new file mode 100644 index 0000000000..07d83fb2e3 --- /dev/null +++ b/internal/v2/database/pagination.go @@ -0,0 +1,17 @@ +package database + +type Pagination struct { + Limit uint32 + Offset uint32 +} + +func (p *Pagination) Write(stmt *Statement) { + if p.Limit > 0 { + stmt.WriteString(" LIMIT ") + stmt.WriteArg(p.Limit) + } + if p.Offset > 0 { + stmt.WriteString(" OFFSET ") + stmt.WriteArg(p.Offset) + } +} diff --git a/internal/v2/database/pagination_test.go b/internal/v2/database/pagination_test.go new file mode 100644 index 0000000000..2158e8f8f4 --- /dev/null +++ b/internal/v2/database/pagination_test.go @@ -0,0 +1,73 @@ +package database + +import ( + "testing" +) + +func TestPagination_Write(t *testing.T) { + type fields struct { + Limit uint32 + Offset uint32 + } + tests := []struct { + name string + fields fields + want wantQuery + }{ + { + name: "no values", + fields: fields{ + Limit: 0, + Offset: 0, + }, + want: wantQuery{ + query: "", + args: []any{}, + }, + }, + { + name: "limit", + fields: fields{ + Limit: 10, + Offset: 0, + }, + want: wantQuery{ + query: " LIMIT $1", + args: []any{uint32(10)}, + }, + }, + { + name: "offset", + fields: fields{ + Limit: 0, + Offset: 10, + }, + want: wantQuery{ + query: " OFFSET $1", + args: []any{uint32(10)}, + }, + }, + { + name: "both", + fields: fields{ + Limit: 10, + Offset: 10, + }, + want: wantQuery{ + query: " LIMIT $1 OFFSET $2", + args: []any{uint32(10), uint32(10)}, + }, + }, + } + for _, tt := range tests { + var stmt Statement + t.Run(tt.name, func(t *testing.T) { + p := &Pagination{ + Limit: tt.fields.Limit, + Offset: tt.fields.Offset, + } + p.Write(&stmt) + assertQuery(t, &stmt, tt.want) + }) + } +} diff --git a/internal/v2/database/sql_helper.go b/internal/v2/database/sql_helper.go new file mode 100644 index 0000000000..4efa2f6d92 --- /dev/null +++ b/internal/v2/database/sql_helper.go @@ -0,0 +1,75 @@ +package database + +import ( + "context" + "database/sql" + + "github.com/zitadel/logging" +) + +type Tx interface { + Commit() error + Rollback() error +} + +func CloseTx(tx Tx, err error) error { + if err != nil { + rollbackErr := tx.Rollback() + logging.OnError(rollbackErr).Debug("unable to rollback") + return err + } + + return tx.Commit() +} + +type DestMapper[R any] func(index int, scan func(dest ...any) error) (*R, error) + +type Rows interface { + Close() error + Err() error + Next() bool + Scan(dest ...any) error +} + +func MapRows[R any](rows Rows, mapper DestMapper[R]) (result []*R, err error) { + defer func() { + closeErr := rows.Close() + logging.OnError(closeErr).Debug("unable to close rows") + + if err == nil && rows.Err() != nil { + result = nil + err = rows.Err() + } + }() + for i := 0; rows.Next(); i++ { + res, err := mapper(i, rows.Scan) + if err != nil { + return nil, err + } + result = append(result, res) + } + + return result, nil +} + +func MapRowsToObject(rows Rows, mapper func(scan func(dest ...any) error) error) (err error) { + defer func() { + closeErr := rows.Close() + logging.OnError(closeErr).Debug("unable to close rows") + + if err == nil && rows.Err() != nil { + err = rows.Err() + } + }() + for rows.Next() { + err = mapper(rows.Scan) + if err != nil { + return err + } + } + return nil +} + +type Querier interface { + QueryContext(context.Context, string, ...any) (*sql.Rows, error) +} diff --git a/internal/v2/database/sql_helper_test.go b/internal/v2/database/sql_helper_test.go new file mode 100644 index 0000000000..c59e4221ec --- /dev/null +++ b/internal/v2/database/sql_helper_test.go @@ -0,0 +1,512 @@ +package database + +import ( + "errors" + "reflect" + "testing" +) + +func TestCloseTx(t *testing.T) { + type args struct { + tx *testTx + err error + } + tests := []struct { + name string + args args + assertErr func(t *testing.T, err error) bool + }{ + { + name: "exec err", + args: args{ + tx: &testTx{ + rollback: execution{ + shouldExecute: true, + }, + }, + err: errExec, + }, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errExec) + if !is { + t.Errorf("execution error expected, got: %v", err) + } + return is + }, + }, + { + name: "exec err and rollback err", + args: args{ + tx: &testTx{ + rollback: execution{ + err: true, + shouldExecute: true, + }, + }, + err: errExec, + }, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errExec) + if !is { + t.Errorf("execution error expected, got: %v", err) + } + return is + }, + }, + { + name: "commit Err", + args: args{ + tx: &testTx{ + commit: execution{ + err: true, + shouldExecute: true, + }, + }, + err: nil, + }, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errCommit) + if !is { + t.Errorf("commit error expected, got: %v", err) + } + return is + }, + }, + { + name: "no err", + args: args{ + tx: &testTx{ + commit: execution{ + shouldExecute: true, + }, + }, + err: nil, + }, + assertErr: func(t *testing.T, err error) bool { + is := err == nil + if !is { + t.Errorf("no error expected, got: %v", err) + } + return is + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := CloseTx(tt.args.tx, tt.args.err) + tt.assertErr(t, err) + tt.args.tx.assert(t) + }) + } +} + +func TestMapRows(t *testing.T) { + type args struct { + rows *testRows + mapper DestMapper[string] + } + var emptyString string + tests := []struct { + name string + args args + wantResult []*string + assertErr func(t *testing.T, err error) bool + }{ + { + name: "no rows, close err", + args: args{ + rows: &testRows{ + closeErr: true, + }, + mapper: nil, + }, + wantResult: nil, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errClose) + if !is { + t.Errorf("close error expected, got: %v", err) + } + return is + }, + }, + { + name: "no rows, close err", + args: args{ + rows: &testRows{ + hasErr: true, + }, + mapper: nil, + }, + wantResult: nil, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errRows) + if !is { + t.Errorf("rows error expected, got: %v", err) + } + return is + }, + }, + { + name: "scan err", + args: args{ + rows: &testRows{ + scanErr: true, + nextCount: 1, + }, + mapper: func(index int, scan func(dest ...any) error) (*string, error) { + var s string + if err := scan(&s); err != nil { + return nil, err + } + return &s, nil + }, + }, + wantResult: nil, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errScan) + if !is { + t.Errorf("scan error expected, got: %v", err) + } + return is + }, + }, + { + name: "exec err", + args: args{ + rows: &testRows{ + nextCount: 1, + }, + mapper: func(index int, scan func(dest ...any) error) (*string, error) { + return nil, errExec + }, + }, + wantResult: nil, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errExec) + if !is { + t.Errorf("exec error expected, got: %v", err) + } + return is + }, + }, + { + name: "exec err, close err", + args: args{ + rows: &testRows{ + closeErr: true, + nextCount: 1, + }, + mapper: func(index int, scan func(dest ...any) error) (*string, error) { + return nil, errExec + }, + }, + wantResult: nil, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errExec) + if !is { + t.Errorf("exec error expected, got: %v", err) + } + return is + }, + }, + { + name: "rows err", + args: args{ + rows: &testRows{ + nextCount: 1, + hasErr: true, + }, + mapper: func(index int, scan func(dest ...any) error) (*string, error) { + var s string + if err := scan(&s); err != nil { + return nil, err + } + return &s, nil + }, + }, + wantResult: nil, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errRows) + if !is { + t.Errorf("rows error expected, got: %v", err) + } + return is + }, + }, + { + name: "no err", + args: args{ + rows: &testRows{ + nextCount: 1, + }, + mapper: func(index int, scan func(dest ...any) error) (*string, error) { + var s string + if err := scan(&s); err != nil { + return nil, err + } + return &s, nil + }, + }, + wantResult: []*string{&emptyString}, + assertErr: func(t *testing.T, err error) bool { + is := err == nil + if !is { + t.Errorf("no error expected, got: %v", err) + } + return is + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotResult, err := MapRows(tt.args.rows, tt.args.mapper) + tt.assertErr(t, err) + if !reflect.DeepEqual(gotResult, tt.wantResult) { + t.Errorf("MapRows() = %v, want %v", gotResult, tt.wantResult) + } + }) + } +} + +func TestMapRowsToObject(t *testing.T) { + type args struct { + rows *testRows + mapper func(scan func(dest ...any) error) error + } + tests := []struct { + name string + args args + assertErr func(t *testing.T, err error) bool + }{ + { + name: "no rows, close err", + args: args{ + rows: &testRows{ + closeErr: true, + }, + mapper: nil, + }, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errClose) + if !is { + t.Errorf("close error expected, got: %v", err) + } + return is + }, + }, + { + name: "no rows, close err", + args: args{ + rows: &testRows{ + hasErr: true, + }, + mapper: nil, + }, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errRows) + if !is { + t.Errorf("rows error expected, got: %v", err) + } + return is + }, + }, + { + name: "scan err", + args: args{ + rows: &testRows{ + scanErr: true, + nextCount: 1, + }, + mapper: func(scan func(dest ...any) error) error { + var s string + if err := scan(&s); err != nil { + return err + } + return nil + }, + }, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errScan) + if !is { + t.Errorf("scan error expected, got: %v", err) + } + return is + }, + }, + { + name: "exec err", + args: args{ + rows: &testRows{ + nextCount: 1, + }, + mapper: func(scan func(dest ...any) error) error { + return errExec + }, + }, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errExec) + if !is { + t.Errorf("exec error expected, got: %v", err) + } + return is + }, + }, + { + name: "exec err, close err", + args: args{ + rows: &testRows{ + closeErr: true, + nextCount: 1, + }, + mapper: func(scan func(dest ...any) error) error { + return errExec + }, + }, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errExec) + if !is { + t.Errorf("exec error expected, got: %v", err) + } + return is + }, + }, + { + name: "rows err", + args: args{ + rows: &testRows{ + nextCount: 1, + hasErr: true, + }, + mapper: func(scan func(dest ...any) error) error { + var s string + return scan(&s) + }, + }, + assertErr: func(t *testing.T, err error) bool { + is := errors.Is(err, errRows) + if !is { + t.Errorf("rows error expected, got: %v", err) + } + return is + }, + }, + { + name: "no err", + args: args{ + rows: &testRows{ + nextCount: 1, + }, + mapper: func(scan func(dest ...any) error) error { + var s string + return scan(&s) + }, + }, + assertErr: func(t *testing.T, err error) bool { + is := err == nil + if !is { + t.Errorf("no error expected, got: %v", err) + } + return is + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := MapRowsToObject(tt.args.rows, tt.args.mapper) + tt.assertErr(t, err) + }) + } +} + +var _ Tx = (*testTx)(nil) + +type testTx struct { + commit, rollback execution +} + +type execution struct { + err bool + didExecute bool + shouldExecute bool +} + +var ( + errCommit = errors.New("commit err") + errRollback = errors.New("rollback err") + errExec = errors.New("exec err") +) + +// Commit implements Tx. +func (t *testTx) Commit() error { + t.commit.didExecute = true + if t.commit.err { + return errCommit + } + return nil +} + +// Rollback implements Tx. +func (t *testTx) Rollback() error { + t.rollback.didExecute = true + if t.rollback.err { + return errRollback + } + return nil +} + +func (tx *testTx) assert(t *testing.T) { + if tx.commit.didExecute != tx.commit.shouldExecute { + t.Errorf("unexpected execution of commit: should %v, did: %v", tx.commit.shouldExecute, tx.commit.didExecute) + } + if tx.rollback.didExecute != tx.rollback.shouldExecute { + t.Errorf("unexpected execution of rollback: should %v, did: %v", tx.rollback.shouldExecute, tx.rollback.didExecute) + } +} + +var _ Rows = (*testRows)(nil) + +var ( + errClose = errors.New("err close") + errRows = errors.New("err rows") + errScan = errors.New("err scan") +) + +type testRows struct { + closeErr bool + scanErr bool + hasErr bool + nextCount int +} + +// Close implements Rows. +func (t *testRows) Close() error { + if t.closeErr { + return errClose + } + return nil +} + +// Err implements Rows. +func (t *testRows) Err() error { + if t.hasErr { + return errRows + } + if t.closeErr { + return errClose + } + return nil +} + +// Next implements Rows. +func (t *testRows) Next() bool { + t.nextCount-- + return t.nextCount >= 0 +} + +// Scan implements Rows. +func (t *testRows) Scan(dest ...any) error { + if t.scanErr { + return errScan + } + return nil +} diff --git a/internal/v2/database/statement.go b/internal/v2/database/statement.go new file mode 100644 index 0000000000..08fbe8aa6c --- /dev/null +++ b/internal/v2/database/statement.go @@ -0,0 +1,222 @@ +package database + +import ( + "fmt" + "slices" + "strconv" + "strings" + "time" + "unsafe" + + "github.com/zitadel/logging" +) + +type Statement struct { + addr *Statement + builder strings.Builder + + args []any + // key is the name of the arg and value is the placeholder + // TODO: condition must know if it's args are named parameters or not + // namedArgs map[placeholder]string +} + +func (stmt *Statement) Args() []any { + if stmt == nil { + return nil + } + return stmt.args +} + +func (stmt *Statement) Reset() { + stmt.builder.Reset() + stmt.addr = nil + stmt.args = nil +} + +// TODO: condition must know if it's args are named parameters or not +// SetNamedArg sets the arg and makes it available for query construction +// func (stmt *Statement) SetNamedArg(name placeholder, value any) (placeholder string) { +// stmt.copyCheck() +// stmt.args = append(stmt.args, value) +// placeholder = fmt.Sprintf("$%d", len(stmt.args)) +// if !strings.HasPrefix(name.string, "@") { +// name.string = "@" + name.string +// } +// stmt.namedArgs[name] = placeholder +// return placeholder +// } + +// AppendArgs appends the args without writing it to Builder +// if any arg is a [placeholder] it's replaced with the placeholders parameter +func (stmt *Statement) AppendArgs(args ...any) { + stmt.copyCheck() + stmt.args = slices.Grow(stmt.args, len(args)) + for _, arg := range args { + stmt.AppendArg(arg) + } +} + +// AppendArg appends the arg without writing it to Builder +// if the arg is a [placeholder] it's replaced with the placeholders parameter +func (stmt *Statement) AppendArg(arg any) int { + stmt.copyCheck() + + // TODO: condition must know if it's args are named parameters or not + // if namedArg, ok := arg.(sql.NamedArg); ok { + // stmt.SetNamedArg(placeholder{namedArg.Name}, namedArg.Value) + // return + // } + stmt.args = append(stmt.args, arg) + return len(stmt.args) +} + +// TODO: condition must know if it's args are named parameters or not +// func Placeholder(name string) placeholder { +// return placeholder{name} +// } + +// TODO: condition must know if it's args are named parameters or not +// type placeholder struct { +// string +// } + +// WriteArgs appends the args and adds the placeholders comma separated to [stmt.Builder] +// if any arg is a [placeholder] it's replaced with the placeholders parameter +func (stmt *Statement) WriteArgs(args ...any) { + stmt.copyCheck() + stmt.args = slices.Grow(stmt.args, len(args)) + for i, arg := range args { + if i > 0 { + stmt.WriteString(", ") + } + stmt.WriteArg(arg) + } +} + +// WriteArg appends the arg and adds the placeholder to [stmt.Builder] +// if the arg is a [placeholder] it's replaced with the placeholders parameter +func (stmt *Statement) WriteArg(arg any) { + stmt.copyCheck() + // TODO: condition must know if it's args are named parameters or not + // if namedPlaceholder, ok := arg.(placeholder); ok { + // stmt.writeNamedPlaceholder(namedPlaceholder) + // return + // } + placeholder := stmt.AppendArg(arg) + stmt.WriteString("$") + stmt.WriteString(strconv.Itoa(placeholder)) +} + +// WriteString extends [strings.Builder.WriteString] +// it replaces named args with the previously provided named args +func (stmt *Statement) WriteString(s string) { + // TODO: condition must know if it's args are named parameters or not + // for name, placeholder := range stmt.namedArgs { + // s = strings.ReplaceAll(s, name.string, placeholder) + // } + stmt.builder.WriteString(s) +} + +// WriteRune extends [strings.Builder.WriteRune] +func (stmt *Statement) WriteRune(r rune) { + // TODO: condition must know if it's args are named parameters or not + // for name, placeholder := range stmt.namedArgs { + // s = strings.ReplaceAll(s, name.string, placeholder) + // } + stmt.builder.WriteRune(r) +} + +// WriteByte extends [strings.Builder.WriteByte] +func (stmt *Statement) WriteByte(b byte) { + // TODO: condition must know if it's args are named parameters or not + // for name, placeholder := range stmt.namedArgs { + // s = strings.ReplaceAll(s, name.string, placeholder) + // } + err := stmt.builder.WriteByte(b) + logging.OnError(err).Warn("unable to write bytes") +} + +// Write extends [strings.Builder.Write] +// it replaces named args with the previously provided named args +func (stmt *Statement) Write(b []byte) { + // TODO: condition must know if it's args are named parameters or not + // for name, placeholder := range stmt.namedArgs { + // bytes.ReplaceAll(b, []byte(name.string), []byte(placeholder)) + // } + stmt.builder.Write(b) +} + +// String builds the query and replaces placeholders starting with "@" +// with the corresponding named arg placeholder +func (stmt *Statement) String() string { + return stmt.builder.String() +} + +// Debug builds the statement and replaces the placeholders with the parameters +func (stmt *Statement) Debug() string { + query := stmt.String() + + for i := len(stmt.args) - 1; i >= 0; i-- { + var argText string + switch arg := stmt.args[i].(type) { + case time.Time: + argText = "'" + arg.Format("2006-01-02 15:04:05Z07:00") + "'" + case string: + argText = "'" + arg + "'" + case []string: + argText = "ARRAY[" + for i, a := range arg { + if i > 0 { + argText += ", " + } + argText += "'" + a + "'" + } + argText += "]" + default: + argText = fmt.Sprint(arg) + } + query = strings.ReplaceAll(query, "$"+strconv.Itoa(i+1), argText) + } + + return query +} + +// TODO: condition must know if it's args are named parameters or not +// func (stmt *Statement) writeNamedPlaceholder(arg placeholder) { +// placeholder, ok := stmt.namedArgs[arg] +// if !ok { +// logging.WithFields("named_placeholder", arg).Fatal("named placeholder not defined") +// } +// stmt.Builder.WriteString(placeholder) +// } + +// copyCheck allows uninitialized usage of stmt +func (stmt *Statement) copyCheck() { + if stmt.addr == nil { + // This hack works around a failing of Go's escape analysis + // that was causing b to escape and be heap allocated. + // See issue 23382. + // TODO: once issue 7921 is fixed, this should be reverted to + // just "stmt.addr = stmt". + stmt.addr = (*Statement)(noescape(unsafe.Pointer(stmt))) + // TODO: condition must know if it's args are named parameters or not + // stmt.namedArgs = make(map[placeholder]string) + } else if stmt.addr != stmt { + panic("statement: illegal use of non-zero Builder copied by value") + } +} + +// noescape hides a pointer from escape analysis. It is the identity function +// but escape analysis doesn't think the output depends on the input. +// noescape is inlined and currently compiles down to zero instructions. +// USE CAREFULLY! +// This was copied from the runtime; see issues 23382 and 7921. +// +//go:nosplit +//go:nocheckptr +func noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + //nolint: staticcheck + return unsafe.Pointer(x ^ 0) +} diff --git a/internal/v2/database/statement_test.go b/internal/v2/database/statement_test.go new file mode 100644 index 0000000000..85407b91c7 --- /dev/null +++ b/internal/v2/database/statement_test.go @@ -0,0 +1,73 @@ +package database + +import ( + "reflect" + "testing" +) + +func TestStatement_WriteArgs(t *testing.T) { + type args struct { + args []any + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "no args", + args: args{ + args: nil, + }, + }, + { + name: "1 arg", + args: args{ + args: []any{"asdf"}, + }, + want: wantQuery{ + query: "$1", + args: []any{"asdf"}, + }, + }, + { + name: "n args", + args: args{ + args: []any{"asdf", "jkl", 1}, + }, + want: wantQuery{ + query: "$1, $2, $3", + args: []any{"asdf", "jkl", 1}, + }, + }, + } + for _, tt := range tests { + var stmt Statement + t.Run(tt.name, func(t *testing.T) { + stmt.WriteArgs(tt.args.args...) + assertQuery(t, &stmt, tt.want) + }) + } +} + +type wantQuery struct { + query string + args []any +} + +func assertQuery(t *testing.T, stmt *Statement, want wantQuery) { + if want.query != stmt.String() { + t.Errorf("unexpected query: want: %q got: %q", want.query, stmt.String()) + } + + if len(want.args) != len(stmt.Args()) { + t.Errorf("unexpected length of args: want %d, got %d", len(want.args), len(stmt.Args())) + return + } + + for i, wantArg := range want.args { + if !reflect.DeepEqual(wantArg, stmt.Args()[i]) { + t.Errorf("unexpected arg at position %d: want: %v, got: %v", i, wantArg, stmt.Args()[i]) + } + } +} diff --git a/internal/v2/database/text_filter.go b/internal/v2/database/text_filter.go new file mode 100644 index 0000000000..a44adbf976 --- /dev/null +++ b/internal/v2/database/text_filter.go @@ -0,0 +1,132 @@ +package database + +import ( + "fmt" + "strings" + + "github.com/zitadel/logging" +) + +type TextFilter[T text] struct { + Filter[textCompare, T] +} + +func NewTextEqual[T text](t T) *TextFilter[T] { + return newTextFilter(textEqual, t) +} + +func NewTextUnequal[T text](t T) *TextFilter[T] { + return newTextFilter(textUnequal, t) +} + +func NewTextEqualInsensitive[T text](t T) *TextFilter[string] { + return newTextFilter(textEqualInsensitive, strings.ToLower(string(t))) +} + +func NewTextUnequalInsensitive[T text](t T) *TextFilter[string] { + return newTextFilter(textUnequalInsensitive, strings.ToLower(string(t))) +} + +func NewTextStartsWith[T text](t T) *TextFilter[T] { + return newTextFilter(textStartsWith, t) +} + +func NewTextStartsWithInsensitive[T text](t T) *TextFilter[string] { + return newTextFilter(textStartsWithInsensitive, strings.ToLower(string(t))) +} + +func NewTextEndsWith[T text](t T) *TextFilter[T] { + return newTextFilter(textEndsWith, t) +} + +func NewTextEndsWithInsensitive[T text](t T) *TextFilter[string] { + return newTextFilter(textEndsWithInsensitive, strings.ToLower(string(t))) +} + +func NewTextContains[T text](t T) *TextFilter[T] { + return newTextFilter(textContains, t) +} + +func NewTextContainsInsensitive[T text](t T) *TextFilter[string] { + return newTextFilter(textContainsInsensitive, strings.ToLower(string(t))) +} + +func newTextFilter[T text](comp textCompare, t T) *TextFilter[T] { + return &TextFilter[T]{ + Filter: Filter[textCompare, T]{ + comp: comp, + value: t, + }, + } +} + +func (f *TextFilter[T]) Write(stmt *Statement, columnName string) { + if f.comp.isInsensitive() { + f.writeCaseInsensitive(stmt, columnName) + return + } + f.Filter.Write(stmt, columnName) +} + +func (f *TextFilter[T]) writeCaseInsensitive(stmt *Statement, columnName string) { + stmt.WriteString("LOWER(") + stmt.WriteString(columnName) + stmt.WriteString(") ") + stmt.WriteString(f.comp.String()) + stmt.WriteRune(' ') + f.writeArg(stmt) +} + +func (f *TextFilter[T]) writeArg(stmt *Statement) { + // TODO: condition must know if it's args are named parameters or not + // var v any = f.value + // workaround for placeholder + // if placeholder, ok := v.(placeholder); ok { + // stmt.Builder.WriteString(" LOWER(") + // stmt.WriteArg(placeholder) + // stmt.Builder.WriteString(")") + // } + stmt.WriteArg(strings.ToLower(fmt.Sprint(f.value))) +} + +type textCompare uint8 + +const ( + textEqual textCompare = iota + textUnequal + textEqualInsensitive + textUnequalInsensitive + textStartsWith + textStartsWithInsensitive + textEndsWith + textEndsWithInsensitive + textContains + textContainsInsensitive +) + +func (c textCompare) String() string { + switch c { + case textEqual, textEqualInsensitive: + return "=" + case textUnequal, textUnequalInsensitive: + return "<>" + case textStartsWith, textStartsWithInsensitive, textEndsWith, textEndsWithInsensitive, textContains, textContainsInsensitive: + return "LIKE" + default: + logging.WithFields("compare", c).Panic("comparison type not implemented") + return "" + } +} + +func (c textCompare) isInsensitive() bool { + return c == textEqualInsensitive || + c == textStartsWithInsensitive || + c == textEndsWithInsensitive || + c == textContainsInsensitive +} + +type text interface { + ~string + // TODO: condition must know if it's args are named parameters or not + // ~string | placeholder +} diff --git a/internal/v2/database/text_filter_test.go b/internal/v2/database/text_filter_test.go new file mode 100644 index 0000000000..e5365c8d66 --- /dev/null +++ b/internal/v2/database/text_filter_test.go @@ -0,0 +1,351 @@ +package database + +import ( + "reflect" + "testing" +) + +func TestNewTextEqual(t *testing.T) { + type args struct { + constructor func(t string) *TextFilter[string] + t string + } + tests := []struct { + name string + args args + want *TextFilter[string] + }{ + { + name: "NewTextEqual", + args: args{ + constructor: NewTextEqual[string], + t: "text", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textEqual, + value: "text", + }, + }, + }, + { + name: "NewTextUnequal", + args: args{ + constructor: NewTextUnequal[string], + t: "text", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textUnequal, + value: "text", + }, + }, + }, + { + name: "NewTextEqualInsensitive", + args: args{ + constructor: NewTextEqualInsensitive[string], + t: "text", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textEqualInsensitive, + value: "text", + }, + }, + }, + { + name: "NewTextEqualInsensitive check lower", + args: args{ + constructor: NewTextEqualInsensitive[string], + t: "tEXt", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textEqualInsensitive, + value: "text", + }, + }, + }, + { + name: "NewTextUnequalInsensitive", + args: args{ + constructor: NewTextUnequalInsensitive[string], + t: "text", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textUnequalInsensitive, + value: "text", + }, + }, + }, + { + name: "NewTextUnequalInsensitive check lower", + args: args{ + constructor: NewTextUnequalInsensitive[string], + t: "tEXt", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textUnequalInsensitive, + value: "text", + }, + }, + }, + { + name: "NewTextStartsWith", + args: args{ + constructor: NewTextStartsWith[string], + t: "text", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textStartsWith, + value: "text", + }, + }, + }, + { + name: "NewTextStartsWithInsensitive", + args: args{ + constructor: NewTextStartsWithInsensitive[string], + t: "text", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textStartsWithInsensitive, + value: "text", + }, + }, + }, + { + name: "NewTextStartsWithInsensitive check lower", + args: args{ + constructor: NewTextStartsWithInsensitive[string], + t: "tEXt", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textStartsWithInsensitive, + value: "text", + }, + }, + }, + { + name: "NewTextEndsWith", + args: args{ + constructor: NewTextEndsWith[string], + t: "text", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textEndsWith, + value: "text", + }, + }, + }, + { + name: "NewTextEndsWithInsensitive", + args: args{ + constructor: NewTextEndsWithInsensitive[string], + t: "text", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textEndsWithInsensitive, + value: "text", + }, + }, + }, + { + name: "NewTextEndsWithInsensitive check lower", + args: args{ + constructor: NewTextEndsWithInsensitive[string], + t: "tEXt", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textEndsWithInsensitive, + value: "text", + }, + }, + }, + { + name: "NewTextContains", + args: args{ + constructor: NewTextContains[string], + t: "text", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textContains, + value: "text", + }, + }, + }, + { + name: "NewTextContainsInsensitive", + args: args{ + constructor: NewTextContainsInsensitive[string], + t: "text", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textContainsInsensitive, + value: "text", + }, + }, + }, + { + name: "NewTextContainsInsensitive to lower", + args: args{ + constructor: NewTextContainsInsensitive[string], + t: "tEXt", + }, + want: &TextFilter[string]{ + Filter: Filter[textCompare, string]{ + comp: textContainsInsensitive, + value: "text", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.args.constructor(tt.args.t); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewTextEqual() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTextConditionWrite(t *testing.T) { + type args struct { + constructor func(t string) *TextFilter[string] + t string + } + tests := []struct { + name string + args args + want wantQuery + }{ + { + name: "NewTextEqual", + args: args{ + constructor: NewTextEqual[string], + t: "text", + }, + want: wantQuery{ + query: "test = $1", + args: []any{"text"}, + }, + }, + { + name: "NewTextUnequal", + args: args{ + constructor: NewTextUnequal[string], + t: "text", + }, + want: wantQuery{ + query: "test <> $1", + args: []any{"text"}, + }, + }, + { + name: "NewTextEqualInsensitive", + args: args{ + constructor: NewTextEqualInsensitive[string], + t: "text", + }, + want: wantQuery{ + query: "LOWER(test) = $1", + args: []any{"text"}, + }, + }, + { + name: "NewTextUnequalInsensitive", + args: args{ + constructor: NewTextUnequalInsensitive[string], + t: "text", + }, + want: wantQuery{ + query: "test <> $1", + args: []any{"text"}, + }, + }, + { + name: "NewTextStartsWith", + args: args{ + constructor: NewTextStartsWith[string], + t: "text", + }, + want: wantQuery{ + query: "test LIKE $1", + args: []any{"text"}, + }, + }, + { + name: "NewTextStartsWithInsensitive", + args: args{ + constructor: NewTextStartsWithInsensitive[string], + t: "text", + }, + want: wantQuery{ + query: "LOWER(test) LIKE $1", + args: []any{"text"}, + }, + }, + { + name: "NewTextEndsWith", + args: args{ + constructor: NewTextEndsWith[string], + t: "text", + }, + want: wantQuery{ + query: "test LIKE $1", + args: []any{"text"}, + }, + }, + { + name: "NewTextEndsWithInsensitive", + args: args{ + constructor: NewTextEndsWithInsensitive[string], + t: "text", + }, + want: wantQuery{ + query: "LOWER(test) LIKE $1", + args: []any{"text"}, + }, + }, + { + name: "NewTextContains", + args: args{ + constructor: NewTextContains[string], + t: "text", + }, + want: wantQuery{ + query: "test LIKE $1", + args: []any{"text"}, + }, + }, + { + name: "NewTextContainsInsensitive", + args: args{ + constructor: NewTextContainsInsensitive[string], + t: "text", + }, + want: wantQuery{ + query: "LOWER(test) LIKE $1", + args: []any{"text"}, + }, + }, + } + for _, tt := range tests { + var stmt Statement + t.Run(tt.name, func(t *testing.T) { + tt.args.constructor(tt.args.t).Write(&stmt, "test") + assertQuery(t, &stmt, tt.want) + }) + } +}