mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 07:57:32 +00:00
chore: move the go code into a subfolder
This commit is contained in:
228
apps/api/internal/database/cockroach/crdb.go
Normal file
228
apps/api/internal/database/cockroach/crdb.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package cockroach
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
)
|
||||
|
||||
func init() {
|
||||
config := new(Config)
|
||||
dialect.Register(config, config, false)
|
||||
}
|
||||
|
||||
const (
|
||||
sslDisabledMode = "disable"
|
||||
sslRequireMode = "require"
|
||||
sslAllowMode = "allow"
|
||||
sslPreferMode = "prefer"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port uint16
|
||||
Database string
|
||||
MaxOpenConns uint32
|
||||
MaxIdleConns uint32
|
||||
MaxConnLifetime time.Duration
|
||||
MaxConnIdleTime time.Duration
|
||||
User User
|
||||
Admin AdminUser
|
||||
// Additional options to be appended as options=<Options>
|
||||
// The value will be taken as is. Multiple options are space separated.
|
||||
Options string
|
||||
}
|
||||
|
||||
func (c *Config) MatchName(name string) bool {
|
||||
for _, key := range []string{"crdb", "cockroach"} {
|
||||
if strings.TrimSpace(strings.ToLower(name)) == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (_ *Config) Decode(configs []any) (dialect.Connector, error) {
|
||||
connector := new(Config)
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
WeaklyTypedInput: true,
|
||||
Result: connector,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
if err = decoder.Decode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
func (c *Config) Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error) {
|
||||
connConfig := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns)
|
||||
|
||||
config, err := pgxpool.ParseConfig(c.String(useAdmin))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if len(connConfig.AfterConnect) > 0 {
|
||||
config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
|
||||
for _, f := range connConfig.AfterConnect {
|
||||
if err := f(ctx, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(connConfig.BeforeAcquire) > 0 {
|
||||
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
|
||||
for _, f := range connConfig.BeforeAcquire {
|
||||
if err := f(ctx, conn); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
if len(connConfig.AfterRelease) > 0 {
|
||||
config.AfterRelease = func(conn *pgx.Conn) bool {
|
||||
for _, f := range connConfig.AfterRelease {
|
||||
if err := f(conn); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if connConfig.MaxOpenConns != 0 {
|
||||
config.MaxConns = int32(connConfig.MaxOpenConns)
|
||||
}
|
||||
|
||||
config.MaxConnLifetime = c.MaxConnLifetime
|
||||
config.MaxConnIdleTime = c.MaxConnIdleTime
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := pool.Ping(context.Background()); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return stdlib.OpenDBFromPool(pool), pool, nil
|
||||
}
|
||||
|
||||
func (c *Config) DatabaseName() string {
|
||||
return c.Database
|
||||
}
|
||||
|
||||
func (c *Config) Username() string {
|
||||
return c.User.Username
|
||||
}
|
||||
|
||||
func (c *Config) Password() string {
|
||||
return c.User.Password
|
||||
}
|
||||
|
||||
func (c *Config) Type() dialect.DatabaseType {
|
||||
return dialect.DatabaseTypeCockroach
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
SSL SSL
|
||||
}
|
||||
|
||||
type AdminUser struct {
|
||||
// ExistingDatabase is the database to connect to before the ZITADEL database exists
|
||||
ExistingDatabase string
|
||||
User `mapstructure:",squash"`
|
||||
}
|
||||
|
||||
type SSL struct {
|
||||
// type of connection security
|
||||
Mode string
|
||||
// RootCert Path to the CA certificate
|
||||
RootCert string
|
||||
// Cert Path to the client certificate
|
||||
Cert string
|
||||
// Key Path to the client private key
|
||||
Key string
|
||||
}
|
||||
|
||||
func (c *Config) checkSSL(user User) {
|
||||
if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" {
|
||||
user.SSL = SSL{Mode: sslDisabledMode}
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.Mode == sslRequireMode || user.SSL.Mode == sslAllowMode || user.SSL.Mode == sslPreferMode {
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.RootCert == "" {
|
||||
logging.WithFields(
|
||||
"cert set", user.SSL.Cert != "",
|
||||
"key set", user.SSL.Key != "",
|
||||
"rootCert set", user.SSL.RootCert != "",
|
||||
).Fatal("at least ssl root cert has to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func (c Config) String(useAdmin bool) string {
|
||||
user := c.User
|
||||
if useAdmin {
|
||||
user = c.Admin.User
|
||||
}
|
||||
c.checkSSL(user)
|
||||
fields := []string{
|
||||
"host=" + c.Host,
|
||||
"port=" + strconv.Itoa(int(c.Port)),
|
||||
"user=" + user.Username,
|
||||
"dbname=" + c.Database,
|
||||
"application_name=" + dialect.DefaultAppName,
|
||||
"sslmode=" + user.SSL.Mode,
|
||||
}
|
||||
if c.Options != "" {
|
||||
fields = append(fields, "options="+c.Options)
|
||||
}
|
||||
if !useAdmin {
|
||||
fields = append(fields, "dbname="+c.Database)
|
||||
} else if c.Admin.ExistingDatabase != "" {
|
||||
fields = append(fields, "dbname="+c.Admin.ExistingDatabase)
|
||||
}
|
||||
if user.Password != "" {
|
||||
fields = append(fields, "password="+user.Password)
|
||||
}
|
||||
if user.SSL.Mode != sslDisabledMode {
|
||||
fields = append(fields, "sslrootcert="+user.SSL.RootCert)
|
||||
if user.SSL.Cert != "" {
|
||||
fields = append(fields, "sslcert="+user.SSL.Cert)
|
||||
}
|
||||
if user.SSL.Key != "" {
|
||||
fields = append(fields, "sslkey="+user.SSL.Key)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(fields, " ")
|
||||
}
|
212
apps/api/internal/database/database.go
Normal file
212
apps/api/internal/database/database.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
_ "github.com/zitadel/zitadel/internal/database/cockroach"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
_ "github.com/zitadel/zitadel/internal/database/postgres"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type ContextQuerier interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type ContextExecuter interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
type ContextQueryExecuter interface {
|
||||
ContextQuerier
|
||||
ContextExecuter
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
ContextQueryExecuter
|
||||
Beginner
|
||||
Conn(ctx context.Context) (*sql.Conn, error)
|
||||
}
|
||||
|
||||
type Beginner interface {
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
type Tx interface {
|
||||
ContextQueryExecuter
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
var (
|
||||
_ Client = (*sql.DB)(nil)
|
||||
_ Tx = (*sql.Tx)(nil)
|
||||
)
|
||||
|
||||
func CloseTransaction(tx Tx, err error) error {
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback()
|
||||
logging.OnError(rollbackErr).Error("failed to rollback transaction")
|
||||
return err
|
||||
}
|
||||
|
||||
commitErr := tx.Commit()
|
||||
logging.OnError(commitErr).Error("failed to commit transaction")
|
||||
return commitErr
|
||||
}
|
||||
|
||||
const (
|
||||
PgUniqueConstraintErrorCode = "23505"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Dialects map[string]interface{} `mapstructure:",remain"`
|
||||
connector dialect.Connector
|
||||
}
|
||||
|
||||
func (c *Config) SetConnector(connector dialect.Connector) {
|
||||
c.connector = connector
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
dialect.Database
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (db *DB) Query(scan func(*sql.Rows) error, query string, args ...any) error {
|
||||
return db.QueryContext(context.Background(), scan, query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) (err error) {
|
||||
rows, err := db.DB.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
closeErr := rows.Close()
|
||||
logging.OnError(closeErr).Info("rows.Close failed")
|
||||
}()
|
||||
|
||||
if err = scan(rows); err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (db *DB) QueryRow(scan func(*sql.Row) error, query string, args ...any) (err error) {
|
||||
return db.QueryRowContext(context.Background(), scan, query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) QueryRowContext(ctx context.Context, scan func(row *sql.Row) error, query string, args ...any) (err error) {
|
||||
row := db.DB.QueryRowContext(ctx, query, args...)
|
||||
logging.OnError(row.Err()).Error("unexpected query error")
|
||||
|
||||
err = scan(row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return row.Err()
|
||||
}
|
||||
|
||||
func QueryJSONObject[T any](ctx context.Context, db *DB, query string, args ...any) (*T, error) {
|
||||
var data []byte
|
||||
err := db.QueryRowContext(ctx, func(row *sql.Row) error {
|
||||
return row.Scan(&data)
|
||||
}, query, args...)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DATAB-Oath6", "Errors.Internal")
|
||||
}
|
||||
obj := new(T)
|
||||
if err = json.Unmarshal(data, obj); err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "DATAB-Vohs6", "Errors.Internal")
|
||||
}
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func Connect(config Config, useAdmin bool) (*DB, error) {
|
||||
client, pool, err := config.connector.Connect(useAdmin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := client.Ping(); err != nil {
|
||||
return nil, zerrors.ThrowPreconditionFailed(err, "DATAB-0pIWD", "Errors.Database.Connection.Failed")
|
||||
}
|
||||
|
||||
return &DB{
|
||||
DB: client,
|
||||
Database: config.connector,
|
||||
Pool: pool,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DecodeHook(allowCockroach bool) func(from, to reflect.Value) (_ interface{}, err error) {
|
||||
return func(from, to reflect.Value) (_ interface{}, err error) {
|
||||
if to.Type() != reflect.TypeOf(Config{}) {
|
||||
return from.Interface(), nil
|
||||
}
|
||||
|
||||
config := new(Config)
|
||||
if err = mapstructure.Decode(from.Interface(), config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configuredDialect := dialect.SelectByConfig(config.Dialects)
|
||||
configs := make([]any, 0, len(config.Dialects))
|
||||
|
||||
for name, dialectConfig := range config.Dialects {
|
||||
if !configuredDialect.Matcher.MatchName(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
configs = append(configs, dialectConfig)
|
||||
}
|
||||
|
||||
if !allowCockroach && configuredDialect.Matcher.Type() == dialect.DatabaseTypeCockroach {
|
||||
logging.Info("Cockroach support was removed with Zitadel v3, please refer to https://zitadel.com/docs/self-hosting/manage/cli/mirror to migrate your data to postgres")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "DATAB-0pIWD", "Cockroach support was removed with Zitadel v3")
|
||||
}
|
||||
|
||||
config.connector, err = configuredDialect.Matcher.Decode(configs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c Config) DatabaseName() string {
|
||||
return c.connector.DatabaseName()
|
||||
}
|
||||
|
||||
func (c Config) Username() string {
|
||||
return c.connector.Username()
|
||||
}
|
||||
|
||||
func (c Config) Password() string {
|
||||
return c.connector.Password()
|
||||
}
|
||||
|
||||
func (c Config) Type() dialect.DatabaseType {
|
||||
return c.connector.Type()
|
||||
}
|
||||
|
||||
func EscapeLikeWildcards(value string) string {
|
||||
value = strings.ReplaceAll(value, "%", "\\%")
|
||||
value = strings.ReplaceAll(value, "_", "\\_")
|
||||
return value
|
||||
}
|
87
apps/api/internal/database/database_test.go
Normal file
87
apps/api/internal/database/database_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database/mock"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestQueryJSONObject(t *testing.T) {
|
||||
type dst struct {
|
||||
A int `json:"a,omitempty"`
|
||||
}
|
||||
const (
|
||||
query = `select $1;`
|
||||
arg = 1
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mock func(*testing.T) *mock.SQLMock
|
||||
want *dst
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "tx error",
|
||||
mock: func(t *testing.T) *mock.SQLMock {
|
||||
return mock.NewSQLMock(t, mock.ExpectQuery("select $1;", mock.WithQueryErr(sql.ErrConnDone)))
|
||||
},
|
||||
wantErr: zerrors.ThrowInternal(sql.ErrConnDone, "DATAB-Oath6", "Errors.Internal"),
|
||||
},
|
||||
{
|
||||
name: "no rows",
|
||||
mock: func(t *testing.T) *mock.SQLMock {
|
||||
return mock.NewSQLMock(t,
|
||||
mock.ExpectQuery(query,
|
||||
mock.WithQueryArgs(arg),
|
||||
mock.WithQueryResult([]string{"json"}, [][]driver.Value{}),
|
||||
),
|
||||
)
|
||||
},
|
||||
wantErr: sql.ErrNoRows,
|
||||
},
|
||||
{
|
||||
name: "unmarshal error",
|
||||
mock: func(t *testing.T) *mock.SQLMock {
|
||||
return mock.NewSQLMock(t,
|
||||
mock.ExpectQuery(query,
|
||||
mock.WithQueryArgs(arg),
|
||||
mock.WithQueryResult([]string{"json"}, [][]driver.Value{{`~~~`}}),
|
||||
),
|
||||
)
|
||||
},
|
||||
wantErr: zerrors.ThrowInternal(nil, "DATAB-Vohs6", "Errors.Internal"),
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
mock: func(t *testing.T) *mock.SQLMock {
|
||||
return mock.NewSQLMock(t,
|
||||
mock.ExpectQuery(query,
|
||||
mock.WithQueryArgs(arg),
|
||||
mock.WithQueryResult([]string{"json"}, [][]driver.Value{{`{"a":1}`}}),
|
||||
),
|
||||
)
|
||||
},
|
||||
want: &dst{A: 1},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mock := tt.mock(t)
|
||||
defer mock.Assert(t)
|
||||
db := &DB{
|
||||
DB: mock.DB,
|
||||
}
|
||||
got, err := QueryJSONObject[dst](context.Background(), db, query, arg)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
75
apps/api/internal/database/dialect/config.go
Normal file
75
apps/api/internal/database/dialect/config.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sync"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type Dialect struct {
|
||||
Matcher Matcher
|
||||
Config Connector
|
||||
IsDefault bool
|
||||
}
|
||||
|
||||
var (
|
||||
dialects []*Dialect
|
||||
defaultDialect *Dialect
|
||||
dialectsMu sync.Mutex
|
||||
)
|
||||
|
||||
type Matcher interface {
|
||||
MatchName(string) bool
|
||||
Decode([]any) (Connector, error)
|
||||
Type() DatabaseType
|
||||
}
|
||||
|
||||
type DatabaseType uint8
|
||||
|
||||
const (
|
||||
DatabaseTypePostgres DatabaseType = iota
|
||||
DatabaseTypeCockroach
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultAppName = "zitadel"
|
||||
)
|
||||
|
||||
type Connector interface {
|
||||
Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error)
|
||||
Password() string
|
||||
Database
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
DatabaseName() string
|
||||
Username() string
|
||||
Type() DatabaseType
|
||||
}
|
||||
|
||||
func Register(matcher Matcher, config Connector, isDefault bool) {
|
||||
dialectsMu.Lock()
|
||||
defer dialectsMu.Unlock()
|
||||
|
||||
d := &Dialect{Matcher: matcher, Config: config}
|
||||
|
||||
if isDefault {
|
||||
defaultDialect = d
|
||||
return
|
||||
}
|
||||
|
||||
dialects = append(dialects, d)
|
||||
}
|
||||
|
||||
func SelectByConfig(config map[string]interface{}) *Dialect {
|
||||
for key := range config {
|
||||
for _, d := range dialects {
|
||||
if d.Matcher.MatchName(key) {
|
||||
return d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return defaultDialect
|
||||
}
|
87
apps/api/internal/database/dialect/connections.go
Normal file
87
apps/api/internal/database/dialect/connections.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
pgxdecimal "github.com/jackc/pgx-shopspring-decimal"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIllegalMaxOpenConns = errors.New("MaxOpenConns of the database must be higher than 3 or 0 for unlimited")
|
||||
ErrIllegalMaxIdleConns = errors.New("MaxIdleConns of the database must be higher than 3 or 0 for unlimited")
|
||||
)
|
||||
|
||||
// ConnectionConfig defines the Max Open and Idle connections for a DB connection pool.
|
||||
type ConnectionConfig struct {
|
||||
MaxOpenConns,
|
||||
MaxIdleConns uint32
|
||||
AfterConnect []func(ctx context.Context, c *pgx.Conn) error
|
||||
BeforeAcquire []func(ctx context.Context, c *pgx.Conn) error
|
||||
AfterRelease []func(c *pgx.Conn) error
|
||||
}
|
||||
|
||||
var afterConnectFuncs = []func(ctx context.Context, c *pgx.Conn) error{
|
||||
func(ctx context.Context, c *pgx.Conn) error {
|
||||
pgxdecimal.Register(c.TypeMap())
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func RegisterAfterConnect(f func(ctx context.Context, c *pgx.Conn) error) {
|
||||
afterConnectFuncs = append(afterConnectFuncs, f)
|
||||
}
|
||||
|
||||
var beforeAcquireFuncs []func(ctx context.Context, c *pgx.Conn) error
|
||||
|
||||
func RegisterBeforeAcquire(f func(ctx context.Context, c *pgx.Conn) error) {
|
||||
beforeAcquireFuncs = append(beforeAcquireFuncs, f)
|
||||
}
|
||||
|
||||
var afterReleaseFuncs []func(c *pgx.Conn) error
|
||||
|
||||
func RegisterAfterRelease(f func(c *pgx.Conn) error) {
|
||||
afterReleaseFuncs = append(afterReleaseFuncs, f)
|
||||
}
|
||||
|
||||
func RegisterDefaultPgTypeVariants[T any](m *pgtype.Map, name, arrayName string) {
|
||||
// T
|
||||
var value T
|
||||
m.RegisterDefaultPgType(value, name)
|
||||
|
||||
// *T
|
||||
valueType := reflect.TypeOf(value)
|
||||
m.RegisterDefaultPgType(reflect.New(valueType).Interface(), name)
|
||||
|
||||
// []T
|
||||
sliceType := reflect.SliceOf(valueType)
|
||||
m.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName)
|
||||
|
||||
// *[]T
|
||||
m.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName)
|
||||
|
||||
// []*T
|
||||
sliceOfPointerType := reflect.SliceOf(reflect.TypeOf(reflect.New(valueType).Interface()))
|
||||
m.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName)
|
||||
|
||||
// *[]*T
|
||||
m.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName)
|
||||
}
|
||||
|
||||
// NewConnectionConfig calculates [ConnectionConfig] values from the passed ratios
|
||||
// and returns the config applicable for the requested purpose.
|
||||
//
|
||||
// openConns and idleConns must be at least 3 or 0, which means no limit.
|
||||
// The pusherRatio and spoolerRatio must be between 0 and 1.
|
||||
func NewConnectionConfig(openConns, idleConns uint32) *ConnectionConfig {
|
||||
return &ConnectionConfig{
|
||||
MaxOpenConns: openConns,
|
||||
MaxIdleConns: idleConns,
|
||||
AfterConnect: afterConnectFuncs,
|
||||
BeforeAcquire: beforeAcquireFuncs,
|
||||
AfterRelease: afterReleaseFuncs,
|
||||
}
|
||||
}
|
5
apps/api/internal/database/migrate/migrate.go
Normal file
5
apps/api/internal/database/migrate/migrate.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package migrate
|
||||
|
||||
import "database/sql"
|
||||
|
||||
func New(*sql.DB) {}
|
139
apps/api/internal/database/mock/sql_mock.go
Normal file
139
apps/api/internal/database/mock/sql_mock.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
type SQLMock struct {
|
||||
DB *sql.DB
|
||||
mock sqlmock.Sqlmock
|
||||
}
|
||||
|
||||
type Expectation func(m sqlmock.Sqlmock)
|
||||
|
||||
func NewSQLMock(t *testing.T, expectations ...Expectation) *SQLMock {
|
||||
db, mock, err := sqlmock.New(
|
||||
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual),
|
||||
sqlmock.ValueConverterOption(new(TypeConverter)),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal("create mock failed", err)
|
||||
}
|
||||
|
||||
for _, expectation := range expectations {
|
||||
expectation(mock)
|
||||
}
|
||||
|
||||
return &SQLMock{
|
||||
DB: db,
|
||||
mock: mock,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SQLMock) Assert(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
if err := m.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations not met: %v", err)
|
||||
}
|
||||
|
||||
m.DB.Close()
|
||||
}
|
||||
|
||||
func ExpectBegin(err error) Expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectBegin()
|
||||
if err != nil {
|
||||
e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ExpectCommit(err error) Expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectCommit()
|
||||
if err != nil {
|
||||
e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ExecOpt func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec
|
||||
|
||||
func WithExecArgs(args ...driver.Value) ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WithArgs(args...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithExecErr(err error) ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func WithExecNoRowsAffected() ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WillReturnResult(driver.ResultNoRows)
|
||||
}
|
||||
}
|
||||
|
||||
func WithExecRowsAffected(affected driver.RowsAffected) ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WillReturnResult(affected)
|
||||
}
|
||||
}
|
||||
|
||||
func ExcpectExec(stmt string, opts ...ExecOpt) Expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectExec(stmt)
|
||||
for _, opt := range opts {
|
||||
e = opt(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type QueryOpt func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery
|
||||
|
||||
func WithQueryArgs(args ...driver.Value) QueryOpt {
|
||||
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
return e.WithArgs(args...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithQueryErr(err error) QueryOpt {
|
||||
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
return e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func WithQueryResult(columns []string, rows [][]driver.Value) QueryOpt {
|
||||
return func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
mockedRows := m.NewRows(columns)
|
||||
for _, row := range rows {
|
||||
mockedRows = mockedRows.AddRow(row...)
|
||||
}
|
||||
return e.WillReturnRows(mockedRows)
|
||||
}
|
||||
}
|
||||
|
||||
func ExpectQuery(stmt string, opts ...QueryOpt) Expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectQuery(stmt)
|
||||
for _, opt := range opts {
|
||||
e = opt(m, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type AnyType[T interface{}] struct{}
|
||||
|
||||
// Match satisfies sqlmock.Argument interface
|
||||
func (a AnyType[T]) Match(v driver.Value) bool {
|
||||
return reflect.TypeOf(new(T)).Elem().Kind().String() == reflect.TypeOf(v).Kind().String()
|
||||
}
|
87
apps/api/internal/database/mock/type_converter.go
Normal file
87
apps/api/internal/database/mock/type_converter.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var _ driver.ValueConverter = (*TypeConverter)(nil)
|
||||
|
||||
type TypeConverter struct{}
|
||||
|
||||
// ConvertValue converts a value to a driver Value.
|
||||
func (s TypeConverter) ConvertValue(v any) (driver.Value, error) {
|
||||
if driver.IsValue(v) {
|
||||
return v, nil
|
||||
}
|
||||
value := reflect.ValueOf(v)
|
||||
|
||||
if rawMessage, ok := v.(json.RawMessage); ok {
|
||||
return convertBytes(rawMessage), nil
|
||||
}
|
||||
|
||||
if value.Kind() == reflect.Slice {
|
||||
//nolint: exhaustive
|
||||
// only defined types
|
||||
switch value.Type().Elem().Kind() {
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
return convertSigned(value), nil
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
return convertUnsigned(value), nil
|
||||
case reflect.String:
|
||||
return convertText(value), nil
|
||||
}
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// converts a text array to valid pgx v5 representation
|
||||
func convertSigned(array reflect.Value) string {
|
||||
slice := make([]string, array.Len())
|
||||
for i := 0; i < array.Len(); i++ {
|
||||
slice[i] = strconv.FormatInt(array.Index(i).Int(), 10)
|
||||
}
|
||||
|
||||
return "{" + strings.Join(slice, ",") + "}"
|
||||
}
|
||||
|
||||
// converts a text array to valid pgx v5 representation
|
||||
func convertUnsigned(array reflect.Value) string {
|
||||
slice := make([]string, array.Len())
|
||||
for i := 0; i < array.Len(); i++ {
|
||||
slice[i] = strconv.FormatUint(array.Index(i).Uint(), 10)
|
||||
}
|
||||
|
||||
return "{" + strings.Join(slice, ",") + "}"
|
||||
}
|
||||
|
||||
// converts a text array to valid pgx v5 representation
|
||||
func convertText(array reflect.Value) string {
|
||||
slice := make([]string, array.Len())
|
||||
for i := 0; i < array.Len(); i++ {
|
||||
slice[i] = array.Index(i).String()
|
||||
}
|
||||
|
||||
return "{" + strings.Join(slice, ",") + "}"
|
||||
}
|
||||
|
||||
func convertBytes(array []byte) string {
|
||||
var builder strings.Builder
|
||||
builder.Grow(hex.EncodedLen(len(array)) + 4)
|
||||
builder.WriteString(`\x`)
|
||||
builder.Write(AppendEncode(nil, array))
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// TODO: remove function after we compile using go 1.22 and use function of hex package `hex.AppendEncode`
|
||||
func AppendEncode(dst, src []byte) []byte {
|
||||
n := hex.EncodedLen(len(src))
|
||||
dst = slices.Grow(dst, n)
|
||||
hex.Encode(dst[len(dst):][:n], src)
|
||||
return dst[:len(dst)+n]
|
||||
}
|
38
apps/api/internal/database/postgres/embedded.go
Normal file
38
apps/api/internal/database/postgres/embedded.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
func StartEmbedded() (embeddedpostgres.Config, func()) {
|
||||
path, err := os.MkdirTemp("", "zitadel-embedded-postgres-*")
|
||||
logging.OnError(err).Fatal("unable to create temp dir")
|
||||
|
||||
port, close := getPort()
|
||||
|
||||
config := embeddedpostgres.DefaultConfig().Version(embeddedpostgres.V16).Port(uint32(port)).RuntimePath(path)
|
||||
embedded := embeddedpostgres.NewDatabase(config)
|
||||
|
||||
close()
|
||||
err = embedded.Start()
|
||||
logging.OnError(err).Fatal("unable to start db")
|
||||
|
||||
return config, func() {
|
||||
logging.OnError(embedded.Stop()).Error("unable to stop db")
|
||||
}
|
||||
}
|
||||
|
||||
// getPort returns a free port and locks it until close is called
|
||||
func getPort() (port uint16, close func()) {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
logging.OnError(err).Fatal("unable to get port")
|
||||
port = uint16(l.Addr().(*net.TCPAddr).Port)
|
||||
logging.WithFields("port", port).Info("Port is available")
|
||||
return port, func() {
|
||||
logging.OnError(l.Close()).Error("unable to close port listener")
|
||||
}
|
||||
}
|
237
apps/api/internal/database/postgres/pg.go
Normal file
237
apps/api/internal/database/postgres/pg.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
)
|
||||
|
||||
func init() {
|
||||
config := new(Config)
|
||||
dialect.Register(config, config, true)
|
||||
}
|
||||
|
||||
const (
|
||||
sslDisabledMode = "disable"
|
||||
sslRequireMode = "require"
|
||||
sslAllowMode = "allow"
|
||||
sslPreferMode = "prefer"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port int32
|
||||
Database string
|
||||
MaxOpenConns uint32
|
||||
MaxIdleConns uint32
|
||||
MaxConnLifetime time.Duration
|
||||
MaxConnIdleTime time.Duration
|
||||
User User
|
||||
Admin AdminUser
|
||||
// Additional options to be appended as options=<Options>
|
||||
// The value will be taken as is. Multiple options are space separated.
|
||||
Options string
|
||||
}
|
||||
|
||||
func (c *Config) MatchName(name string) bool {
|
||||
for _, key := range []string{"pg", "postgres"} {
|
||||
if strings.TrimSpace(strings.ToLower(name)) == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (_ *Config) Decode(configs []interface{}) (dialect.Connector, error) {
|
||||
connector := new(Config)
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
WeaklyTypedInput: true,
|
||||
Result: connector,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
if err = decoder.Decode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return connector, nil
|
||||
}
|
||||
|
||||
func (c *Config) Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error) {
|
||||
connConfig := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns)
|
||||
|
||||
config, err := pgxpool.ParseConfig(c.String(useAdmin))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if len(connConfig.AfterConnect) > 0 {
|
||||
config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
|
||||
for _, f := range connConfig.AfterConnect {
|
||||
if err := f(ctx, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(connConfig.BeforeAcquire) > 0 {
|
||||
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
|
||||
for _, f := range connConfig.BeforeAcquire {
|
||||
if err := f(ctx, conn); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if len(connConfig.AfterRelease) > 0 {
|
||||
config.AfterRelease = func(conn *pgx.Conn) bool {
|
||||
for _, f := range connConfig.AfterRelease {
|
||||
if err := f(conn); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if connConfig.MaxOpenConns != 0 {
|
||||
config.MaxConns = int32(connConfig.MaxOpenConns)
|
||||
}
|
||||
|
||||
config.MaxConnLifetime = c.MaxConnLifetime
|
||||
config.MaxConnIdleTime = c.MaxConnIdleTime
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(
|
||||
context.Background(),
|
||||
config,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := pool.Ping(context.Background()); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return stdlib.OpenDBFromPool(pool), pool, nil
|
||||
}
|
||||
|
||||
func (c *Config) DatabaseName() string {
|
||||
return c.Database
|
||||
}
|
||||
|
||||
func (c *Config) Username() string {
|
||||
return c.User.Username
|
||||
}
|
||||
|
||||
func (c *Config) Password() string {
|
||||
return c.User.Password
|
||||
}
|
||||
|
||||
func (c *Config) Type() dialect.DatabaseType {
|
||||
return dialect.DatabaseTypePostgres
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
SSL SSL
|
||||
}
|
||||
|
||||
type AdminUser struct {
|
||||
// ExistingDatabase is the database to connect to before the ZITADEL database exists
|
||||
ExistingDatabase string
|
||||
User `mapstructure:",squash"`
|
||||
}
|
||||
|
||||
type SSL struct {
|
||||
// type of connection security
|
||||
Mode string
|
||||
// RootCert Path to the CA certificate
|
||||
RootCert string
|
||||
// Cert Path to the client certificate
|
||||
Cert string
|
||||
// Key Path to the client private key
|
||||
Key string
|
||||
}
|
||||
|
||||
func (s *Config) checkSSL(user User) {
|
||||
if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" {
|
||||
user.SSL = SSL{Mode: sslDisabledMode}
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.Mode == sslRequireMode || user.SSL.Mode == sslAllowMode || user.SSL.Mode == sslPreferMode {
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.RootCert == "" {
|
||||
logging.WithFields(
|
||||
"cert set", user.SSL.Cert != "",
|
||||
"key set", user.SSL.Key != "",
|
||||
"rootCert set", user.SSL.RootCert != "",
|
||||
).Fatal("at least ssl root cert has to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func (c Config) String(useAdmin bool) string {
|
||||
user := c.User
|
||||
if useAdmin {
|
||||
user = c.Admin.User
|
||||
}
|
||||
c.checkSSL(user)
|
||||
fields := []string{
|
||||
"host=" + c.Host,
|
||||
"port=" + strconv.Itoa(int(c.Port)),
|
||||
"user=" + user.Username,
|
||||
"application_name=" + dialect.DefaultAppName,
|
||||
"sslmode=" + user.SSL.Mode,
|
||||
}
|
||||
if c.Options != "" {
|
||||
fields = append(fields, "options="+c.Options)
|
||||
}
|
||||
if user.Password != "" {
|
||||
fields = append(fields, "password="+user.Password)
|
||||
}
|
||||
if !useAdmin {
|
||||
fields = append(fields, "dbname="+c.Database)
|
||||
} else {
|
||||
defaultDB := c.Admin.ExistingDatabase
|
||||
if defaultDB == "" {
|
||||
defaultDB = "postgres"
|
||||
}
|
||||
fields = append(fields, "dbname="+defaultDB)
|
||||
}
|
||||
if user.SSL.Mode != sslDisabledMode {
|
||||
if user.SSL.RootCert != "" {
|
||||
fields = append(fields, "sslrootcert="+user.SSL.RootCert)
|
||||
}
|
||||
if user.SSL.Cert != "" {
|
||||
fields = append(fields, "sslcert="+user.SSL.Cert)
|
||||
}
|
||||
if user.SSL.Key != "" {
|
||||
fields = append(fields, "sslkey="+user.SSL.Key)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(fields, " ")
|
||||
}
|
261
apps/api/internal/database/type.go
Normal file
261
apps/api/internal/database/type.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type TextArray[T ~string] pgtype.FlatArray[T]
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (s *TextArray[T]) Scan(src any) error {
|
||||
var typedArray []string
|
||||
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*s) = make(TextArray[T], len(typedArray))
|
||||
for i, value := range typedArray {
|
||||
(*s)[i] = T(value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the [database/sql/driver.Valuer] interface.
|
||||
func (s TextArray[T]) Value() (driver.Value, error) {
|
||||
if len(s) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
typed := make([]string, len(s))
|
||||
|
||||
for i, value := range s {
|
||||
typed[i] = string(value)
|
||||
}
|
||||
|
||||
return []byte("{" + strings.Join(typed, ",") + "}"), nil
|
||||
}
|
||||
|
||||
type ByteArray[T ~byte] pgtype.FlatArray[T]
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (s *ByteArray[T]) Scan(src any) error {
|
||||
var typedArray []byte
|
||||
typedArray, ok := src.([]byte)
|
||||
if !ok {
|
||||
// tests use a different src type
|
||||
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
(*s) = make(ByteArray[T], len(typedArray))
|
||||
for i, value := range typedArray {
|
||||
(*s)[i] = T(value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the [database/sql/driver.Valuer] interface.
|
||||
func (s ByteArray[T]) Value() (driver.Value, error) {
|
||||
if len(s) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
typed := make([]byte, len(s))
|
||||
|
||||
for i, value := range s {
|
||||
typed[i] = byte(value)
|
||||
}
|
||||
|
||||
return typed, nil
|
||||
}
|
||||
|
||||
type numberField interface {
|
||||
~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~int | ~uint
|
||||
}
|
||||
|
||||
type numberTypeField interface {
|
||||
int8 | uint8 | int16 | uint16 | int32 | uint32 | int64 | uint64 | int | uint
|
||||
}
|
||||
|
||||
var _ sql.Scanner = (*NumberArray[int8])(nil)
|
||||
|
||||
type NumberArray[F numberField] pgtype.FlatArray[F]
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (a *NumberArray[F]) Scan(src any) (err error) {
|
||||
var (
|
||||
mapper func()
|
||||
scanner sql.Scanner
|
||||
)
|
||||
|
||||
//nolint: exhaustive
|
||||
// only defined types
|
||||
switch reflect.TypeOf(*a).Elem().Kind() {
|
||||
case reflect.Int8:
|
||||
mapper, scanner = castedScan[int8](a)
|
||||
case reflect.Uint8:
|
||||
// we provide int16 is a workaround because pgx thinks we want to scan a byte array if we provide uint8
|
||||
mapper, scanner = castedScan[int16](a)
|
||||
case reflect.Int16:
|
||||
mapper, scanner = castedScan[int16](a)
|
||||
case reflect.Uint16:
|
||||
mapper, scanner = castedScan[uint16](a)
|
||||
case reflect.Int32:
|
||||
mapper, scanner = castedScan[int32](a)
|
||||
case reflect.Uint32:
|
||||
mapper, scanner = castedScan[uint32](a)
|
||||
case reflect.Int64:
|
||||
mapper, scanner = castedScan[int64](a)
|
||||
case reflect.Uint64:
|
||||
mapper, scanner = castedScan[uint64](a)
|
||||
case reflect.Int:
|
||||
mapper, scanner = castedScan[int](a)
|
||||
case reflect.Uint:
|
||||
mapper, scanner = castedScan[uint](a)
|
||||
}
|
||||
|
||||
if err = scanner.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
mapper()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func castedScan[T numberTypeField, F numberField](a *NumberArray[F]) (mapper func(), scanner sql.Scanner) {
|
||||
var typedArray []T
|
||||
|
||||
mapper = func() {
|
||||
(*a) = make(NumberArray[F], len(typedArray))
|
||||
for i, value := range typedArray {
|
||||
(*a)[i] = F(value)
|
||||
}
|
||||
}
|
||||
scanner = pgtype.NewMap().SQLScanner(&typedArray)
|
||||
|
||||
return mapper, scanner
|
||||
}
|
||||
|
||||
type Map[V any] map[string]V
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (m *Map[V]) Scan(src any) error {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bytes := src.([]byte)
|
||||
if len(bytes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(bytes, &m)
|
||||
}
|
||||
|
||||
// Value implements the [database/sql/driver.Valuer] interface.
|
||||
func (m Map[V]) Value() (driver.Value, error) {
|
||||
if len(m) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
type Duration time.Duration
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (d *Duration) Scan(src any) error {
|
||||
switch duration := src.(type) {
|
||||
case *time.Duration:
|
||||
*d = Duration(*duration)
|
||||
return nil
|
||||
case time.Duration:
|
||||
*d = Duration(duration)
|
||||
return nil
|
||||
case *pgtype.Interval:
|
||||
*d = intervalToDuration(duration)
|
||||
return nil
|
||||
case pgtype.Interval:
|
||||
*d = intervalToDuration(&duration)
|
||||
return nil
|
||||
case int64:
|
||||
*d = Duration(duration)
|
||||
return nil
|
||||
}
|
||||
interval := new(pgtype.Interval)
|
||||
if err := interval.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
*d = intervalToDuration(interval)
|
||||
return nil
|
||||
}
|
||||
|
||||
func intervalToDuration(interval *pgtype.Interval) Duration {
|
||||
return Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour)
|
||||
}
|
||||
|
||||
// NullDuration can be used for NULL intervals.
|
||||
// If Valid is false, the scanned value was NULL
|
||||
// This behavior is similar to [database/sql.NullString]
|
||||
type NullDuration struct {
|
||||
Valid bool
|
||||
Duration time.Duration
|
||||
}
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (d *NullDuration) Scan(src any) error {
|
||||
if src == nil {
|
||||
d.Duration, d.Valid = 0, false
|
||||
return nil
|
||||
}
|
||||
duration := new(Duration)
|
||||
if err := duration.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
d.Duration, d.Valid = time.Duration(*duration), true
|
||||
return nil
|
||||
}
|
||||
|
||||
// JSONArray allows sending and receiving JSON arrays to and from the database.
|
||||
// It implements the [database/sql.Scanner] and [database/sql/driver.Valuer] interfaces.
|
||||
// Values are marshaled and unmarshaled using the [encoding/json] package.
|
||||
type JSONArray[T any] []T
|
||||
|
||||
// NewJSONArray wraps an existing slice into a JSONArray.
|
||||
func NewJSONArray[T any](a []T) JSONArray[T] {
|
||||
return JSONArray[T](a)
|
||||
}
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (a *JSONArray[T]) Scan(src any) error {
|
||||
if src == nil {
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
bytes := src.([]byte)
|
||||
if len(bytes) == 0 {
|
||||
*a = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(bytes, a)
|
||||
}
|
||||
|
||||
// Value implements the [database/sql/driver.Valuer] interface.
|
||||
func (a JSONArray[T]) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return json.Marshal(a)
|
||||
}
|
541
apps/api/internal/database/type_test.go
Normal file
541
apps/api/internal/database/type_test.go
Normal file
@@ -0,0 +1,541 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"testing"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMap_Scan(t *testing.T) {
|
||||
type args struct {
|
||||
src []byte
|
||||
}
|
||||
type res[V any] struct {
|
||||
want Map[V]
|
||||
err bool
|
||||
}
|
||||
type testCase[V any] struct {
|
||||
name string
|
||||
m Map[V]
|
||||
args args
|
||||
res[V]
|
||||
}
|
||||
tests := []testCase[string]{
|
||||
{
|
||||
"nil",
|
||||
Map[string]{},
|
||||
args{src: nil},
|
||||
res[string]{
|
||||
want: Map[string]{},
|
||||
err: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"null",
|
||||
Map[string]{},
|
||||
args{src: []byte("invalid")},
|
||||
res[string]{
|
||||
want: Map[string]{},
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"null",
|
||||
Map[string]{},
|
||||
args{src: nil},
|
||||
res[string]{
|
||||
want: Map[string]{},
|
||||
},
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
Map[string]{},
|
||||
args{src: []byte(`{}`)},
|
||||
res[string]{
|
||||
want: Map[string]{},
|
||||
},
|
||||
},
|
||||
{
|
||||
"set",
|
||||
Map[string]{},
|
||||
args{src: []byte(`{"key": "value"}`)},
|
||||
res[string]{
|
||||
want: Map[string]{
|
||||
"key": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
|
||||
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
|
||||
}
|
||||
assert.Equal(t, tt.res.want, tt.m)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_Value(t *testing.T) {
|
||||
type res struct {
|
||||
want driver.Value
|
||||
err bool
|
||||
}
|
||||
type testCase[V any] struct {
|
||||
name string
|
||||
m Map[V]
|
||||
res res
|
||||
}
|
||||
tests := []testCase[string]{
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
res{
|
||||
want: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
Map[string]{},
|
||||
res{
|
||||
want: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"set",
|
||||
Map[string]{
|
||||
"key": "value",
|
||||
},
|
||||
res{
|
||||
want: driver.Value([]byte(`{"key":"value"}`)),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.m.Value()
|
||||
if tt.res.err {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
if !tt.res.err {
|
||||
require.NoError(t, err)
|
||||
assert.Equalf(t, tt.res.want, got, "Value()")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type typedInt int
|
||||
|
||||
func TestNumberArray_Scan(t *testing.T) {
|
||||
type args struct {
|
||||
src any
|
||||
}
|
||||
type res struct {
|
||||
want any
|
||||
err bool
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
m sql.Scanner
|
||||
args args
|
||||
res res
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "typedInt",
|
||||
m: new(NumberArray[typedInt]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[typedInt]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int8",
|
||||
m: new(NumberArray[int8]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[int8]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uint8",
|
||||
m: new(NumberArray[uint8]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[uint8]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int16",
|
||||
m: new(NumberArray[int16]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[int16]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uint16",
|
||||
m: new(NumberArray[uint16]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[uint16]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int32",
|
||||
m: new(NumberArray[int32]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[int32]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uint32",
|
||||
m: new(NumberArray[uint32]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[uint32]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int64",
|
||||
m: new(NumberArray[int64]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[int64]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uint64",
|
||||
m: new(NumberArray[uint64]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[uint64]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
m: new(NumberArray[int]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[int]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uint",
|
||||
m: new(NumberArray[uint]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[uint]{1, 2},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
|
||||
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.res.want, tt.m)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type typedText string
|
||||
|
||||
func TestTextArray_Scan(t *testing.T) {
|
||||
type args struct {
|
||||
src any
|
||||
}
|
||||
type res struct {
|
||||
want sql.Scanner
|
||||
err bool
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
m sql.Scanner
|
||||
args args
|
||||
res
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
"string",
|
||||
new(TextArray[string]),
|
||||
args{src: "{asdf,fdas}"},
|
||||
res{
|
||||
want: &TextArray[string]{"asdf", "fdas"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"typedText",
|
||||
new(TextArray[typedText]),
|
||||
args{src: "{asdf,fdas}"},
|
||||
res{
|
||||
want: &TextArray[typedText]{"asdf", "fdas"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
|
||||
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.res.want, tt.m)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextArray_Value(t *testing.T) {
|
||||
type res struct {
|
||||
want driver.Value
|
||||
err bool
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
m driver.Valuer
|
||||
res res
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
"empty",
|
||||
TextArray[string]{},
|
||||
res{
|
||||
want: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"set",
|
||||
TextArray[string]{"a", "s", "d", "f"},
|
||||
res{
|
||||
want: driver.Value([]byte("{a,s,d,f}")),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.m.Value()
|
||||
if tt.res.err {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
if !tt.res.err {
|
||||
require.NoError(t, err)
|
||||
assert.Equalf(t, tt.res.want, got, "Value()")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type typedByte byte
|
||||
|
||||
func TestByteArray_Scan(t *testing.T) {
|
||||
wantedBytes := []byte("asdf")
|
||||
wantedTypedBytes := []typedByte("asdf")
|
||||
type args struct {
|
||||
src any
|
||||
}
|
||||
type res struct {
|
||||
want sql.Scanner
|
||||
err bool
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
m sql.Scanner
|
||||
args args
|
||||
res
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
"bytes",
|
||||
new(ByteArray[byte]),
|
||||
args{src: []byte("asdf")},
|
||||
res{
|
||||
want: (*ByteArray[byte])(&wantedBytes),
|
||||
},
|
||||
},
|
||||
{
|
||||
"typed",
|
||||
new(ByteArray[typedByte]),
|
||||
args{src: []byte("asdf")},
|
||||
res{
|
||||
want: (*ByteArray[typedByte])(&wantedTypedBytes),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
|
||||
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.res.want, tt.m)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteArray_Value(t *testing.T) {
|
||||
type res struct {
|
||||
want driver.Value
|
||||
err bool
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
m driver.Valuer
|
||||
res res
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
"empty",
|
||||
ByteArray[byte]{},
|
||||
res{
|
||||
want: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"set",
|
||||
ByteArray[byte]([]byte("{\"type\": \"object\", \"$schema\": \"urn:zitadel:schema:v1\"}")),
|
||||
res{
|
||||
want: driver.Value([]byte("{\"type\": \"object\", \"$schema\": \"urn:zitadel:schema:v1\"}")),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.m.Value()
|
||||
if tt.res.err {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
if !tt.res.err {
|
||||
require.NoError(t, err)
|
||||
assert.Equalf(t, tt.res.want, got, "Value()")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuration_Scan(t *testing.T) {
|
||||
duration := Duration(10)
|
||||
type args struct {
|
||||
src any
|
||||
}
|
||||
type res struct {
|
||||
want sql.Scanner
|
||||
err bool
|
||||
}
|
||||
type testCase[V ~string] struct {
|
||||
name string
|
||||
m sql.Scanner
|
||||
args args
|
||||
res
|
||||
}
|
||||
tests := []testCase[string]{
|
||||
{
|
||||
name: "int64",
|
||||
m: new(Duration),
|
||||
args: args{src: int64(duration)},
|
||||
res: res{
|
||||
want: &duration,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
|
||||
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.res.want, tt.m)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONArray_Scan(t *testing.T) {
|
||||
type args struct {
|
||||
src any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *JSONArray[string]
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
args: args{src: nil},
|
||||
want: new(JSONArray[string]),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "zero bytes",
|
||||
args: args{src: []byte("")},
|
||||
want: new(JSONArray[string]),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
args: args{src: []byte("[]")},
|
||||
want: gu.Ptr(JSONArray[string]{}),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
args: args{src: []byte("[\"a\", \"b\"]")},
|
||||
want: gu.Ptr(JSONArray[string]{"a", "b"}),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "json error",
|
||||
args: args{src: []byte("{\"a\": \"b\"}")},
|
||||
want: new(JSONArray[string]),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := new(JSONArray[string])
|
||||
err := got.Scan(tt.args.src)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONArray_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a []string
|
||||
want driver.Value
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
a: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
a: []string{},
|
||||
want: []byte("[]"),
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
a: []string{"a", "b"},
|
||||
want: []byte("[\"a\",\"b\"]"),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewJSONArray(tt.a).Value()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user