multiple tries

This commit is contained in:
adlerhurst
2025-04-29 06:03:47 +02:00
parent 77c4cc8185
commit 986c62b61a
131 changed files with 9805 additions and 47 deletions

View File

@@ -0,0 +1,9 @@
package database
import (
"context"
)
type Connector interface {
Connect(ctx context.Context) (Pool, error)
}

View File

@@ -0,0 +1,60 @@
package database
import (
"context"
)
var (
db *database
)
type database struct {
connector Connector
pool Pool
}
type Pool interface {
Beginner
QueryExecutor
Acquire(ctx context.Context) (Client, error)
Close(ctx context.Context) error
}
type Client interface {
Beginner
QueryExecutor
Release(ctx context.Context) error
}
type Querier interface {
Query(ctx context.Context, stmt string, args ...any) (Rows, error)
QueryRow(ctx context.Context, stmt string, args ...any) Row
}
type Executor interface {
Exec(ctx context.Context, stmt string, args ...any) error
}
type QueryExecutor interface {
Querier
Executor
}
type Scanner interface {
Scan(dest ...any) error
}
type Row interface {
Scanner
}
type Rows interface {
Row
Next() bool
Close() error
Err() error
}
type Query[T any] func(querier Querier) (result T, err error)

View File

@@ -0,0 +1,92 @@
package dialect
import (
"context"
"errors"
"reflect"
"github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
"github.com/zitadel/zitadel/backend/storage/database"
"github.com/zitadel/zitadel/backend/storage/database/dialect/postgres"
)
type Hook struct {
Match func(string) bool
Decode func(config any) (database.Connector, error)
Name string
Constructor func() database.Connector
}
var hooks = []Hook{
{
Match: postgres.NameMatcher,
Decode: postgres.DecodeConfig,
Name: postgres.Name,
Constructor: func() database.Connector { return new(postgres.Config) },
},
// {
// Match: gosql.NameMatcher,
// Decode: gosql.DecodeConfig,
// Name: gosql.Name,
// Constructor: func() database.Connector { return new(gosql.Config) },
// },
}
type Config struct {
Dialects map[string]any `mapstructure:",remain" yaml:",inline"`
connector database.Connector
}
func (c Config) Connect(ctx context.Context) (database.Pool, error) {
if len(c.Dialects) != 1 {
return nil, errors.New("Exactly one dialect must be configured")
}
return c.connector.Connect(ctx)
}
// Hooks implements [configure.Unmarshaller].
func (c Config) Hooks() []viper.DecoderConfigOption {
return []viper.DecoderConfigOption{
viper.DecodeHook(decodeHook),
}
}
func decodeHook(from, to reflect.Value) (_ any, err error) {
if to.Type() != reflect.TypeOf(Config{}) {
return from.Interface(), nil
}
config := new(Config)
if err = mapstructure.Decode(from.Interface(), config); err != nil {
return nil, err
}
if err = config.decodeDialect(); err != nil {
return nil, err
}
return config, nil
}
func (c *Config) decodeDialect() error {
for _, hook := range hooks {
for name, config := range c.Dialects {
if !hook.Match(name) {
continue
}
connector, err := hook.Decode(config)
if err != nil {
return err
}
c.connector = connector
return nil
}
}
return errors.New("no dialect found")
}

View File

@@ -0,0 +1,80 @@
package postgres
import (
"context"
"errors"
"slices"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var (
_ database.Connector = (*Config)(nil)
Name = "postgres"
)
type Config struct {
*pgxpool.Config
// Host string
// Port int32
// Database string
// MaxOpenConns uint32
// MaxIdleConns uint32
// MaxConnLifetime time.Duration
// MaxConnIdleTime time.Duration
// User User
// // Additional options to be appended as options=<Options>
// // The value will be taken as is. Multiple options are space separated.
// Options string
configuredFields []string
}
// Connect implements [database.Connector].
func (c *Config) Connect(ctx context.Context) (database.Pool, error) {
pool, err := pgxpool.NewWithConfig(ctx, c.Config)
if err != nil {
return nil, err
}
if err = pool.Ping(ctx); err != nil {
return nil, err
}
return &pgxPool{pool}, nil
}
func NameMatcher(name string) bool {
return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name))
}
func DecodeConfig(input any) (database.Connector, error) {
switch c := input.(type) {
case string:
config, err := pgxpool.ParseConfig(c)
if err != nil {
return nil, err
}
return &Config{Config: config}, nil
case map[string]any:
connector := new(Config)
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
WeaklyTypedInput: true,
Result: connector,
})
if err != nil {
return nil, err
}
if err = decoder.Decode(c); err != nil {
return nil, err
}
return &Config{
Config: &pgxpool.Config{},
}, nil
}
return nil, errors.New("invalid configuration")
}

View File

@@ -0,0 +1,48 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type pgxConn struct{ *pgxpool.Conn }
var _ database.Client = (*pgxConn)(nil)
// Release implements [database.Client].
func (c *pgxConn) Release(_ context.Context) error {
c.Conn.Release()
return nil
}
// Begin implements [database.Client].
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, err
}
return &pgxTx{tx}, nil
}
// Query implements sql.Client.
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Conn.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements sql.Client.
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return c.Conn.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) error {
_, err := c.Conn.Exec(ctx, sql, args...)
return err
}

View File

@@ -0,0 +1,57 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type pgxPool struct{ *pgxpool.Pool }
var _ database.Pool = (*pgxPool)(nil)
// Acquire implements [database.Pool].
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
conn, err := c.Pool.Acquire(ctx)
if err != nil {
return nil, err
}
return &pgxConn{conn}, nil
}
// Query implements [database.Pool].
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Pool.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements [database.Pool].
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return c.Pool.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) error {
_, err := c.Pool.Exec(ctx, sql, args...)
return err
}
// Begin implements [database.Pool].
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, err
}
return &pgxTx{tx}, nil
}
// Close implements [database.Pool].
func (c *pgxPool) Close(_ context.Context) error {
c.Pool.Close()
return nil
}

View File

@@ -0,0 +1,18 @@
package postgres
import (
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ database.Rows = (*Rows)(nil)
type Rows struct{ pgx.Rows }
// Close implements [database.Rows].
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
func (r *Rows) Close() error {
r.Rows.Close()
return nil
}

View File

@@ -0,0 +1,95 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type pgxTx struct{ pgx.Tx }
var _ database.Transaction = (*pgxTx)(nil)
// Commit implements [database.Transaction].
func (tx *pgxTx) Commit(ctx context.Context) error {
return tx.Tx.Commit(ctx)
}
// Rollback implements [database.Transaction].
func (tx *pgxTx) Rollback(ctx context.Context) error {
return tx.Tx.Rollback(ctx)
}
// End implements [database.Transaction].
func (tx *pgxTx) End(ctx context.Context, err error) error {
if err != nil {
tx.Rollback(ctx)
return err
}
return tx.Commit(ctx)
}
// Query implements [database.Transaction].
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := tx.Tx.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements [database.Transaction].
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return tx.Tx.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Transaction].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) error {
_, err := tx.Tx.Exec(ctx, sql, args...)
return err
}
// Begin implements [database.Transaction].
// As postgres does not support nested transactions we use savepoints to emulate them.
func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) {
savepoint, err := tx.Tx.Begin(ctx)
if err != nil {
return nil, err
}
return &pgxTx{savepoint}, nil
}
func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions {
if opts == nil {
return pgx.TxOptions{}
}
return pgx.TxOptions{
IsoLevel: isolationToPgx(opts.IsolationLevel),
AccessMode: accessModeToPgx(opts.AccessMode),
}
}
func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel {
switch isolation {
case database.IsolationLevelSerializable:
return pgx.Serializable
case database.IsolationLevelReadCommitted:
return pgx.ReadCommitted
default:
return pgx.Serializable
}
}
func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode {
switch accessMode {
case database.AccessModeReadWrite:
return pgx.ReadWrite
case database.AccessModeReadOnly:
return pgx.ReadOnly
default:
return pgx.ReadWrite
}
}

View File

@@ -0,0 +1,3 @@
package database
//go:generate mockgen -typed -package mock -destination ./mock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,160 @@
package repository
import (
"fmt"
"github.com/zitadel/zitadel/backend/v3/domain"
)
type field interface {
fmt.Stringer
}
type fieldDescriptor struct {
schema string
table string
name string
}
func (f fieldDescriptor) String() string {
return f.schema + "." + f.table + "." + f.name
}
type ignoreCaseFieldDescriptor struct {
fieldDescriptor
fieldNameSuffix string
}
func (f ignoreCaseFieldDescriptor) String() string {
return f.fieldDescriptor.String() + f.fieldNameSuffix
}
type textFieldDescriptor struct {
field
isIgnoreCase bool
}
type clause[Op domain.Operation] struct {
field field
op Op
}
const (
schema = "zitadel"
userTable = "users"
)
var userFields = map[domain.UserField]field{
domain.UserFieldInstanceID: fieldDescriptor{
schema: schema,
table: userTable,
name: "instance_id",
},
domain.UserFieldOrgID: fieldDescriptor{
schema: schema,
table: userTable,
name: "org_id",
},
domain.UserFieldID: fieldDescriptor{
schema: schema,
table: userTable,
name: "id",
},
domain.UserFieldUsername: textFieldDescriptor{
field: ignoreCaseFieldDescriptor{
fieldDescriptor: fieldDescriptor{
schema: schema,
table: userTable,
name: "username",
},
fieldNameSuffix: "_lower",
},
},
domain.UserHumanFieldEmail: textFieldDescriptor{
field: ignoreCaseFieldDescriptor{
fieldDescriptor: fieldDescriptor{
schema: schema,
table: userTable,
name: "email",
},
fieldNameSuffix: "_lower",
},
},
domain.UserHumanFieldEmailVerified: fieldDescriptor{
schema: schema,
table: userTable,
name: "email_is_verified",
},
}
type textClause[V domain.Text] struct {
clause[domain.TextOperation]
value V
}
var textOp map[domain.TextOperation]string = map[domain.TextOperation]string{
domain.TextOperationEqual: " = ",
domain.TextOperationNotEqual: " <> ",
domain.TextOperationStartsWith: " LIKE ",
domain.TextOperationStartsWithIgnoreCase: " LIKE ",
}
func (tc textClause[V]) Write(stmt *statement) {
placeholder := stmt.appendArg(tc.value)
var (
left, right string
)
switch tc.clause.op {
case domain.TextOperationEqual:
left = tc.clause.field.String()
right = placeholder
case domain.TextOperationNotEqual:
left = tc.clause.field.String()
right = placeholder
case domain.TextOperationStartsWith:
left = tc.clause.field.String()
right = placeholder + "%"
case domain.TextOperationStartsWithIgnoreCase:
left = tc.clause.field.String()
if _, ok := tc.clause.field.(ignoreCaseFieldDescriptor); !ok {
left = "LOWER(" + left + ")"
}
right = "LOWER(" + placeholder + "%)"
}
stmt.builder.WriteString(left)
stmt.builder.WriteString(textOp[tc.clause.op])
stmt.builder.WriteString(right)
}
type boolClause[V domain.Bool] struct {
clause[domain.BoolOperation]
value V
}
func (bc boolClause[V]) Write(stmt *statement) {
if !bc.value {
stmt.builder.WriteString("NOT ")
}
stmt.builder.WriteString(bc.clause.field.String())
}
type numberClause[V domain.Number] struct {
clause[domain.NumberOperation]
value V
}
var numberOp map[domain.NumberOperation]string = map[domain.NumberOperation]string{
domain.NumberOperationEqual: " = ",
domain.NumberOperationNotEqual: " <> ",
domain.NumberOperationLessThan: " < ",
domain.NumberOperationLessThanOrEqual: " <= ",
domain.NumberOperationGreaterThan: " > ",
domain.NumberOperationGreaterThanOrEqual: " >= ",
}
func (nc numberClause[V]) Write(stmt *statement) {
stmt.builder.WriteString(nc.clause.field.String())
stmt.builder.WriteString(numberOp[nc.clause.op])
stmt.builder.WriteString(stmt.appendArg(nc.value))
}

View File

@@ -0,0 +1,45 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/internal/crypto"
)
type cryptoRepo struct {
database.QueryExecutor
}
func Crypto(db database.QueryExecutor) domain.CryptoRepository {
return &cryptoRepo{
QueryExecutor: db,
}
}
const getEncryptionConfigQuery = "SELECT" +
" length" +
", expiry" +
", should_include_lower_letters" +
", should_include_upper_letters" +
", should_include_digits" +
", should_include_symbols" +
" FROM encryption_config"
func (repo *cryptoRepo) GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) {
var config crypto.GeneratorConfig
row := repo.QueryRow(ctx, getEncryptionConfigQuery)
err := row.Scan(
&config.Length,
&config.Expiry,
&config.IncludeLowerLetters,
&config.IncludeUpperLetters,
&config.IncludeDigits,
&config.IncludeSymbols,
)
if err != nil {
return nil, err
}
return &config, nil
}

View File

@@ -0,0 +1,7 @@
// Repository package provides the database repository for the application.
// It contains the implementation of the [repository pattern](https://martinfowler.com/eaaCatalog/repository.html) for the database.
// funcs which need to interact with the database should create interfaces which are implemented by the
// [query] and [exec] structs respectively their factory methods [Query] and [Execute]. The [query] struct is used for read operations, while the [exec] struct is used for write operations.
package repository

View File

@@ -0,0 +1,54 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type instance struct {
database.QueryExecutor
}
func Instance(client database.QueryExecutor) domain.InstanceRepository {
return &instance{QueryExecutor: client}
}
func (i *instance) ByID(ctx context.Context, id string) (*domain.Instance, error) {
var instance domain.Instance
err := i.QueryExecutor.QueryRow(ctx, `SELECT id, name, created_at, updated_at, deleted_at FROM instances WHERE id = $1`, id).Scan(
&instance.ID,
&instance.Name,
&instance.CreatedAt,
&instance.UpdatedAt,
&instance.DeletedAt,
)
if err != nil {
return nil, err
}
return &instance, nil
}
const createInstanceStmt = `INSERT INTO instances (id, name) VALUES ($1, $2) RETURNING created_at, updated_at`
// Create implements [domain.InstanceRepository].
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
return i.QueryExecutor.QueryRow(ctx, createInstanceStmt,
instance.ID,
instance.Name,
).Scan(
&instance.CreatedAt,
&instance.UpdatedAt,
)
}
// On implements [domain.InstanceRepository].
func (i *instance) On(id string) domain.InstanceOperation {
return &instanceOperation{
QueryExecutor: i.QueryExecutor,
id: id,
}
}
var _ domain.InstanceRepository = (*instance)(nil)

View File

@@ -0,0 +1,52 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type instanceOperation struct {
database.QueryExecutor
id string
}
const addInstanceAdminStmt = `INSERT INTO instance_admins (instance_id, user_id, roles) VALUES ($1, $2, $3)`
// AddAdmin implements [domain.InstanceOperation].
func (i *instanceOperation) AddAdmin(ctx context.Context, userID string, roles []string) error {
return i.QueryExecutor.Exec(ctx, addInstanceAdminStmt, i.id, userID, roles)
}
// Delete implements [domain.InstanceOperation].
func (i *instanceOperation) Delete(ctx context.Context) error {
return i.QueryExecutor.Exec(ctx, `DELETE FROM instances WHERE id = $1`, i.id)
}
const removeInstanceAdminStmt = `DELETE FROM instance_admins WHERE instance_id = $1 AND user_id = $2`
// RemoveAdmin implements [domain.InstanceOperation].
func (i *instanceOperation) RemoveAdmin(ctx context.Context, userID string) error {
return i.QueryExecutor.Exec(ctx, removeInstanceAdminStmt, i.id, userID)
}
const setInstanceAdminRolesStmt = `UPDATE instance_admins SET roles = $1 WHERE instance_id = $2 AND user_id = $3`
// SetAdminRoles implements [domain.InstanceOperation].
func (i *instanceOperation) SetAdminRoles(ctx context.Context, userID string, roles []string) error {
return i.QueryExecutor.Exec(ctx, setInstanceAdminRolesStmt, roles, i.id, userID)
}
const updateInstanceStmt = `UPDATE instances SET name = $1, updated_at = $2 WHERE id = $3 RETURNING updated_at`
// Update implements [domain.InstanceOperation].
func (i *instanceOperation) Update(ctx context.Context, instance *domain.Instance) error {
return i.QueryExecutor.QueryRow(ctx, updateInstanceStmt,
instance.Name,
instance.UpdatedAt,
i.id,
).Scan(&instance.UpdatedAt)
}
var _ domain.InstanceOperation = (*instanceOperation)(nil)

View File

@@ -0,0 +1,17 @@
package repository
import (
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type query struct{ database.Querier }
func Query(querier database.Querier) *query {
return &query{Querier: querier}
}
type executor struct{ database.Executor }
func Execute(exec database.Executor) *executor {
return &executor{Executor: exec}
}

View File

@@ -0,0 +1,21 @@
package repository
import "strings"
type statement struct {
builder strings.Builder
args []any
}
func (s *statement) appendArg(arg any) (placeholder string) {
s.args = append(s.args, arg)
return "$" + string(len(s.args))
}
func (s *statement) appendArgs(args ...any) (placeholders []string) {
placeholders = make([]string, len(args))
for i, arg := range args {
placeholders[i] = s.appendArg(arg)
}
return placeholders
}

View File

@@ -0,0 +1,43 @@
package stmt
import "fmt"
type Column[T any] interface {
fmt.Stringer
statementApplier[T]
scanner(t *T) any
}
type columnDescriptor[T any] struct {
name string
scan func(*T) any
}
func (cd columnDescriptor[T]) scanner(t *T) any {
return cd.scan(t)
}
// Apply implements [Column].
func (f columnDescriptor[T]) Apply(stmt *statement[T]) {
stmt.builder.WriteString(stmt.columnPrefix())
stmt.builder.WriteString(f.String())
}
// String implements [Column].
func (f columnDescriptor[T]) String() string {
return f.name
}
var _ Column[any] = (*columnDescriptor[any])(nil)
type ignoreCaseColumnDescriptor[T any] struct {
columnDescriptor[T]
fieldNameSuffix string
}
func (f ignoreCaseColumnDescriptor[T]) ApplyIgnoreCase(stmt *statement[T]) {
stmt.builder.WriteString(f.String())
stmt.builder.WriteString(f.fieldNameSuffix)
}
var _ Column[any] = (*ignoreCaseColumnDescriptor[any])(nil)

View File

@@ -0,0 +1,97 @@
package stmt
import "fmt"
type statementApplier[T any] interface {
// Apply writes the statement to the builder.
Apply(stmt *statement[T])
}
type Condition[T any] interface {
statementApplier[T]
}
type op interface {
TextOperation | NumberOperation | ListOperation
fmt.Stringer
}
type operation[T any, O op] struct {
o O
}
func (o operation[T, O]) String() string {
return o.o.String()
}
func (o operation[T, O]) Apply(stmt *statement[T]) {
stmt.builder.WriteString(o.o.String())
}
type condition[V, T any, OP op] struct {
field Column[T]
op OP
value V
}
func (c *condition[V, T, OP]) Apply(stmt *statement[T]) {
// placeholder := stmt.appendArg(c.value)
stmt.builder.WriteString(stmt.columnPrefix())
stmt.builder.WriteString(c.field.String())
// stmt.builder.WriteString(c.op)
// stmt.builder.WriteString(placeholder)
}
type and[T any] struct {
conditions []Condition[T]
}
func And[T any](conditions ...Condition[T]) *and[T] {
return &and[T]{
conditions: conditions,
}
}
// Apply implements [Condition].
func (a *and[T]) Apply(stmt *statement[T]) {
if len(a.conditions) > 1 {
stmt.builder.WriteString("(")
defer stmt.builder.WriteString(")")
}
for i, condition := range a.conditions {
if i > 0 {
stmt.builder.WriteString(" AND ")
}
condition.Apply(stmt)
}
}
var _ Condition[any] = (*and[any])(nil)
type or[T any] struct {
conditions []Condition[T]
}
func Or[T any](conditions ...Condition[T]) *or[T] {
return &or[T]{
conditions: conditions,
}
}
// Apply implements [Condition].
func (o *or[T]) Apply(stmt *statement[T]) {
if len(o.conditions) > 1 {
stmt.builder.WriteString("(")
defer stmt.builder.WriteString(")")
}
for i, condition := range o.conditions {
if i > 0 {
stmt.builder.WriteString(" OR ")
}
condition.Apply(stmt)
}
}
var _ Condition[any] = (*or[any])(nil)

View File

@@ -0,0 +1,71 @@
package stmt
type ListEntry interface {
Number | Text | any
}
type ListCondition[E ListEntry, T any] struct {
condition[[]E, T, ListOperation]
}
func (lc *ListCondition[E, T]) Apply(stmt *statement[T]) {
placeholder := stmt.appendArg(lc.value)
switch lc.op {
case ListOperationEqual, ListOperationNotEqual:
lc.field.Apply(stmt)
operation[T, ListOperation]{lc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
case ListOperationContainsAny, ListOperationContainsAll:
lc.field.Apply(stmt)
operation[T, ListOperation]{lc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
case ListOperationNotContainsAny, ListOperationNotContainsAll:
stmt.builder.WriteString("NOT (")
lc.field.Apply(stmt)
operation[T, ListOperation]{lc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString(")")
default:
panic("unknown list operation")
}
}
type ListOperation uint8
const (
// ListOperationEqual checks if the arrays are equal including the order of the elements
ListOperationEqual ListOperation = iota + 1
// ListOperationNotEqual checks if the arrays are not equal including the order of the elements
ListOperationNotEqual
// ListOperationContains checks if the array column contains all the values of the specified array
ListOperationContainsAll
// ListOperationContainsAny checks if the arrays have at least one value in common
ListOperationContainsAny
// ListOperationContainsAll checks if the array column contains all the values of the specified array
// ListOperationNotContainsAll checks if the specified array is not contained by the column
ListOperationNotContainsAll
// ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array
ListOperationNotContainsAny
)
var listOperations = map[ListOperation]string{
// ListOperationEqual checks if the lists are equal
ListOperationEqual: " = ",
// ListOperationNotEqual checks if the lists are not equal
ListOperationNotEqual: " <> ",
// ListOperationContainsAny checks if the arrays have at least one value in common
ListOperationContainsAny: " && ",
// ListOperationContainsAll checks if the array column contains all the values of the specified array
ListOperationContainsAll: " @> ",
// ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array
ListOperationNotContainsAny: " && ", // Base operator for NOT (A && B)
// ListOperationNotContainsAll checks if the array column is not contained by the specified array
ListOperationNotContainsAll: " <@ ", // Base operator for NOT (A <@ B)
}
func (lo ListOperation) String() string {
return listOperations[lo]
}

View File

@@ -0,0 +1,61 @@
package stmt
import (
"time"
"golang.org/x/exp/constraints"
)
type Number interface {
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
}
type between[N Number] struct {
min, max N
}
type NumberBetween[V Number, T any] struct {
condition[between[V], T, NumberOperation]
}
func (nb *NumberBetween[V, T]) Apply(stmt *statement[T]) {
nb.field.Apply(stmt)
stmt.builder.WriteString(" BETWEEN ")
stmt.builder.WriteString(stmt.appendArg(nb.value.min))
stmt.builder.WriteString(" AND ")
stmt.builder.WriteString(stmt.appendArg(nb.value.max))
}
type NumberCondition[V Number, T any] struct {
condition[V, T, NumberOperation]
}
func (nc *NumberCondition[V, T]) Apply(stmt *statement[T]) {
nc.field.Apply(stmt)
operation[T, NumberOperation]{nc.op}.Apply(stmt)
stmt.builder.WriteString(stmt.appendArg(nc.value))
}
type NumberOperation uint8
const (
NumberOperationEqual NumberOperation = iota + 1
NumberOperationNotEqual
NumberOperationLessThan
NumberOperationLessThanOrEqual
NumberOperationGreaterThan
NumberOperationGreaterThanOrEqual
)
var numberOperations = map[NumberOperation]string{
NumberOperationEqual: " = ",
NumberOperationNotEqual: " <> ",
NumberOperationLessThan: " < ",
NumberOperationLessThanOrEqual: " <= ",
NumberOperationGreaterThan: " > ",
NumberOperationGreaterThanOrEqual: " >= ",
}
func (no NumberOperation) String() string {
return numberOperations[no]
}

View File

@@ -0,0 +1,104 @@
package stmt
import (
"fmt"
"strings"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type statement[T any] struct {
builder strings.Builder
client database.QueryExecutor
columns []Column[T]
schema string
table string
alias string
condition Condition[T]
limit uint32
offset uint32
// order by fieldname and sort direction false for asc true for desc
// orderBy SortingColumns[C]
args []any
existingArgs map[any]string
}
func (s *statement[T]) scanners(t *T) []any {
scanners := make([]any, len(s.columns))
for i, column := range s.columns {
scanners[i] = column.scanner(t)
}
return scanners
}
func (s *statement[T]) query() string {
s.builder.WriteString(`SELECT `)
for i, column := range s.columns {
if i > 0 {
s.builder.WriteString(", ")
}
column.Apply(s)
}
s.builder.WriteString(` FROM `)
s.builder.WriteString(s.schema)
s.builder.WriteRune('.')
s.builder.WriteString(s.table)
if s.alias != "" {
s.builder.WriteString(" AS ")
s.builder.WriteString(s.alias)
}
s.builder.WriteString(` WHERE `)
s.condition.Apply(s)
if s.limit > 0 {
s.builder.WriteString(` LIMIT `)
s.builder.WriteString(s.appendArg(s.limit))
}
if s.offset > 0 {
s.builder.WriteString(` OFFSET `)
s.builder.WriteString(s.appendArg(s.offset))
}
return s.builder.String()
}
// func (s *statement[T]) Where(condition Condition[T]) *statement[T] {
// s.condition = condition
// return s
// }
// func (s *statement[T]) Limit(limit uint32) *statement[T] {
// s.limit = limit
// return s
// }
// func (s *statement[T]) Offset(offset uint32) *statement[T] {
// s.offset = offset
// return s
// }
func (s *statement[T]) columnPrefix() string {
if s.alias != "" {
return s.alias + "."
}
return s.schema + "." + s.table + "."
}
func (s *statement[T]) appendArg(arg any) string {
if s.existingArgs == nil {
s.existingArgs = make(map[any]string)
}
if existing, ok := s.existingArgs[arg]; ok {
return existing
}
s.args = append(s.args, arg)
placeholder := fmt.Sprintf("$%d", len(s.args))
s.existingArgs[arg] = placeholder
return placeholder
}

View File

@@ -0,0 +1,18 @@
package stmt_test
import (
"context"
"testing"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt"
)
func Test_Bla(t *testing.T) {
stmt.User(nil).Where(
stmt.Or(
stmt.UserIDCondition("123"),
stmt.UserIDCondition("123"),
stmt.UserUsernameCondition(stmt.TextOperationEqualIgnoreCase, "test"),
),
).Limit(1).Offset(1).Get(context.Background())
}

View File

@@ -0,0 +1,72 @@
package stmt
type Text interface {
~string | ~[]byte
}
type TextCondition[V Text, T any] struct {
condition[V, T, TextOperation]
}
func (tc *TextCondition[V, T]) Apply(stmt *statement[T]) {
placeholder := stmt.appendArg(tc.value)
switch tc.op {
case TextOperationEqual, TextOperationNotEqual:
tc.field.Apply(stmt)
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
case TextOperationEqualIgnoreCase:
if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok {
desc.ApplyIgnoreCase(stmt)
} else {
stmt.builder.WriteString("LOWER(")
tc.field.Apply(stmt)
stmt.builder.WriteString(")")
}
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString("LOWER(")
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString(")")
case TextOperationStartsWith:
tc.field.Apply(stmt)
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString("|| '%'")
case TextOperationStartsWithIgnoreCase:
if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok {
desc.ApplyIgnoreCase(stmt)
} else {
stmt.builder.WriteString("LOWER(")
tc.field.Apply(stmt)
stmt.builder.WriteString(")")
}
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString("LOWER(")
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString(")")
stmt.builder.WriteString("|| '%'")
}
}
type TextOperation uint8
const (
TextOperationEqual TextOperation = iota + 1
TextOperationEqualIgnoreCase
TextOperationNotEqual
TextOperationStartsWith
TextOperationStartsWithIgnoreCase
)
var textOperations = map[TextOperation]string{
TextOperationEqual: " = ",
TextOperationEqualIgnoreCase: " = ",
TextOperationNotEqual: " <> ",
TextOperationStartsWith: " LIKE ",
TextOperationStartsWithIgnoreCase: " LIKE ",
}
func (to TextOperation) String() string {
return textOperations[to]
}

View File

@@ -0,0 +1,193 @@
package stmt
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type userStatement struct {
statement[domain.User]
}
func User(client database.QueryExecutor) *userStatement {
return &userStatement{
statement: statement[domain.User]{
schema: "zitadel",
table: "users",
alias: "u",
client: client,
columns: []Column[domain.User]{
userColumns[UserInstanceID],
userColumns[UserOrgID],
userColumns[UserColumnID],
userColumns[UserColumnUsername],
userColumns[UserCreatedAt],
userColumns[UserUpdatedAt],
userColumns[UserDeletedAt],
},
},
}
}
func (s *userStatement) Where(condition Condition[domain.User]) *userStatement {
s.condition = condition
return s
}
func (s *userStatement) Limit(limit uint32) *userStatement {
s.limit = limit
return s
}
func (s *userStatement) Offset(offset uint32) *userStatement {
s.offset = offset
return s
}
func (s *userStatement) Get(ctx context.Context) (*domain.User, error) {
var user domain.User
err := s.client.QueryRow(ctx, s.query(), s.statement.args...).Scan(s.scanners(&user)...)
if err != nil {
return nil, err
}
return &user, nil
}
func (s *userStatement) List(ctx context.Context) ([]*domain.User, error) {
var users []*domain.User
rows, err := s.client.Query(ctx, s.query(), s.statement.args...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var user domain.User
err = rows.Scan(s.scanners(&user)...)
if err != nil {
return nil, err
}
users = append(users, &user)
}
return users, nil
}
func (s *userStatement) SetUsername(ctx context.Context, username string) error {
return nil
}
type UserColumn uint8
var (
userColumns map[UserColumn]Column[domain.User] = map[UserColumn]Column[domain.User]{
UserInstanceID: columnDescriptor[domain.User]{
name: "instance_id",
scan: func(u *domain.User) any {
return &u.InstanceID
},
},
UserOrgID: columnDescriptor[domain.User]{
name: "org_id",
scan: func(u *domain.User) any {
return &u.OrgID
},
},
UserColumnID: columnDescriptor[domain.User]{
name: "id",
scan: func(u *domain.User) any {
return &u.ID
},
},
UserColumnUsername: ignoreCaseColumnDescriptor[domain.User]{
columnDescriptor: columnDescriptor[domain.User]{
name: "username",
scan: func(u *domain.User) any {
return &u.Username
},
},
fieldNameSuffix: "_lower",
},
UserCreatedAt: columnDescriptor[domain.User]{
name: "created_at",
scan: func(u *domain.User) any {
return &u.CreatedAt
},
},
UserUpdatedAt: columnDescriptor[domain.User]{
name: "updated_at",
scan: func(u *domain.User) any {
return &u.UpdatedAt
},
},
UserDeletedAt: columnDescriptor[domain.User]{
name: "deleted_at",
scan: func(u *domain.User) any {
return &u.DeletedAt
},
},
}
humanColumns = map[UserColumn]Column[domain.User]{
UserHumanColumnEmail: ignoreCaseColumnDescriptor[domain.User]{
columnDescriptor: columnDescriptor[domain.User]{
name: "email",
scan: func(u *domain.User) any {
human, ok := u.Traits.(*domain.Human)
if !ok {
return nil
}
if human.Email == nil {
human.Email = new(domain.Email)
}
return &human.Email.Address
},
},
fieldNameSuffix: "_lower",
},
UserHumanColumnEmailVerified: columnDescriptor[domain.User]{
name: "email_is_verified",
scan: func(u *domain.User) any {
human, ok := u.Traits.(*domain.Human)
if !ok {
return nil
}
if human.Email == nil {
human.Email = new(domain.Email)
}
return &human.Email.IsVerified
},
},
}
machineColumns = map[UserColumn]Column[domain.User]{
UserMachineDescription: columnDescriptor[domain.User]{
name: "description",
scan: func(u *domain.User) any {
machine, ok := u.Traits.(*domain.Machine)
if !ok {
return nil
}
if machine == nil {
machine = new(domain.Machine)
}
return &machine.Description
},
},
}
)
const (
UserInstanceID UserColumn = iota + 1
UserOrgID
UserColumnID
UserColumnUsername
UserHumanColumnEmail
UserHumanColumnEmailVerified
UserMachineDescription
UserCreatedAt
UserUpdatedAt
UserDeletedAt
)

View File

@@ -0,0 +1,23 @@
package stmt
import "github.com/zitadel/zitadel/backend/v3/domain"
func UserIDCondition(id string) *TextCondition[string, domain.User] {
return &TextCondition[string, domain.User]{
condition: condition[string, domain.User, TextOperation]{
field: userColumns[UserColumnID],
op: TextOperationEqual,
value: id,
},
}
}
func UserUsernameCondition(op TextOperation, username string) *TextCondition[string, domain.User] {
return &TextCondition[string, domain.User]{
condition: condition[string, domain.User, TextOperation]{
field: userColumns[UserColumnUsername],
op: op,
value: username,
},
}
}

View File

@@ -0,0 +1,135 @@
package stmt
// type table struct {
// schema string
// name string
// possibleJoins []*join
// columns []*col
// }
// type col struct {
// *table
// name string
// }
// type join struct {
// *table
// on []*joinColumns
// }
// type joinColumns struct {
// left, right *col
// }
// var (
// userTable = &table{
// schema: "zitadel",
// name: "users",
// }
// userColumns = []*col{
// userInstanceIDColumn,
// userOrgIDColumn,
// userIDColumn,
// userUsernameColumn,
// }
// userInstanceIDColumn = &col{
// table: userTable,
// name: "instance_id",
// }
// userOrgIDColumn = &col{
// table: userTable,
// name: "org_id",
// }
// userIDColumn = &col{
// table: userTable,
// name: "id",
// }
// userUsernameColumn = &col{
// table: userTable,
// name: "username",
// }
// userJoins = []*join{
// {
// table: instanceTable,
// on: []*joinColumns{
// {
// left: instanceIDColumn,
// right: userInstanceIDColumn,
// },
// },
// },
// {
// table: orgTable,
// on: []*joinColumns{
// {
// left: orgIDColumn,
// right: userOrgIDColumn,
// },
// },
// },
// }
// )
// var (
// instanceTable = &table{
// schema: "zitadel",
// name: "instances",
// }
// instanceColumns = []*col{
// instanceIDColumn,
// instanceNameColumn,
// }
// instanceIDColumn = &col{
// table: instanceTable,
// name: "id",
// }
// instanceNameColumn = &col{
// table: instanceTable,
// name: "name",
// }
// )
// var (
// orgTable = &table{
// schema: "zitadel",
// name: "orgs",
// }
// orgColumns = []*col{
// orgInstanceIDColumn,
// orgIDColumn,
// orgNameColumn,
// }
// orgInstanceIDColumn = &col{
// table: orgTable,
// name: "instance_id",
// }
// orgIDColumn = &col{
// table: orgTable,
// name: "id",
// }
// orgNameColumn = &col{
// table: orgTable,
// name: "name",
// }
// )
// func init() {
// instanceTable.columns = instanceColumns
// userTable.columns = userColumns
// userTable.possibleJoins = []join{
// {
// table: userTable,
// on: []joinColumns{
// {
// left: userIDColumn,
// right: userIDColumn,
// },
// },
// },
// }
// }

View File

@@ -0,0 +1,55 @@
package v3
type Column interface {
Name() string
Write(builder statementBuilder)
}
type ignoreCaseColumn interface {
Column
WriteIgnoreCase(builder statementBuilder)
}
var (
columnNameID = "id"
columnNameName = "name"
columnNameCreatedAt = "created_at"
columnNameUpdatedAt = "updated_at"
columnNameDeletedAt = "deleted_at"
columnNameInstanceID = "instance_id"
columnNameOrgID = "org_id"
)
type column struct {
table Table
name string
}
// Write implements Column.
func (c *column) Write(builder statementBuilder) {
c.table.writeOn(builder)
builder.writeRune('.')
builder.writeString(c.name)
}
// Name implements [Column].
func (c *column) Name() string {
return c.name
}
var _ Column = (*column)(nil)
type columnIgnoreCase struct {
column
suffix string
}
// WriteIgnoreCase implements ignoreCaseColumn.
func (c *columnIgnoreCase) WriteIgnoreCase(builder statementBuilder) {
c.Write(builder)
builder.writeString(c.suffix)
}
var _ ignoreCaseColumn = (*columnIgnoreCase)(nil)

View File

@@ -0,0 +1,182 @@
package v3
type statementBuilder interface {
write([]byte)
writeString(string)
writeRune(rune)
appendArg(any) (placeholder string)
table() Table
}
type Condition interface {
writeOn(builder statementBuilder)
}
type and struct {
conditions []Condition
}
func And(conditions ...Condition) *and {
return &and{conditions: conditions}
}
// writeOn implements [Condition].
func (a *and) writeOn(builder statementBuilder) {
if len(a.conditions) > 1 {
builder.writeString("(")
defer builder.writeString(")")
}
for i, condition := range a.conditions {
if i > 0 {
builder.writeString(" AND ")
}
condition.writeOn(builder)
}
}
var _ Condition = (*and)(nil)
type or struct {
conditions []Condition
}
func Or(conditions ...Condition) *or {
return &or{conditions: conditions}
}
// writeOn implements [Condition].
func (o *or) writeOn(builder statementBuilder) {
if len(o.conditions) > 1 {
builder.writeString("(")
defer builder.writeString(")")
}
for i, condition := range o.conditions {
if i > 0 {
builder.writeString(" OR ")
}
condition.writeOn(builder)
}
}
var _ Condition = (*or)(nil)
type isNull struct {
column Column
}
func IsNull(column Column) *isNull {
return &isNull{column: column}
}
// writeOn implements [Condition].
func (cond *isNull) writeOn(builder statementBuilder) {
cond.column.Write(builder)
builder.writeString(" IS NULL")
}
var _ Condition = (*isNull)(nil)
type isNotNull struct {
column Column
}
func IsNotNull(column Column) *isNotNull {
return &isNotNull{column: column}
}
// writeOn implements [Condition].
func (cond *isNotNull) writeOn(builder statementBuilder) {
cond.column.Write(builder)
builder.writeString(" IS NOT NULL")
}
var _ Condition = (*isNotNull)(nil)
type condition[Op Operator, V Value] struct {
column Column
operator Op
value V
}
// writeOn implements [Condition].
func (cond condition[Op, V]) writeOn(builder statementBuilder) {
cond.column.Write(builder)
builder.writeString(cond.operator.String())
builder.writeString(builder.appendArg(cond.value))
}
var _ Condition = (*condition[TextOperator, string])(nil)
type textCondition[V Text] struct {
condition[TextOperator, V]
}
func NewTextCondition[V Text](column Column, operator TextOperator, value V) *textCondition[V] {
return &textCondition[V]{
condition: condition[TextOperator, V]{
column: column,
operator: operator,
value: value,
},
}
}
// writeOn implements [Condition].
func (cond *textCondition[V]) writeOn(builder statementBuilder) {
switch cond.operator {
case TextOperatorEqual, TextOperatorNotEqual:
cond.column.Write(builder)
builder.writeString(cond.operator.String())
builder.writeString(builder.appendArg(cond.value))
case TextOperatorEqualIgnoreCase, TextOperatorNotEqualIgnoreCase:
if col, ok := cond.column.(ignoreCaseColumn); ok {
col.WriteIgnoreCase(builder)
} else {
builder.writeString("LOWER(")
cond.column.Write(builder)
builder.writeString(")")
}
builder.writeString(cond.operator.String())
builder.writeString("LOWER(")
builder.writeString(builder.appendArg(cond.value))
builder.writeString(")")
case TextOperatorStartsWith:
cond.column.Write(builder)
builder.writeString(cond.operator.String())
builder.writeString(builder.appendArg(cond.value))
builder.writeString(" || '%'")
case TextOperatorStartsWithIgnoreCase:
if col, ok := cond.column.(ignoreCaseColumn); ok {
col.WriteIgnoreCase(builder)
} else {
builder.writeString("LOWER(")
cond.column.Write(builder)
builder.writeString(")")
}
builder.writeString(cond.operator.String())
builder.writeString("LOWER(")
builder.writeString(builder.appendArg(cond.value))
builder.writeString(") || '%'")
}
}
var _ Condition = (*textCondition[string])(nil)
type numberCondition[V Number] struct {
condition[NumberOperator, V]
}
func NewNumberCondition[V Number](column Column, operator NumberOperator, value V) *numberCondition[V] {
return &numberCondition[V]{
condition: condition[NumberOperator, V]{
column: column,
operator: operator,
value: value,
},
}
}
var _ Condition = (*numberCondition[int])(nil)

View File

@@ -0,0 +1,104 @@
package v3
import (
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Instance struct {
id string
name string
createdAt time.Time
updatedAt time.Time
deletedAt time.Time
}
// Columns implements [object].
func (Instance) Columns(table Table) []Column {
return []Column{
&column{
table: table,
name: columnNameID,
},
&column{
table: table,
name: columnNameName,
},
&column{
table: table,
name: columnNameCreatedAt,
},
&column{
table: table,
name: columnNameUpdatedAt,
},
&column{
table: table,
name: columnNameDeletedAt,
},
}
}
// Scan implements [object].
func (i Instance) Scan(row database.Scanner) error {
return row.Scan(
&i.id,
&i.name,
&i.createdAt,
&i.updatedAt,
&i.deletedAt,
)
}
type instanceTable struct {
*table
}
func InstanceTable() *instanceTable {
table := &instanceTable{
table: newTable[Instance]("zitadel", "instances"),
}
table.possibleJoins = func(t Table) map[string]Column {
switch on := t.(type) {
case *instanceTable:
return map[string]Column{
columnNameID: on.IDColumn(),
}
case *orgTable:
return map[string]Column{
columnNameID: on.InstanceIDColumn(),
}
case *userTable:
return map[string]Column{
columnNameID: on.InstanceIDColumn(),
}
default:
return nil
}
}
return table
}
func (i *instanceTable) IDColumn() Column {
return i.columns[columnNameID]
}
func (i *instanceTable) NameColumn() Column {
return i.columns[columnNameName]
}
func (i *instanceTable) CreatedAtColumn() Column {
return i.columns[columnNameCreatedAt]
}
func (i *instanceTable) UpdatedAtColumn() Column {
return i.columns[columnNameUpdatedAt]
}
func (i *instanceTable) DeletedAtColumn() Column {
return i.columns[columnNameDeletedAt]
}

View File

@@ -0,0 +1,11 @@
package v3
type join struct {
table Table
conditions []joinCondition
}
type joinCondition struct {
left Column
right Column
}

View File

@@ -0,0 +1,82 @@
package v3
import (
"fmt"
"time"
"golang.org/x/exp/constraints"
)
type Value interface {
Bool | Number | Text
}
type Text interface {
~string | ~[]byte
}
type Number interface {
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
}
type Bool interface {
~bool
}
type Operator interface {
fmt.Stringer
}
type TextOperator uint8
// String implements [Operator].
func (t TextOperator) String() string {
return textOperators[t]
}
const (
TextOperatorEqual TextOperator = iota + 1
TextOperatorEqualIgnoreCase
TextOperatorNotEqual
TextOperatorNotEqualIgnoreCase
TextOperatorStartsWith
TextOperatorStartsWithIgnoreCase
)
var textOperators = map[TextOperator]string{
TextOperatorEqual: " = ",
TextOperatorEqualIgnoreCase: " LIKE ",
TextOperatorNotEqual: " <> ",
TextOperatorNotEqualIgnoreCase: " NOT LIKE ",
TextOperatorStartsWith: " LIKE ",
TextOperatorStartsWithIgnoreCase: " LIKE ",
}
var _ Operator = TextOperator(0)
type NumberOperator uint8
// String implements Operator.
func (n NumberOperator) String() string {
return numberOperators[n]
}
const (
NumberOperatorEqual NumberOperator = iota + 1
NumberOperatorNotEqual
NumberOperatorLessThan
NumberOperatorLessThanOrEqual
NumberOperatorGreaterThan
NumberOperatorGreaterThanOrEqual
)
var numberOperators = map[NumberOperator]string{
NumberOperatorEqual: " = ",
NumberOperatorNotEqual: " <> ",
NumberOperatorLessThan: " < ",
NumberOperatorLessThanOrEqual: " <= ",
NumberOperatorGreaterThan: " > ",
NumberOperatorGreaterThanOrEqual: " >= ",
}
var _ Operator = NumberOperator(0)

View File

@@ -0,0 +1,117 @@
package v3
import (
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Org struct {
instanceID string
id string
name string
createdAt time.Time
updatedAt time.Time
deletedAt time.Time
}
// Columns implements [object].
func (Org) Columns(table Table) []Column {
return []Column{
&column{
table: table,
name: columnNameInstanceID,
},
&column{
table: table,
name: columnNameID,
},
&column{
table: table,
name: columnNameName,
},
&column{
table: table,
name: columnNameCreatedAt,
},
&column{
table: table,
name: columnNameUpdatedAt,
},
&column{
table: table,
name: columnNameDeletedAt,
},
}
}
// Scan implements [object].
func (o Org) Scan(row database.Scanner) error {
return row.Scan(
&o.instanceID,
&o.id,
&o.name,
&o.createdAt,
&o.updatedAt,
&o.deletedAt,
)
}
type orgTable struct {
*table
}
func OrgTable() *orgTable {
table := &orgTable{
table: newTable[Org]("zitadel", "orgs"),
}
table.possibleJoins = func(table Table) map[string]Column {
switch on := table.(type) {
case *instanceTable:
return map[string]Column{
columnNameInstanceID: on.IDColumn(),
}
case *orgTable:
return map[string]Column{
columnNameInstanceID: on.InstanceIDColumn(),
columnNameID: on.IDColumn(),
}
case *userTable:
return map[string]Column{
columnNameInstanceID: on.InstanceIDColumn(),
columnNameID: on.IDColumn(),
}
default:
return nil
}
}
return table
}
func (o *orgTable) InstanceIDColumn() Column {
return o.columns[columnNameInstanceID]
}
func (o *orgTable) IDColumn() Column {
return o.columns[columnNameID]
}
func (o *orgTable) NameColumn() Column {
return o.columns[columnNameName]
}
func (o *orgTable) CreatedAtColumn() Column {
return o.columns[columnNameCreatedAt]
}
func (o *orgTable) UpdatedAtColumn() Column {
return o.columns[columnNameUpdatedAt]
}
func (o *orgTable) DeletedAtColumn() Column {
return o.columns[columnNameDeletedAt]
}

View File

@@ -0,0 +1,188 @@
package v3
import (
"context"
"fmt"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Query[O object] interface {
Where(condition Condition)
Join(tables ...Table)
Limit(limit uint32)
Offset(offset uint32)
OrderBy(columns ...Column)
Result(ctx context.Context, client database.Querier) (*O, error)
Results(ctx context.Context, client database.Querier) ([]O, error)
fmt.Stringer
statementBuilder
}
type query[O object] struct {
*statement[O]
joins []join
limit uint32
offset uint32
orderBy []Column
}
func NewQuery[O object](table Table) Query[O] {
return &query[O]{
statement: newStatement[O](table),
}
}
// Result implements [Query].
func (q *query[O]) Result(ctx context.Context, client database.Querier) (*O, error) {
var object O
row := client.QueryRow(ctx, q.String(), q.args...)
if err := object.Scan(row); err != nil {
return nil, err
}
return &object, nil
}
// Results implements [Query].
func (q *query[O]) Results(ctx context.Context, client database.Querier) ([]O, error) {
var objects []O
rows, err := client.Query(ctx, q.String(), q.args...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var object O
if err := object.Scan(rows); err != nil {
return nil, err
}
objects = append(objects, object)
}
return objects, rows.Err()
}
// Join implements [Query].
func (q *query[O]) Join(tables ...Table) {
for _, tbl := range tables {
cols := q.tbl.(*table).possibleJoins(tbl)
if len(cols) == 0 {
panic(fmt.Sprintf("table %q does not have any possible joins with table %q", q.tbl.Name(), tbl.Name()))
}
q.joins = append(q.joins, join{
table: tbl,
conditions: make([]joinCondition, 0, len(cols)),
})
for colName, col := range cols {
q.joins[len(q.joins)-1].conditions = append(q.joins[len(q.joins)-1].conditions, joinCondition{
left: q.tbl.(*table).columns[colName],
right: col,
})
}
}
}
func (q *query[O]) Limit(limit uint32) {
q.limit = limit
}
func (q *query[O]) Offset(offset uint32) {
q.offset = offset
}
func (q *query[O]) OrderBy(columns ...Column) {
for _, allowedColumn := range q.columns {
for _, column := range columns {
if allowedColumn.Name() == column.Name() {
q.orderBy = append(q.orderBy, column)
}
}
}
}
// String implements [fmt.Stringer] and [Query].
func (q *query[O]) String() string {
q.writeSelectColumns()
q.writeFrom()
q.writeJoins()
q.writeCondition()
q.writeOrderBy()
q.writeLimit()
q.writeOffset()
q.writeGroupBy()
return q.builder.String()
}
func (q *query[O]) writeSelectColumns() {
q.builder.WriteString("SELECT ")
for i, column := range q.columns {
if i > 0 {
q.builder.WriteString(", ")
}
q.builder.WriteString(q.tbl.Alias())
q.builder.WriteRune('.')
q.builder.WriteString(column.Name())
}
}
func (q *query[O]) writeJoins() {
for _, join := range q.joins {
q.builder.WriteString(" JOIN ")
q.builder.WriteString(join.table.Schema())
q.builder.WriteRune('.')
q.builder.WriteString(join.table.Name())
if join.table.Alias() != "" {
q.builder.WriteString(" AS ")
q.builder.WriteString(join.table.Alias())
}
q.builder.WriteString(" ON ")
for i, condition := range join.conditions {
if i > 0 {
q.builder.WriteString(" AND ")
}
q.builder.WriteString(condition.left.Name())
q.builder.WriteString(" = ")
q.builder.WriteString(condition.right.Name())
}
}
}
func (q *query[O]) writeOrderBy() {
if len(q.orderBy) == 0 {
return
}
q.builder.WriteString(" ORDER BY ")
for i, order := range q.orderBy {
if i > 0 {
q.builder.WriteString(", ")
}
order.Write(q)
}
}
func (q *query[O]) writeLimit() {
if q.limit == 0 {
return
}
q.builder.WriteString(" LIMIT ")
q.builder.WriteString(q.appendArg(q.limit))
}
func (q *query[O]) writeOffset() {
if q.offset == 0 {
return
}
q.builder.WriteString(" OFFSET ")
q.builder.WriteString(q.appendArg(q.offset))
}
func (q *query[O]) writeGroupBy() {
q.builder.WriteString(" GROUP BY ")
}

View File

@@ -0,0 +1,85 @@
package v3
import (
"fmt"
"strings"
)
type statement[T object] struct {
tbl Table
columns []Column
condition Condition
builder strings.Builder
args []any
existingArgs map[any]string
}
func newStatement[O object](t Table) *statement[O] {
var o O
return &statement[O]{
tbl: t,
columns: o.Columns(t),
}
}
// Where implements [Query].
func (stmt *statement[T]) Where(condition Condition) {
stmt.condition = condition
}
func (stmt *statement[T]) writeFrom() {
stmt.builder.WriteString(" FROM ")
stmt.builder.WriteString(stmt.tbl.Schema())
stmt.builder.WriteRune('.')
stmt.builder.WriteString(stmt.tbl.Name())
if stmt.tbl.Alias() != "" {
stmt.builder.WriteString(" AS ")
stmt.builder.WriteString(stmt.tbl.Alias())
}
}
func (stmt *statement[T]) writeCondition() {
if stmt.condition == nil {
return
}
stmt.builder.WriteString(" WHERE ")
stmt.condition.writeOn(stmt)
}
// appendArg implements [statementBuilder].
func (stmt *statement[T]) appendArg(arg any) (placeholder string) {
if stmt.existingArgs == nil {
stmt.existingArgs = make(map[any]string)
}
if placeholder, ok := stmt.existingArgs[arg]; ok {
return placeholder
}
stmt.args = append(stmt.args, arg)
placeholder = fmt.Sprintf("$%d", len(stmt.args))
stmt.existingArgs[arg] = placeholder
return placeholder
}
// table implements [statementBuilder].
func (stmt *statement[T]) table() Table {
return stmt.tbl
}
// write implements [statementBuilder].
func (stmt *statement[T]) write(data []byte) {
stmt.builder.Write(data)
}
// writeRune implements [statementBuilder].
func (stmt *statement[T]) writeRune(r rune) {
stmt.builder.WriteRune(r)
}
// writeString implements [statementBuilder].
func (stmt *statement[T]) writeString(s string) {
stmt.builder.WriteString(s)
}
var _ statementBuilder = (*statement[Instance])(nil)

View File

@@ -0,0 +1,84 @@
package v3
import "github.com/zitadel/zitadel/backend/v3/storage/database"
type object interface {
User | Org | Instance
Columns(t Table) []Column
Scan(s database.Scanner) error
}
type Table interface {
Schema() string
Name() string
Alias() string
Columns() []Column
writeOn(builder statementBuilder)
}
type table struct {
schema string
name string
alias string
possibleJoins func(table Table) map[string]Column
columns map[string]Column
colList []Column
}
func newTable[O object](schema, name string) *table {
t := &table{
schema: schema,
name: name,
}
var o O
t.colList = o.Columns(t)
t.columns = make(map[string]Column, len(t.colList))
for _, col := range t.colList {
t.columns[col.Name()] = col
}
return t
}
// Columns implements [Table].
func (t *table) Columns() []Column {
if len(t.colList) > 0 {
return t.colList
}
t.colList = make([]Column, 0, len(t.columns))
for _, column := range t.columns {
t.colList = append(t.colList, column)
}
return t.colList
}
// Name implements [Table].
func (t *table) Name() string {
return t.name
}
// Schema implements [Table].
func (t *table) Schema() string {
return t.schema
}
// Alias implements [Table].
func (t *table) Alias() string {
if t.alias != "" {
return t.alias
}
return t.schema + "." + t.name
}
// writeOn implements [Table].
func (t *table) writeOn(builder statementBuilder) {
builder.writeString(t.Alias())
}
var _ Table = (*table)(nil)

View File

@@ -0,0 +1,170 @@
package v3
import (
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type User struct {
instanceID string
orgID string
id string
username string
createdAt time.Time
updatedAt time.Time
deletedAt time.Time
}
// Columns implements [object].
func (u User) Columns(table Table) []Column {
return []Column{
&column{
table: table,
name: columnNameInstanceID,
},
&column{
table: table,
name: columnNameOrgID,
},
&column{
table: table,
name: columnNameID,
},
&columnIgnoreCase{
column: column{
table: table,
name: userTableUsernameColumn,
},
suffix: "_lower",
},
&column{
table: table,
name: columnNameCreatedAt,
},
&column{
table: table,
name: columnNameUpdatedAt,
},
&column{
table: table,
name: columnNameDeletedAt,
},
}
}
// Scan implements [object].
func (u User) Scan(row database.Scanner) error {
return row.Scan(
&u.instanceID,
&u.orgID,
&u.id,
&u.username,
&u.createdAt,
&u.updatedAt,
&u.deletedAt,
)
}
type userTable struct {
*table
}
const (
userTableUsernameColumn = "username"
)
func UserTable() *userTable {
table := &userTable{
table: newTable[User]("zitadel", "users"),
}
table.possibleJoins = func(table Table) map[string]Column {
switch on := table.(type) {
case *userTable:
return map[string]Column{
columnNameInstanceID: on.InstanceIDColumn(),
columnNameOrgID: on.OrgIDColumn(),
columnNameID: on.IDColumn(),
}
case *orgTable:
return map[string]Column{
columnNameInstanceID: on.InstanceIDColumn(),
columnNameOrgID: on.IDColumn(),
}
case *instanceTable:
return map[string]Column{
columnNameInstanceID: on.IDColumn(),
}
default:
return nil
}
}
return table
}
func (t *userTable) InstanceIDColumn() Column {
return t.columns[columnNameInstanceID]
}
func (t *userTable) OrgIDColumn() Column {
return t.columns[columnNameOrgID]
}
func (t *userTable) IDColumn() Column {
return t.columns[columnNameID]
}
func (t *userTable) UsernameColumn() Column {
return t.columns[userTableUsernameColumn]
}
func (t *userTable) CreatedAtColumn() Column {
return t.columns[columnNameCreatedAt]
}
func (t *userTable) UpdatedAtColumn() Column {
return t.columns[columnNameUpdatedAt]
}
func (t *userTable) DeletedAtColumn() Column {
return t.columns[columnNameDeletedAt]
}
func NewUserQuery() Query[User] {
q := NewQuery[User](UserTable())
return q
}
type userByIDCondition[T Text] struct {
id T
}
func UserByID[T Text](id T) Condition {
return &userByIDCondition[T]{id: id}
}
// writeOn implements Condition.
func (u *userByIDCondition[T]) writeOn(builder statementBuilder) {
NewTextCondition(builder.table().(*userTable).IDColumn(), TextOperatorEqual, u.id).writeOn(builder)
}
var _ Condition = (*userByIDCondition[string])(nil)
type userByUsernameCondition[T Text] struct {
username T
operator TextOperator
}
func UserByUsername[T Text](username T, operator TextOperator) Condition {
return &userByUsernameCondition[T]{username: username, operator: operator}
}
// writeOn implements Condition.
func (u *userByUsernameCondition[T]) writeOn(builder statementBuilder) {
NewTextCondition(builder.table().(*userTable).UsernameColumn(), u.operator, u.username).writeOn(builder)
}
var _ Condition = (*userByUsernameCondition[string])(nil)

View File

@@ -0,0 +1,25 @@
package v3_test
import (
"context"
"testing"
v3 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v3"
)
type user struct{}
func TestUser(t *testing.T) {
query := v3.NewUserQuery()
query.Where(
v3.Or(
v3.UserByID("123"),
v3.UserByUsername("test", v3.TextOperatorStartsWithIgnoreCase),
),
)
query.Limit(10)
query.Offset(5)
// query.OrderBy(
query.Result(context.TODO(), nil)
}

View File

@@ -0,0 +1,78 @@
package v4
type Change interface {
Column
}
type change[V Value] struct {
column Column
value V
}
func newChange[V Value](col Column, value V) Change {
return &change[V]{
column: col,
value: value,
}
}
func newUpdatePtrColumn[V Value](col Column, value *V) Change {
if value == nil {
return newChange(col, nullDBInstruction)
}
return newChange(col, *value)
}
// writeTo implements [Change].
func (c change[V]) writeTo(builder *statementBuilder) {
c.column.writeTo(builder)
builder.WriteString(" = ")
builder.writeArg(c.value)
}
type Changes []Change
func newChanges(cols ...Change) Change {
return Changes(cols)
}
// writeTo implements [Change].
func (m Changes) writeTo(builder *statementBuilder) {
for i, col := range m {
if i > 0 {
builder.WriteString(", ")
}
col.writeTo(builder)
}
}
var _ Change = Changes(nil)
var _ Change = (*change[string])(nil)
type Column interface {
writeTo(builder *statementBuilder)
}
type column struct {
name string
}
func (c column) writeTo(builder *statementBuilder) {
builder.WriteString(c.name)
}
type ignoreCaseColumn interface {
Column
writeIgnoreCaseTo(builder *statementBuilder)
}
type ignoreCaseCol struct {
column
suffix string
}
func (c ignoreCaseCol) writeIgnoreCaseTo(builder *statementBuilder) {
c.column.writeTo(builder)
builder.WriteString(c.suffix)
}

View File

@@ -0,0 +1,112 @@
package v4
type Condition interface {
writeTo(builder *statementBuilder)
}
type and struct {
conditions []Condition
}
// writeTo implements [Condition].
func (a *and) writeTo(builder *statementBuilder) {
if len(a.conditions) > 1 {
builder.WriteString("(")
defer builder.WriteString(")")
}
for i, condition := range a.conditions {
if i > 0 {
builder.WriteString(" AND ")
}
condition.writeTo(builder)
}
}
func And(conditions ...Condition) *and {
return &and{conditions: conditions}
}
var _ Condition = (*and)(nil)
type or struct {
conditions []Condition
}
// writeTo implements [Condition].
func (o *or) writeTo(builder *statementBuilder) {
if len(o.conditions) > 1 {
builder.WriteString("(")
defer builder.WriteString(")")
}
for i, condition := range o.conditions {
if i > 0 {
builder.WriteString(" OR ")
}
condition.writeTo(builder)
}
}
func Or(conditions ...Condition) *or {
return &or{conditions: conditions}
}
var _ Condition = (*or)(nil)
type isNull struct {
column Column
}
// writeTo implements [Condition].
func (i *isNull) writeTo(builder *statementBuilder) {
i.column.writeTo(builder)
builder.WriteString(" IS NULL")
}
func IsNull(column Column) *isNull {
return &isNull{column: column}
}
var _ Condition = (*isNull)(nil)
type isNotNull struct {
column Column
}
// writeTo implements [Condition].
func (i *isNotNull) writeTo(builder *statementBuilder) {
i.column.writeTo(builder)
builder.WriteString(" IS NOT NULL")
}
func IsNotNull(column Column) *isNotNull {
return &isNotNull{column: column}
}
var _ Condition = (*isNotNull)(nil)
type valueCondition func(builder *statementBuilder)
func newTextCondition[V Text](col Column, op TextOperator, value V) Condition {
return valueCondition(func(builder *statementBuilder) {
writeTextOperation(builder, col, op, value)
})
}
func newNumberCondition[V Number](col Column, op NumberOperator, value V) Condition {
return valueCondition(func(builder *statementBuilder) {
writeNumberOperation(builder, col, op, value)
})
}
func newBooleanCondition[V Boolean](col Column, value V) Condition {
return valueCondition(func(builder *statementBuilder) {
writeBooleanOperation(builder, col, value)
})
}
// writeTo implements [Condition].
func (c valueCondition) writeTo(builder *statementBuilder) {
c(builder)
}
var _ Condition = (*valueCondition)(nil)

View File

@@ -0,0 +1,2 @@
// this test focuses on queries rather than on tables
package v4

View File

@@ -0,0 +1,149 @@
CREATE TABLE objects (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
deleted_at TIMESTAMP
);
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TABLE instances(
name VARCHAR(50) NOT NULL
, PRIMARY KEY (id)
) INHERITS (objects);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON instances
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE instance_objects(
instance_id INT NOT NULL
, PRIMARY KEY (instance_id, id)
-- as foreign keys are not inherited we need to define them on the child tables
--, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (objects);
CREATE TABLE orgs(
name VARCHAR(50) NOT NULL
, PRIMARY KEY (instance_id, id)
, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (instance_objects);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON orgs
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE org_objects(
org_id INT NOT NULL
, PRIMARY KEY (instance_id, org_id, id)
-- as foreign keys are not inherited we need to define them on the child tables
-- CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id),
-- CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (instance_objects);
CREATE TABLE users (
username VARCHAR(50) NOT NULL
, PRIMARY KEY (instance_id, org_id, id)
-- as foreign keys are not inherited we need to define them on the child tables
-- , CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
-- , CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (org_objects);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON users
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE human_users(
first_name VARCHAR(50)
, last_name VARCHAR(50)
, PRIMARY KEY (instance_id, org_id, id)
-- CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id),
, CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
, CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (users);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON human_users
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE machine_users(
description VARCHAR(50)
, PRIMARY KEY (instance_id, org_id, id)
-- , CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id)
, CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
, CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (users);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON machine_users
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
select u.*, hu.first_name, hu.last_name, mu.description from users u
left join human_users hu on u.instance_id = hu.instance_id and u.org_id = hu.org_id and u.id = hu.id
left join machine_users mu on u.instance_id = mu.instance_id and u.org_id = mu.org_id and u.id = mu.id
-- where
-- u.instance_id = 1
-- and u.org_id = 3
-- and u.id = 7
;
create view users_view as (
SELECT
id
, created_at
, updated_at
, deleted_at
, instance_id
, org_id
, username
, first_name
, last_name
, description
FROM (
(SELECT
id
, created_at
, updated_at
, deleted_at
, instance_id
, org_id
, username
, first_name
, last_name
, NULL AS description
FROM
human_users)
UNION
(SELECT
id
, created_at
, updated_at
, deleted_at
, instance_id
, org_id
, username
, NULL AS first_name
, NULL AS last_name
, description
FROM
machine_users)
));

View File

@@ -0,0 +1,139 @@
package v4
import (
"time"
"golang.org/x/exp/constraints"
)
type Value interface {
Boolean | Number | Text | databaseInstruction
}
type Operator interface {
BooleanOperator | NumberOperator | TextOperator
}
type Text interface {
~string | ~[]byte
}
type TextOperator uint8
const (
// TextOperatorEqual compares two strings for equality.
TextOperatorEqual TextOperator = iota + 1
// TextOperatorEqualIgnoreCase compares two strings for equality, ignoring case.
TextOperatorEqualIgnoreCase
// TextOperatorNotEqual compares two strings for inequality.
TextOperatorNotEqual
// TextOperatorNotEqualIgnoreCase compares two strings for inequality, ignoring case.
TextOperatorNotEqualIgnoreCase
// TextOperatorStartsWith checks if the first string starts with the second.
TextOperatorStartsWith
// TextOperatorStartsWithIgnoreCase checks if the first string starts with the second, ignoring case.
TextOperatorStartsWithIgnoreCase
)
var textOperators = map[TextOperator]string{
TextOperatorEqual: " = ",
TextOperatorEqualIgnoreCase: " LIKE ",
TextOperatorNotEqual: " <> ",
TextOperatorNotEqualIgnoreCase: " NOT LIKE ",
TextOperatorStartsWith: " LIKE ",
TextOperatorStartsWithIgnoreCase: " LIKE ",
}
func writeTextOperation[T Text](builder *statementBuilder, col Column, op TextOperator, value T) {
switch op {
case TextOperatorEqual, TextOperatorNotEqual:
col.writeTo(builder)
builder.WriteString(textOperators[op])
builder.WriteString(builder.appendArg(value))
case TextOperatorEqualIgnoreCase, TextOperatorNotEqualIgnoreCase:
if ignoreCaseCol, ok := col.(ignoreCaseColumn); ok {
ignoreCaseCol.writeIgnoreCaseTo(builder)
} else {
builder.WriteString("LOWER(")
col.writeTo(builder)
builder.WriteString(")")
}
builder.WriteString(textOperators[op])
builder.WriteString("LOWER(")
builder.WriteString(builder.appendArg(value))
builder.WriteString(")")
case TextOperatorStartsWith:
col.writeTo(builder)
builder.WriteString(textOperators[op])
builder.WriteString(builder.appendArg(value))
builder.WriteString(" || '%'")
case TextOperatorStartsWithIgnoreCase:
if ignoreCaseCol, ok := col.(ignoreCaseColumn); ok {
ignoreCaseCol.writeIgnoreCaseTo(builder)
} else {
builder.WriteString("LOWER(")
col.writeTo(builder)
builder.WriteString(")")
}
builder.WriteString(textOperators[op])
builder.WriteString("LOWER(")
builder.WriteString(builder.appendArg(value))
builder.WriteString(")")
builder.WriteString(" || '%'")
default:
panic("unsupported text operation")
}
}
type Number interface {
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
}
type NumberOperator uint8
const (
// NumberOperatorEqual compares two numbers for equality.
NumberOperatorEqual NumberOperator = iota + 1
// NumberOperatorNotEqual compares two numbers for inequality.
NumberOperatorNotEqual
// NumberOperatorLessThan compares two numbers to check if the first is less than the second.
NumberOperatorLessThan
// NumberOperatorLessThanOrEqual compares two numbers to check if the first is less than or equal to the second.
NumberOperatorAtLeast
// NumberOperatorGreaterThan compares two numbers to check if the first is greater than the second.
NumberOperatorGreaterThan
// NumberOperatorGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second.
NumberOperatorAtMost
)
var numberOperators = map[NumberOperator]string{
NumberOperatorEqual: " = ",
NumberOperatorNotEqual: " <> ",
NumberOperatorLessThan: " < ",
NumberOperatorAtLeast: " <= ",
NumberOperatorGreaterThan: " > ",
NumberOperatorAtMost: " >= ",
}
func writeNumberOperation[T Number](builder *statementBuilder, col Column, op NumberOperator, value T) {
col.writeTo(builder)
builder.WriteString(numberOperators[op])
builder.WriteString(builder.appendArg(value))
}
type Boolean interface {
~bool
}
type BooleanOperator uint8
const (
BooleanOperatorIsTrue BooleanOperator = iota + 1
BooleanOperatorIsFalse
)
func writeBooleanOperation[T Boolean](builder *statementBuilder, col Column, value T) {
col.writeTo(builder)
builder.WriteString(" IS ")
builder.WriteString(builder.appendArg(value))
}

View File

@@ -0,0 +1,18 @@
package v4
type Org struct {
InstanceID string
ID string
Name string
Dates
}
type GetOrg struct{}
type ListOrgs struct{}
type CreateOrg struct{}
type UpdateOrg struct{}
type DeleteOrg struct{}

View File

@@ -0,0 +1,46 @@
package v4
import (
"strconv"
"strings"
)
type databaseInstruction string
const (
nowDBInstruction databaseInstruction = "NOW()"
nullDBInstruction databaseInstruction = "NULL"
)
type statementBuilder struct {
strings.Builder
args []any
existingArgs map[any]string
}
func (b *statementBuilder) writeArg(arg any) {
b.WriteString(b.appendArg(arg))
}
func (b *statementBuilder) appendArg(arg any) (placeholder string) {
if b.existingArgs == nil {
b.existingArgs = make(map[any]string)
}
if placeholder, ok := b.existingArgs[arg]; ok {
return placeholder
}
if instruction, ok := arg.(databaseInstruction); ok {
return string(instruction)
}
b.args = append(b.args, arg)
placeholder = "$" + strconv.Itoa(len(b.args))
b.existingArgs[arg] = placeholder
return placeholder
}
func (b *statementBuilder) appendArgs(args ...any) {
for _, arg := range args {
b.appendArg(arg)
}
}

View File

@@ -0,0 +1,239 @@
package v4
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Dates struct {
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt time.Time
}
type User struct {
InstanceID string
OrgID string
ID string
Username string
Traits userTrait
Dates
}
type UserType string
type userTrait interface {
userTrait()
Type() UserType
}
const userQuery = `SELECT u.instance_id, u.org_id, u.id, u.username, u.type, u.created_at, u.updated_at, u.deleted_at,` +
` h.first_name, h.last_name, h.email_address, h.email_verified_at, h.phone_number, h.phone_verified_at, m.description` +
` FROM users u` +
` LEFT JOIN user_humans h ON u.instance_id = h.instance_id AND u.org_id = h.org_id AND u.id = h.id` +
` LEFT JOIN user_machines m ON u.instance_id = m.instance_id AND u.org_id = m.org_id AND u.id = m.id`
type user struct {
builder statementBuilder
client database.QueryExecutor
condition Condition
}
func UserRepository(client database.QueryExecutor) *user {
return &user{
client: client,
}
}
func (u *user) WithCondition(condition Condition) *user {
u.condition = condition
return u
}
func (u *user) Get(ctx context.Context) (*User, error) {
u.builder.WriteString(userQuery)
u.writeCondition()
return scanUser(u.client.QueryRow(ctx, u.builder.String(), u.builder.args...))
}
func (u *user) List(ctx context.Context) (users []*User, err error) {
u.builder.WriteString(userQuery)
u.writeCondition()
rows, err := u.client.Query(ctx, u.builder.String(), u.builder.args...)
if err != nil {
return nil, err
}
defer func() {
closeErr := rows.Close()
if err != nil {
return
}
err = closeErr
}()
for rows.Next() {
user, err := scanUser(rows)
if err != nil {
return nil, err
}
users = append(users, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return users, nil
}
const (
createUserCte = `WITH user AS (` +
`INSERT INTO users (instance_id, org_id, id, username, type) VALUES ($1, $2, $3, $4, $5)` +
` RETURNING *)`
createHumanStmt = createUserCte + ` INSERT INTO user_humans h (instance_id, org_id, user_id, first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at)` +
` SELECT u.instance_id, u.org_id, u.id, $6, $7, $8, $9, $10, $11` +
` FROM user u` +
` RETURNING u.created_at, u.updated_at, u.deleted_at`
createMachineStmt = createUserCte + ` INSERT INTO user_machines (instance_id, org_id, user_id, description)` +
` SELECT u.instance_id, u.org_id, u.id, $6` +
` FROM user u` +
` RETURNING u.created_at, u.updated_at`
)
func (u *user) Create(ctx context.Context, user *User) error {
u.builder.appendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
switch trait := user.Traits.(type) {
case *Human:
u.builder.WriteString(createHumanStmt)
u.builder.appendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt)
case *Machine:
u.builder.WriteString(createMachineStmt)
u.builder.appendArgs(trait.Description)
}
return u.client.QueryRow(ctx, u.builder.String(), u.builder.args...).Scan(user.CreatedAt, user.UpdatedAt)
}
func (u *user) InstanceIDColumn() Column {
return column{name: "u.instance_id"}
}
func (u *user) InstanceIDCondition(instanceID string) Condition {
return newTextCondition(u.InstanceIDColumn(), TextOperatorEqual, instanceID)
}
func (u *user) OrgIDColumn() Column {
return column{name: "u.org_id"}
}
func (u *user) OrgIDCondition(orgID string) Condition {
return newTextCondition(u.OrgIDColumn(), TextOperatorEqual, orgID)
}
func (u *user) IDColumn() Column {
return column{name: "u.id"}
}
func (u *user) IDCondition(userID string) Condition {
return newTextCondition(u.IDColumn(), TextOperatorEqual, userID)
}
func (u *user) UsernameColumn() Column {
return ignoreCaseCol{
column: column{name: "u.username"},
suffix: "_lower",
}
}
func (u user) SetUsername(username string) Change {
return newChange(u.UsernameColumn(), username)
}
func (u *user) UsernameCondition(op TextOperator, username string) Condition {
return newTextCondition(u.UsernameColumn(), op, username)
}
func (u *user) CreatedAtColumn() Column {
return column{name: "u.created_at"}
}
func (u *user) CreatedAtCondition(op NumberOperator, createdAt time.Time) Condition {
return newNumberCondition(u.CreatedAtColumn(), op, createdAt)
}
func (u *user) UpdatedAtColumn() Column {
return column{name: "u.updated_at"}
}
func (u *user) UpdatedAtCondition(op NumberOperator, updatedAt time.Time) Condition {
return newNumberCondition(u.UpdatedAtColumn(), op, updatedAt)
}
func (u *user) DeletedAtColumn() Column {
return column{name: "u.deleted_at"}
}
func (u *user) DeletedCondition(isDeleted bool) Condition {
if isDeleted {
return IsNotNull(u.DeletedAtColumn())
}
return IsNull(u.DeletedAtColumn())
}
func (u *user) DeletedAtCondition(op NumberOperator, deletedAt time.Time) Condition {
return newNumberCondition(u.DeletedAtColumn(), op, deletedAt)
}
func (u *user) writeCondition() {
if u.condition == nil {
return
}
u.builder.WriteString(" WHERE ")
u.condition.writeTo(&u.builder)
}
func scanUser(scanner database.Scanner) (*User, error) {
var (
user User
human Human
email Email
phone Phone
machine Machine
typ UserType
)
err := scanner.Scan(
&user.InstanceID,
&user.OrgID,
&user.ID,
&user.Username,
&typ,
&user.Dates.CreatedAt,
&user.Dates.UpdatedAt,
&user.Dates.DeletedAt,
&human.FirstName,
&human.LastName,
&email.Address,
&email.VerifiedAt,
&phone.Number,
&phone.VerifiedAt,
&machine.Description,
)
if err != nil {
return nil, err
}
switch typ {
case UserTypeHuman:
if email.Address != "" {
human.Email = &email
}
if phone.Number != "" {
human.Phone = &phone
}
user.Traits = &human
case UserTypeMachine:
user.Traits = &machine
}
return &user, nil
}

View File

@@ -0,0 +1,187 @@
package v4
import (
"context"
"time"
)
type Human struct {
FirstName string
LastName string
Email *Email
Phone *Phone
}
const UserTypeHuman UserType = "human"
func (Human) userTrait() {}
func (h Human) Type() UserType {
return UserTypeHuman
}
var _ userTrait = (*Human)(nil)
type Email struct {
Address string
Verification
}
type Phone struct {
Number string
Verification
}
type Verification struct {
VerifiedAt time.Time
}
type userHuman struct {
*user
}
func (u *user) Human() *userHuman {
return &userHuman{user: u}
}
const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h`
func (u *userHuman) GetEmail(ctx context.Context) (*Email, error) {
var email Email
u.builder.WriteString(userEmailQuery)
u.writeCondition()
err := u.client.QueryRow(ctx, u.builder.String(), u.builder.args...).Scan(
&email.Address,
&email.Verification.VerifiedAt,
)
if err != nil {
return nil, err
}
return &email, nil
}
func (h userHuman) Update(ctx context.Context, changes ...Change) error {
h.builder.WriteString(`UPDATE human_users h SET `)
Changes(changes).writeTo(&h.builder)
h.writeCondition()
stmt := h.builder.String()
return h.client.Exec(ctx, stmt, h.builder.args...)
}
func (h userHuman) SetFirstName(firstName string) Change {
return newChange(h.FirstNameColumn(), firstName)
}
func (h userHuman) FirstNameColumn() Column {
return column{"h.first_name"}
}
func (h userHuman) FirstNameCondition(op TextOperator, firstName string) Condition {
return newTextCondition(h.FirstNameColumn(), op, firstName)
}
func (h userHuman) SetLastName(lastName string) Change {
return newChange(h.LastNameColumn(), lastName)
}
func (h userHuman) LastNameColumn() Column {
return column{"h.last_name"}
}
func (h userHuman) LastNameCondition(op TextOperator, lastName string) Condition {
return newTextCondition(h.LastNameColumn(), op, lastName)
}
func (h userHuman) EmailAddressColumn() Column {
return ignoreCaseCol{
column: column{"h.email_address"},
suffix: "_lower",
}
}
func (h userHuman) EmailAddressCondition(op TextOperator, email string) Condition {
return newTextCondition(h.EmailAddressColumn(), op, email)
}
func (h userHuman) EmailVerifiedAtColumn() Column {
return column{"h.email_verified_at"}
}
func (h *userHuman) EmailAddressVerifiedCondition(isVerified bool) Condition {
if isVerified {
return IsNotNull(h.EmailVerifiedAtColumn())
}
return IsNull(h.EmailVerifiedAtColumn())
}
func (h userHuman) EmailVerifiedAtCondition(op TextOperator, emailVerifiedAt string) Condition {
return newTextCondition(h.EmailVerifiedAtColumn(), op, emailVerifiedAt)
}
func (h userHuman) SetEmailAddress(address string) Change {
return newChange(h.EmailAddressColumn(), address)
}
// SetEmailVerified sets the verified column of the email
// if at is zero the statement uses the database timestamp
func (h userHuman) SetEmailVerified(at time.Time) Change {
if at.IsZero() {
return newChange(h.EmailVerifiedAtColumn(), nowDBInstruction)
}
return newChange(h.EmailVerifiedAtColumn(), at)
}
func (h userHuman) SetEmail(address string, verified *time.Time) Change {
return newChanges(
h.SetEmailAddress(address),
newUpdatePtrColumn(h.EmailVerifiedAtColumn(), verified),
)
}
func (h userHuman) PhoneNumberColumn() Column {
return column{"h.phone_number"}
}
func (h userHuman) SetPhoneNumber(number string) Change {
return newChange(h.PhoneNumberColumn(), number)
}
func (h userHuman) PhoneNumberCondition(op TextOperator, phoneNumber string) Condition {
return newTextCondition(h.PhoneNumberColumn(), op, phoneNumber)
}
func (h userHuman) PhoneVerifiedAtColumn() Column {
return column{"h.phone_verified_at"}
}
func (h userHuman) PhoneNumberVerifiedCondition(isVerified bool) Condition {
if isVerified {
return IsNotNull(h.PhoneVerifiedAtColumn())
}
return IsNull(h.PhoneVerifiedAtColumn())
}
// SetPhoneVerified sets the verified column of the phone
// if at is zero the statement uses the database timestamp
func (h userHuman) SetPhoneVerified(at time.Time) Change {
if at.IsZero() {
return newChange(h.PhoneVerifiedAtColumn(), nowDBInstruction)
}
return newChange(h.PhoneVerifiedAtColumn(), at)
}
func (h userHuman) PhoneVerifiedAtCondition(op TextOperator, phoneVerifiedAt string) Condition {
return newTextCondition(h.PhoneVerifiedAtColumn(), op, phoneVerifiedAt)
}
func (h userHuman) SetPhone(number string, verifiedAt *time.Time) Change {
return newChanges(
h.SetPhoneNumber(number),
newUpdatePtrColumn(h.PhoneVerifiedAtColumn(), verifiedAt),
)
}

View File

@@ -0,0 +1,41 @@
package v4
import "context"
type Machine struct {
Description string
}
func (Machine) userTrait() {}
func (m Machine) Type() UserType {
return UserTypeMachine
}
const UserTypeMachine UserType = "machine"
var _ userTrait = (*Machine)(nil)
type userMachine struct {
*user
}
func (u *user) Machine() *userMachine {
return &userMachine{user: u}
}
func (m userMachine) Update(ctx context.Context, cols ...Change) (*Machine, error) {
return nil, nil
}
func (userMachine) DescriptionColumn() Column {
return column{"m.description"}
}
func (m userMachine) SetDescription(description string) Change {
return newChange(m.DescriptionColumn(), description)
}
func (m userMachine) DescriptionCondition(op TextOperator, description string) Condition {
return newTextCondition(m.DescriptionColumn(), op, description)
}

View File

@@ -0,0 +1,65 @@
package v4_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
)
func TestQueryUser(t *testing.T) {
t.Run("User filters", func(t *testing.T) {
user := v4.UserRepository(nil)
user.WithCondition(
v4.And(
v4.Or(
user.IDCondition("test"),
user.IDCondition("2"),
),
user.UsernameCondition(v4.TextOperatorStartsWithIgnoreCase, "test"),
),
).Get(context.Background())
})
t.Run("machine and human filters", func(t *testing.T) {
user := v4.UserRepository(nil)
machine := user.Machine()
human := user.Human()
user.WithCondition(
v4.And(
user.UsernameCondition(v4.TextOperatorStartsWithIgnoreCase, "test"),
v4.Or(
machine.DescriptionCondition(v4.TextOperatorStartsWithIgnoreCase, "test"),
human.EmailAddressVerifiedCondition(true),
v4.IsNotNull(machine.DescriptionColumn()),
),
),
)
human.GetEmail(context.Background())
})
}
type dbInstruction string
func TestArg(t *testing.T) {
var bla any = "asdf"
instr, ok := bla.(dbInstruction)
assert.False(t, ok)
assert.Empty(t, instr)
bla = dbInstruction("asdf")
instr, ok = bla.(dbInstruction)
assert.True(t, ok)
assert.Equal(t, instr, dbInstruction("asdf"))
}
func TestWriteUser(t *testing.T) {
t.Run("update user", func(t *testing.T) {
user := v4.UserRepository(nil)
user.WithCondition(user.IDCondition("test")).Human().Update(
context.Background(),
user.SetUsername("test"),
)
})
}

View File

@@ -0,0 +1,39 @@
package repository
import (
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type user struct {
database.QueryExecutor
}
func User(client database.QueryExecutor) domain.UserRepository {
// return &user{QueryExecutor: client}
return nil
}
// On implements [domain.UserRepository].
func (exec *user) On(clauses ...domain.UserClause) domain.UserOperation {
return &userOperation{
QueryExecutor: exec.QueryExecutor,
clauses: clauses,
}
}
// OnHuman implements [domain.UserRepository].
func (exec *user) OnHuman(clauses ...domain.UserClause) domain.HumanOperation {
return &humanOperation{
userOperation: *exec.On(clauses...).(*userOperation),
}
}
// OnMachine implements [domain.UserRepository].
func (exec *user) OnMachine(clauses ...domain.UserClause) domain.MachineOperation {
return &machineOperation{
userOperation: *exec.On(clauses...).(*userOperation),
}
}
// var _ domain.UserRepository = (*user)(nil)

View File

@@ -0,0 +1,36 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
)
type humanOperation struct {
userOperation
}
// GetEmail implements domain.HumanOperation.
func (h *humanOperation) GetEmail(ctx context.Context) (*domain.Email, error) {
var email domain.Email
err := h.QueryExecutor.QueryRow(ctx, `SELECT email, is_email_verified FROM human_users WHERE id = $1`, h.clauses).Scan(
&email.Address,
&email.IsVerified,
)
if err != nil {
return nil, err
}
return &email, nil
}
// SetEmail implements domain.HumanOperation.
func (h *humanOperation) SetEmail(ctx context.Context, email string) error {
return h.QueryExecutor.Exec(ctx, `UPDATE human_users SET email = $1 WHERE id = $2`, email, h.clauses)
}
// SetEmailVerified implements domain.HumanOperation.
func (h *humanOperation) SetEmailVerified(ctx context.Context, email string) error {
return h.QueryExecutor.Exec(ctx, `UPDATE human_users SET is_email_verified = $1 WHERE id = $2 AND email = $3`, true, h.clauses, email)
}
var _ domain.HumanOperation = (*humanOperation)(nil)

View File

@@ -0,0 +1,18 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
)
type machineOperation struct {
userOperation
}
// SetDescription implements domain.MachineOperation.
func (m *machineOperation) SetDescription(ctx context.Context, description string) error {
return m.QueryExecutor.Exec(ctx, `UPDATE machines SET description = $1 WHERE id = $2`, description, m.clauses)
}
var _ domain.MachineOperation = (*machineOperation)(nil)

View File

@@ -0,0 +1,68 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type userOperation struct {
database.QueryExecutor
clauses []domain.UserClause
}
// Delete implements [domain.UserOperation].
func (u *userOperation) Delete(ctx context.Context) error {
return u.QueryExecutor.Exec(ctx, `DELETE FROM users WHERE id = $1`, u.clauses)
}
// SetUsername implements [domain.UserOperation].
func (u *userOperation) SetUsername(ctx context.Context, username string) error {
var stmt statement
stmt.builder.WriteString(`UPDATE users SET username = $1 WHERE `)
stmt.appendArg(username)
clausesToSQL(&stmt, u.clauses)
return u.QueryExecutor.Exec(ctx, stmt.builder.String(), stmt.args...)
}
var _ domain.UserOperation = (*userOperation)(nil)
func UserIDQuery(id string) domain.UserClause {
return textClause[string]{
clause: clause[domain.TextOperation]{
field: userFields[domain.UserFieldID],
op: domain.TextOperationEqual,
},
value: id,
}
}
func HumanEmailQuery(op domain.TextOperation, email string) domain.UserClause {
return textClause[string]{
clause: clause[domain.TextOperation]{
field: userFields[domain.UserHumanFieldEmail],
op: op,
},
value: email,
}
}
func HumanEmailVerifiedQuery(op domain.BoolOperation) domain.UserClause {
return boolClause[domain.BoolOperation]{
clause: clause[domain.BoolOperation]{
field: userFields[domain.UserHumanFieldEmailVerified],
op: op,
},
}
}
func clausesToSQL(stmt *statement, clauses []domain.UserClause) {
for _, clause := range clauses {
stmt.builder.WriteString(userFields[clause.Field()].String())
stmt.builder.WriteString(clause.Operation().String())
stmt.appendArg(clause.Args()...)
}
}

View File

@@ -0,0 +1,36 @@
package database
import "context"
type Transaction interface {
Commit(ctx context.Context) error
Rollback(ctx context.Context) error
End(ctx context.Context, err error) error
Begin(ctx context.Context) (Transaction, error)
QueryExecutor
}
type Beginner interface {
Begin(ctx context.Context, opts *TransactionOptions) (Transaction, error)
}
type TransactionOptions struct {
IsolationLevel IsolationLevel
AccessMode AccessMode
}
type IsolationLevel uint8
const (
IsolationLevelSerializable IsolationLevel = iota
IsolationLevelReadCommitted
)
type AccessMode uint8
const (
AccessModeReadWrite AccessMode = iota
AccessModeReadOnly
)