refactor(v2): init database package (#7802)

This commit is contained in:
Silvan 2024-04-25 08:45:34 +02:00 committed by GitHub
parent 207b20ff0f
commit 5131328291
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 2200 additions and 0 deletions

View File

@ -0,0 +1,33 @@
package database
type Condition interface {
Write(stmt *Statement, columnName string)
}
type Filter[C compare, V value] struct {
comp C
value V
}
func (f Filter[C, V]) Write(stmt *Statement, columnName string) {
prepareWrite(stmt, columnName, f.comp)
stmt.WriteArg(f.value)
}
func prepareWrite[C compare](stmt *Statement, columnName string, comp C) {
stmt.WriteString(columnName)
stmt.WriteRune(' ')
stmt.WriteString(comp.String())
stmt.WriteRune(' ')
}
type compare interface {
numberCompare | textCompare | listCompare
String() string
}
type value interface {
number | text
// TODO: condition must know if it's args are named parameters or not
// number | text | placeholder
}

View File

@ -0,0 +1,57 @@
package database
import "github.com/zitadel/logging"
type ListFilter[V value] struct {
comp listCompare
list []V
}
func NewListEquals[V value](list ...V) *ListFilter[V] {
return newListFilter[V](listEqual, list)
}
func NewListContains[V value](list ...V) *ListFilter[V] {
return newListFilter[V](listContain, list)
}
func NewListNotContains[V value](list ...V) *ListFilter[V] {
return newListFilter[V](listNotContain, list)
}
func newListFilter[V value](comp listCompare, list []V) *ListFilter[V] {
return &ListFilter[V]{
comp: comp,
list: list,
}
}
func (f ListFilter[V]) Write(stmt *Statement, columnName string) {
if len(f.list) == 0 {
logging.WithFields("column", columnName).Debug("skip list filter because no entries defined")
return
}
if f.comp == listNotContain {
stmt.WriteString("NOT(")
}
stmt.WriteString(columnName)
stmt.WriteString(" = ")
if f.comp != listEqual {
stmt.WriteString("ANY(")
}
stmt.WriteArg(f.list)
if f.comp != listEqual {
stmt.WriteString(")")
}
if f.comp == listNotContain {
stmt.WriteRune(')')
}
}
type listCompare uint8
const (
listEqual listCompare = iota
listContain
listNotContain
)

View File

@ -0,0 +1,122 @@
package database
import (
"reflect"
"testing"
)
func TestNewListConstructors(t *testing.T) {
type args struct {
constructor func(t ...string) *ListFilter[string]
t []string
}
tests := []struct {
name string
args args
want *ListFilter[string]
}{
{
name: "NewListEquals",
args: args{
constructor: NewListEquals[string],
t: []string{"as", "df"},
},
want: &ListFilter[string]{
comp: listEqual,
list: []string{"as", "df"},
},
},
{
name: "NewListContains",
args: args{
constructor: NewListContains[string],
t: []string{"as", "df"},
},
want: &ListFilter[string]{
comp: listContain,
list: []string{"as", "df"},
},
},
{
name: "NewListNotContains",
args: args{
constructor: NewListNotContains[string],
t: []string{"as", "df"},
},
want: &ListFilter[string]{
comp: listNotContain,
list: []string{"as", "df"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.args.constructor(tt.args.t...); !reflect.DeepEqual(got, tt.want) {
t.Errorf("number constructor = %v, want %v", got, tt.want)
}
})
}
}
func TestNewListConditionWrite(t *testing.T) {
type args struct {
constructor func(t ...string) *ListFilter[string]
t []string
}
tests := []struct {
name string
args args
want wantQuery
}{
{
name: "ListEquals",
args: args{
constructor: NewListEquals[string],
t: []string{"as", "df"},
},
want: wantQuery{
query: "test = $1",
args: []any{[]string{"as", "df"}},
},
},
{
name: "ListContains",
args: args{
constructor: NewListContains[string],
t: []string{"as", "df"},
},
want: wantQuery{
query: "test = ANY($1)",
args: []any{[]string{"as", "df"}},
},
},
{
name: "ListNotContains",
args: args{
constructor: NewListNotContains[string],
t: []string{"as", "df"},
},
want: wantQuery{
query: "NOT(test = ANY($1))",
args: []any{[]string{"as", "df"}},
},
},
{
name: "empty list",
args: args{
constructor: NewListNotContains[string],
},
want: wantQuery{
query: "",
args: nil,
},
},
}
for _, tt := range tests {
var stmt Statement
t.Run(tt.name, func(t *testing.T) {
tt.args.constructor(tt.args.t...).Write(&stmt, "test")
assertQuery(t, &stmt, tt.want)
})
}
}

View File

@ -0,0 +1,139 @@
package mock
import (
"database/sql"
"database/sql/driver"
"reflect"
"testing"
"github.com/DATA-DOG/go-sqlmock"
)
type SQLMock struct {
DB *sql.DB
mock sqlmock.Sqlmock
}
type Expectation func(m sqlmock.Sqlmock)
func NewSQLMock(t *testing.T, expectations ...Expectation) *SQLMock {
db, mock, err := sqlmock.New(
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual),
sqlmock.ValueConverterOption(new(TypeConverter)),
)
if err != nil {
t.Fatal("create mock failed", err)
}
for _, expectation := range expectations {
expectation(mock)
}
return &SQLMock{
DB: db,
mock: mock,
}
}
func (m *SQLMock) Assert(t *testing.T) {
t.Helper()
if err := m.mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations not met: %v", err)
}
m.DB.Close()
}
func ExpectBegin(err error) Expectation {
return func(m sqlmock.Sqlmock) {
e := m.ExpectBegin()
if err != nil {
e.WillReturnError(err)
}
}
}
func ExpectCommit(err error) Expectation {
return func(m sqlmock.Sqlmock) {
e := m.ExpectCommit()
if err != nil {
e.WillReturnError(err)
}
}
}
type ExecOpt func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec
func WithExecArgs(args ...driver.Value) ExecOpt {
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
return e.WithArgs(args...)
}
}
func WithExecErr(err error) ExecOpt {
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
return e.WillReturnError(err)
}
}
func WithExecNoRowsAffected() ExecOpt {
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
return e.WillReturnResult(driver.ResultNoRows)
}
}
func WithExecRowsAffected(affected driver.RowsAffected) ExecOpt {
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
return e.WillReturnResult(affected)
}
}
func ExpectExec(stmt string, opts ...ExecOpt) Expectation {
return func(m sqlmock.Sqlmock) {
e := m.ExpectExec(stmt)
for _, opt := range opts {
e = opt(e)
}
}
}
type QueryOpt func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery
func WithQueryArgs(args ...driver.Value) QueryOpt {
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
return e.WithArgs(args...)
}
}
func WithQueryErr(err error) QueryOpt {
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
return e.WillReturnError(err)
}
}
func WithQueryResult(columns []string, rows [][]driver.Value) QueryOpt {
return func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
mockedRows := m.NewRows(columns)
for _, row := range rows {
mockedRows = mockedRows.AddRow(row...)
}
return e.WillReturnRows(mockedRows)
}
}
func ExpectQuery(stmt string, opts ...QueryOpt) Expectation {
return func(m sqlmock.Sqlmock) {
e := m.ExpectQuery(stmt)
for _, opt := range opts {
e = opt(m, e)
}
}
}
type AnyType[T interface{}] struct{}
// Match satisfies sqlmock.Argument interface
func (a AnyType[T]) Match(v driver.Value) bool {
return reflect.TypeOf(new(T)).Elem().Kind().String() == reflect.TypeOf(v).Kind().String()
}

View File

@ -0,0 +1,78 @@
package mock
import (
"database/sql/driver"
"encoding/hex"
"encoding/json"
"reflect"
"strconv"
"strings"
)
var _ driver.ValueConverter = (*TypeConverter)(nil)
type TypeConverter struct{}
// ConvertValue converts a value to a driver Value.
func (s TypeConverter) ConvertValue(v any) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}
value := reflect.ValueOf(v)
if rawMessage, ok := v.(json.RawMessage); ok {
return convertBytes(rawMessage), nil
}
if value.Kind() == reflect.Slice {
//nolint: exhaustive
// only defined types
switch value.Type().Elem().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return convertSigned(value), nil
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return convertUnsigned(value), nil
case reflect.String:
return convertText(value), nil
}
}
return v, nil
}
// converts a text array to valid pgx v5 representation
func convertSigned(array reflect.Value) string {
slice := make([]string, array.Len())
for i := 0; i < array.Len(); i++ {
slice[i] = strconv.FormatInt(array.Index(i).Int(), 10)
}
return "{" + strings.Join(slice, ",") + "}"
}
// converts a text array to valid pgx v5 representation
func convertUnsigned(array reflect.Value) string {
slice := make([]string, array.Len())
for i := 0; i < array.Len(); i++ {
slice[i] = strconv.FormatUint(array.Index(i).Uint(), 10)
}
return "{" + strings.Join(slice, ",") + "}"
}
// converts a text array to valid pgx v5 representation
func convertText(array reflect.Value) string {
slice := make([]string, array.Len())
for i := 0; i < array.Len(); i++ {
slice[i] = array.Index(i).String()
}
return "{" + strings.Join(slice, ",") + "}"
}
func convertBytes(array []byte) string {
var builder strings.Builder
builder.Grow(hex.EncodedLen(len(array)) + 4)
builder.WriteString(`\x`)
builder.Write(hex.AppendEncode(nil, array))
return builder.String()
}

View File

@ -0,0 +1,100 @@
package database
import (
"time"
"github.com/zitadel/logging"
"golang.org/x/exp/constraints"
)
type NumberFilter[N number] struct {
Filter[numberCompare, N]
}
func NewNumberEquals[N number](n N) *NumberFilter[N] {
return newNumberFilter(numberEqual, n)
}
func NewNumberAtLeast[N number](n N) *NumberFilter[N] {
return newNumberFilter(numberAtLeast, n)
}
func NewNumberAtMost[N number](n N) *NumberFilter[N] {
return newNumberFilter(numberAtMost, n)
}
func NewNumberGreater[N number](n N) *NumberFilter[N] {
return newNumberFilter(numberGreater, n)
}
func NewNumberLess[N number](n N) *NumberFilter[N] {
return newNumberFilter(numberLess, n)
}
func NewNumberUnequal[N number](n N) *NumberFilter[N] {
return newNumberFilter(numberUnequal, n)
}
func newNumberFilter[N number](comp numberCompare, n N) *NumberFilter[N] {
return &NumberFilter[N]{
Filter: Filter[numberCompare, N]{
comp: comp,
value: n,
},
}
}
// NumberBetweenFilter combines [AtLeast] and [AtMost] comparisons
type NumberBetweenFilter[N number] struct {
min, max N
}
func NewNumberBetween[N number](min, max N) *NumberBetweenFilter[N] {
return &NumberBetweenFilter[N]{
min: min,
max: max,
}
}
func (f NumberBetweenFilter[N]) Write(stmt *Statement, columnName string) {
NewNumberAtLeast[N](f.min).Write(stmt, columnName)
stmt.WriteString(" AND ")
NewNumberAtMost[N](f.max).Write(stmt, columnName)
}
type numberCompare uint8
const (
numberEqual numberCompare = iota
numberAtLeast
numberAtMost
numberGreater
numberLess
numberUnequal
)
func (c numberCompare) String() string {
switch c {
case numberEqual:
return "="
case numberAtLeast:
return ">="
case numberAtMost:
return "<="
case numberGreater:
return ">"
case numberLess:
return "<"
case numberUnequal:
return "<>"
default:
logging.WithFields("compare", c).Panic("comparison type not implemented")
return ""
}
}
type number interface {
constraints.Integer | constraints.Float | time.Time
// TODO: condition must know if it's args are named parameters or not
// constraints.Integer | constraints.Float | time.Time | placeholder
}

View File

@ -0,0 +1,216 @@
package database
import (
"reflect"
"testing"
)
func TestNewNumberConstructors(t *testing.T) {
type args struct {
constructor func(t int8) *NumberFilter[int8]
t int8
}
tests := []struct {
name string
args args
want *NumberFilter[int8]
}{
{
name: "NewNumberEqual",
args: args{
constructor: NewNumberEquals[int8],
t: 10,
},
want: &NumberFilter[int8]{
Filter: Filter[numberCompare, int8]{
comp: numberEqual,
value: 10,
},
},
},
{
name: "NewNumberAtLeast",
args: args{
constructor: NewNumberAtLeast[int8],
t: 10,
},
want: &NumberFilter[int8]{
Filter: Filter[numberCompare, int8]{
comp: numberAtLeast,
value: 10,
},
},
},
{
name: "NewNumberAtMost",
args: args{
constructor: NewNumberAtMost[int8],
t: 10,
},
want: &NumberFilter[int8]{
Filter: Filter[numberCompare, int8]{
comp: numberAtMost,
value: 10,
},
},
},
{
name: "NewNumberGreater",
args: args{
constructor: NewNumberGreater[int8],
t: 10,
},
want: &NumberFilter[int8]{
Filter: Filter[numberCompare, int8]{
comp: numberGreater,
value: 10,
},
},
},
{
name: "NewNumberLess",
args: args{
constructor: NewNumberLess[int8],
t: 10,
},
want: &NumberFilter[int8]{
Filter: Filter[numberCompare, int8]{
comp: numberLess,
value: 10,
},
},
},
{
name: "NewNumberUnequal",
args: args{
constructor: NewNumberUnequal[int8],
t: 10,
},
want: &NumberFilter[int8]{
Filter: Filter[numberCompare, int8]{
comp: numberUnequal,
value: 10,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.args.constructor(tt.args.t); !reflect.DeepEqual(got, tt.want) {
t.Errorf("number constructor = %v, want %v", got, tt.want)
}
})
}
}
func TestNewNumberConditionWrite(t *testing.T) {
type args struct {
constructor func(t int8) *NumberFilter[int8]
t int8
}
tests := []struct {
name string
args args
want wantQuery
}{
{
name: "NewNumberEqual",
args: args{
constructor: NewNumberEquals[int8],
t: 10,
},
want: wantQuery{
query: "test = $1",
args: []any{int8(10)},
},
},
{
name: "NewNumberAtLeast",
args: args{
constructor: NewNumberAtLeast[int8],
t: 10,
},
want: wantQuery{
query: "test >= $1",
args: []any{int8(10)},
},
},
{
name: "NewNumberAtMost",
args: args{
constructor: NewNumberAtMost[int8],
t: 10,
},
want: wantQuery{
query: "test <= $1",
args: []any{int8(10)},
},
},
{
name: "NewNumberGreater",
args: args{
constructor: NewNumberGreater[int8],
t: 10,
},
want: wantQuery{
query: "test > $1",
args: []any{int8(10)},
},
},
{
name: "NewNumberLess",
args: args{
constructor: NewNumberLess[int8],
t: 10,
},
want: wantQuery{
query: "test < $1",
args: []any{int8(10)},
},
},
{
name: "NewNumberUnequal",
args: args{
constructor: NewNumberUnequal[int8],
t: 10,
},
want: wantQuery{
query: "test <> $1",
args: []any{int8(10)},
},
},
}
for _, tt := range tests {
var stmt Statement
t.Run(tt.name, func(t *testing.T) {
tt.args.constructor(tt.args.t).Write(&stmt, "test")
assertQuery(t, &stmt, tt.want)
})
}
}
func TestNumberBetween(t *testing.T) {
filter := NewNumberBetween[int8](10, 20)
if !reflect.DeepEqual(filter, &NumberBetweenFilter[int8]{min: 10, max: 20}) {
t.Errorf("unexpected filter: %v", filter)
}
var stmt Statement
filter.Write(&stmt, "test")
if stmt.String() != "test >= $1 AND test <= $2" {
t.Errorf("unexpected query: got: %q", stmt.String())
}
if len(stmt.Args()) != 2 {
t.Errorf("unexpected length of args: got %d", len(stmt.Args()))
return
}
if !reflect.DeepEqual(int8(10), stmt.Args()[0]) {
t.Errorf("unexpected arg at position 0: want: 10, got: %v", stmt.Args()[0])
}
if !reflect.DeepEqual(int8(20), stmt.Args()[1]) {
t.Errorf("unexpected arg at position 1: want: 20, got: %v", stmt.Args()[1])
}
}

View File

@ -0,0 +1,17 @@
package database
type Pagination struct {
Limit uint32
Offset uint32
}
func (p *Pagination) Write(stmt *Statement) {
if p.Limit > 0 {
stmt.WriteString(" LIMIT ")
stmt.WriteArg(p.Limit)
}
if p.Offset > 0 {
stmt.WriteString(" OFFSET ")
stmt.WriteArg(p.Offset)
}
}

View File

@ -0,0 +1,73 @@
package database
import (
"testing"
)
func TestPagination_Write(t *testing.T) {
type fields struct {
Limit uint32
Offset uint32
}
tests := []struct {
name string
fields fields
want wantQuery
}{
{
name: "no values",
fields: fields{
Limit: 0,
Offset: 0,
},
want: wantQuery{
query: "",
args: []any{},
},
},
{
name: "limit",
fields: fields{
Limit: 10,
Offset: 0,
},
want: wantQuery{
query: " LIMIT $1",
args: []any{uint32(10)},
},
},
{
name: "offset",
fields: fields{
Limit: 0,
Offset: 10,
},
want: wantQuery{
query: " OFFSET $1",
args: []any{uint32(10)},
},
},
{
name: "both",
fields: fields{
Limit: 10,
Offset: 10,
},
want: wantQuery{
query: " LIMIT $1 OFFSET $2",
args: []any{uint32(10), uint32(10)},
},
},
}
for _, tt := range tests {
var stmt Statement
t.Run(tt.name, func(t *testing.T) {
p := &Pagination{
Limit: tt.fields.Limit,
Offset: tt.fields.Offset,
}
p.Write(&stmt)
assertQuery(t, &stmt, tt.want)
})
}
}

View File

@ -0,0 +1,75 @@
package database
import (
"context"
"database/sql"
"github.com/zitadel/logging"
)
type Tx interface {
Commit() error
Rollback() error
}
func CloseTx(tx Tx, err error) error {
if err != nil {
rollbackErr := tx.Rollback()
logging.OnError(rollbackErr).Debug("unable to rollback")
return err
}
return tx.Commit()
}
type DestMapper[R any] func(index int, scan func(dest ...any) error) (*R, error)
type Rows interface {
Close() error
Err() error
Next() bool
Scan(dest ...any) error
}
func MapRows[R any](rows Rows, mapper DestMapper[R]) (result []*R, err error) {
defer func() {
closeErr := rows.Close()
logging.OnError(closeErr).Debug("unable to close rows")
if err == nil && rows.Err() != nil {
result = nil
err = rows.Err()
}
}()
for i := 0; rows.Next(); i++ {
res, err := mapper(i, rows.Scan)
if err != nil {
return nil, err
}
result = append(result, res)
}
return result, nil
}
func MapRowsToObject(rows Rows, mapper func(scan func(dest ...any) error) error) (err error) {
defer func() {
closeErr := rows.Close()
logging.OnError(closeErr).Debug("unable to close rows")
if err == nil && rows.Err() != nil {
err = rows.Err()
}
}()
for rows.Next() {
err = mapper(rows.Scan)
if err != nil {
return err
}
}
return nil
}
type Querier interface {
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
}

View File

@ -0,0 +1,512 @@
package database
import (
"errors"
"reflect"
"testing"
)
func TestCloseTx(t *testing.T) {
type args struct {
tx *testTx
err error
}
tests := []struct {
name string
args args
assertErr func(t *testing.T, err error) bool
}{
{
name: "exec err",
args: args{
tx: &testTx{
rollback: execution{
shouldExecute: true,
},
},
err: errExec,
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("execution error expected, got: %v", err)
}
return is
},
},
{
name: "exec err and rollback err",
args: args{
tx: &testTx{
rollback: execution{
err: true,
shouldExecute: true,
},
},
err: errExec,
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("execution error expected, got: %v", err)
}
return is
},
},
{
name: "commit Err",
args: args{
tx: &testTx{
commit: execution{
err: true,
shouldExecute: true,
},
},
err: nil,
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errCommit)
if !is {
t.Errorf("commit error expected, got: %v", err)
}
return is
},
},
{
name: "no err",
args: args{
tx: &testTx{
commit: execution{
shouldExecute: true,
},
},
err: nil,
},
assertErr: func(t *testing.T, err error) bool {
is := err == nil
if !is {
t.Errorf("no error expected, got: %v", err)
}
return is
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CloseTx(tt.args.tx, tt.args.err)
tt.assertErr(t, err)
tt.args.tx.assert(t)
})
}
}
func TestMapRows(t *testing.T) {
type args struct {
rows *testRows
mapper DestMapper[string]
}
var emptyString string
tests := []struct {
name string
args args
wantResult []*string
assertErr func(t *testing.T, err error) bool
}{
{
name: "no rows, close err",
args: args{
rows: &testRows{
closeErr: true,
},
mapper: nil,
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errClose)
if !is {
t.Errorf("close error expected, got: %v", err)
}
return is
},
},
{
name: "no rows, close err",
args: args{
rows: &testRows{
hasErr: true,
},
mapper: nil,
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errRows)
if !is {
t.Errorf("rows error expected, got: %v", err)
}
return is
},
},
{
name: "scan err",
args: args{
rows: &testRows{
scanErr: true,
nextCount: 1,
},
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
var s string
if err := scan(&s); err != nil {
return nil, err
}
return &s, nil
},
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errScan)
if !is {
t.Errorf("scan error expected, got: %v", err)
}
return is
},
},
{
name: "exec err",
args: args{
rows: &testRows{
nextCount: 1,
},
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
return nil, errExec
},
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("exec error expected, got: %v", err)
}
return is
},
},
{
name: "exec err, close err",
args: args{
rows: &testRows{
closeErr: true,
nextCount: 1,
},
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
return nil, errExec
},
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("exec error expected, got: %v", err)
}
return is
},
},
{
name: "rows err",
args: args{
rows: &testRows{
nextCount: 1,
hasErr: true,
},
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
var s string
if err := scan(&s); err != nil {
return nil, err
}
return &s, nil
},
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errRows)
if !is {
t.Errorf("rows error expected, got: %v", err)
}
return is
},
},
{
name: "no err",
args: args{
rows: &testRows{
nextCount: 1,
},
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
var s string
if err := scan(&s); err != nil {
return nil, err
}
return &s, nil
},
},
wantResult: []*string{&emptyString},
assertErr: func(t *testing.T, err error) bool {
is := err == nil
if !is {
t.Errorf("no error expected, got: %v", err)
}
return is
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotResult, err := MapRows(tt.args.rows, tt.args.mapper)
tt.assertErr(t, err)
if !reflect.DeepEqual(gotResult, tt.wantResult) {
t.Errorf("MapRows() = %v, want %v", gotResult, tt.wantResult)
}
})
}
}
func TestMapRowsToObject(t *testing.T) {
type args struct {
rows *testRows
mapper func(scan func(dest ...any) error) error
}
tests := []struct {
name string
args args
assertErr func(t *testing.T, err error) bool
}{
{
name: "no rows, close err",
args: args{
rows: &testRows{
closeErr: true,
},
mapper: nil,
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errClose)
if !is {
t.Errorf("close error expected, got: %v", err)
}
return is
},
},
{
name: "no rows, close err",
args: args{
rows: &testRows{
hasErr: true,
},
mapper: nil,
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errRows)
if !is {
t.Errorf("rows error expected, got: %v", err)
}
return is
},
},
{
name: "scan err",
args: args{
rows: &testRows{
scanErr: true,
nextCount: 1,
},
mapper: func(scan func(dest ...any) error) error {
var s string
if err := scan(&s); err != nil {
return err
}
return nil
},
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errScan)
if !is {
t.Errorf("scan error expected, got: %v", err)
}
return is
},
},
{
name: "exec err",
args: args{
rows: &testRows{
nextCount: 1,
},
mapper: func(scan func(dest ...any) error) error {
return errExec
},
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("exec error expected, got: %v", err)
}
return is
},
},
{
name: "exec err, close err",
args: args{
rows: &testRows{
closeErr: true,
nextCount: 1,
},
mapper: func(scan func(dest ...any) error) error {
return errExec
},
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("exec error expected, got: %v", err)
}
return is
},
},
{
name: "rows err",
args: args{
rows: &testRows{
nextCount: 1,
hasErr: true,
},
mapper: func(scan func(dest ...any) error) error {
var s string
return scan(&s)
},
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errRows)
if !is {
t.Errorf("rows error expected, got: %v", err)
}
return is
},
},
{
name: "no err",
args: args{
rows: &testRows{
nextCount: 1,
},
mapper: func(scan func(dest ...any) error) error {
var s string
return scan(&s)
},
},
assertErr: func(t *testing.T, err error) bool {
is := err == nil
if !is {
t.Errorf("no error expected, got: %v", err)
}
return is
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := MapRowsToObject(tt.args.rows, tt.args.mapper)
tt.assertErr(t, err)
})
}
}
var _ Tx = (*testTx)(nil)
type testTx struct {
commit, rollback execution
}
type execution struct {
err bool
didExecute bool
shouldExecute bool
}
var (
errCommit = errors.New("commit err")
errRollback = errors.New("rollback err")
errExec = errors.New("exec err")
)
// Commit implements Tx.
func (t *testTx) Commit() error {
t.commit.didExecute = true
if t.commit.err {
return errCommit
}
return nil
}
// Rollback implements Tx.
func (t *testTx) Rollback() error {
t.rollback.didExecute = true
if t.rollback.err {
return errRollback
}
return nil
}
func (tx *testTx) assert(t *testing.T) {
if tx.commit.didExecute != tx.commit.shouldExecute {
t.Errorf("unexpected execution of commit: should %v, did: %v", tx.commit.shouldExecute, tx.commit.didExecute)
}
if tx.rollback.didExecute != tx.rollback.shouldExecute {
t.Errorf("unexpected execution of rollback: should %v, did: %v", tx.rollback.shouldExecute, tx.rollback.didExecute)
}
}
var _ Rows = (*testRows)(nil)
var (
errClose = errors.New("err close")
errRows = errors.New("err rows")
errScan = errors.New("err scan")
)
type testRows struct {
closeErr bool
scanErr bool
hasErr bool
nextCount int
}
// Close implements Rows.
func (t *testRows) Close() error {
if t.closeErr {
return errClose
}
return nil
}
// Err implements Rows.
func (t *testRows) Err() error {
if t.hasErr {
return errRows
}
if t.closeErr {
return errClose
}
return nil
}
// Next implements Rows.
func (t *testRows) Next() bool {
t.nextCount--
return t.nextCount >= 0
}
// Scan implements Rows.
func (t *testRows) Scan(dest ...any) error {
if t.scanErr {
return errScan
}
return nil
}

View File

@ -0,0 +1,222 @@
package database
import (
"fmt"
"slices"
"strconv"
"strings"
"time"
"unsafe"
"github.com/zitadel/logging"
)
type Statement struct {
addr *Statement
builder strings.Builder
args []any
// key is the name of the arg and value is the placeholder
// TODO: condition must know if it's args are named parameters or not
// namedArgs map[placeholder]string
}
func (stmt *Statement) Args() []any {
if stmt == nil {
return nil
}
return stmt.args
}
func (stmt *Statement) Reset() {
stmt.builder.Reset()
stmt.addr = nil
stmt.args = nil
}
// TODO: condition must know if it's args are named parameters or not
// SetNamedArg sets the arg and makes it available for query construction
// func (stmt *Statement) SetNamedArg(name placeholder, value any) (placeholder string) {
// stmt.copyCheck()
// stmt.args = append(stmt.args, value)
// placeholder = fmt.Sprintf("$%d", len(stmt.args))
// if !strings.HasPrefix(name.string, "@") {
// name.string = "@" + name.string
// }
// stmt.namedArgs[name] = placeholder
// return placeholder
// }
// AppendArgs appends the args without writing it to Builder
// if any arg is a [placeholder] it's replaced with the placeholders parameter
func (stmt *Statement) AppendArgs(args ...any) {
stmt.copyCheck()
stmt.args = slices.Grow(stmt.args, len(args))
for _, arg := range args {
stmt.AppendArg(arg)
}
}
// AppendArg appends the arg without writing it to Builder
// if the arg is a [placeholder] it's replaced with the placeholders parameter
func (stmt *Statement) AppendArg(arg any) int {
stmt.copyCheck()
// TODO: condition must know if it's args are named parameters or not
// if namedArg, ok := arg.(sql.NamedArg); ok {
// stmt.SetNamedArg(placeholder{namedArg.Name}, namedArg.Value)
// return
// }
stmt.args = append(stmt.args, arg)
return len(stmt.args)
}
// TODO: condition must know if it's args are named parameters or not
// func Placeholder(name string) placeholder {
// return placeholder{name}
// }
// TODO: condition must know if it's args are named parameters or not
// type placeholder struct {
// string
// }
// WriteArgs appends the args and adds the placeholders comma separated to [stmt.Builder]
// if any arg is a [placeholder] it's replaced with the placeholders parameter
func (stmt *Statement) WriteArgs(args ...any) {
stmt.copyCheck()
stmt.args = slices.Grow(stmt.args, len(args))
for i, arg := range args {
if i > 0 {
stmt.WriteString(", ")
}
stmt.WriteArg(arg)
}
}
// WriteArg appends the arg and adds the placeholder to [stmt.Builder]
// if the arg is a [placeholder] it's replaced with the placeholders parameter
func (stmt *Statement) WriteArg(arg any) {
stmt.copyCheck()
// TODO: condition must know if it's args are named parameters or not
// if namedPlaceholder, ok := arg.(placeholder); ok {
// stmt.writeNamedPlaceholder(namedPlaceholder)
// return
// }
placeholder := stmt.AppendArg(arg)
stmt.WriteString("$")
stmt.WriteString(strconv.Itoa(placeholder))
}
// WriteString extends [strings.Builder.WriteString]
// it replaces named args with the previously provided named args
func (stmt *Statement) WriteString(s string) {
// TODO: condition must know if it's args are named parameters or not
// for name, placeholder := range stmt.namedArgs {
// s = strings.ReplaceAll(s, name.string, placeholder)
// }
stmt.builder.WriteString(s)
}
// WriteRune extends [strings.Builder.WriteRune]
func (stmt *Statement) WriteRune(r rune) {
// TODO: condition must know if it's args are named parameters or not
// for name, placeholder := range stmt.namedArgs {
// s = strings.ReplaceAll(s, name.string, placeholder)
// }
stmt.builder.WriteRune(r)
}
// WriteByte extends [strings.Builder.WriteByte]
func (stmt *Statement) WriteByte(b byte) {
// TODO: condition must know if it's args are named parameters or not
// for name, placeholder := range stmt.namedArgs {
// s = strings.ReplaceAll(s, name.string, placeholder)
// }
err := stmt.builder.WriteByte(b)
logging.OnError(err).Warn("unable to write bytes")
}
// Write extends [strings.Builder.Write]
// it replaces named args with the previously provided named args
func (stmt *Statement) Write(b []byte) {
// TODO: condition must know if it's args are named parameters or not
// for name, placeholder := range stmt.namedArgs {
// bytes.ReplaceAll(b, []byte(name.string), []byte(placeholder))
// }
stmt.builder.Write(b)
}
// String builds the query and replaces placeholders starting with "@"
// with the corresponding named arg placeholder
func (stmt *Statement) String() string {
return stmt.builder.String()
}
// Debug builds the statement and replaces the placeholders with the parameters
func (stmt *Statement) Debug() string {
query := stmt.String()
for i := len(stmt.args) - 1; i >= 0; i-- {
var argText string
switch arg := stmt.args[i].(type) {
case time.Time:
argText = "'" + arg.Format("2006-01-02 15:04:05Z07:00") + "'"
case string:
argText = "'" + arg + "'"
case []string:
argText = "ARRAY["
for i, a := range arg {
if i > 0 {
argText += ", "
}
argText += "'" + a + "'"
}
argText += "]"
default:
argText = fmt.Sprint(arg)
}
query = strings.ReplaceAll(query, "$"+strconv.Itoa(i+1), argText)
}
return query
}
// TODO: condition must know if it's args are named parameters or not
// func (stmt *Statement) writeNamedPlaceholder(arg placeholder) {
// placeholder, ok := stmt.namedArgs[arg]
// if !ok {
// logging.WithFields("named_placeholder", arg).Fatal("named placeholder not defined")
// }
// stmt.Builder.WriteString(placeholder)
// }
// copyCheck allows uninitialized usage of stmt
func (stmt *Statement) copyCheck() {
if stmt.addr == nil {
// This hack works around a failing of Go's escape analysis
// that was causing b to escape and be heap allocated.
// See issue 23382.
// TODO: once issue 7921 is fixed, this should be reverted to
// just "stmt.addr = stmt".
stmt.addr = (*Statement)(noescape(unsafe.Pointer(stmt)))
// TODO: condition must know if it's args are named parameters or not
// stmt.namedArgs = make(map[placeholder]string)
} else if stmt.addr != stmt {
panic("statement: illegal use of non-zero Builder copied by value")
}
}
// noescape hides a pointer from escape analysis. It is the identity function
// but escape analysis doesn't think the output depends on the input.
// noescape is inlined and currently compiles down to zero instructions.
// USE CAREFULLY!
// This was copied from the runtime; see issues 23382 and 7921.
//
//go:nosplit
//go:nocheckptr
func noescape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p)
//nolint: staticcheck
return unsafe.Pointer(x ^ 0)
}

View File

@ -0,0 +1,73 @@
package database
import (
"reflect"
"testing"
)
func TestStatement_WriteArgs(t *testing.T) {
type args struct {
args []any
}
tests := []struct {
name string
args args
want wantQuery
}{
{
name: "no args",
args: args{
args: nil,
},
},
{
name: "1 arg",
args: args{
args: []any{"asdf"},
},
want: wantQuery{
query: "$1",
args: []any{"asdf"},
},
},
{
name: "n args",
args: args{
args: []any{"asdf", "jkl", 1},
},
want: wantQuery{
query: "$1, $2, $3",
args: []any{"asdf", "jkl", 1},
},
},
}
for _, tt := range tests {
var stmt Statement
t.Run(tt.name, func(t *testing.T) {
stmt.WriteArgs(tt.args.args...)
assertQuery(t, &stmt, tt.want)
})
}
}
type wantQuery struct {
query string
args []any
}
func assertQuery(t *testing.T, stmt *Statement, want wantQuery) {
if want.query != stmt.String() {
t.Errorf("unexpected query: want: %q got: %q", want.query, stmt.String())
}
if len(want.args) != len(stmt.Args()) {
t.Errorf("unexpected length of args: want %d, got %d", len(want.args), len(stmt.Args()))
return
}
for i, wantArg := range want.args {
if !reflect.DeepEqual(wantArg, stmt.Args()[i]) {
t.Errorf("unexpected arg at position %d: want: %v, got: %v", i, wantArg, stmt.Args()[i])
}
}
}

View File

@ -0,0 +1,132 @@
package database
import (
"fmt"
"strings"
"github.com/zitadel/logging"
)
type TextFilter[T text] struct {
Filter[textCompare, T]
}
func NewTextEqual[T text](t T) *TextFilter[T] {
return newTextFilter(textEqual, t)
}
func NewTextUnequal[T text](t T) *TextFilter[T] {
return newTextFilter(textUnequal, t)
}
func NewTextEqualInsensitive[T text](t T) *TextFilter[string] {
return newTextFilter(textEqualInsensitive, strings.ToLower(string(t)))
}
func NewTextUnequalInsensitive[T text](t T) *TextFilter[string] {
return newTextFilter(textUnequalInsensitive, strings.ToLower(string(t)))
}
func NewTextStartsWith[T text](t T) *TextFilter[T] {
return newTextFilter(textStartsWith, t)
}
func NewTextStartsWithInsensitive[T text](t T) *TextFilter[string] {
return newTextFilter(textStartsWithInsensitive, strings.ToLower(string(t)))
}
func NewTextEndsWith[T text](t T) *TextFilter[T] {
return newTextFilter(textEndsWith, t)
}
func NewTextEndsWithInsensitive[T text](t T) *TextFilter[string] {
return newTextFilter(textEndsWithInsensitive, strings.ToLower(string(t)))
}
func NewTextContains[T text](t T) *TextFilter[T] {
return newTextFilter(textContains, t)
}
func NewTextContainsInsensitive[T text](t T) *TextFilter[string] {
return newTextFilter(textContainsInsensitive, strings.ToLower(string(t)))
}
func newTextFilter[T text](comp textCompare, t T) *TextFilter[T] {
return &TextFilter[T]{
Filter: Filter[textCompare, T]{
comp: comp,
value: t,
},
}
}
func (f *TextFilter[T]) Write(stmt *Statement, columnName string) {
if f.comp.isInsensitive() {
f.writeCaseInsensitive(stmt, columnName)
return
}
f.Filter.Write(stmt, columnName)
}
func (f *TextFilter[T]) writeCaseInsensitive(stmt *Statement, columnName string) {
stmt.WriteString("LOWER(")
stmt.WriteString(columnName)
stmt.WriteString(") ")
stmt.WriteString(f.comp.String())
stmt.WriteRune(' ')
f.writeArg(stmt)
}
func (f *TextFilter[T]) writeArg(stmt *Statement) {
// TODO: condition must know if it's args are named parameters or not
// var v any = f.value
// workaround for placeholder
// if placeholder, ok := v.(placeholder); ok {
// stmt.Builder.WriteString(" LOWER(")
// stmt.WriteArg(placeholder)
// stmt.Builder.WriteString(")")
// }
stmt.WriteArg(strings.ToLower(fmt.Sprint(f.value)))
}
type textCompare uint8
const (
textEqual textCompare = iota
textUnequal
textEqualInsensitive
textUnequalInsensitive
textStartsWith
textStartsWithInsensitive
textEndsWith
textEndsWithInsensitive
textContains
textContainsInsensitive
)
func (c textCompare) String() string {
switch c {
case textEqual, textEqualInsensitive:
return "="
case textUnequal, textUnequalInsensitive:
return "<>"
case textStartsWith, textStartsWithInsensitive, textEndsWith, textEndsWithInsensitive, textContains, textContainsInsensitive:
return "LIKE"
default:
logging.WithFields("compare", c).Panic("comparison type not implemented")
return ""
}
}
func (c textCompare) isInsensitive() bool {
return c == textEqualInsensitive ||
c == textStartsWithInsensitive ||
c == textEndsWithInsensitive ||
c == textContainsInsensitive
}
type text interface {
~string
// TODO: condition must know if it's args are named parameters or not
// ~string | placeholder
}

View File

@ -0,0 +1,351 @@
package database
import (
"reflect"
"testing"
)
func TestNewTextEqual(t *testing.T) {
type args struct {
constructor func(t string) *TextFilter[string]
t string
}
tests := []struct {
name string
args args
want *TextFilter[string]
}{
{
name: "NewTextEqual",
args: args{
constructor: NewTextEqual[string],
t: "text",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textEqual,
value: "text",
},
},
},
{
name: "NewTextUnequal",
args: args{
constructor: NewTextUnequal[string],
t: "text",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textUnequal,
value: "text",
},
},
},
{
name: "NewTextEqualInsensitive",
args: args{
constructor: NewTextEqualInsensitive[string],
t: "text",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textEqualInsensitive,
value: "text",
},
},
},
{
name: "NewTextEqualInsensitive check lower",
args: args{
constructor: NewTextEqualInsensitive[string],
t: "tEXt",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textEqualInsensitive,
value: "text",
},
},
},
{
name: "NewTextUnequalInsensitive",
args: args{
constructor: NewTextUnequalInsensitive[string],
t: "text",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textUnequalInsensitive,
value: "text",
},
},
},
{
name: "NewTextUnequalInsensitive check lower",
args: args{
constructor: NewTextUnequalInsensitive[string],
t: "tEXt",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textUnequalInsensitive,
value: "text",
},
},
},
{
name: "NewTextStartsWith",
args: args{
constructor: NewTextStartsWith[string],
t: "text",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textStartsWith,
value: "text",
},
},
},
{
name: "NewTextStartsWithInsensitive",
args: args{
constructor: NewTextStartsWithInsensitive[string],
t: "text",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textStartsWithInsensitive,
value: "text",
},
},
},
{
name: "NewTextStartsWithInsensitive check lower",
args: args{
constructor: NewTextStartsWithInsensitive[string],
t: "tEXt",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textStartsWithInsensitive,
value: "text",
},
},
},
{
name: "NewTextEndsWith",
args: args{
constructor: NewTextEndsWith[string],
t: "text",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textEndsWith,
value: "text",
},
},
},
{
name: "NewTextEndsWithInsensitive",
args: args{
constructor: NewTextEndsWithInsensitive[string],
t: "text",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textEndsWithInsensitive,
value: "text",
},
},
},
{
name: "NewTextEndsWithInsensitive check lower",
args: args{
constructor: NewTextEndsWithInsensitive[string],
t: "tEXt",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textEndsWithInsensitive,
value: "text",
},
},
},
{
name: "NewTextContains",
args: args{
constructor: NewTextContains[string],
t: "text",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textContains,
value: "text",
},
},
},
{
name: "NewTextContainsInsensitive",
args: args{
constructor: NewTextContainsInsensitive[string],
t: "text",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textContainsInsensitive,
value: "text",
},
},
},
{
name: "NewTextContainsInsensitive to lower",
args: args{
constructor: NewTextContainsInsensitive[string],
t: "tEXt",
},
want: &TextFilter[string]{
Filter: Filter[textCompare, string]{
comp: textContainsInsensitive,
value: "text",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.args.constructor(tt.args.t); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewTextEqual() = %v, want %v", got, tt.want)
}
})
}
}
func TestTextConditionWrite(t *testing.T) {
type args struct {
constructor func(t string) *TextFilter[string]
t string
}
tests := []struct {
name string
args args
want wantQuery
}{
{
name: "NewTextEqual",
args: args{
constructor: NewTextEqual[string],
t: "text",
},
want: wantQuery{
query: "test = $1",
args: []any{"text"},
},
},
{
name: "NewTextUnequal",
args: args{
constructor: NewTextUnequal[string],
t: "text",
},
want: wantQuery{
query: "test <> $1",
args: []any{"text"},
},
},
{
name: "NewTextEqualInsensitive",
args: args{
constructor: NewTextEqualInsensitive[string],
t: "text",
},
want: wantQuery{
query: "LOWER(test) = $1",
args: []any{"text"},
},
},
{
name: "NewTextUnequalInsensitive",
args: args{
constructor: NewTextUnequalInsensitive[string],
t: "text",
},
want: wantQuery{
query: "test <> $1",
args: []any{"text"},
},
},
{
name: "NewTextStartsWith",
args: args{
constructor: NewTextStartsWith[string],
t: "text",
},
want: wantQuery{
query: "test LIKE $1",
args: []any{"text"},
},
},
{
name: "NewTextStartsWithInsensitive",
args: args{
constructor: NewTextStartsWithInsensitive[string],
t: "text",
},
want: wantQuery{
query: "LOWER(test) LIKE $1",
args: []any{"text"},
},
},
{
name: "NewTextEndsWith",
args: args{
constructor: NewTextEndsWith[string],
t: "text",
},
want: wantQuery{
query: "test LIKE $1",
args: []any{"text"},
},
},
{
name: "NewTextEndsWithInsensitive",
args: args{
constructor: NewTextEndsWithInsensitive[string],
t: "text",
},
want: wantQuery{
query: "LOWER(test) LIKE $1",
args: []any{"text"},
},
},
{
name: "NewTextContains",
args: args{
constructor: NewTextContains[string],
t: "text",
},
want: wantQuery{
query: "test LIKE $1",
args: []any{"text"},
},
},
{
name: "NewTextContainsInsensitive",
args: args{
constructor: NewTextContainsInsensitive[string],
t: "text",
},
want: wantQuery{
query: "LOWER(test) LIKE $1",
args: []any{"text"},
},
},
}
for _, tt := range tests {
var stmt Statement
t.Run(tt.name, func(t *testing.T) {
tt.args.constructor(tt.args.t).Write(&stmt, "test")
assertQuery(t, &stmt, tt.want)
})
}
}