mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-12 02:54:20 +00:00
refactor(v2): init database package (#7802)
This commit is contained in:
parent
207b20ff0f
commit
5131328291
33
internal/v2/database/filter.go
Normal file
33
internal/v2/database/filter.go
Normal file
@ -0,0 +1,33 @@
|
||||
package database
|
||||
|
||||
type Condition interface {
|
||||
Write(stmt *Statement, columnName string)
|
||||
}
|
||||
|
||||
type Filter[C compare, V value] struct {
|
||||
comp C
|
||||
value V
|
||||
}
|
||||
|
||||
func (f Filter[C, V]) Write(stmt *Statement, columnName string) {
|
||||
prepareWrite(stmt, columnName, f.comp)
|
||||
stmt.WriteArg(f.value)
|
||||
}
|
||||
|
||||
func prepareWrite[C compare](stmt *Statement, columnName string, comp C) {
|
||||
stmt.WriteString(columnName)
|
||||
stmt.WriteRune(' ')
|
||||
stmt.WriteString(comp.String())
|
||||
stmt.WriteRune(' ')
|
||||
}
|
||||
|
||||
type compare interface {
|
||||
numberCompare | textCompare | listCompare
|
||||
String() string
|
||||
}
|
||||
|
||||
type value interface {
|
||||
number | text
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// number | text | placeholder
|
||||
}
|
57
internal/v2/database/list_filter.go
Normal file
57
internal/v2/database/list_filter.go
Normal file
@ -0,0 +1,57 @@
|
||||
package database
|
||||
|
||||
import "github.com/zitadel/logging"
|
||||
|
||||
type ListFilter[V value] struct {
|
||||
comp listCompare
|
||||
list []V
|
||||
}
|
||||
|
||||
func NewListEquals[V value](list ...V) *ListFilter[V] {
|
||||
return newListFilter[V](listEqual, list)
|
||||
}
|
||||
|
||||
func NewListContains[V value](list ...V) *ListFilter[V] {
|
||||
return newListFilter[V](listContain, list)
|
||||
}
|
||||
|
||||
func NewListNotContains[V value](list ...V) *ListFilter[V] {
|
||||
return newListFilter[V](listNotContain, list)
|
||||
}
|
||||
|
||||
func newListFilter[V value](comp listCompare, list []V) *ListFilter[V] {
|
||||
return &ListFilter[V]{
|
||||
comp: comp,
|
||||
list: list,
|
||||
}
|
||||
}
|
||||
|
||||
func (f ListFilter[V]) Write(stmt *Statement, columnName string) {
|
||||
if len(f.list) == 0 {
|
||||
logging.WithFields("column", columnName).Debug("skip list filter because no entries defined")
|
||||
return
|
||||
}
|
||||
if f.comp == listNotContain {
|
||||
stmt.WriteString("NOT(")
|
||||
}
|
||||
stmt.WriteString(columnName)
|
||||
stmt.WriteString(" = ")
|
||||
if f.comp != listEqual {
|
||||
stmt.WriteString("ANY(")
|
||||
}
|
||||
stmt.WriteArg(f.list)
|
||||
if f.comp != listEqual {
|
||||
stmt.WriteString(")")
|
||||
}
|
||||
if f.comp == listNotContain {
|
||||
stmt.WriteRune(')')
|
||||
}
|
||||
}
|
||||
|
||||
type listCompare uint8
|
||||
|
||||
const (
|
||||
listEqual listCompare = iota
|
||||
listContain
|
||||
listNotContain
|
||||
)
|
122
internal/v2/database/list_filter_test.go
Normal file
122
internal/v2/database/list_filter_test.go
Normal file
@ -0,0 +1,122 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewListConstructors(t *testing.T) {
|
||||
type args struct {
|
||||
constructor func(t ...string) *ListFilter[string]
|
||||
t []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *ListFilter[string]
|
||||
}{
|
||||
{
|
||||
name: "NewListEquals",
|
||||
args: args{
|
||||
constructor: NewListEquals[string],
|
||||
t: []string{"as", "df"},
|
||||
},
|
||||
want: &ListFilter[string]{
|
||||
comp: listEqual,
|
||||
list: []string{"as", "df"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewListContains",
|
||||
args: args{
|
||||
constructor: NewListContains[string],
|
||||
t: []string{"as", "df"},
|
||||
},
|
||||
want: &ListFilter[string]{
|
||||
comp: listContain,
|
||||
list: []string{"as", "df"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewListNotContains",
|
||||
args: args{
|
||||
constructor: NewListNotContains[string],
|
||||
t: []string{"as", "df"},
|
||||
},
|
||||
want: &ListFilter[string]{
|
||||
comp: listNotContain,
|
||||
list: []string{"as", "df"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.args.constructor(tt.args.t...); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("number constructor = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewListConditionWrite(t *testing.T) {
|
||||
type args struct {
|
||||
constructor func(t ...string) *ListFilter[string]
|
||||
t []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want wantQuery
|
||||
}{
|
||||
{
|
||||
name: "ListEquals",
|
||||
args: args{
|
||||
constructor: NewListEquals[string],
|
||||
t: []string{"as", "df"},
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test = $1",
|
||||
args: []any{[]string{"as", "df"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ListContains",
|
||||
args: args{
|
||||
constructor: NewListContains[string],
|
||||
t: []string{"as", "df"},
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test = ANY($1)",
|
||||
args: []any{[]string{"as", "df"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ListNotContains",
|
||||
args: args{
|
||||
constructor: NewListNotContains[string],
|
||||
t: []string{"as", "df"},
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "NOT(test = ANY($1))",
|
||||
args: []any{[]string{"as", "df"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty list",
|
||||
args: args{
|
||||
constructor: NewListNotContains[string],
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
var stmt Statement
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.args.constructor(tt.args.t...).Write(&stmt, "test")
|
||||
assertQuery(t, &stmt, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
139
internal/v2/database/mock/sql_mock.go
Normal file
139
internal/v2/database/mock/sql_mock.go
Normal file
@ -0,0 +1,139 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
type SQLMock struct {
|
||||
DB *sql.DB
|
||||
mock sqlmock.Sqlmock
|
||||
}
|
||||
|
||||
type Expectation func(m sqlmock.Sqlmock)
|
||||
|
||||
func NewSQLMock(t *testing.T, expectations ...Expectation) *SQLMock {
|
||||
db, mock, err := sqlmock.New(
|
||||
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual),
|
||||
sqlmock.ValueConverterOption(new(TypeConverter)),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal("create mock failed", err)
|
||||
}
|
||||
|
||||
for _, expectation := range expectations {
|
||||
expectation(mock)
|
||||
}
|
||||
|
||||
return &SQLMock{
|
||||
DB: db,
|
||||
mock: mock,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SQLMock) Assert(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
if err := m.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations not met: %v", err)
|
||||
}
|
||||
|
||||
m.DB.Close()
|
||||
}
|
||||
|
||||
func ExpectBegin(err error) Expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectBegin()
|
||||
if err != nil {
|
||||
e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExpectCommit(err error) Expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectCommit()
|
||||
if err != nil {
|
||||
e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ExecOpt func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec
|
||||
|
||||
func WithExecArgs(args ...driver.Value) ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WithArgs(args...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithExecErr(err error) ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func WithExecNoRowsAffected() ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WillReturnResult(driver.ResultNoRows)
|
||||
}
|
||||
}
|
||||
|
||||
func WithExecRowsAffected(affected driver.RowsAffected) ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WillReturnResult(affected)
|
||||
}
|
||||
}
|
||||
|
||||
func ExpectExec(stmt string, opts ...ExecOpt) Expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectExec(stmt)
|
||||
for _, opt := range opts {
|
||||
e = opt(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type QueryOpt func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery
|
||||
|
||||
func WithQueryArgs(args ...driver.Value) QueryOpt {
|
||||
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
return e.WithArgs(args...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithQueryErr(err error) QueryOpt {
|
||||
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
return e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func WithQueryResult(columns []string, rows [][]driver.Value) QueryOpt {
|
||||
return func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
mockedRows := m.NewRows(columns)
|
||||
for _, row := range rows {
|
||||
mockedRows = mockedRows.AddRow(row...)
|
||||
}
|
||||
return e.WillReturnRows(mockedRows)
|
||||
}
|
||||
}
|
||||
|
||||
func ExpectQuery(stmt string, opts ...QueryOpt) Expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectQuery(stmt)
|
||||
for _, opt := range opts {
|
||||
e = opt(m, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type AnyType[T interface{}] struct{}
|
||||
|
||||
// Match satisfies sqlmock.Argument interface
|
||||
func (a AnyType[T]) Match(v driver.Value) bool {
|
||||
return reflect.TypeOf(new(T)).Elem().Kind().String() == reflect.TypeOf(v).Kind().String()
|
||||
}
|
78
internal/v2/database/mock/type_converter.go
Normal file
78
internal/v2/database/mock/type_converter.go
Normal file
@ -0,0 +1,78 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var _ driver.ValueConverter = (*TypeConverter)(nil)
|
||||
|
||||
type TypeConverter struct{}
|
||||
|
||||
// ConvertValue converts a value to a driver Value.
|
||||
func (s TypeConverter) ConvertValue(v any) (driver.Value, error) {
|
||||
if driver.IsValue(v) {
|
||||
return v, nil
|
||||
}
|
||||
value := reflect.ValueOf(v)
|
||||
|
||||
if rawMessage, ok := v.(json.RawMessage); ok {
|
||||
return convertBytes(rawMessage), nil
|
||||
}
|
||||
|
||||
if value.Kind() == reflect.Slice {
|
||||
//nolint: exhaustive
|
||||
// only defined types
|
||||
switch value.Type().Elem().Kind() {
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
return convertSigned(value), nil
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
return convertUnsigned(value), nil
|
||||
case reflect.String:
|
||||
return convertText(value), nil
|
||||
}
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// converts a text array to valid pgx v5 representation
|
||||
func convertSigned(array reflect.Value) string {
|
||||
slice := make([]string, array.Len())
|
||||
for i := 0; i < array.Len(); i++ {
|
||||
slice[i] = strconv.FormatInt(array.Index(i).Int(), 10)
|
||||
}
|
||||
|
||||
return "{" + strings.Join(slice, ",") + "}"
|
||||
}
|
||||
|
||||
// converts a text array to valid pgx v5 representation
|
||||
func convertUnsigned(array reflect.Value) string {
|
||||
slice := make([]string, array.Len())
|
||||
for i := 0; i < array.Len(); i++ {
|
||||
slice[i] = strconv.FormatUint(array.Index(i).Uint(), 10)
|
||||
}
|
||||
|
||||
return "{" + strings.Join(slice, ",") + "}"
|
||||
}
|
||||
|
||||
// converts a text array to valid pgx v5 representation
|
||||
func convertText(array reflect.Value) string {
|
||||
slice := make([]string, array.Len())
|
||||
for i := 0; i < array.Len(); i++ {
|
||||
slice[i] = array.Index(i).String()
|
||||
}
|
||||
|
||||
return "{" + strings.Join(slice, ",") + "}"
|
||||
}
|
||||
|
||||
func convertBytes(array []byte) string {
|
||||
var builder strings.Builder
|
||||
builder.Grow(hex.EncodedLen(len(array)) + 4)
|
||||
builder.WriteString(`\x`)
|
||||
builder.Write(hex.AppendEncode(nil, array))
|
||||
return builder.String()
|
||||
}
|
100
internal/v2/database/number_filter.go
Normal file
100
internal/v2/database/number_filter.go
Normal file
@ -0,0 +1,100 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type NumberFilter[N number] struct {
|
||||
Filter[numberCompare, N]
|
||||
}
|
||||
|
||||
func NewNumberEquals[N number](n N) *NumberFilter[N] {
|
||||
return newNumberFilter(numberEqual, n)
|
||||
}
|
||||
|
||||
func NewNumberAtLeast[N number](n N) *NumberFilter[N] {
|
||||
return newNumberFilter(numberAtLeast, n)
|
||||
}
|
||||
|
||||
func NewNumberAtMost[N number](n N) *NumberFilter[N] {
|
||||
return newNumberFilter(numberAtMost, n)
|
||||
}
|
||||
|
||||
func NewNumberGreater[N number](n N) *NumberFilter[N] {
|
||||
return newNumberFilter(numberGreater, n)
|
||||
}
|
||||
|
||||
func NewNumberLess[N number](n N) *NumberFilter[N] {
|
||||
return newNumberFilter(numberLess, n)
|
||||
}
|
||||
|
||||
func NewNumberUnequal[N number](n N) *NumberFilter[N] {
|
||||
return newNumberFilter(numberUnequal, n)
|
||||
}
|
||||
|
||||
func newNumberFilter[N number](comp numberCompare, n N) *NumberFilter[N] {
|
||||
return &NumberFilter[N]{
|
||||
Filter: Filter[numberCompare, N]{
|
||||
comp: comp,
|
||||
value: n,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NumberBetweenFilter combines [AtLeast] and [AtMost] comparisons
|
||||
type NumberBetweenFilter[N number] struct {
|
||||
min, max N
|
||||
}
|
||||
|
||||
func NewNumberBetween[N number](min, max N) *NumberBetweenFilter[N] {
|
||||
return &NumberBetweenFilter[N]{
|
||||
min: min,
|
||||
max: max,
|
||||
}
|
||||
}
|
||||
|
||||
func (f NumberBetweenFilter[N]) Write(stmt *Statement, columnName string) {
|
||||
NewNumberAtLeast[N](f.min).Write(stmt, columnName)
|
||||
stmt.WriteString(" AND ")
|
||||
NewNumberAtMost[N](f.max).Write(stmt, columnName)
|
||||
}
|
||||
|
||||
type numberCompare uint8
|
||||
|
||||
const (
|
||||
numberEqual numberCompare = iota
|
||||
numberAtLeast
|
||||
numberAtMost
|
||||
numberGreater
|
||||
numberLess
|
||||
numberUnequal
|
||||
)
|
||||
|
||||
func (c numberCompare) String() string {
|
||||
switch c {
|
||||
case numberEqual:
|
||||
return "="
|
||||
case numberAtLeast:
|
||||
return ">="
|
||||
case numberAtMost:
|
||||
return "<="
|
||||
case numberGreater:
|
||||
return ">"
|
||||
case numberLess:
|
||||
return "<"
|
||||
case numberUnequal:
|
||||
return "<>"
|
||||
default:
|
||||
logging.WithFields("compare", c).Panic("comparison type not implemented")
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
type number interface {
|
||||
constraints.Integer | constraints.Float | time.Time
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// constraints.Integer | constraints.Float | time.Time | placeholder
|
||||
}
|
216
internal/v2/database/number_filter_test.go
Normal file
216
internal/v2/database/number_filter_test.go
Normal file
@ -0,0 +1,216 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewNumberConstructors(t *testing.T) {
|
||||
type args struct {
|
||||
constructor func(t int8) *NumberFilter[int8]
|
||||
t int8
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *NumberFilter[int8]
|
||||
}{
|
||||
{
|
||||
name: "NewNumberEqual",
|
||||
args: args{
|
||||
constructor: NewNumberEquals[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: &NumberFilter[int8]{
|
||||
Filter: Filter[numberCompare, int8]{
|
||||
comp: numberEqual,
|
||||
value: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewNumberAtLeast",
|
||||
args: args{
|
||||
constructor: NewNumberAtLeast[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: &NumberFilter[int8]{
|
||||
Filter: Filter[numberCompare, int8]{
|
||||
comp: numberAtLeast,
|
||||
value: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewNumberAtMost",
|
||||
args: args{
|
||||
constructor: NewNumberAtMost[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: &NumberFilter[int8]{
|
||||
Filter: Filter[numberCompare, int8]{
|
||||
comp: numberAtMost,
|
||||
value: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewNumberGreater",
|
||||
args: args{
|
||||
constructor: NewNumberGreater[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: &NumberFilter[int8]{
|
||||
Filter: Filter[numberCompare, int8]{
|
||||
comp: numberGreater,
|
||||
value: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewNumberLess",
|
||||
args: args{
|
||||
constructor: NewNumberLess[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: &NumberFilter[int8]{
|
||||
Filter: Filter[numberCompare, int8]{
|
||||
comp: numberLess,
|
||||
value: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewNumberUnequal",
|
||||
args: args{
|
||||
constructor: NewNumberUnequal[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: &NumberFilter[int8]{
|
||||
Filter: Filter[numberCompare, int8]{
|
||||
comp: numberUnequal,
|
||||
value: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.args.constructor(tt.args.t); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("number constructor = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNumberConditionWrite(t *testing.T) {
|
||||
type args struct {
|
||||
constructor func(t int8) *NumberFilter[int8]
|
||||
t int8
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want wantQuery
|
||||
}{
|
||||
{
|
||||
name: "NewNumberEqual",
|
||||
args: args{
|
||||
constructor: NewNumberEquals[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test = $1",
|
||||
args: []any{int8(10)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewNumberAtLeast",
|
||||
args: args{
|
||||
constructor: NewNumberAtLeast[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test >= $1",
|
||||
args: []any{int8(10)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewNumberAtMost",
|
||||
args: args{
|
||||
constructor: NewNumberAtMost[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test <= $1",
|
||||
args: []any{int8(10)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewNumberGreater",
|
||||
args: args{
|
||||
constructor: NewNumberGreater[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test > $1",
|
||||
args: []any{int8(10)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewNumberLess",
|
||||
args: args{
|
||||
constructor: NewNumberLess[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test < $1",
|
||||
args: []any{int8(10)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewNumberUnequal",
|
||||
args: args{
|
||||
constructor: NewNumberUnequal[int8],
|
||||
t: 10,
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test <> $1",
|
||||
args: []any{int8(10)},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
var stmt Statement
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.args.constructor(tt.args.t).Write(&stmt, "test")
|
||||
assertQuery(t, &stmt, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumberBetween(t *testing.T) {
|
||||
filter := NewNumberBetween[int8](10, 20)
|
||||
|
||||
if !reflect.DeepEqual(filter, &NumberBetweenFilter[int8]{min: 10, max: 20}) {
|
||||
t.Errorf("unexpected filter: %v", filter)
|
||||
}
|
||||
|
||||
var stmt Statement
|
||||
filter.Write(&stmt, "test")
|
||||
if stmt.String() != "test >= $1 AND test <= $2" {
|
||||
t.Errorf("unexpected query: got: %q", stmt.String())
|
||||
}
|
||||
|
||||
if len(stmt.Args()) != 2 {
|
||||
t.Errorf("unexpected length of args: got %d", len(stmt.Args()))
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(int8(10), stmt.Args()[0]) {
|
||||
t.Errorf("unexpected arg at position 0: want: 10, got: %v", stmt.Args()[0])
|
||||
}
|
||||
if !reflect.DeepEqual(int8(20), stmt.Args()[1]) {
|
||||
t.Errorf("unexpected arg at position 1: want: 20, got: %v", stmt.Args()[1])
|
||||
}
|
||||
}
|
17
internal/v2/database/pagination.go
Normal file
17
internal/v2/database/pagination.go
Normal file
@ -0,0 +1,17 @@
|
||||
package database
|
||||
|
||||
type Pagination struct {
|
||||
Limit uint32
|
||||
Offset uint32
|
||||
}
|
||||
|
||||
func (p *Pagination) Write(stmt *Statement) {
|
||||
if p.Limit > 0 {
|
||||
stmt.WriteString(" LIMIT ")
|
||||
stmt.WriteArg(p.Limit)
|
||||
}
|
||||
if p.Offset > 0 {
|
||||
stmt.WriteString(" OFFSET ")
|
||||
stmt.WriteArg(p.Offset)
|
||||
}
|
||||
}
|
73
internal/v2/database/pagination_test.go
Normal file
73
internal/v2/database/pagination_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPagination_Write(t *testing.T) {
|
||||
type fields struct {
|
||||
Limit uint32
|
||||
Offset uint32
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want wantQuery
|
||||
}{
|
||||
{
|
||||
name: "no values",
|
||||
fields: fields{
|
||||
Limit: 0,
|
||||
Offset: 0,
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "",
|
||||
args: []any{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "limit",
|
||||
fields: fields{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
want: wantQuery{
|
||||
query: " LIMIT $1",
|
||||
args: []any{uint32(10)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "offset",
|
||||
fields: fields{
|
||||
Limit: 0,
|
||||
Offset: 10,
|
||||
},
|
||||
want: wantQuery{
|
||||
query: " OFFSET $1",
|
||||
args: []any{uint32(10)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "both",
|
||||
fields: fields{
|
||||
Limit: 10,
|
||||
Offset: 10,
|
||||
},
|
||||
want: wantQuery{
|
||||
query: " LIMIT $1 OFFSET $2",
|
||||
args: []any{uint32(10), uint32(10)},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
var stmt Statement
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Pagination{
|
||||
Limit: tt.fields.Limit,
|
||||
Offset: tt.fields.Offset,
|
||||
}
|
||||
p.Write(&stmt)
|
||||
assertQuery(t, &stmt, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
75
internal/v2/database/sql_helper.go
Normal file
75
internal/v2/database/sql_helper.go
Normal file
@ -0,0 +1,75 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
type Tx interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
func CloseTx(tx Tx, err error) error {
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback()
|
||||
logging.OnError(rollbackErr).Debug("unable to rollback")
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
type DestMapper[R any] func(index int, scan func(dest ...any) error) (*R, error)
|
||||
|
||||
type Rows interface {
|
||||
Close() error
|
||||
Err() error
|
||||
Next() bool
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func MapRows[R any](rows Rows, mapper DestMapper[R]) (result []*R, err error) {
|
||||
defer func() {
|
||||
closeErr := rows.Close()
|
||||
logging.OnError(closeErr).Debug("unable to close rows")
|
||||
|
||||
if err == nil && rows.Err() != nil {
|
||||
result = nil
|
||||
err = rows.Err()
|
||||
}
|
||||
}()
|
||||
for i := 0; rows.Next(); i++ {
|
||||
res, err := mapper(i, rows.Scan)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, res)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func MapRowsToObject(rows Rows, mapper func(scan func(dest ...any) error) error) (err error) {
|
||||
defer func() {
|
||||
closeErr := rows.Close()
|
||||
logging.OnError(closeErr).Debug("unable to close rows")
|
||||
|
||||
if err == nil && rows.Err() != nil {
|
||||
err = rows.Err()
|
||||
}
|
||||
}()
|
||||
for rows.Next() {
|
||||
err = mapper(rows.Scan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Querier interface {
|
||||
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
|
||||
}
|
512
internal/v2/database/sql_helper_test.go
Normal file
512
internal/v2/database/sql_helper_test.go
Normal file
@ -0,0 +1,512 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCloseTx(t *testing.T) {
|
||||
type args struct {
|
||||
tx *testTx
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
assertErr func(t *testing.T, err error) bool
|
||||
}{
|
||||
{
|
||||
name: "exec err",
|
||||
args: args{
|
||||
tx: &testTx{
|
||||
rollback: execution{
|
||||
shouldExecute: true,
|
||||
},
|
||||
},
|
||||
err: errExec,
|
||||
},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errExec)
|
||||
if !is {
|
||||
t.Errorf("execution error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exec err and rollback err",
|
||||
args: args{
|
||||
tx: &testTx{
|
||||
rollback: execution{
|
||||
err: true,
|
||||
shouldExecute: true,
|
||||
},
|
||||
},
|
||||
err: errExec,
|
||||
},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errExec)
|
||||
if !is {
|
||||
t.Errorf("execution error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "commit Err",
|
||||
args: args{
|
||||
tx: &testTx{
|
||||
commit: execution{
|
||||
err: true,
|
||||
shouldExecute: true,
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errCommit)
|
||||
if !is {
|
||||
t.Errorf("commit error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no err",
|
||||
args: args{
|
||||
tx: &testTx{
|
||||
commit: execution{
|
||||
shouldExecute: true,
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := err == nil
|
||||
if !is {
|
||||
t.Errorf("no error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := CloseTx(tt.args.tx, tt.args.err)
|
||||
tt.assertErr(t, err)
|
||||
tt.args.tx.assert(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapRows(t *testing.T) {
|
||||
type args struct {
|
||||
rows *testRows
|
||||
mapper DestMapper[string]
|
||||
}
|
||||
var emptyString string
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantResult []*string
|
||||
assertErr func(t *testing.T, err error) bool
|
||||
}{
|
||||
{
|
||||
name: "no rows, close err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
closeErr: true,
|
||||
},
|
||||
mapper: nil,
|
||||
},
|
||||
wantResult: nil,
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errClose)
|
||||
if !is {
|
||||
t.Errorf("close error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no rows, close err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
hasErr: true,
|
||||
},
|
||||
mapper: nil,
|
||||
},
|
||||
wantResult: nil,
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errRows)
|
||||
if !is {
|
||||
t.Errorf("rows error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "scan err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
scanErr: true,
|
||||
nextCount: 1,
|
||||
},
|
||||
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
|
||||
var s string
|
||||
if err := scan(&s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
},
|
||||
},
|
||||
wantResult: nil,
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errScan)
|
||||
if !is {
|
||||
t.Errorf("scan error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exec err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
nextCount: 1,
|
||||
},
|
||||
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
|
||||
return nil, errExec
|
||||
},
|
||||
},
|
||||
wantResult: nil,
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errExec)
|
||||
if !is {
|
||||
t.Errorf("exec error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exec err, close err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
closeErr: true,
|
||||
nextCount: 1,
|
||||
},
|
||||
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
|
||||
return nil, errExec
|
||||
},
|
||||
},
|
||||
wantResult: nil,
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errExec)
|
||||
if !is {
|
||||
t.Errorf("exec error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "rows err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
nextCount: 1,
|
||||
hasErr: true,
|
||||
},
|
||||
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
|
||||
var s string
|
||||
if err := scan(&s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
},
|
||||
},
|
||||
wantResult: nil,
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errRows)
|
||||
if !is {
|
||||
t.Errorf("rows error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
nextCount: 1,
|
||||
},
|
||||
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
|
||||
var s string
|
||||
if err := scan(&s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &s, nil
|
||||
},
|
||||
},
|
||||
wantResult: []*string{&emptyString},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := err == nil
|
||||
if !is {
|
||||
t.Errorf("no error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotResult, err := MapRows(tt.args.rows, tt.args.mapper)
|
||||
tt.assertErr(t, err)
|
||||
if !reflect.DeepEqual(gotResult, tt.wantResult) {
|
||||
t.Errorf("MapRows() = %v, want %v", gotResult, tt.wantResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapRowsToObject(t *testing.T) {
|
||||
type args struct {
|
||||
rows *testRows
|
||||
mapper func(scan func(dest ...any) error) error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
assertErr func(t *testing.T, err error) bool
|
||||
}{
|
||||
{
|
||||
name: "no rows, close err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
closeErr: true,
|
||||
},
|
||||
mapper: nil,
|
||||
},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errClose)
|
||||
if !is {
|
||||
t.Errorf("close error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no rows, close err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
hasErr: true,
|
||||
},
|
||||
mapper: nil,
|
||||
},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errRows)
|
||||
if !is {
|
||||
t.Errorf("rows error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "scan err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
scanErr: true,
|
||||
nextCount: 1,
|
||||
},
|
||||
mapper: func(scan func(dest ...any) error) error {
|
||||
var s string
|
||||
if err := scan(&s); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errScan)
|
||||
if !is {
|
||||
t.Errorf("scan error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exec err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
nextCount: 1,
|
||||
},
|
||||
mapper: func(scan func(dest ...any) error) error {
|
||||
return errExec
|
||||
},
|
||||
},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errExec)
|
||||
if !is {
|
||||
t.Errorf("exec error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exec err, close err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
closeErr: true,
|
||||
nextCount: 1,
|
||||
},
|
||||
mapper: func(scan func(dest ...any) error) error {
|
||||
return errExec
|
||||
},
|
||||
},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errExec)
|
||||
if !is {
|
||||
t.Errorf("exec error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "rows err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
nextCount: 1,
|
||||
hasErr: true,
|
||||
},
|
||||
mapper: func(scan func(dest ...any) error) error {
|
||||
var s string
|
||||
return scan(&s)
|
||||
},
|
||||
},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := errors.Is(err, errRows)
|
||||
if !is {
|
||||
t.Errorf("rows error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no err",
|
||||
args: args{
|
||||
rows: &testRows{
|
||||
nextCount: 1,
|
||||
},
|
||||
mapper: func(scan func(dest ...any) error) error {
|
||||
var s string
|
||||
return scan(&s)
|
||||
},
|
||||
},
|
||||
assertErr: func(t *testing.T, err error) bool {
|
||||
is := err == nil
|
||||
if !is {
|
||||
t.Errorf("no error expected, got: %v", err)
|
||||
}
|
||||
return is
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := MapRowsToObject(tt.args.rows, tt.args.mapper)
|
||||
tt.assertErr(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var _ Tx = (*testTx)(nil)
|
||||
|
||||
type testTx struct {
|
||||
commit, rollback execution
|
||||
}
|
||||
|
||||
type execution struct {
|
||||
err bool
|
||||
didExecute bool
|
||||
shouldExecute bool
|
||||
}
|
||||
|
||||
var (
|
||||
errCommit = errors.New("commit err")
|
||||
errRollback = errors.New("rollback err")
|
||||
errExec = errors.New("exec err")
|
||||
)
|
||||
|
||||
// Commit implements Tx.
|
||||
func (t *testTx) Commit() error {
|
||||
t.commit.didExecute = true
|
||||
if t.commit.err {
|
||||
return errCommit
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rollback implements Tx.
|
||||
func (t *testTx) Rollback() error {
|
||||
t.rollback.didExecute = true
|
||||
if t.rollback.err {
|
||||
return errRollback
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *testTx) assert(t *testing.T) {
|
||||
if tx.commit.didExecute != tx.commit.shouldExecute {
|
||||
t.Errorf("unexpected execution of commit: should %v, did: %v", tx.commit.shouldExecute, tx.commit.didExecute)
|
||||
}
|
||||
if tx.rollback.didExecute != tx.rollback.shouldExecute {
|
||||
t.Errorf("unexpected execution of rollback: should %v, did: %v", tx.rollback.shouldExecute, tx.rollback.didExecute)
|
||||
}
|
||||
}
|
||||
|
||||
var _ Rows = (*testRows)(nil)
|
||||
|
||||
var (
|
||||
errClose = errors.New("err close")
|
||||
errRows = errors.New("err rows")
|
||||
errScan = errors.New("err scan")
|
||||
)
|
||||
|
||||
type testRows struct {
|
||||
closeErr bool
|
||||
scanErr bool
|
||||
hasErr bool
|
||||
nextCount int
|
||||
}
|
||||
|
||||
// Close implements Rows.
|
||||
func (t *testRows) Close() error {
|
||||
if t.closeErr {
|
||||
return errClose
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Err implements Rows.
|
||||
func (t *testRows) Err() error {
|
||||
if t.hasErr {
|
||||
return errRows
|
||||
}
|
||||
if t.closeErr {
|
||||
return errClose
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Next implements Rows.
|
||||
func (t *testRows) Next() bool {
|
||||
t.nextCount--
|
||||
return t.nextCount >= 0
|
||||
}
|
||||
|
||||
// Scan implements Rows.
|
||||
func (t *testRows) Scan(dest ...any) error {
|
||||
if t.scanErr {
|
||||
return errScan
|
||||
}
|
||||
return nil
|
||||
}
|
222
internal/v2/database/statement.go
Normal file
222
internal/v2/database/statement.go
Normal file
@ -0,0 +1,222 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
type Statement struct {
|
||||
addr *Statement
|
||||
builder strings.Builder
|
||||
|
||||
args []any
|
||||
// key is the name of the arg and value is the placeholder
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// namedArgs map[placeholder]string
|
||||
}
|
||||
|
||||
func (stmt *Statement) Args() []any {
|
||||
if stmt == nil {
|
||||
return nil
|
||||
}
|
||||
return stmt.args
|
||||
}
|
||||
|
||||
func (stmt *Statement) Reset() {
|
||||
stmt.builder.Reset()
|
||||
stmt.addr = nil
|
||||
stmt.args = nil
|
||||
}
|
||||
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// SetNamedArg sets the arg and makes it available for query construction
|
||||
// func (stmt *Statement) SetNamedArg(name placeholder, value any) (placeholder string) {
|
||||
// stmt.copyCheck()
|
||||
// stmt.args = append(stmt.args, value)
|
||||
// placeholder = fmt.Sprintf("$%d", len(stmt.args))
|
||||
// if !strings.HasPrefix(name.string, "@") {
|
||||
// name.string = "@" + name.string
|
||||
// }
|
||||
// stmt.namedArgs[name] = placeholder
|
||||
// return placeholder
|
||||
// }
|
||||
|
||||
// AppendArgs appends the args without writing it to Builder
|
||||
// if any arg is a [placeholder] it's replaced with the placeholders parameter
|
||||
func (stmt *Statement) AppendArgs(args ...any) {
|
||||
stmt.copyCheck()
|
||||
stmt.args = slices.Grow(stmt.args, len(args))
|
||||
for _, arg := range args {
|
||||
stmt.AppendArg(arg)
|
||||
}
|
||||
}
|
||||
|
||||
// AppendArg appends the arg without writing it to Builder
|
||||
// if the arg is a [placeholder] it's replaced with the placeholders parameter
|
||||
func (stmt *Statement) AppendArg(arg any) int {
|
||||
stmt.copyCheck()
|
||||
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// if namedArg, ok := arg.(sql.NamedArg); ok {
|
||||
// stmt.SetNamedArg(placeholder{namedArg.Name}, namedArg.Value)
|
||||
// return
|
||||
// }
|
||||
stmt.args = append(stmt.args, arg)
|
||||
return len(stmt.args)
|
||||
}
|
||||
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// func Placeholder(name string) placeholder {
|
||||
// return placeholder{name}
|
||||
// }
|
||||
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// type placeholder struct {
|
||||
// string
|
||||
// }
|
||||
|
||||
// WriteArgs appends the args and adds the placeholders comma separated to [stmt.Builder]
|
||||
// if any arg is a [placeholder] it's replaced with the placeholders parameter
|
||||
func (stmt *Statement) WriteArgs(args ...any) {
|
||||
stmt.copyCheck()
|
||||
stmt.args = slices.Grow(stmt.args, len(args))
|
||||
for i, arg := range args {
|
||||
if i > 0 {
|
||||
stmt.WriteString(", ")
|
||||
}
|
||||
stmt.WriteArg(arg)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteArg appends the arg and adds the placeholder to [stmt.Builder]
|
||||
// if the arg is a [placeholder] it's replaced with the placeholders parameter
|
||||
func (stmt *Statement) WriteArg(arg any) {
|
||||
stmt.copyCheck()
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// if namedPlaceholder, ok := arg.(placeholder); ok {
|
||||
// stmt.writeNamedPlaceholder(namedPlaceholder)
|
||||
// return
|
||||
// }
|
||||
placeholder := stmt.AppendArg(arg)
|
||||
stmt.WriteString("$")
|
||||
stmt.WriteString(strconv.Itoa(placeholder))
|
||||
}
|
||||
|
||||
// WriteString extends [strings.Builder.WriteString]
|
||||
// it replaces named args with the previously provided named args
|
||||
func (stmt *Statement) WriteString(s string) {
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// for name, placeholder := range stmt.namedArgs {
|
||||
// s = strings.ReplaceAll(s, name.string, placeholder)
|
||||
// }
|
||||
stmt.builder.WriteString(s)
|
||||
}
|
||||
|
||||
// WriteRune extends [strings.Builder.WriteRune]
|
||||
func (stmt *Statement) WriteRune(r rune) {
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// for name, placeholder := range stmt.namedArgs {
|
||||
// s = strings.ReplaceAll(s, name.string, placeholder)
|
||||
// }
|
||||
stmt.builder.WriteRune(r)
|
||||
}
|
||||
|
||||
// WriteByte extends [strings.Builder.WriteByte]
|
||||
func (stmt *Statement) WriteByte(b byte) {
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// for name, placeholder := range stmt.namedArgs {
|
||||
// s = strings.ReplaceAll(s, name.string, placeholder)
|
||||
// }
|
||||
err := stmt.builder.WriteByte(b)
|
||||
logging.OnError(err).Warn("unable to write bytes")
|
||||
}
|
||||
|
||||
// Write extends [strings.Builder.Write]
|
||||
// it replaces named args with the previously provided named args
|
||||
func (stmt *Statement) Write(b []byte) {
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// for name, placeholder := range stmt.namedArgs {
|
||||
// bytes.ReplaceAll(b, []byte(name.string), []byte(placeholder))
|
||||
// }
|
||||
stmt.builder.Write(b)
|
||||
}
|
||||
|
||||
// String builds the query and replaces placeholders starting with "@"
|
||||
// with the corresponding named arg placeholder
|
||||
func (stmt *Statement) String() string {
|
||||
return stmt.builder.String()
|
||||
}
|
||||
|
||||
// Debug builds the statement and replaces the placeholders with the parameters
|
||||
func (stmt *Statement) Debug() string {
|
||||
query := stmt.String()
|
||||
|
||||
for i := len(stmt.args) - 1; i >= 0; i-- {
|
||||
var argText string
|
||||
switch arg := stmt.args[i].(type) {
|
||||
case time.Time:
|
||||
argText = "'" + arg.Format("2006-01-02 15:04:05Z07:00") + "'"
|
||||
case string:
|
||||
argText = "'" + arg + "'"
|
||||
case []string:
|
||||
argText = "ARRAY["
|
||||
for i, a := range arg {
|
||||
if i > 0 {
|
||||
argText += ", "
|
||||
}
|
||||
argText += "'" + a + "'"
|
||||
}
|
||||
argText += "]"
|
||||
default:
|
||||
argText = fmt.Sprint(arg)
|
||||
}
|
||||
query = strings.ReplaceAll(query, "$"+strconv.Itoa(i+1), argText)
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// func (stmt *Statement) writeNamedPlaceholder(arg placeholder) {
|
||||
// placeholder, ok := stmt.namedArgs[arg]
|
||||
// if !ok {
|
||||
// logging.WithFields("named_placeholder", arg).Fatal("named placeholder not defined")
|
||||
// }
|
||||
// stmt.Builder.WriteString(placeholder)
|
||||
// }
|
||||
|
||||
// copyCheck allows uninitialized usage of stmt
|
||||
func (stmt *Statement) copyCheck() {
|
||||
if stmt.addr == nil {
|
||||
// This hack works around a failing of Go's escape analysis
|
||||
// that was causing b to escape and be heap allocated.
|
||||
// See issue 23382.
|
||||
// TODO: once issue 7921 is fixed, this should be reverted to
|
||||
// just "stmt.addr = stmt".
|
||||
stmt.addr = (*Statement)(noescape(unsafe.Pointer(stmt)))
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// stmt.namedArgs = make(map[placeholder]string)
|
||||
} else if stmt.addr != stmt {
|
||||
panic("statement: illegal use of non-zero Builder copied by value")
|
||||
}
|
||||
}
|
||||
|
||||
// noescape hides a pointer from escape analysis. It is the identity function
|
||||
// but escape analysis doesn't think the output depends on the input.
|
||||
// noescape is inlined and currently compiles down to zero instructions.
|
||||
// USE CAREFULLY!
|
||||
// This was copied from the runtime; see issues 23382 and 7921.
|
||||
//
|
||||
//go:nosplit
|
||||
//go:nocheckptr
|
||||
func noescape(p unsafe.Pointer) unsafe.Pointer {
|
||||
x := uintptr(p)
|
||||
//nolint: staticcheck
|
||||
return unsafe.Pointer(x ^ 0)
|
||||
}
|
73
internal/v2/database/statement_test.go
Normal file
73
internal/v2/database/statement_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStatement_WriteArgs(t *testing.T) {
|
||||
type args struct {
|
||||
args []any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want wantQuery
|
||||
}{
|
||||
{
|
||||
name: "no args",
|
||||
args: args{
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "1 arg",
|
||||
args: args{
|
||||
args: []any{"asdf"},
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "$1",
|
||||
args: []any{"asdf"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "n args",
|
||||
args: args{
|
||||
args: []any{"asdf", "jkl", 1},
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "$1, $2, $3",
|
||||
args: []any{"asdf", "jkl", 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
var stmt Statement
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
stmt.WriteArgs(tt.args.args...)
|
||||
assertQuery(t, &stmt, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type wantQuery struct {
|
||||
query string
|
||||
args []any
|
||||
}
|
||||
|
||||
func assertQuery(t *testing.T, stmt *Statement, want wantQuery) {
|
||||
if want.query != stmt.String() {
|
||||
t.Errorf("unexpected query: want: %q got: %q", want.query, stmt.String())
|
||||
}
|
||||
|
||||
if len(want.args) != len(stmt.Args()) {
|
||||
t.Errorf("unexpected length of args: want %d, got %d", len(want.args), len(stmt.Args()))
|
||||
return
|
||||
}
|
||||
|
||||
for i, wantArg := range want.args {
|
||||
if !reflect.DeepEqual(wantArg, stmt.Args()[i]) {
|
||||
t.Errorf("unexpected arg at position %d: want: %v, got: %v", i, wantArg, stmt.Args()[i])
|
||||
}
|
||||
}
|
||||
}
|
132
internal/v2/database/text_filter.go
Normal file
132
internal/v2/database/text_filter.go
Normal file
@ -0,0 +1,132 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
type TextFilter[T text] struct {
|
||||
Filter[textCompare, T]
|
||||
}
|
||||
|
||||
func NewTextEqual[T text](t T) *TextFilter[T] {
|
||||
return newTextFilter(textEqual, t)
|
||||
}
|
||||
|
||||
func NewTextUnequal[T text](t T) *TextFilter[T] {
|
||||
return newTextFilter(textUnequal, t)
|
||||
}
|
||||
|
||||
func NewTextEqualInsensitive[T text](t T) *TextFilter[string] {
|
||||
return newTextFilter(textEqualInsensitive, strings.ToLower(string(t)))
|
||||
}
|
||||
|
||||
func NewTextUnequalInsensitive[T text](t T) *TextFilter[string] {
|
||||
return newTextFilter(textUnequalInsensitive, strings.ToLower(string(t)))
|
||||
}
|
||||
|
||||
func NewTextStartsWith[T text](t T) *TextFilter[T] {
|
||||
return newTextFilter(textStartsWith, t)
|
||||
}
|
||||
|
||||
func NewTextStartsWithInsensitive[T text](t T) *TextFilter[string] {
|
||||
return newTextFilter(textStartsWithInsensitive, strings.ToLower(string(t)))
|
||||
}
|
||||
|
||||
func NewTextEndsWith[T text](t T) *TextFilter[T] {
|
||||
return newTextFilter(textEndsWith, t)
|
||||
}
|
||||
|
||||
func NewTextEndsWithInsensitive[T text](t T) *TextFilter[string] {
|
||||
return newTextFilter(textEndsWithInsensitive, strings.ToLower(string(t)))
|
||||
}
|
||||
|
||||
func NewTextContains[T text](t T) *TextFilter[T] {
|
||||
return newTextFilter(textContains, t)
|
||||
}
|
||||
|
||||
func NewTextContainsInsensitive[T text](t T) *TextFilter[string] {
|
||||
return newTextFilter(textContainsInsensitive, strings.ToLower(string(t)))
|
||||
}
|
||||
|
||||
func newTextFilter[T text](comp textCompare, t T) *TextFilter[T] {
|
||||
return &TextFilter[T]{
|
||||
Filter: Filter[textCompare, T]{
|
||||
comp: comp,
|
||||
value: t,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *TextFilter[T]) Write(stmt *Statement, columnName string) {
|
||||
if f.comp.isInsensitive() {
|
||||
f.writeCaseInsensitive(stmt, columnName)
|
||||
return
|
||||
}
|
||||
f.Filter.Write(stmt, columnName)
|
||||
}
|
||||
|
||||
func (f *TextFilter[T]) writeCaseInsensitive(stmt *Statement, columnName string) {
|
||||
stmt.WriteString("LOWER(")
|
||||
stmt.WriteString(columnName)
|
||||
stmt.WriteString(") ")
|
||||
stmt.WriteString(f.comp.String())
|
||||
stmt.WriteRune(' ')
|
||||
f.writeArg(stmt)
|
||||
}
|
||||
|
||||
func (f *TextFilter[T]) writeArg(stmt *Statement) {
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// var v any = f.value
|
||||
// workaround for placeholder
|
||||
// if placeholder, ok := v.(placeholder); ok {
|
||||
// stmt.Builder.WriteString(" LOWER(")
|
||||
// stmt.WriteArg(placeholder)
|
||||
// stmt.Builder.WriteString(")")
|
||||
// }
|
||||
stmt.WriteArg(strings.ToLower(fmt.Sprint(f.value)))
|
||||
}
|
||||
|
||||
type textCompare uint8
|
||||
|
||||
const (
|
||||
textEqual textCompare = iota
|
||||
textUnequal
|
||||
textEqualInsensitive
|
||||
textUnequalInsensitive
|
||||
textStartsWith
|
||||
textStartsWithInsensitive
|
||||
textEndsWith
|
||||
textEndsWithInsensitive
|
||||
textContains
|
||||
textContainsInsensitive
|
||||
)
|
||||
|
||||
func (c textCompare) String() string {
|
||||
switch c {
|
||||
case textEqual, textEqualInsensitive:
|
||||
return "="
|
||||
case textUnequal, textUnequalInsensitive:
|
||||
return "<>"
|
||||
case textStartsWith, textStartsWithInsensitive, textEndsWith, textEndsWithInsensitive, textContains, textContainsInsensitive:
|
||||
return "LIKE"
|
||||
default:
|
||||
logging.WithFields("compare", c).Panic("comparison type not implemented")
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (c textCompare) isInsensitive() bool {
|
||||
return c == textEqualInsensitive ||
|
||||
c == textStartsWithInsensitive ||
|
||||
c == textEndsWithInsensitive ||
|
||||
c == textContainsInsensitive
|
||||
}
|
||||
|
||||
type text interface {
|
||||
~string
|
||||
// TODO: condition must know if it's args are named parameters or not
|
||||
// ~string | placeholder
|
||||
}
|
351
internal/v2/database/text_filter_test.go
Normal file
351
internal/v2/database/text_filter_test.go
Normal file
@ -0,0 +1,351 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewTextEqual(t *testing.T) {
|
||||
type args struct {
|
||||
constructor func(t string) *TextFilter[string]
|
||||
t string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *TextFilter[string]
|
||||
}{
|
||||
{
|
||||
name: "NewTextEqual",
|
||||
args: args{
|
||||
constructor: NewTextEqual[string],
|
||||
t: "text",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textEqual,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextUnequal",
|
||||
args: args{
|
||||
constructor: NewTextUnequal[string],
|
||||
t: "text",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textUnequal,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextEqualInsensitive",
|
||||
args: args{
|
||||
constructor: NewTextEqualInsensitive[string],
|
||||
t: "text",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textEqualInsensitive,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextEqualInsensitive check lower",
|
||||
args: args{
|
||||
constructor: NewTextEqualInsensitive[string],
|
||||
t: "tEXt",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textEqualInsensitive,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextUnequalInsensitive",
|
||||
args: args{
|
||||
constructor: NewTextUnequalInsensitive[string],
|
||||
t: "text",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textUnequalInsensitive,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextUnequalInsensitive check lower",
|
||||
args: args{
|
||||
constructor: NewTextUnequalInsensitive[string],
|
||||
t: "tEXt",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textUnequalInsensitive,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextStartsWith",
|
||||
args: args{
|
||||
constructor: NewTextStartsWith[string],
|
||||
t: "text",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textStartsWith,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextStartsWithInsensitive",
|
||||
args: args{
|
||||
constructor: NewTextStartsWithInsensitive[string],
|
||||
t: "text",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textStartsWithInsensitive,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextStartsWithInsensitive check lower",
|
||||
args: args{
|
||||
constructor: NewTextStartsWithInsensitive[string],
|
||||
t: "tEXt",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textStartsWithInsensitive,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextEndsWith",
|
||||
args: args{
|
||||
constructor: NewTextEndsWith[string],
|
||||
t: "text",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textEndsWith,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextEndsWithInsensitive",
|
||||
args: args{
|
||||
constructor: NewTextEndsWithInsensitive[string],
|
||||
t: "text",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textEndsWithInsensitive,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextEndsWithInsensitive check lower",
|
||||
args: args{
|
||||
constructor: NewTextEndsWithInsensitive[string],
|
||||
t: "tEXt",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textEndsWithInsensitive,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextContains",
|
||||
args: args{
|
||||
constructor: NewTextContains[string],
|
||||
t: "text",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textContains,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextContainsInsensitive",
|
||||
args: args{
|
||||
constructor: NewTextContainsInsensitive[string],
|
||||
t: "text",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textContainsInsensitive,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextContainsInsensitive to lower",
|
||||
args: args{
|
||||
constructor: NewTextContainsInsensitive[string],
|
||||
t: "tEXt",
|
||||
},
|
||||
want: &TextFilter[string]{
|
||||
Filter: Filter[textCompare, string]{
|
||||
comp: textContainsInsensitive,
|
||||
value: "text",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.args.constructor(tt.args.t); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewTextEqual() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextConditionWrite(t *testing.T) {
|
||||
type args struct {
|
||||
constructor func(t string) *TextFilter[string]
|
||||
t string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want wantQuery
|
||||
}{
|
||||
{
|
||||
name: "NewTextEqual",
|
||||
args: args{
|
||||
constructor: NewTextEqual[string],
|
||||
t: "text",
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test = $1",
|
||||
args: []any{"text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextUnequal",
|
||||
args: args{
|
||||
constructor: NewTextUnequal[string],
|
||||
t: "text",
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test <> $1",
|
||||
args: []any{"text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextEqualInsensitive",
|
||||
args: args{
|
||||
constructor: NewTextEqualInsensitive[string],
|
||||
t: "text",
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "LOWER(test) = $1",
|
||||
args: []any{"text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextUnequalInsensitive",
|
||||
args: args{
|
||||
constructor: NewTextUnequalInsensitive[string],
|
||||
t: "text",
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test <> $1",
|
||||
args: []any{"text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextStartsWith",
|
||||
args: args{
|
||||
constructor: NewTextStartsWith[string],
|
||||
t: "text",
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test LIKE $1",
|
||||
args: []any{"text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextStartsWithInsensitive",
|
||||
args: args{
|
||||
constructor: NewTextStartsWithInsensitive[string],
|
||||
t: "text",
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "LOWER(test) LIKE $1",
|
||||
args: []any{"text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextEndsWith",
|
||||
args: args{
|
||||
constructor: NewTextEndsWith[string],
|
||||
t: "text",
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test LIKE $1",
|
||||
args: []any{"text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextEndsWithInsensitive",
|
||||
args: args{
|
||||
constructor: NewTextEndsWithInsensitive[string],
|
||||
t: "text",
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "LOWER(test) LIKE $1",
|
||||
args: []any{"text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextContains",
|
||||
args: args{
|
||||
constructor: NewTextContains[string],
|
||||
t: "text",
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "test LIKE $1",
|
||||
args: []any{"text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NewTextContainsInsensitive",
|
||||
args: args{
|
||||
constructor: NewTextContainsInsensitive[string],
|
||||
t: "text",
|
||||
},
|
||||
want: wantQuery{
|
||||
query: "LOWER(test) LIKE $1",
|
||||
args: []any{"text"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
var stmt Statement
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.args.constructor(tt.args.t).Write(&stmt, "test")
|
||||
assertQuery(t, &stmt, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user