chore: move the go code into a subfolder

This commit is contained in:
Florian Forster
2025-08-05 15:20:32 -07:00
parent 4ad22ba456
commit cd2921de26
2978 changed files with 373 additions and 300 deletions

View File

@@ -0,0 +1,228 @@
package cockroach
import (
"context"
"database/sql"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database/dialect"
)
func init() {
config := new(Config)
dialect.Register(config, config, false)
}
const (
sslDisabledMode = "disable"
sslRequireMode = "require"
sslAllowMode = "allow"
sslPreferMode = "prefer"
)
type Config struct {
Host string
Port uint16
Database string
MaxOpenConns uint32
MaxIdleConns uint32
MaxConnLifetime time.Duration
MaxConnIdleTime time.Duration
User User
Admin AdminUser
// Additional options to be appended as options=<Options>
// The value will be taken as is. Multiple options are space separated.
Options string
}
func (c *Config) MatchName(name string) bool {
for _, key := range []string{"crdb", "cockroach"} {
if strings.TrimSpace(strings.ToLower(name)) == key {
return true
}
}
return false
}
func (_ *Config) Decode(configs []any) (dialect.Connector, error) {
connector := new(Config)
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
WeaklyTypedInput: true,
Result: connector,
})
if err != nil {
return nil, err
}
for _, config := range configs {
if err = decoder.Decode(config); err != nil {
return nil, err
}
}
return connector, nil
}
func (c *Config) Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error) {
connConfig := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns)
config, err := pgxpool.ParseConfig(c.String(useAdmin))
if err != nil {
return nil, nil, err
}
if len(connConfig.AfterConnect) > 0 {
config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
for _, f := range connConfig.AfterConnect {
if err := f(ctx, conn); err != nil {
return err
}
}
return nil
}
}
if len(connConfig.BeforeAcquire) > 0 {
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
for _, f := range connConfig.BeforeAcquire {
if err := f(ctx, conn); err != nil {
return false
}
}
return true
}
}
if len(connConfig.AfterRelease) > 0 {
config.AfterRelease = func(conn *pgx.Conn) bool {
for _, f := range connConfig.AfterRelease {
if err := f(conn); err != nil {
return false
}
}
return true
}
}
if connConfig.MaxOpenConns != 0 {
config.MaxConns = int32(connConfig.MaxOpenConns)
}
config.MaxConnLifetime = c.MaxConnLifetime
config.MaxConnIdleTime = c.MaxConnIdleTime
pool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil {
return nil, nil, err
}
if err := pool.Ping(context.Background()); err != nil {
return nil, nil, err
}
return stdlib.OpenDBFromPool(pool), pool, nil
}
func (c *Config) DatabaseName() string {
return c.Database
}
func (c *Config) Username() string {
return c.User.Username
}
func (c *Config) Password() string {
return c.User.Password
}
func (c *Config) Type() dialect.DatabaseType {
return dialect.DatabaseTypeCockroach
}
type User struct {
Username string
Password string
SSL SSL
}
type AdminUser struct {
// ExistingDatabase is the database to connect to before the ZITADEL database exists
ExistingDatabase string
User `mapstructure:",squash"`
}
type SSL struct {
// type of connection security
Mode string
// RootCert Path to the CA certificate
RootCert string
// Cert Path to the client certificate
Cert string
// Key Path to the client private key
Key string
}
func (c *Config) checkSSL(user User) {
if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" {
user.SSL = SSL{Mode: sslDisabledMode}
return
}
if user.SSL.Mode == sslRequireMode || user.SSL.Mode == sslAllowMode || user.SSL.Mode == sslPreferMode {
return
}
if user.SSL.RootCert == "" {
logging.WithFields(
"cert set", user.SSL.Cert != "",
"key set", user.SSL.Key != "",
"rootCert set", user.SSL.RootCert != "",
).Fatal("at least ssl root cert has to be set")
}
}
func (c Config) String(useAdmin bool) string {
user := c.User
if useAdmin {
user = c.Admin.User
}
c.checkSSL(user)
fields := []string{
"host=" + c.Host,
"port=" + strconv.Itoa(int(c.Port)),
"user=" + user.Username,
"dbname=" + c.Database,
"application_name=" + dialect.DefaultAppName,
"sslmode=" + user.SSL.Mode,
}
if c.Options != "" {
fields = append(fields, "options="+c.Options)
}
if !useAdmin {
fields = append(fields, "dbname="+c.Database)
} else if c.Admin.ExistingDatabase != "" {
fields = append(fields, "dbname="+c.Admin.ExistingDatabase)
}
if user.Password != "" {
fields = append(fields, "password="+user.Password)
}
if user.SSL.Mode != sslDisabledMode {
fields = append(fields, "sslrootcert="+user.SSL.RootCert)
if user.SSL.Cert != "" {
fields = append(fields, "sslcert="+user.SSL.Cert)
}
if user.SSL.Key != "" {
fields = append(fields, "sslkey="+user.SSL.Key)
}
}
return strings.Join(fields, " ")
}

View File

@@ -0,0 +1,212 @@
package database
import (
"context"
"database/sql"
"encoding/json"
"errors"
"reflect"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/logging"
_ "github.com/zitadel/zitadel/internal/database/cockroach"
"github.com/zitadel/zitadel/internal/database/dialect"
_ "github.com/zitadel/zitadel/internal/database/postgres"
"github.com/zitadel/zitadel/internal/zerrors"
)
type ContextQuerier interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
type ContextExecuter interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type ContextQueryExecuter interface {
ContextQuerier
ContextExecuter
}
type Client interface {
ContextQueryExecuter
Beginner
Conn(ctx context.Context) (*sql.Conn, error)
}
type Beginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
type Tx interface {
ContextQueryExecuter
Commit() error
Rollback() error
}
var (
_ Client = (*sql.DB)(nil)
_ Tx = (*sql.Tx)(nil)
)
func CloseTransaction(tx Tx, err error) error {
if err != nil {
rollbackErr := tx.Rollback()
logging.OnError(rollbackErr).Error("failed to rollback transaction")
return err
}
commitErr := tx.Commit()
logging.OnError(commitErr).Error("failed to commit transaction")
return commitErr
}
const (
PgUniqueConstraintErrorCode = "23505"
)
type Config struct {
Dialects map[string]interface{} `mapstructure:",remain"`
connector dialect.Connector
}
func (c *Config) SetConnector(connector dialect.Connector) {
c.connector = connector
}
type DB struct {
*sql.DB
dialect.Database
Pool *pgxpool.Pool
}
func (db *DB) Query(scan func(*sql.Rows) error, query string, args ...any) error {
return db.QueryContext(context.Background(), scan, query, args...)
}
func (db *DB) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) (err error) {
rows, err := db.DB.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer func() {
closeErr := rows.Close()
logging.OnError(closeErr).Info("rows.Close failed")
}()
if err = scan(rows); err != nil {
return err
}
return rows.Err()
}
func (db *DB) QueryRow(scan func(*sql.Row) error, query string, args ...any) (err error) {
return db.QueryRowContext(context.Background(), scan, query, args...)
}
func (db *DB) QueryRowContext(ctx context.Context, scan func(row *sql.Row) error, query string, args ...any) (err error) {
row := db.DB.QueryRowContext(ctx, query, args...)
logging.OnError(row.Err()).Error("unexpected query error")
err = scan(row)
if err != nil {
return err
}
return row.Err()
}
func QueryJSONObject[T any](ctx context.Context, db *DB, query string, args ...any) (*T, error) {
var data []byte
err := db.QueryRowContext(ctx, func(row *sql.Row) error {
return row.Scan(&data)
}, query, args...)
if errors.Is(err, sql.ErrNoRows) {
return nil, err
}
if err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-Oath6", "Errors.Internal")
}
obj := new(T)
if err = json.Unmarshal(data, obj); err != nil {
return nil, zerrors.ThrowInternal(err, "DATAB-Vohs6", "Errors.Internal")
}
return obj, nil
}
func Connect(config Config, useAdmin bool) (*DB, error) {
client, pool, err := config.connector.Connect(useAdmin)
if err != nil {
return nil, err
}
if err := client.Ping(); err != nil {
return nil, zerrors.ThrowPreconditionFailed(err, "DATAB-0pIWD", "Errors.Database.Connection.Failed")
}
return &DB{
DB: client,
Database: config.connector,
Pool: pool,
}, nil
}
func DecodeHook(allowCockroach bool) func(from, to reflect.Value) (_ interface{}, err error) {
return func(from, to reflect.Value) (_ interface{}, err error) {
if to.Type() != reflect.TypeOf(Config{}) {
return from.Interface(), nil
}
config := new(Config)
if err = mapstructure.Decode(from.Interface(), config); err != nil {
return nil, err
}
configuredDialect := dialect.SelectByConfig(config.Dialects)
configs := make([]any, 0, len(config.Dialects))
for name, dialectConfig := range config.Dialects {
if !configuredDialect.Matcher.MatchName(name) {
continue
}
configs = append(configs, dialectConfig)
}
if !allowCockroach && configuredDialect.Matcher.Type() == dialect.DatabaseTypeCockroach {
logging.Info("Cockroach support was removed with Zitadel v3, please refer to https://zitadel.com/docs/self-hosting/manage/cli/mirror to migrate your data to postgres")
return nil, zerrors.ThrowPreconditionFailed(nil, "DATAB-0pIWD", "Cockroach support was removed with Zitadel v3")
}
config.connector, err = configuredDialect.Matcher.Decode(configs)
if err != nil {
return nil, err
}
return config, nil
}
}
func (c Config) DatabaseName() string {
return c.connector.DatabaseName()
}
func (c Config) Username() string {
return c.connector.Username()
}
func (c Config) Password() string {
return c.connector.Password()
}
func (c Config) Type() dialect.DatabaseType {
return c.connector.Type()
}
func EscapeLikeWildcards(value string) string {
value = strings.ReplaceAll(value, "%", "\\%")
value = strings.ReplaceAll(value, "_", "\\_")
return value
}

View File

@@ -0,0 +1,87 @@
package database
import (
"context"
"database/sql"
"database/sql/driver"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestQueryJSONObject(t *testing.T) {
type dst struct {
A int `json:"a,omitempty"`
}
const (
query = `select $1;`
arg = 1
)
tests := []struct {
name string
mock func(*testing.T) *mock.SQLMock
want *dst
wantErr error
}{
{
name: "tx error",
mock: func(t *testing.T) *mock.SQLMock {
return mock.NewSQLMock(t, mock.ExpectQuery("select $1;", mock.WithQueryErr(sql.ErrConnDone)))
},
wantErr: zerrors.ThrowInternal(sql.ErrConnDone, "DATAB-Oath6", "Errors.Internal"),
},
{
name: "no rows",
mock: func(t *testing.T) *mock.SQLMock {
return mock.NewSQLMock(t,
mock.ExpectQuery(query,
mock.WithQueryArgs(arg),
mock.WithQueryResult([]string{"json"}, [][]driver.Value{}),
),
)
},
wantErr: sql.ErrNoRows,
},
{
name: "unmarshal error",
mock: func(t *testing.T) *mock.SQLMock {
return mock.NewSQLMock(t,
mock.ExpectQuery(query,
mock.WithQueryArgs(arg),
mock.WithQueryResult([]string{"json"}, [][]driver.Value{{`~~~`}}),
),
)
},
wantErr: zerrors.ThrowInternal(nil, "DATAB-Vohs6", "Errors.Internal"),
},
{
name: "success",
mock: func(t *testing.T) *mock.SQLMock {
return mock.NewSQLMock(t,
mock.ExpectQuery(query,
mock.WithQueryArgs(arg),
mock.WithQueryResult([]string{"json"}, [][]driver.Value{{`{"a":1}`}}),
),
)
},
want: &dst{A: 1},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mock := tt.mock(t)
defer mock.Assert(t)
db := &DB{
DB: mock.DB,
}
got, err := QueryJSONObject[dst](context.Background(), db, query, arg)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,75 @@
package dialect
import (
"database/sql"
"sync"
"github.com/jackc/pgx/v5/pgxpool"
)
type Dialect struct {
Matcher Matcher
Config Connector
IsDefault bool
}
var (
dialects []*Dialect
defaultDialect *Dialect
dialectsMu sync.Mutex
)
type Matcher interface {
MatchName(string) bool
Decode([]any) (Connector, error)
Type() DatabaseType
}
type DatabaseType uint8
const (
DatabaseTypePostgres DatabaseType = iota
DatabaseTypeCockroach
)
const (
DefaultAppName = "zitadel"
)
type Connector interface {
Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error)
Password() string
Database
}
type Database interface {
DatabaseName() string
Username() string
Type() DatabaseType
}
func Register(matcher Matcher, config Connector, isDefault bool) {
dialectsMu.Lock()
defer dialectsMu.Unlock()
d := &Dialect{Matcher: matcher, Config: config}
if isDefault {
defaultDialect = d
return
}
dialects = append(dialects, d)
}
func SelectByConfig(config map[string]interface{}) *Dialect {
for key := range config {
for _, d := range dialects {
if d.Matcher.MatchName(key) {
return d
}
}
}
return defaultDialect
}

View File

@@ -0,0 +1,87 @@
package dialect
import (
"context"
"errors"
"reflect"
pgxdecimal "github.com/jackc/pgx-shopspring-decimal"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)
var (
ErrIllegalMaxOpenConns = errors.New("MaxOpenConns of the database must be higher than 3 or 0 for unlimited")
ErrIllegalMaxIdleConns = errors.New("MaxIdleConns of the database must be higher than 3 or 0 for unlimited")
)
// ConnectionConfig defines the Max Open and Idle connections for a DB connection pool.
type ConnectionConfig struct {
MaxOpenConns,
MaxIdleConns uint32
AfterConnect []func(ctx context.Context, c *pgx.Conn) error
BeforeAcquire []func(ctx context.Context, c *pgx.Conn) error
AfterRelease []func(c *pgx.Conn) error
}
var afterConnectFuncs = []func(ctx context.Context, c *pgx.Conn) error{
func(ctx context.Context, c *pgx.Conn) error {
pgxdecimal.Register(c.TypeMap())
return nil
},
}
func RegisterAfterConnect(f func(ctx context.Context, c *pgx.Conn) error) {
afterConnectFuncs = append(afterConnectFuncs, f)
}
var beforeAcquireFuncs []func(ctx context.Context, c *pgx.Conn) error
func RegisterBeforeAcquire(f func(ctx context.Context, c *pgx.Conn) error) {
beforeAcquireFuncs = append(beforeAcquireFuncs, f)
}
var afterReleaseFuncs []func(c *pgx.Conn) error
func RegisterAfterRelease(f func(c *pgx.Conn) error) {
afterReleaseFuncs = append(afterReleaseFuncs, f)
}
func RegisterDefaultPgTypeVariants[T any](m *pgtype.Map, name, arrayName string) {
// T
var value T
m.RegisterDefaultPgType(value, name)
// *T
valueType := reflect.TypeOf(value)
m.RegisterDefaultPgType(reflect.New(valueType).Interface(), name)
// []T
sliceType := reflect.SliceOf(valueType)
m.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName)
// *[]T
m.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName)
// []*T
sliceOfPointerType := reflect.SliceOf(reflect.TypeOf(reflect.New(valueType).Interface()))
m.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName)
// *[]*T
m.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName)
}
// NewConnectionConfig calculates [ConnectionConfig] values from the passed ratios
// and returns the config applicable for the requested purpose.
//
// openConns and idleConns must be at least 3 or 0, which means no limit.
// The pusherRatio and spoolerRatio must be between 0 and 1.
func NewConnectionConfig(openConns, idleConns uint32) *ConnectionConfig {
return &ConnectionConfig{
MaxOpenConns: openConns,
MaxIdleConns: idleConns,
AfterConnect: afterConnectFuncs,
BeforeAcquire: beforeAcquireFuncs,
AfterRelease: afterReleaseFuncs,
}
}

View File

@@ -0,0 +1,5 @@
package migrate
import "database/sql"
func New(*sql.DB) {}

View File

@@ -0,0 +1,139 @@
package mock
import (
"database/sql"
"database/sql/driver"
"reflect"
"testing"
"github.com/DATA-DOG/go-sqlmock"
)
type SQLMock struct {
DB *sql.DB
mock sqlmock.Sqlmock
}
type Expectation func(m sqlmock.Sqlmock)
func NewSQLMock(t *testing.T, expectations ...Expectation) *SQLMock {
db, mock, err := sqlmock.New(
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual),
sqlmock.ValueConverterOption(new(TypeConverter)),
)
if err != nil {
t.Fatal("create mock failed", err)
}
for _, expectation := range expectations {
expectation(mock)
}
return &SQLMock{
DB: db,
mock: mock,
}
}
func (m *SQLMock) Assert(t *testing.T) {
t.Helper()
if err := m.mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations not met: %v", err)
}
m.DB.Close()
}
func ExpectBegin(err error) Expectation {
return func(m sqlmock.Sqlmock) {
e := m.ExpectBegin()
if err != nil {
e.WillReturnError(err)
}
}
}
func ExpectCommit(err error) Expectation {
return func(m sqlmock.Sqlmock) {
e := m.ExpectCommit()
if err != nil {
e.WillReturnError(err)
}
}
}
type ExecOpt func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec
func WithExecArgs(args ...driver.Value) ExecOpt {
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
return e.WithArgs(args...)
}
}
func WithExecErr(err error) ExecOpt {
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
return e.WillReturnError(err)
}
}
func WithExecNoRowsAffected() ExecOpt {
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
return e.WillReturnResult(driver.ResultNoRows)
}
}
func WithExecRowsAffected(affected driver.RowsAffected) ExecOpt {
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
return e.WillReturnResult(affected)
}
}
func ExcpectExec(stmt string, opts ...ExecOpt) Expectation {
return func(m sqlmock.Sqlmock) {
e := m.ExpectExec(stmt)
for _, opt := range opts {
e = opt(e)
}
}
}
type QueryOpt func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery
func WithQueryArgs(args ...driver.Value) QueryOpt {
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
return e.WithArgs(args...)
}
}
func WithQueryErr(err error) QueryOpt {
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
return e.WillReturnError(err)
}
}
func WithQueryResult(columns []string, rows [][]driver.Value) QueryOpt {
return func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
mockedRows := m.NewRows(columns)
for _, row := range rows {
mockedRows = mockedRows.AddRow(row...)
}
return e.WillReturnRows(mockedRows)
}
}
func ExpectQuery(stmt string, opts ...QueryOpt) Expectation {
return func(m sqlmock.Sqlmock) {
e := m.ExpectQuery(stmt)
for _, opt := range opts {
e = opt(m, e)
}
}
}
type AnyType[T interface{}] struct{}
// Match satisfies sqlmock.Argument interface
func (a AnyType[T]) Match(v driver.Value) bool {
return reflect.TypeOf(new(T)).Elem().Kind().String() == reflect.TypeOf(v).Kind().String()
}

View File

@@ -0,0 +1,87 @@
package mock
import (
"database/sql/driver"
"encoding/hex"
"encoding/json"
"reflect"
"slices"
"strconv"
"strings"
)
var _ driver.ValueConverter = (*TypeConverter)(nil)
type TypeConverter struct{}
// ConvertValue converts a value to a driver Value.
func (s TypeConverter) ConvertValue(v any) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}
value := reflect.ValueOf(v)
if rawMessage, ok := v.(json.RawMessage); ok {
return convertBytes(rawMessage), nil
}
if value.Kind() == reflect.Slice {
//nolint: exhaustive
// only defined types
switch value.Type().Elem().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return convertSigned(value), nil
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return convertUnsigned(value), nil
case reflect.String:
return convertText(value), nil
}
}
return v, nil
}
// converts a text array to valid pgx v5 representation
func convertSigned(array reflect.Value) string {
slice := make([]string, array.Len())
for i := 0; i < array.Len(); i++ {
slice[i] = strconv.FormatInt(array.Index(i).Int(), 10)
}
return "{" + strings.Join(slice, ",") + "}"
}
// converts a text array to valid pgx v5 representation
func convertUnsigned(array reflect.Value) string {
slice := make([]string, array.Len())
for i := 0; i < array.Len(); i++ {
slice[i] = strconv.FormatUint(array.Index(i).Uint(), 10)
}
return "{" + strings.Join(slice, ",") + "}"
}
// converts a text array to valid pgx v5 representation
func convertText(array reflect.Value) string {
slice := make([]string, array.Len())
for i := 0; i < array.Len(); i++ {
slice[i] = array.Index(i).String()
}
return "{" + strings.Join(slice, ",") + "}"
}
func convertBytes(array []byte) string {
var builder strings.Builder
builder.Grow(hex.EncodedLen(len(array)) + 4)
builder.WriteString(`\x`)
builder.Write(AppendEncode(nil, array))
return builder.String()
}
// TODO: remove function after we compile using go 1.22 and use function of hex package `hex.AppendEncode`
func AppendEncode(dst, src []byte) []byte {
n := hex.EncodedLen(len(src))
dst = slices.Grow(dst, n)
hex.Encode(dst[len(dst):][:n], src)
return dst[:len(dst)+n]
}

View File

@@ -0,0 +1,38 @@
package postgres
import (
"net"
"os"
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
"github.com/zitadel/logging"
)
func StartEmbedded() (embeddedpostgres.Config, func()) {
path, err := os.MkdirTemp("", "zitadel-embedded-postgres-*")
logging.OnError(err).Fatal("unable to create temp dir")
port, close := getPort()
config := embeddedpostgres.DefaultConfig().Version(embeddedpostgres.V16).Port(uint32(port)).RuntimePath(path)
embedded := embeddedpostgres.NewDatabase(config)
close()
err = embedded.Start()
logging.OnError(err).Fatal("unable to start db")
return config, func() {
logging.OnError(embedded.Stop()).Error("unable to stop db")
}
}
// getPort returns a free port and locks it until close is called
func getPort() (port uint16, close func()) {
l, err := net.Listen("tcp", ":0")
logging.OnError(err).Fatal("unable to get port")
port = uint16(l.Addr().(*net.TCPAddr).Port)
logging.WithFields("port", port).Info("Port is available")
return port, func() {
logging.OnError(l.Close()).Error("unable to close port listener")
}
}

View File

@@ -0,0 +1,237 @@
package postgres
import (
"context"
"database/sql"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database/dialect"
)
func init() {
config := new(Config)
dialect.Register(config, config, true)
}
const (
sslDisabledMode = "disable"
sslRequireMode = "require"
sslAllowMode = "allow"
sslPreferMode = "prefer"
)
type Config struct {
Host string
Port int32
Database string
MaxOpenConns uint32
MaxIdleConns uint32
MaxConnLifetime time.Duration
MaxConnIdleTime time.Duration
User User
Admin AdminUser
// Additional options to be appended as options=<Options>
// The value will be taken as is. Multiple options are space separated.
Options string
}
func (c *Config) MatchName(name string) bool {
for _, key := range []string{"pg", "postgres"} {
if strings.TrimSpace(strings.ToLower(name)) == key {
return true
}
}
return false
}
func (_ *Config) Decode(configs []interface{}) (dialect.Connector, error) {
connector := new(Config)
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
WeaklyTypedInput: true,
Result: connector,
})
if err != nil {
return nil, err
}
for _, config := range configs {
if err = decoder.Decode(config); err != nil {
return nil, err
}
}
return connector, nil
}
func (c *Config) Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error) {
connConfig := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns)
config, err := pgxpool.ParseConfig(c.String(useAdmin))
if err != nil {
return nil, nil, err
}
if len(connConfig.AfterConnect) > 0 {
config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
for _, f := range connConfig.AfterConnect {
if err := f(ctx, conn); err != nil {
return err
}
}
return nil
}
}
if len(connConfig.BeforeAcquire) > 0 {
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
for _, f := range connConfig.BeforeAcquire {
if err := f(ctx, conn); err != nil {
return false
}
}
return true
}
}
if len(connConfig.AfterRelease) > 0 {
config.AfterRelease = func(conn *pgx.Conn) bool {
for _, f := range connConfig.AfterRelease {
if err := f(conn); err != nil {
return false
}
}
return true
}
}
if connConfig.MaxOpenConns != 0 {
config.MaxConns = int32(connConfig.MaxOpenConns)
}
config.MaxConnLifetime = c.MaxConnLifetime
config.MaxConnIdleTime = c.MaxConnIdleTime
pool, err := pgxpool.NewWithConfig(
context.Background(),
config,
)
if err != nil {
return nil, nil, err
}
if err := pool.Ping(context.Background()); err != nil {
return nil, nil, err
}
return stdlib.OpenDBFromPool(pool), pool, nil
}
func (c *Config) DatabaseName() string {
return c.Database
}
func (c *Config) Username() string {
return c.User.Username
}
func (c *Config) Password() string {
return c.User.Password
}
func (c *Config) Type() dialect.DatabaseType {
return dialect.DatabaseTypePostgres
}
type User struct {
Username string
Password string
SSL SSL
}
type AdminUser struct {
// ExistingDatabase is the database to connect to before the ZITADEL database exists
ExistingDatabase string
User `mapstructure:",squash"`
}
type SSL struct {
// type of connection security
Mode string
// RootCert Path to the CA certificate
RootCert string
// Cert Path to the client certificate
Cert string
// Key Path to the client private key
Key string
}
func (s *Config) checkSSL(user User) {
if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" {
user.SSL = SSL{Mode: sslDisabledMode}
return
}
if user.SSL.Mode == sslRequireMode || user.SSL.Mode == sslAllowMode || user.SSL.Mode == sslPreferMode {
return
}
if user.SSL.RootCert == "" {
logging.WithFields(
"cert set", user.SSL.Cert != "",
"key set", user.SSL.Key != "",
"rootCert set", user.SSL.RootCert != "",
).Fatal("at least ssl root cert has to be set")
}
}
func (c Config) String(useAdmin bool) string {
user := c.User
if useAdmin {
user = c.Admin.User
}
c.checkSSL(user)
fields := []string{
"host=" + c.Host,
"port=" + strconv.Itoa(int(c.Port)),
"user=" + user.Username,
"application_name=" + dialect.DefaultAppName,
"sslmode=" + user.SSL.Mode,
}
if c.Options != "" {
fields = append(fields, "options="+c.Options)
}
if user.Password != "" {
fields = append(fields, "password="+user.Password)
}
if !useAdmin {
fields = append(fields, "dbname="+c.Database)
} else {
defaultDB := c.Admin.ExistingDatabase
if defaultDB == "" {
defaultDB = "postgres"
}
fields = append(fields, "dbname="+defaultDB)
}
if user.SSL.Mode != sslDisabledMode {
if user.SSL.RootCert != "" {
fields = append(fields, "sslrootcert="+user.SSL.RootCert)
}
if user.SSL.Cert != "" {
fields = append(fields, "sslcert="+user.SSL.Cert)
}
if user.SSL.Key != "" {
fields = append(fields, "sslkey="+user.SSL.Key)
}
}
return strings.Join(fields, " ")
}

View File

@@ -0,0 +1,261 @@
package database
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"reflect"
"strings"
"time"
"github.com/jackc/pgx/v5/pgtype"
)
type TextArray[T ~string] pgtype.FlatArray[T]
// Scan implements the [database/sql.Scanner] interface.
func (s *TextArray[T]) Scan(src any) error {
var typedArray []string
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
if err != nil {
return err
}
(*s) = make(TextArray[T], len(typedArray))
for i, value := range typedArray {
(*s)[i] = T(value)
}
return nil
}
// Value implements the [database/sql/driver.Valuer] interface.
func (s TextArray[T]) Value() (driver.Value, error) {
if len(s) == 0 {
return nil, nil
}
typed := make([]string, len(s))
for i, value := range s {
typed[i] = string(value)
}
return []byte("{" + strings.Join(typed, ",") + "}"), nil
}
type ByteArray[T ~byte] pgtype.FlatArray[T]
// Scan implements the [database/sql.Scanner] interface.
func (s *ByteArray[T]) Scan(src any) error {
var typedArray []byte
typedArray, ok := src.([]byte)
if !ok {
// tests use a different src type
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
if err != nil {
return err
}
}
(*s) = make(ByteArray[T], len(typedArray))
for i, value := range typedArray {
(*s)[i] = T(value)
}
return nil
}
// Value implements the [database/sql/driver.Valuer] interface.
func (s ByteArray[T]) Value() (driver.Value, error) {
if len(s) == 0 {
return nil, nil
}
typed := make([]byte, len(s))
for i, value := range s {
typed[i] = byte(value)
}
return typed, nil
}
type numberField interface {
~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~int | ~uint
}
type numberTypeField interface {
int8 | uint8 | int16 | uint16 | int32 | uint32 | int64 | uint64 | int | uint
}
var _ sql.Scanner = (*NumberArray[int8])(nil)
type NumberArray[F numberField] pgtype.FlatArray[F]
// Scan implements the [database/sql.Scanner] interface.
func (a *NumberArray[F]) Scan(src any) (err error) {
var (
mapper func()
scanner sql.Scanner
)
//nolint: exhaustive
// only defined types
switch reflect.TypeOf(*a).Elem().Kind() {
case reflect.Int8:
mapper, scanner = castedScan[int8](a)
case reflect.Uint8:
// we provide int16 is a workaround because pgx thinks we want to scan a byte array if we provide uint8
mapper, scanner = castedScan[int16](a)
case reflect.Int16:
mapper, scanner = castedScan[int16](a)
case reflect.Uint16:
mapper, scanner = castedScan[uint16](a)
case reflect.Int32:
mapper, scanner = castedScan[int32](a)
case reflect.Uint32:
mapper, scanner = castedScan[uint32](a)
case reflect.Int64:
mapper, scanner = castedScan[int64](a)
case reflect.Uint64:
mapper, scanner = castedScan[uint64](a)
case reflect.Int:
mapper, scanner = castedScan[int](a)
case reflect.Uint:
mapper, scanner = castedScan[uint](a)
}
if err = scanner.Scan(src); err != nil {
return err
}
mapper()
return nil
}
func castedScan[T numberTypeField, F numberField](a *NumberArray[F]) (mapper func(), scanner sql.Scanner) {
var typedArray []T
mapper = func() {
(*a) = make(NumberArray[F], len(typedArray))
for i, value := range typedArray {
(*a)[i] = F(value)
}
}
scanner = pgtype.NewMap().SQLScanner(&typedArray)
return mapper, scanner
}
type Map[V any] map[string]V
// Scan implements the [database/sql.Scanner] interface.
func (m *Map[V]) Scan(src any) error {
if src == nil {
return nil
}
bytes := src.([]byte)
if len(bytes) == 0 {
return nil
}
return json.Unmarshal(bytes, &m)
}
// Value implements the [database/sql/driver.Valuer] interface.
func (m Map[V]) Value() (driver.Value, error) {
if len(m) == 0 {
return nil, nil
}
return json.Marshal(m)
}
type Duration time.Duration
// Scan implements the [database/sql.Scanner] interface.
func (d *Duration) Scan(src any) error {
switch duration := src.(type) {
case *time.Duration:
*d = Duration(*duration)
return nil
case time.Duration:
*d = Duration(duration)
return nil
case *pgtype.Interval:
*d = intervalToDuration(duration)
return nil
case pgtype.Interval:
*d = intervalToDuration(&duration)
return nil
case int64:
*d = Duration(duration)
return nil
}
interval := new(pgtype.Interval)
if err := interval.Scan(src); err != nil {
return err
}
*d = intervalToDuration(interval)
return nil
}
func intervalToDuration(interval *pgtype.Interval) Duration {
return Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour)
}
// NullDuration can be used for NULL intervals.
// If Valid is false, the scanned value was NULL
// This behavior is similar to [database/sql.NullString]
type NullDuration struct {
Valid bool
Duration time.Duration
}
// Scan implements the [database/sql.Scanner] interface.
func (d *NullDuration) Scan(src any) error {
if src == nil {
d.Duration, d.Valid = 0, false
return nil
}
duration := new(Duration)
if err := duration.Scan(src); err != nil {
return err
}
d.Duration, d.Valid = time.Duration(*duration), true
return nil
}
// JSONArray allows sending and receiving JSON arrays to and from the database.
// It implements the [database/sql.Scanner] and [database/sql/driver.Valuer] interfaces.
// Values are marshaled and unmarshaled using the [encoding/json] package.
type JSONArray[T any] []T
// NewJSONArray wraps an existing slice into a JSONArray.
func NewJSONArray[T any](a []T) JSONArray[T] {
return JSONArray[T](a)
}
// Scan implements the [database/sql.Scanner] interface.
func (a *JSONArray[T]) Scan(src any) error {
if src == nil {
*a = nil
return nil
}
bytes := src.([]byte)
if len(bytes) == 0 {
*a = nil
return nil
}
return json.Unmarshal(bytes, a)
}
// Value implements the [database/sql/driver.Valuer] interface.
func (a JSONArray[T]) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
return json.Marshal(a)
}

View File

@@ -0,0 +1,541 @@
package database
import (
"database/sql"
"database/sql/driver"
"testing"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMap_Scan(t *testing.T) {
type args struct {
src []byte
}
type res[V any] struct {
want Map[V]
err bool
}
type testCase[V any] struct {
name string
m Map[V]
args args
res[V]
}
tests := []testCase[string]{
{
"nil",
Map[string]{},
args{src: nil},
res[string]{
want: Map[string]{},
err: false,
},
},
{
"null",
Map[string]{},
args{src: []byte("invalid")},
res[string]{
want: Map[string]{},
err: true,
},
},
{
"null",
Map[string]{},
args{src: nil},
res[string]{
want: Map[string]{},
},
},
{
"empty",
Map[string]{},
args{src: []byte(`{}`)},
res[string]{
want: Map[string]{},
},
},
{
"set",
Map[string]{},
args{src: []byte(`{"key": "value"}`)},
res[string]{
want: Map[string]{
"key": "value",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
}
assert.Equal(t, tt.res.want, tt.m)
})
}
}
func TestMap_Value(t *testing.T) {
type res struct {
want driver.Value
err bool
}
type testCase[V any] struct {
name string
m Map[V]
res res
}
tests := []testCase[string]{
{
"nil",
nil,
res{
want: nil,
},
},
{
"empty",
Map[string]{},
res{
want: nil,
},
},
{
"set",
Map[string]{
"key": "value",
},
res{
want: driver.Value([]byte(`{"key":"value"}`)),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.m.Value()
if tt.res.err {
assert.Error(t, err)
}
if !tt.res.err {
require.NoError(t, err)
assert.Equalf(t, tt.res.want, got, "Value()")
}
})
}
}
type typedInt int
func TestNumberArray_Scan(t *testing.T) {
type args struct {
src any
}
type res struct {
want any
err bool
}
type testCase struct {
name string
m sql.Scanner
args args
res res
}
tests := []testCase{
{
name: "typedInt",
m: new(NumberArray[typedInt]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[typedInt]{1, 2},
},
},
{
name: "int8",
m: new(NumberArray[int8]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[int8]{1, 2},
},
},
{
name: "uint8",
m: new(NumberArray[uint8]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[uint8]{1, 2},
},
},
{
name: "int16",
m: new(NumberArray[int16]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[int16]{1, 2},
},
},
{
name: "uint16",
m: new(NumberArray[uint16]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[uint16]{1, 2},
},
},
{
name: "int32",
m: new(NumberArray[int32]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[int32]{1, 2},
},
},
{
name: "uint32",
m: new(NumberArray[uint32]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[uint32]{1, 2},
},
},
{
name: "int64",
m: new(NumberArray[int64]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[int64]{1, 2},
},
},
{
name: "uint64",
m: new(NumberArray[uint64]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[uint64]{1, 2},
},
},
{
name: "int",
m: new(NumberArray[int]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[int]{1, 2},
},
},
{
name: "uint",
m: new(NumberArray[uint]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[uint]{1, 2},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
}
assert.Equal(t, tt.res.want, tt.m)
})
}
}
type typedText string
func TestTextArray_Scan(t *testing.T) {
type args struct {
src any
}
type res struct {
want sql.Scanner
err bool
}
type testCase struct {
name string
m sql.Scanner
args args
res
}
tests := []testCase{
{
"string",
new(TextArray[string]),
args{src: "{asdf,fdas}"},
res{
want: &TextArray[string]{"asdf", "fdas"},
},
},
{
"typedText",
new(TextArray[typedText]),
args{src: "{asdf,fdas}"},
res{
want: &TextArray[typedText]{"asdf", "fdas"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
}
assert.Equal(t, tt.res.want, tt.m)
})
}
}
func TestTextArray_Value(t *testing.T) {
type res struct {
want driver.Value
err bool
}
type testCase struct {
name string
m driver.Valuer
res res
}
tests := []testCase{
{
"empty",
TextArray[string]{},
res{
want: nil,
},
},
{
"set",
TextArray[string]{"a", "s", "d", "f"},
res{
want: driver.Value([]byte("{a,s,d,f}")),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.m.Value()
if tt.res.err {
assert.Error(t, err)
}
if !tt.res.err {
require.NoError(t, err)
assert.Equalf(t, tt.res.want, got, "Value()")
}
})
}
}
type typedByte byte
func TestByteArray_Scan(t *testing.T) {
wantedBytes := []byte("asdf")
wantedTypedBytes := []typedByte("asdf")
type args struct {
src any
}
type res struct {
want sql.Scanner
err bool
}
type testCase struct {
name string
m sql.Scanner
args args
res
}
tests := []testCase{
{
"bytes",
new(ByteArray[byte]),
args{src: []byte("asdf")},
res{
want: (*ByteArray[byte])(&wantedBytes),
},
},
{
"typed",
new(ByteArray[typedByte]),
args{src: []byte("asdf")},
res{
want: (*ByteArray[typedByte])(&wantedTypedBytes),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
}
assert.Equal(t, tt.res.want, tt.m)
})
}
}
func TestByteArray_Value(t *testing.T) {
type res struct {
want driver.Value
err bool
}
type testCase struct {
name string
m driver.Valuer
res res
}
tests := []testCase{
{
"empty",
ByteArray[byte]{},
res{
want: nil,
},
},
{
"set",
ByteArray[byte]([]byte("{\"type\": \"object\", \"$schema\": \"urn:zitadel:schema:v1\"}")),
res{
want: driver.Value([]byte("{\"type\": \"object\", \"$schema\": \"urn:zitadel:schema:v1\"}")),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.m.Value()
if tt.res.err {
assert.Error(t, err)
}
if !tt.res.err {
require.NoError(t, err)
assert.Equalf(t, tt.res.want, got, "Value()")
}
})
}
}
func TestDuration_Scan(t *testing.T) {
duration := Duration(10)
type args struct {
src any
}
type res struct {
want sql.Scanner
err bool
}
type testCase[V ~string] struct {
name string
m sql.Scanner
args args
res
}
tests := []testCase[string]{
{
name: "int64",
m: new(Duration),
args: args{src: int64(duration)},
res: res{
want: &duration,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
}
assert.Equal(t, tt.res.want, tt.m)
})
}
}
func TestJSONArray_Scan(t *testing.T) {
type args struct {
src any
}
tests := []struct {
name string
args args
want *JSONArray[string]
wantErr bool
}{
{
name: "nil",
args: args{src: nil},
want: new(JSONArray[string]),
wantErr: false,
},
{
name: "zero bytes",
args: args{src: []byte("")},
want: new(JSONArray[string]),
wantErr: false,
},
{
name: "empty",
args: args{src: []byte("[]")},
want: gu.Ptr(JSONArray[string]{}),
wantErr: false,
},
{
name: "ok",
args: args{src: []byte("[\"a\", \"b\"]")},
want: gu.Ptr(JSONArray[string]{"a", "b"}),
wantErr: false,
},
{
name: "json error",
args: args{src: []byte("{\"a\": \"b\"}")},
want: new(JSONArray[string]),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := new(JSONArray[string])
err := got.Scan(tt.args.src)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestJSONArray_Value(t *testing.T) {
tests := []struct {
name string
a []string
want driver.Value
}{
{
name: "nil",
a: nil,
want: nil,
},
{
name: "empty",
a: []string{},
want: []byte("[]"),
},
{
name: "ok",
a: []string{"a", "b"},
want: []byte("[\"a\",\"b\"]"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewJSONArray(tt.a).Value()
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}