mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-12 02:54:20 +00:00
refactor(v2): init database package (#7802)
This commit is contained in:
parent
207b20ff0f
commit
5131328291
33
internal/v2/database/filter.go
Normal file
33
internal/v2/database/filter.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
type Condition interface {
|
||||||
|
Write(stmt *Statement, columnName string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Filter[C compare, V value] struct {
|
||||||
|
comp C
|
||||||
|
value V
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Filter[C, V]) Write(stmt *Statement, columnName string) {
|
||||||
|
prepareWrite(stmt, columnName, f.comp)
|
||||||
|
stmt.WriteArg(f.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareWrite[C compare](stmt *Statement, columnName string, comp C) {
|
||||||
|
stmt.WriteString(columnName)
|
||||||
|
stmt.WriteRune(' ')
|
||||||
|
stmt.WriteString(comp.String())
|
||||||
|
stmt.WriteRune(' ')
|
||||||
|
}
|
||||||
|
|
||||||
|
type compare interface {
|
||||||
|
numberCompare | textCompare | listCompare
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type value interface {
|
||||||
|
number | text
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// number | text | placeholder
|
||||||
|
}
|
57
internal/v2/database/list_filter.go
Normal file
57
internal/v2/database/list_filter.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import "github.com/zitadel/logging"
|
||||||
|
|
||||||
|
type ListFilter[V value] struct {
|
||||||
|
comp listCompare
|
||||||
|
list []V
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewListEquals[V value](list ...V) *ListFilter[V] {
|
||||||
|
return newListFilter[V](listEqual, list)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewListContains[V value](list ...V) *ListFilter[V] {
|
||||||
|
return newListFilter[V](listContain, list)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewListNotContains[V value](list ...V) *ListFilter[V] {
|
||||||
|
return newListFilter[V](listNotContain, list)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newListFilter[V value](comp listCompare, list []V) *ListFilter[V] {
|
||||||
|
return &ListFilter[V]{
|
||||||
|
comp: comp,
|
||||||
|
list: list,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f ListFilter[V]) Write(stmt *Statement, columnName string) {
|
||||||
|
if len(f.list) == 0 {
|
||||||
|
logging.WithFields("column", columnName).Debug("skip list filter because no entries defined")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f.comp == listNotContain {
|
||||||
|
stmt.WriteString("NOT(")
|
||||||
|
}
|
||||||
|
stmt.WriteString(columnName)
|
||||||
|
stmt.WriteString(" = ")
|
||||||
|
if f.comp != listEqual {
|
||||||
|
stmt.WriteString("ANY(")
|
||||||
|
}
|
||||||
|
stmt.WriteArg(f.list)
|
||||||
|
if f.comp != listEqual {
|
||||||
|
stmt.WriteString(")")
|
||||||
|
}
|
||||||
|
if f.comp == listNotContain {
|
||||||
|
stmt.WriteRune(')')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type listCompare uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
listEqual listCompare = iota
|
||||||
|
listContain
|
||||||
|
listNotContain
|
||||||
|
)
|
122
internal/v2/database/list_filter_test.go
Normal file
122
internal/v2/database/list_filter_test.go
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewListConstructors(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
constructor func(t ...string) *ListFilter[string]
|
||||||
|
t []string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *ListFilter[string]
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "NewListEquals",
|
||||||
|
args: args{
|
||||||
|
constructor: NewListEquals[string],
|
||||||
|
t: []string{"as", "df"},
|
||||||
|
},
|
||||||
|
want: &ListFilter[string]{
|
||||||
|
comp: listEqual,
|
||||||
|
list: []string{"as", "df"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewListContains",
|
||||||
|
args: args{
|
||||||
|
constructor: NewListContains[string],
|
||||||
|
t: []string{"as", "df"},
|
||||||
|
},
|
||||||
|
want: &ListFilter[string]{
|
||||||
|
comp: listContain,
|
||||||
|
list: []string{"as", "df"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewListNotContains",
|
||||||
|
args: args{
|
||||||
|
constructor: NewListNotContains[string],
|
||||||
|
t: []string{"as", "df"},
|
||||||
|
},
|
||||||
|
want: &ListFilter[string]{
|
||||||
|
comp: listNotContain,
|
||||||
|
list: []string{"as", "df"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.args.constructor(tt.args.t...); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("number constructor = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewListConditionWrite(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
constructor func(t ...string) *ListFilter[string]
|
||||||
|
t []string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want wantQuery
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ListEquals",
|
||||||
|
args: args{
|
||||||
|
constructor: NewListEquals[string],
|
||||||
|
t: []string{"as", "df"},
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test = $1",
|
||||||
|
args: []any{[]string{"as", "df"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ListContains",
|
||||||
|
args: args{
|
||||||
|
constructor: NewListContains[string],
|
||||||
|
t: []string{"as", "df"},
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test = ANY($1)",
|
||||||
|
args: []any{[]string{"as", "df"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ListNotContains",
|
||||||
|
args: args{
|
||||||
|
constructor: NewListNotContains[string],
|
||||||
|
t: []string{"as", "df"},
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "NOT(test = ANY($1))",
|
||||||
|
args: []any{[]string{"as", "df"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty list",
|
||||||
|
args: args{
|
||||||
|
constructor: NewListNotContains[string],
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
var stmt Statement
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.args.constructor(tt.args.t...).Write(&stmt, "test")
|
||||||
|
assertQuery(t, &stmt, tt.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
139
internal/v2/database/mock/sql_mock.go
Normal file
139
internal/v2/database/mock/sql_mock.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SQLMock struct {
|
||||||
|
DB *sql.DB
|
||||||
|
mock sqlmock.Sqlmock
|
||||||
|
}
|
||||||
|
|
||||||
|
type Expectation func(m sqlmock.Sqlmock)
|
||||||
|
|
||||||
|
func NewSQLMock(t *testing.T, expectations ...Expectation) *SQLMock {
|
||||||
|
db, mock, err := sqlmock.New(
|
||||||
|
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual),
|
||||||
|
sqlmock.ValueConverterOption(new(TypeConverter)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("create mock failed", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expectation := range expectations {
|
||||||
|
expectation(mock)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SQLMock{
|
||||||
|
DB: db,
|
||||||
|
mock: mock,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SQLMock) Assert(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if err := m.mock.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("expectations not met: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.DB.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpectBegin(err error) Expectation {
|
||||||
|
return func(m sqlmock.Sqlmock) {
|
||||||
|
e := m.ExpectBegin()
|
||||||
|
if err != nil {
|
||||||
|
e.WillReturnError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpectCommit(err error) Expectation {
|
||||||
|
return func(m sqlmock.Sqlmock) {
|
||||||
|
e := m.ExpectCommit()
|
||||||
|
if err != nil {
|
||||||
|
e.WillReturnError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExecOpt func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec
|
||||||
|
|
||||||
|
func WithExecArgs(args ...driver.Value) ExecOpt {
|
||||||
|
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||||
|
return e.WithArgs(args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithExecErr(err error) ExecOpt {
|
||||||
|
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||||
|
return e.WillReturnError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithExecNoRowsAffected() ExecOpt {
|
||||||
|
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||||
|
return e.WillReturnResult(driver.ResultNoRows)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithExecRowsAffected(affected driver.RowsAffected) ExecOpt {
|
||||||
|
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||||
|
return e.WillReturnResult(affected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpectExec(stmt string, opts ...ExecOpt) Expectation {
|
||||||
|
return func(m sqlmock.Sqlmock) {
|
||||||
|
e := m.ExpectExec(stmt)
|
||||||
|
for _, opt := range opts {
|
||||||
|
e = opt(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryOpt func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery
|
||||||
|
|
||||||
|
func WithQueryArgs(args ...driver.Value) QueryOpt {
|
||||||
|
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||||
|
return e.WithArgs(args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithQueryErr(err error) QueryOpt {
|
||||||
|
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||||
|
return e.WillReturnError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithQueryResult(columns []string, rows [][]driver.Value) QueryOpt {
|
||||||
|
return func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||||
|
mockedRows := m.NewRows(columns)
|
||||||
|
for _, row := range rows {
|
||||||
|
mockedRows = mockedRows.AddRow(row...)
|
||||||
|
}
|
||||||
|
return e.WillReturnRows(mockedRows)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpectQuery(stmt string, opts ...QueryOpt) Expectation {
|
||||||
|
return func(m sqlmock.Sqlmock) {
|
||||||
|
e := m.ExpectQuery(stmt)
|
||||||
|
for _, opt := range opts {
|
||||||
|
e = opt(m, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AnyType[T interface{}] struct{}
|
||||||
|
|
||||||
|
// Match satisfies sqlmock.Argument interface
|
||||||
|
func (a AnyType[T]) Match(v driver.Value) bool {
|
||||||
|
return reflect.TypeOf(new(T)).Elem().Kind().String() == reflect.TypeOf(v).Kind().String()
|
||||||
|
}
|
78
internal/v2/database/mock/type_converter.go
Normal file
78
internal/v2/database/mock/type_converter.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ driver.ValueConverter = (*TypeConverter)(nil)
|
||||||
|
|
||||||
|
type TypeConverter struct{}
|
||||||
|
|
||||||
|
// ConvertValue converts a value to a driver Value.
|
||||||
|
func (s TypeConverter) ConvertValue(v any) (driver.Value, error) {
|
||||||
|
if driver.IsValue(v) {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
value := reflect.ValueOf(v)
|
||||||
|
|
||||||
|
if rawMessage, ok := v.(json.RawMessage); ok {
|
||||||
|
return convertBytes(rawMessage), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if value.Kind() == reflect.Slice {
|
||||||
|
//nolint: exhaustive
|
||||||
|
// only defined types
|
||||||
|
switch value.Type().Elem().Kind() {
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||||
|
return convertSigned(value), nil
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||||
|
return convertUnsigned(value), nil
|
||||||
|
case reflect.String:
|
||||||
|
return convertText(value), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// converts a text array to valid pgx v5 representation
|
||||||
|
func convertSigned(array reflect.Value) string {
|
||||||
|
slice := make([]string, array.Len())
|
||||||
|
for i := 0; i < array.Len(); i++ {
|
||||||
|
slice[i] = strconv.FormatInt(array.Index(i).Int(), 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
return "{" + strings.Join(slice, ",") + "}"
|
||||||
|
}
|
||||||
|
|
||||||
|
// converts a text array to valid pgx v5 representation
|
||||||
|
func convertUnsigned(array reflect.Value) string {
|
||||||
|
slice := make([]string, array.Len())
|
||||||
|
for i := 0; i < array.Len(); i++ {
|
||||||
|
slice[i] = strconv.FormatUint(array.Index(i).Uint(), 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
return "{" + strings.Join(slice, ",") + "}"
|
||||||
|
}
|
||||||
|
|
||||||
|
// converts a text array to valid pgx v5 representation
|
||||||
|
func convertText(array reflect.Value) string {
|
||||||
|
slice := make([]string, array.Len())
|
||||||
|
for i := 0; i < array.Len(); i++ {
|
||||||
|
slice[i] = array.Index(i).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return "{" + strings.Join(slice, ",") + "}"
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertBytes(array []byte) string {
|
||||||
|
var builder strings.Builder
|
||||||
|
builder.Grow(hex.EncodedLen(len(array)) + 4)
|
||||||
|
builder.WriteString(`\x`)
|
||||||
|
builder.Write(hex.AppendEncode(nil, array))
|
||||||
|
return builder.String()
|
||||||
|
}
|
100
internal/v2/database/number_filter.go
Normal file
100
internal/v2/database/number_filter.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zitadel/logging"
|
||||||
|
"golang.org/x/exp/constraints"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NumberFilter[N number] struct {
|
||||||
|
Filter[numberCompare, N]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNumberEquals[N number](n N) *NumberFilter[N] {
|
||||||
|
return newNumberFilter(numberEqual, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNumberAtLeast[N number](n N) *NumberFilter[N] {
|
||||||
|
return newNumberFilter(numberAtLeast, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNumberAtMost[N number](n N) *NumberFilter[N] {
|
||||||
|
return newNumberFilter(numberAtMost, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNumberGreater[N number](n N) *NumberFilter[N] {
|
||||||
|
return newNumberFilter(numberGreater, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNumberLess[N number](n N) *NumberFilter[N] {
|
||||||
|
return newNumberFilter(numberLess, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNumberUnequal[N number](n N) *NumberFilter[N] {
|
||||||
|
return newNumberFilter(numberUnequal, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNumberFilter[N number](comp numberCompare, n N) *NumberFilter[N] {
|
||||||
|
return &NumberFilter[N]{
|
||||||
|
Filter: Filter[numberCompare, N]{
|
||||||
|
comp: comp,
|
||||||
|
value: n,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumberBetweenFilter combines [AtLeast] and [AtMost] comparisons
|
||||||
|
type NumberBetweenFilter[N number] struct {
|
||||||
|
min, max N
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNumberBetween[N number](min, max N) *NumberBetweenFilter[N] {
|
||||||
|
return &NumberBetweenFilter[N]{
|
||||||
|
min: min,
|
||||||
|
max: max,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f NumberBetweenFilter[N]) Write(stmt *Statement, columnName string) {
|
||||||
|
NewNumberAtLeast[N](f.min).Write(stmt, columnName)
|
||||||
|
stmt.WriteString(" AND ")
|
||||||
|
NewNumberAtMost[N](f.max).Write(stmt, columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
type numberCompare uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
numberEqual numberCompare = iota
|
||||||
|
numberAtLeast
|
||||||
|
numberAtMost
|
||||||
|
numberGreater
|
||||||
|
numberLess
|
||||||
|
numberUnequal
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c numberCompare) String() string {
|
||||||
|
switch c {
|
||||||
|
case numberEqual:
|
||||||
|
return "="
|
||||||
|
case numberAtLeast:
|
||||||
|
return ">="
|
||||||
|
case numberAtMost:
|
||||||
|
return "<="
|
||||||
|
case numberGreater:
|
||||||
|
return ">"
|
||||||
|
case numberLess:
|
||||||
|
return "<"
|
||||||
|
case numberUnequal:
|
||||||
|
return "<>"
|
||||||
|
default:
|
||||||
|
logging.WithFields("compare", c).Panic("comparison type not implemented")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type number interface {
|
||||||
|
constraints.Integer | constraints.Float | time.Time
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// constraints.Integer | constraints.Float | time.Time | placeholder
|
||||||
|
}
|
216
internal/v2/database/number_filter_test.go
Normal file
216
internal/v2/database/number_filter_test.go
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewNumberConstructors(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
constructor func(t int8) *NumberFilter[int8]
|
||||||
|
t int8
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *NumberFilter[int8]
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "NewNumberEqual",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberEquals[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: &NumberFilter[int8]{
|
||||||
|
Filter: Filter[numberCompare, int8]{
|
||||||
|
comp: numberEqual,
|
||||||
|
value: 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewNumberAtLeast",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberAtLeast[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: &NumberFilter[int8]{
|
||||||
|
Filter: Filter[numberCompare, int8]{
|
||||||
|
comp: numberAtLeast,
|
||||||
|
value: 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewNumberAtMost",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberAtMost[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: &NumberFilter[int8]{
|
||||||
|
Filter: Filter[numberCompare, int8]{
|
||||||
|
comp: numberAtMost,
|
||||||
|
value: 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewNumberGreater",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberGreater[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: &NumberFilter[int8]{
|
||||||
|
Filter: Filter[numberCompare, int8]{
|
||||||
|
comp: numberGreater,
|
||||||
|
value: 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewNumberLess",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberLess[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: &NumberFilter[int8]{
|
||||||
|
Filter: Filter[numberCompare, int8]{
|
||||||
|
comp: numberLess,
|
||||||
|
value: 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewNumberUnequal",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberUnequal[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: &NumberFilter[int8]{
|
||||||
|
Filter: Filter[numberCompare, int8]{
|
||||||
|
comp: numberUnequal,
|
||||||
|
value: 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.args.constructor(tt.args.t); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("number constructor = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewNumberConditionWrite(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
constructor func(t int8) *NumberFilter[int8]
|
||||||
|
t int8
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want wantQuery
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "NewNumberEqual",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberEquals[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test = $1",
|
||||||
|
args: []any{int8(10)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewNumberAtLeast",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberAtLeast[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test >= $1",
|
||||||
|
args: []any{int8(10)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewNumberAtMost",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberAtMost[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test <= $1",
|
||||||
|
args: []any{int8(10)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewNumberGreater",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberGreater[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test > $1",
|
||||||
|
args: []any{int8(10)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewNumberLess",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberLess[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test < $1",
|
||||||
|
args: []any{int8(10)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewNumberUnequal",
|
||||||
|
args: args{
|
||||||
|
constructor: NewNumberUnequal[int8],
|
||||||
|
t: 10,
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test <> $1",
|
||||||
|
args: []any{int8(10)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
var stmt Statement
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.args.constructor(tt.args.t).Write(&stmt, "test")
|
||||||
|
assertQuery(t, &stmt, tt.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNumberBetween(t *testing.T) {
|
||||||
|
filter := NewNumberBetween[int8](10, 20)
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(filter, &NumberBetweenFilter[int8]{min: 10, max: 20}) {
|
||||||
|
t.Errorf("unexpected filter: %v", filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
var stmt Statement
|
||||||
|
filter.Write(&stmt, "test")
|
||||||
|
if stmt.String() != "test >= $1 AND test <= $2" {
|
||||||
|
t.Errorf("unexpected query: got: %q", stmt.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(stmt.Args()) != 2 {
|
||||||
|
t.Errorf("unexpected length of args: got %d", len(stmt.Args()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(int8(10), stmt.Args()[0]) {
|
||||||
|
t.Errorf("unexpected arg at position 0: want: 10, got: %v", stmt.Args()[0])
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(int8(20), stmt.Args()[1]) {
|
||||||
|
t.Errorf("unexpected arg at position 1: want: 20, got: %v", stmt.Args()[1])
|
||||||
|
}
|
||||||
|
}
|
17
internal/v2/database/pagination.go
Normal file
17
internal/v2/database/pagination.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
type Pagination struct {
|
||||||
|
Limit uint32
|
||||||
|
Offset uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pagination) Write(stmt *Statement) {
|
||||||
|
if p.Limit > 0 {
|
||||||
|
stmt.WriteString(" LIMIT ")
|
||||||
|
stmt.WriteArg(p.Limit)
|
||||||
|
}
|
||||||
|
if p.Offset > 0 {
|
||||||
|
stmt.WriteString(" OFFSET ")
|
||||||
|
stmt.WriteArg(p.Offset)
|
||||||
|
}
|
||||||
|
}
|
73
internal/v2/database/pagination_test.go
Normal file
73
internal/v2/database/pagination_test.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPagination_Write(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Limit uint32
|
||||||
|
Offset uint32
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want wantQuery
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no values",
|
||||||
|
fields: fields{
|
||||||
|
Limit: 0,
|
||||||
|
Offset: 0,
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "",
|
||||||
|
args: []any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "limit",
|
||||||
|
fields: fields{
|
||||||
|
Limit: 10,
|
||||||
|
Offset: 0,
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: " LIMIT $1",
|
||||||
|
args: []any{uint32(10)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "offset",
|
||||||
|
fields: fields{
|
||||||
|
Limit: 0,
|
||||||
|
Offset: 10,
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: " OFFSET $1",
|
||||||
|
args: []any{uint32(10)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both",
|
||||||
|
fields: fields{
|
||||||
|
Limit: 10,
|
||||||
|
Offset: 10,
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: " LIMIT $1 OFFSET $2",
|
||||||
|
args: []any{uint32(10), uint32(10)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
var stmt Statement
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Pagination{
|
||||||
|
Limit: tt.fields.Limit,
|
||||||
|
Offset: tt.fields.Offset,
|
||||||
|
}
|
||||||
|
p.Write(&stmt)
|
||||||
|
assertQuery(t, &stmt, tt.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
75
internal/v2/database/sql_helper.go
Normal file
75
internal/v2/database/sql_helper.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/zitadel/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tx interface {
|
||||||
|
Commit() error
|
||||||
|
Rollback() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func CloseTx(tx Tx, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
rollbackErr := tx.Rollback()
|
||||||
|
logging.OnError(rollbackErr).Debug("unable to rollback")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
type DestMapper[R any] func(index int, scan func(dest ...any) error) (*R, error)
|
||||||
|
|
||||||
|
type Rows interface {
|
||||||
|
Close() error
|
||||||
|
Err() error
|
||||||
|
Next() bool
|
||||||
|
Scan(dest ...any) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapRows[R any](rows Rows, mapper DestMapper[R]) (result []*R, err error) {
|
||||||
|
defer func() {
|
||||||
|
closeErr := rows.Close()
|
||||||
|
logging.OnError(closeErr).Debug("unable to close rows")
|
||||||
|
|
||||||
|
if err == nil && rows.Err() != nil {
|
||||||
|
result = nil
|
||||||
|
err = rows.Err()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for i := 0; rows.Next(); i++ {
|
||||||
|
res, err := mapper(i, rows.Scan)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func MapRowsToObject(rows Rows, mapper func(scan func(dest ...any) error) error) (err error) {
|
||||||
|
defer func() {
|
||||||
|
closeErr := rows.Close()
|
||||||
|
logging.OnError(closeErr).Debug("unable to close rows")
|
||||||
|
|
||||||
|
if err == nil && rows.Err() != nil {
|
||||||
|
err = rows.Err()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for rows.Next() {
|
||||||
|
err = mapper(rows.Scan)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Querier interface {
|
||||||
|
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
|
||||||
|
}
|
512
internal/v2/database/sql_helper_test.go
Normal file
512
internal/v2/database/sql_helper_test.go
Normal file
@ -0,0 +1,512 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCloseTx(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
tx *testTx
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
assertErr func(t *testing.T, err error) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exec err",
|
||||||
|
args: args{
|
||||||
|
tx: &testTx{
|
||||||
|
rollback: execution{
|
||||||
|
shouldExecute: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errExec,
|
||||||
|
},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errExec)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("execution error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exec err and rollback err",
|
||||||
|
args: args{
|
||||||
|
tx: &testTx{
|
||||||
|
rollback: execution{
|
||||||
|
err: true,
|
||||||
|
shouldExecute: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: errExec,
|
||||||
|
},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errExec)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("execution error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "commit Err",
|
||||||
|
args: args{
|
||||||
|
tx: &testTx{
|
||||||
|
commit: execution{
|
||||||
|
err: true,
|
||||||
|
shouldExecute: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errCommit)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("commit error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no err",
|
||||||
|
args: args{
|
||||||
|
tx: &testTx{
|
||||||
|
commit: execution{
|
||||||
|
shouldExecute: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := err == nil
|
||||||
|
if !is {
|
||||||
|
t.Errorf("no error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := CloseTx(tt.args.tx, tt.args.err)
|
||||||
|
tt.assertErr(t, err)
|
||||||
|
tt.args.tx.assert(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapRows(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
rows *testRows
|
||||||
|
mapper DestMapper[string]
|
||||||
|
}
|
||||||
|
var emptyString string
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantResult []*string
|
||||||
|
assertErr func(t *testing.T, err error) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no rows, close err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
closeErr: true,
|
||||||
|
},
|
||||||
|
mapper: nil,
|
||||||
|
},
|
||||||
|
wantResult: nil,
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errClose)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("close error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no rows, close err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
hasErr: true,
|
||||||
|
},
|
||||||
|
mapper: nil,
|
||||||
|
},
|
||||||
|
wantResult: nil,
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errRows)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("rows error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "scan err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
scanErr: true,
|
||||||
|
nextCount: 1,
|
||||||
|
},
|
||||||
|
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
|
||||||
|
var s string
|
||||||
|
if err := scan(&s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &s, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantResult: nil,
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errScan)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("scan error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exec err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
nextCount: 1,
|
||||||
|
},
|
||||||
|
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
|
||||||
|
return nil, errExec
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantResult: nil,
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errExec)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("exec error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exec err, close err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
closeErr: true,
|
||||||
|
nextCount: 1,
|
||||||
|
},
|
||||||
|
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
|
||||||
|
return nil, errExec
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantResult: nil,
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errExec)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("exec error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rows err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
nextCount: 1,
|
||||||
|
hasErr: true,
|
||||||
|
},
|
||||||
|
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
|
||||||
|
var s string
|
||||||
|
if err := scan(&s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &s, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantResult: nil,
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errRows)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("rows error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
nextCount: 1,
|
||||||
|
},
|
||||||
|
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
|
||||||
|
var s string
|
||||||
|
if err := scan(&s); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &s, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantResult: []*string{&emptyString},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := err == nil
|
||||||
|
if !is {
|
||||||
|
t.Errorf("no error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotResult, err := MapRows(tt.args.rows, tt.args.mapper)
|
||||||
|
tt.assertErr(t, err)
|
||||||
|
if !reflect.DeepEqual(gotResult, tt.wantResult) {
|
||||||
|
t.Errorf("MapRows() = %v, want %v", gotResult, tt.wantResult)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapRowsToObject(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
rows *testRows
|
||||||
|
mapper func(scan func(dest ...any) error) error
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
assertErr func(t *testing.T, err error) bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no rows, close err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
closeErr: true,
|
||||||
|
},
|
||||||
|
mapper: nil,
|
||||||
|
},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errClose)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("close error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no rows, close err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
hasErr: true,
|
||||||
|
},
|
||||||
|
mapper: nil,
|
||||||
|
},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errRows)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("rows error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "scan err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
scanErr: true,
|
||||||
|
nextCount: 1,
|
||||||
|
},
|
||||||
|
mapper: func(scan func(dest ...any) error) error {
|
||||||
|
var s string
|
||||||
|
if err := scan(&s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errScan)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("scan error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exec err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
nextCount: 1,
|
||||||
|
},
|
||||||
|
mapper: func(scan func(dest ...any) error) error {
|
||||||
|
return errExec
|
||||||
|
},
|
||||||
|
},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errExec)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("exec error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exec err, close err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
closeErr: true,
|
||||||
|
nextCount: 1,
|
||||||
|
},
|
||||||
|
mapper: func(scan func(dest ...any) error) error {
|
||||||
|
return errExec
|
||||||
|
},
|
||||||
|
},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errExec)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("exec error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rows err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
nextCount: 1,
|
||||||
|
hasErr: true,
|
||||||
|
},
|
||||||
|
mapper: func(scan func(dest ...any) error) error {
|
||||||
|
var s string
|
||||||
|
return scan(&s)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := errors.Is(err, errRows)
|
||||||
|
if !is {
|
||||||
|
t.Errorf("rows error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no err",
|
||||||
|
args: args{
|
||||||
|
rows: &testRows{
|
||||||
|
nextCount: 1,
|
||||||
|
},
|
||||||
|
mapper: func(scan func(dest ...any) error) error {
|
||||||
|
var s string
|
||||||
|
return scan(&s)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
assertErr: func(t *testing.T, err error) bool {
|
||||||
|
is := err == nil
|
||||||
|
if !is {
|
||||||
|
t.Errorf("no error expected, got: %v", err)
|
||||||
|
}
|
||||||
|
return is
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := MapRowsToObject(tt.args.rows, tt.args.mapper)
|
||||||
|
tt.assertErr(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Tx = (*testTx)(nil)
|
||||||
|
|
||||||
|
type testTx struct {
|
||||||
|
commit, rollback execution
|
||||||
|
}
|
||||||
|
|
||||||
|
type execution struct {
|
||||||
|
err bool
|
||||||
|
didExecute bool
|
||||||
|
shouldExecute bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errCommit = errors.New("commit err")
|
||||||
|
errRollback = errors.New("rollback err")
|
||||||
|
errExec = errors.New("exec err")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Commit implements Tx.
|
||||||
|
func (t *testTx) Commit() error {
|
||||||
|
t.commit.didExecute = true
|
||||||
|
if t.commit.err {
|
||||||
|
return errCommit
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rollback implements Tx.
|
||||||
|
func (t *testTx) Rollback() error {
|
||||||
|
t.rollback.didExecute = true
|
||||||
|
if t.rollback.err {
|
||||||
|
return errRollback
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *testTx) assert(t *testing.T) {
|
||||||
|
if tx.commit.didExecute != tx.commit.shouldExecute {
|
||||||
|
t.Errorf("unexpected execution of commit: should %v, did: %v", tx.commit.shouldExecute, tx.commit.didExecute)
|
||||||
|
}
|
||||||
|
if tx.rollback.didExecute != tx.rollback.shouldExecute {
|
||||||
|
t.Errorf("unexpected execution of rollback: should %v, did: %v", tx.rollback.shouldExecute, tx.rollback.didExecute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Rows = (*testRows)(nil)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errClose = errors.New("err close")
|
||||||
|
errRows = errors.New("err rows")
|
||||||
|
errScan = errors.New("err scan")
|
||||||
|
)
|
||||||
|
|
||||||
|
type testRows struct {
|
||||||
|
closeErr bool
|
||||||
|
scanErr bool
|
||||||
|
hasErr bool
|
||||||
|
nextCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements Rows.
|
||||||
|
func (t *testRows) Close() error {
|
||||||
|
if t.closeErr {
|
||||||
|
return errClose
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err implements Rows.
|
||||||
|
func (t *testRows) Err() error {
|
||||||
|
if t.hasErr {
|
||||||
|
return errRows
|
||||||
|
}
|
||||||
|
if t.closeErr {
|
||||||
|
return errClose
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next implements Rows.
|
||||||
|
func (t *testRows) Next() bool {
|
||||||
|
t.nextCount--
|
||||||
|
return t.nextCount >= 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements Rows.
|
||||||
|
func (t *testRows) Scan(dest ...any) error {
|
||||||
|
if t.scanErr {
|
||||||
|
return errScan
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
222
internal/v2/database/statement.go
Normal file
222
internal/v2/database/statement.go
Normal file
@ -0,0 +1,222 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/zitadel/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Statement struct {
|
||||||
|
addr *Statement
|
||||||
|
builder strings.Builder
|
||||||
|
|
||||||
|
args []any
|
||||||
|
// key is the name of the arg and value is the placeholder
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// namedArgs map[placeholder]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *Statement) Args() []any {
|
||||||
|
if stmt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return stmt.args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stmt *Statement) Reset() {
|
||||||
|
stmt.builder.Reset()
|
||||||
|
stmt.addr = nil
|
||||||
|
stmt.args = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// SetNamedArg sets the arg and makes it available for query construction
|
||||||
|
// func (stmt *Statement) SetNamedArg(name placeholder, value any) (placeholder string) {
|
||||||
|
// stmt.copyCheck()
|
||||||
|
// stmt.args = append(stmt.args, value)
|
||||||
|
// placeholder = fmt.Sprintf("$%d", len(stmt.args))
|
||||||
|
// if !strings.HasPrefix(name.string, "@") {
|
||||||
|
// name.string = "@" + name.string
|
||||||
|
// }
|
||||||
|
// stmt.namedArgs[name] = placeholder
|
||||||
|
// return placeholder
|
||||||
|
// }
|
||||||
|
|
||||||
|
// AppendArgs appends the args without writing it to Builder
|
||||||
|
// if any arg is a [placeholder] it's replaced with the placeholders parameter
|
||||||
|
func (stmt *Statement) AppendArgs(args ...any) {
|
||||||
|
stmt.copyCheck()
|
||||||
|
stmt.args = slices.Grow(stmt.args, len(args))
|
||||||
|
for _, arg := range args {
|
||||||
|
stmt.AppendArg(arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendArg appends the arg without writing it to Builder
|
||||||
|
// if the arg is a [placeholder] it's replaced with the placeholders parameter
|
||||||
|
func (stmt *Statement) AppendArg(arg any) int {
|
||||||
|
stmt.copyCheck()
|
||||||
|
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// if namedArg, ok := arg.(sql.NamedArg); ok {
|
||||||
|
// stmt.SetNamedArg(placeholder{namedArg.Name}, namedArg.Value)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
stmt.args = append(stmt.args, arg)
|
||||||
|
return len(stmt.args)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// func Placeholder(name string) placeholder {
|
||||||
|
// return placeholder{name}
|
||||||
|
// }
|
||||||
|
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// type placeholder struct {
|
||||||
|
// string
|
||||||
|
// }
|
||||||
|
|
||||||
|
// WriteArgs appends the args and adds the placeholders comma separated to [stmt.Builder]
|
||||||
|
// if any arg is a [placeholder] it's replaced with the placeholders parameter
|
||||||
|
func (stmt *Statement) WriteArgs(args ...any) {
|
||||||
|
stmt.copyCheck()
|
||||||
|
stmt.args = slices.Grow(stmt.args, len(args))
|
||||||
|
for i, arg := range args {
|
||||||
|
if i > 0 {
|
||||||
|
stmt.WriteString(", ")
|
||||||
|
}
|
||||||
|
stmt.WriteArg(arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteArg appends the arg and adds the placeholder to [stmt.Builder]
|
||||||
|
// if the arg is a [placeholder] it's replaced with the placeholders parameter
|
||||||
|
func (stmt *Statement) WriteArg(arg any) {
|
||||||
|
stmt.copyCheck()
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// if namedPlaceholder, ok := arg.(placeholder); ok {
|
||||||
|
// stmt.writeNamedPlaceholder(namedPlaceholder)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
placeholder := stmt.AppendArg(arg)
|
||||||
|
stmt.WriteString("$")
|
||||||
|
stmt.WriteString(strconv.Itoa(placeholder))
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteString extends [strings.Builder.WriteString]
|
||||||
|
// it replaces named args with the previously provided named args
|
||||||
|
func (stmt *Statement) WriteString(s string) {
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// for name, placeholder := range stmt.namedArgs {
|
||||||
|
// s = strings.ReplaceAll(s, name.string, placeholder)
|
||||||
|
// }
|
||||||
|
stmt.builder.WriteString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteRune extends [strings.Builder.WriteRune]
|
||||||
|
func (stmt *Statement) WriteRune(r rune) {
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// for name, placeholder := range stmt.namedArgs {
|
||||||
|
// s = strings.ReplaceAll(s, name.string, placeholder)
|
||||||
|
// }
|
||||||
|
stmt.builder.WriteRune(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteByte extends [strings.Builder.WriteByte]
|
||||||
|
func (stmt *Statement) WriteByte(b byte) {
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// for name, placeholder := range stmt.namedArgs {
|
||||||
|
// s = strings.ReplaceAll(s, name.string, placeholder)
|
||||||
|
// }
|
||||||
|
err := stmt.builder.WriteByte(b)
|
||||||
|
logging.OnError(err).Warn("unable to write bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write extends [strings.Builder.Write]
|
||||||
|
// it replaces named args with the previously provided named args
|
||||||
|
func (stmt *Statement) Write(b []byte) {
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// for name, placeholder := range stmt.namedArgs {
|
||||||
|
// bytes.ReplaceAll(b, []byte(name.string), []byte(placeholder))
|
||||||
|
// }
|
||||||
|
stmt.builder.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// String builds the query and replaces placeholders starting with "@"
|
||||||
|
// with the corresponding named arg placeholder
|
||||||
|
func (stmt *Statement) String() string {
|
||||||
|
return stmt.builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug builds the statement and replaces the placeholders with the parameters
|
||||||
|
func (stmt *Statement) Debug() string {
|
||||||
|
query := stmt.String()
|
||||||
|
|
||||||
|
for i := len(stmt.args) - 1; i >= 0; i-- {
|
||||||
|
var argText string
|
||||||
|
switch arg := stmt.args[i].(type) {
|
||||||
|
case time.Time:
|
||||||
|
argText = "'" + arg.Format("2006-01-02 15:04:05Z07:00") + "'"
|
||||||
|
case string:
|
||||||
|
argText = "'" + arg + "'"
|
||||||
|
case []string:
|
||||||
|
argText = "ARRAY["
|
||||||
|
for i, a := range arg {
|
||||||
|
if i > 0 {
|
||||||
|
argText += ", "
|
||||||
|
}
|
||||||
|
argText += "'" + a + "'"
|
||||||
|
}
|
||||||
|
argText += "]"
|
||||||
|
default:
|
||||||
|
argText = fmt.Sprint(arg)
|
||||||
|
}
|
||||||
|
query = strings.ReplaceAll(query, "$"+strconv.Itoa(i+1), argText)
|
||||||
|
}
|
||||||
|
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// func (stmt *Statement) writeNamedPlaceholder(arg placeholder) {
|
||||||
|
// placeholder, ok := stmt.namedArgs[arg]
|
||||||
|
// if !ok {
|
||||||
|
// logging.WithFields("named_placeholder", arg).Fatal("named placeholder not defined")
|
||||||
|
// }
|
||||||
|
// stmt.Builder.WriteString(placeholder)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// copyCheck allows uninitialized usage of stmt
|
||||||
|
func (stmt *Statement) copyCheck() {
|
||||||
|
if stmt.addr == nil {
|
||||||
|
// This hack works around a failing of Go's escape analysis
|
||||||
|
// that was causing b to escape and be heap allocated.
|
||||||
|
// See issue 23382.
|
||||||
|
// TODO: once issue 7921 is fixed, this should be reverted to
|
||||||
|
// just "stmt.addr = stmt".
|
||||||
|
stmt.addr = (*Statement)(noescape(unsafe.Pointer(stmt)))
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// stmt.namedArgs = make(map[placeholder]string)
|
||||||
|
} else if stmt.addr != stmt {
|
||||||
|
panic("statement: illegal use of non-zero Builder copied by value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// noescape hides a pointer from escape analysis. It is the identity function
|
||||||
|
// but escape analysis doesn't think the output depends on the input.
|
||||||
|
// noescape is inlined and currently compiles down to zero instructions.
|
||||||
|
// USE CAREFULLY!
|
||||||
|
// This was copied from the runtime; see issues 23382 and 7921.
|
||||||
|
//
|
||||||
|
//go:nosplit
|
||||||
|
//go:nocheckptr
|
||||||
|
func noescape(p unsafe.Pointer) unsafe.Pointer {
|
||||||
|
x := uintptr(p)
|
||||||
|
//nolint: staticcheck
|
||||||
|
return unsafe.Pointer(x ^ 0)
|
||||||
|
}
|
73
internal/v2/database/statement_test.go
Normal file
73
internal/v2/database/statement_test.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStatement_WriteArgs(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want wantQuery
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no args",
|
||||||
|
args: args{
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "1 arg",
|
||||||
|
args: args{
|
||||||
|
args: []any{"asdf"},
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "$1",
|
||||||
|
args: []any{"asdf"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "n args",
|
||||||
|
args: args{
|
||||||
|
args: []any{"asdf", "jkl", 1},
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "$1, $2, $3",
|
||||||
|
args: []any{"asdf", "jkl", 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
var stmt Statement
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
stmt.WriteArgs(tt.args.args...)
|
||||||
|
assertQuery(t, &stmt, tt.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type wantQuery struct {
|
||||||
|
query string
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertQuery(t *testing.T, stmt *Statement, want wantQuery) {
|
||||||
|
if want.query != stmt.String() {
|
||||||
|
t.Errorf("unexpected query: want: %q got: %q", want.query, stmt.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(want.args) != len(stmt.Args()) {
|
||||||
|
t.Errorf("unexpected length of args: want %d, got %d", len(want.args), len(stmt.Args()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, wantArg := range want.args {
|
||||||
|
if !reflect.DeepEqual(wantArg, stmt.Args()[i]) {
|
||||||
|
t.Errorf("unexpected arg at position %d: want: %v, got: %v", i, wantArg, stmt.Args()[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
132
internal/v2/database/text_filter.go
Normal file
132
internal/v2/database/text_filter.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/zitadel/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TextFilter[T text] struct {
|
||||||
|
Filter[textCompare, T]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextEqual[T text](t T) *TextFilter[T] {
|
||||||
|
return newTextFilter(textEqual, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextUnequal[T text](t T) *TextFilter[T] {
|
||||||
|
return newTextFilter(textUnequal, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextEqualInsensitive[T text](t T) *TextFilter[string] {
|
||||||
|
return newTextFilter(textEqualInsensitive, strings.ToLower(string(t)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextUnequalInsensitive[T text](t T) *TextFilter[string] {
|
||||||
|
return newTextFilter(textUnequalInsensitive, strings.ToLower(string(t)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextStartsWith[T text](t T) *TextFilter[T] {
|
||||||
|
return newTextFilter(textStartsWith, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextStartsWithInsensitive[T text](t T) *TextFilter[string] {
|
||||||
|
return newTextFilter(textStartsWithInsensitive, strings.ToLower(string(t)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextEndsWith[T text](t T) *TextFilter[T] {
|
||||||
|
return newTextFilter(textEndsWith, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextEndsWithInsensitive[T text](t T) *TextFilter[string] {
|
||||||
|
return newTextFilter(textEndsWithInsensitive, strings.ToLower(string(t)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextContains[T text](t T) *TextFilter[T] {
|
||||||
|
return newTextFilter(textContains, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTextContainsInsensitive[T text](t T) *TextFilter[string] {
|
||||||
|
return newTextFilter(textContainsInsensitive, strings.ToLower(string(t)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTextFilter[T text](comp textCompare, t T) *TextFilter[T] {
|
||||||
|
return &TextFilter[T]{
|
||||||
|
Filter: Filter[textCompare, T]{
|
||||||
|
comp: comp,
|
||||||
|
value: t,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *TextFilter[T]) Write(stmt *Statement, columnName string) {
|
||||||
|
if f.comp.isInsensitive() {
|
||||||
|
f.writeCaseInsensitive(stmt, columnName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.Filter.Write(stmt, columnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *TextFilter[T]) writeCaseInsensitive(stmt *Statement, columnName string) {
|
||||||
|
stmt.WriteString("LOWER(")
|
||||||
|
stmt.WriteString(columnName)
|
||||||
|
stmt.WriteString(") ")
|
||||||
|
stmt.WriteString(f.comp.String())
|
||||||
|
stmt.WriteRune(' ')
|
||||||
|
f.writeArg(stmt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *TextFilter[T]) writeArg(stmt *Statement) {
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// var v any = f.value
|
||||||
|
// workaround for placeholder
|
||||||
|
// if placeholder, ok := v.(placeholder); ok {
|
||||||
|
// stmt.Builder.WriteString(" LOWER(")
|
||||||
|
// stmt.WriteArg(placeholder)
|
||||||
|
// stmt.Builder.WriteString(")")
|
||||||
|
// }
|
||||||
|
stmt.WriteArg(strings.ToLower(fmt.Sprint(f.value)))
|
||||||
|
}
|
||||||
|
|
||||||
|
type textCompare uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
textEqual textCompare = iota
|
||||||
|
textUnequal
|
||||||
|
textEqualInsensitive
|
||||||
|
textUnequalInsensitive
|
||||||
|
textStartsWith
|
||||||
|
textStartsWithInsensitive
|
||||||
|
textEndsWith
|
||||||
|
textEndsWithInsensitive
|
||||||
|
textContains
|
||||||
|
textContainsInsensitive
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c textCompare) String() string {
|
||||||
|
switch c {
|
||||||
|
case textEqual, textEqualInsensitive:
|
||||||
|
return "="
|
||||||
|
case textUnequal, textUnequalInsensitive:
|
||||||
|
return "<>"
|
||||||
|
case textStartsWith, textStartsWithInsensitive, textEndsWith, textEndsWithInsensitive, textContains, textContainsInsensitive:
|
||||||
|
return "LIKE"
|
||||||
|
default:
|
||||||
|
logging.WithFields("compare", c).Panic("comparison type not implemented")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c textCompare) isInsensitive() bool {
|
||||||
|
return c == textEqualInsensitive ||
|
||||||
|
c == textStartsWithInsensitive ||
|
||||||
|
c == textEndsWithInsensitive ||
|
||||||
|
c == textContainsInsensitive
|
||||||
|
}
|
||||||
|
|
||||||
|
type text interface {
|
||||||
|
~string
|
||||||
|
// TODO: condition must know if it's args are named parameters or not
|
||||||
|
// ~string | placeholder
|
||||||
|
}
|
351
internal/v2/database/text_filter_test.go
Normal file
351
internal/v2/database/text_filter_test.go
Normal file
@ -0,0 +1,351 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewTextEqual(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
constructor func(t string) *TextFilter[string]
|
||||||
|
t string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *TextFilter[string]
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "NewTextEqual",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextEqual[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textEqual,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextUnequal",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextUnequal[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textUnequal,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextEqualInsensitive",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextEqualInsensitive[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textEqualInsensitive,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextEqualInsensitive check lower",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextEqualInsensitive[string],
|
||||||
|
t: "tEXt",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textEqualInsensitive,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextUnequalInsensitive",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextUnequalInsensitive[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textUnequalInsensitive,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextUnequalInsensitive check lower",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextUnequalInsensitive[string],
|
||||||
|
t: "tEXt",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textUnequalInsensitive,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextStartsWith",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextStartsWith[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textStartsWith,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextStartsWithInsensitive",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextStartsWithInsensitive[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textStartsWithInsensitive,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextStartsWithInsensitive check lower",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextStartsWithInsensitive[string],
|
||||||
|
t: "tEXt",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textStartsWithInsensitive,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextEndsWith",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextEndsWith[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textEndsWith,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextEndsWithInsensitive",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextEndsWithInsensitive[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textEndsWithInsensitive,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextEndsWithInsensitive check lower",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextEndsWithInsensitive[string],
|
||||||
|
t: "tEXt",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textEndsWithInsensitive,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextContains",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextContains[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textContains,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextContainsInsensitive",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextContainsInsensitive[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textContainsInsensitive,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextContainsInsensitive to lower",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextContainsInsensitive[string],
|
||||||
|
t: "tEXt",
|
||||||
|
},
|
||||||
|
want: &TextFilter[string]{
|
||||||
|
Filter: Filter[textCompare, string]{
|
||||||
|
comp: textContainsInsensitive,
|
||||||
|
value: "text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.args.constructor(tt.args.t); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("NewTextEqual() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTextConditionWrite(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
constructor func(t string) *TextFilter[string]
|
||||||
|
t string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want wantQuery
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "NewTextEqual",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextEqual[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test = $1",
|
||||||
|
args: []any{"text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextUnequal",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextUnequal[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test <> $1",
|
||||||
|
args: []any{"text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextEqualInsensitive",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextEqualInsensitive[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "LOWER(test) = $1",
|
||||||
|
args: []any{"text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextUnequalInsensitive",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextUnequalInsensitive[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test <> $1",
|
||||||
|
args: []any{"text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextStartsWith",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextStartsWith[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test LIKE $1",
|
||||||
|
args: []any{"text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextStartsWithInsensitive",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextStartsWithInsensitive[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "LOWER(test) LIKE $1",
|
||||||
|
args: []any{"text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextEndsWith",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextEndsWith[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test LIKE $1",
|
||||||
|
args: []any{"text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextEndsWithInsensitive",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextEndsWithInsensitive[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "LOWER(test) LIKE $1",
|
||||||
|
args: []any{"text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextContains",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextContains[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "test LIKE $1",
|
||||||
|
args: []any{"text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NewTextContainsInsensitive",
|
||||||
|
args: args{
|
||||||
|
constructor: NewTextContainsInsensitive[string],
|
||||||
|
t: "text",
|
||||||
|
},
|
||||||
|
want: wantQuery{
|
||||||
|
query: "LOWER(test) LIKE $1",
|
||||||
|
args: []any{"text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
var stmt Statement
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.args.constructor(tt.args.t).Write(&stmt, "test")
|
||||||
|
assertQuery(t, &stmt, tt.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user