This commit is contained in:
adlerhurst
2025-03-06 17:12:23 +01:00
parent 594152581c
commit 67b22ef9c4
10 changed files with 582 additions and 180 deletions

View File

@@ -36,20 +36,12 @@ func (o *Object) Configure(v *viper.Viper) error {
continue continue
} }
structField := o.value.Field(i) structField := o.value.Field(i)
tag, err := newFieldTag(o.value.Type().Field(i), structField.Interface())
if err != nil {
return err
}
if tag.skip {
continue
}
f := Field{ f := Field{
tag: tag,
value: structField, value: structField,
structField: o.value.Type().Field(i), structField: o.value.Type().Field(i),
} }
err = f.Configure(v) err := f.Configure(v)
if err != nil { if err != nil {
return err return err
} }
@@ -85,9 +77,7 @@ func (f *Field) callCustom(v *viper.Viper) (ok bool, err error) {
if !f.value.Type().Implements(customType) { if !f.value.Type().Implements(customType) {
return false, nil return false, nil
} }
if f.value.IsNil() {
f.value.Set(reflect.New(f.value.Type().Elem()))
}
custom := f.value.Interface().(Custom) custom := f.value.Interface().(Custom)
value, err := custom.Configure() value, err := custom.Configure()
if err != nil { if err != nil {
@@ -108,6 +98,18 @@ func (f *Field) callCustom(v *viper.Viper) (ok bool, err error) {
} }
func (f *Field) Configure(v *viper.Viper) error { func (f *Field) Configure(v *viper.Viper) error {
if f.value.IsNil() {
f.value.Set(reflect.New(f.value.Type().Elem()))
}
tag, err := newFieldTag(f.structField, f.value.Interface())
if err != nil {
return err
}
if tag.skip {
return nil
}
if ok, err := f.callCustom(v); ok || err != nil { if ok, err := f.callCustom(v); ok || err != nil {
return err return err
} }

View File

@@ -0,0 +1,115 @@
package bla4
import (
"fmt"
"reflect"
"strings"
"github.com/spf13/viper"
)
type field struct {
info reflect.StructField
viper *viper.Viper
tag
}
func (f *field) label() string {
var builder strings.Builder
builder.WriteString(f.info.Name)
builder.WriteString(" (")
builder.WriteString(f.info.Type.Kind().String())
builder.WriteString(")")
if f.description != "" {
builder.WriteString(": ")
builder.WriteString(f.description)
}
return builder.String()
}
func (f *field) sub() *viper.Viper {
if !f.viper.IsSet(f.fieldName) {
f.viper.Set(f.fieldName, map[string]any{})
}
return f.viper.Sub(f.fieldName)
}
func (f *field) printStructInfo() {
var builder strings.Builder
builder.WriteString("------- ")
builder.WriteString(f.info.Name)
builder.WriteString(" -------")
if f.description != "" {
builder.WriteString(": ")
builder.WriteString(f.description)
}
fmt.Println(builder.String())
}
type tag struct {
skip bool
fieldName string
description string
value reflect.Value
}
const (
tagName = "configure"
defaultKey = "default"
descriptionKey = "description"
)
func newTag(field reflect.StructField, current reflect.Value) (config tag, err error) {
config.fieldName = field.Name
defer func() {
if !config.value.IsValid() {
if field.Type.Kind() == reflect.Pointer {
config.value = reflect.New(field.Type.Elem())
return
}
config.value = reflect.New(field.Type).Elem()
}
}()
if !current.IsZero() {
config.value = current
}
value, ok := field.Tag.Lookup(tagName)
if !ok {
return config, nil
}
if value == "-" {
config.skip = true
return config, nil
}
fields := strings.Split(value, ",")
for _, f := range fields {
configSplit := strings.Split(f, "=")
switch strings.ToLower(configSplit[0]) {
case defaultKey:
if !config.value.IsZero() {
continue
}
value, err := kindMapper(field.Type.Kind())(configSplit[1])
if err != nil {
return config, err
}
config.value.Set(reflect.ValueOf(value))
case descriptionKey:
config.description = configSplit[1]
}
}
return config, nil
}
func (tag tag) defaultValue() string {
if tag.value.IsZero() {
return ""
}
return fmt.Sprintf("%v", tag.value.Interface())
}

View File

@@ -0,0 +1,121 @@
package bla4
import (
"reflect"
"strconv"
"sync"
"time"
)
var writeMu sync.RWMutex
func SetTypeMapper(typ reflect.Type, mapper func(input string) (any, error)) {
writeMu.Lock()
defer writeMu.Unlock()
typeMappers[typ] = mapper
}
func SetTypeMapperFor[T any](mapper func(input string) (any, error)) {
writeMu.Lock()
defer writeMu.Unlock()
typeMappers[reflect.TypeFor[T]()] = mapper
}
func typeMapper(typ reflect.Type) func(input string) (any, error) {
writeMu.RLock()
defer writeMu.RUnlock()
return typeMappers[typ]
}
var typeMappers = map[reflect.Type]func(string) (any, error){
reflect.TypeFor[time.Duration](): func(input string) (any, error) {
return time.ParseDuration(input)
},
}
// SetKindMapper overwrites the mapper for the given kind.
func SetKindMapper(kind reflect.Kind, mapper func(input string) (any, error)) {
writeMu.Lock()
defer writeMu.Unlock()
kindMappers[kind] = mapper
}
func kindMapper(kind reflect.Kind) func(input string) (any, error) {
writeMu.RLock()
defer writeMu.RUnlock()
return kindMappers[kind]
}
var kindMappers = map[reflect.Kind]func(input string) (any, error){
reflect.String: func(input string) (any, error) {
return input, nil
},
reflect.Bool: func(input string) (any, error) {
return strconv.ParseBool(input)
},
reflect.Int: func(input string) (any, error) {
return strconv.Atoi(input)
},
reflect.Int8: func(input string) (val any, err error) {
val, err = strconv.ParseInt(input, 10, 8)
val = int8(val.(int64))
return val, err
},
reflect.Int16: func(input string) (val any, err error) {
val, err = strconv.ParseInt(input, 10, 16)
val = int16(val.(int64))
return val, err
},
reflect.Int32: func(input string) (val any, err error) {
val, err = strconv.ParseInt(input, 10, 32)
val = int32(val.(int64))
return val, err
},
reflect.Int64: func(input string) (any, error) {
return strconv.ParseInt(input, 10, 64)
},
reflect.Uint: func(input string) (val any, err error) {
val, err = strconv.ParseUint(input, 10, 0)
val = uint(val.(uint64))
return val, err
},
reflect.Uint8: func(input string) (val any, err error) {
val, err = strconv.ParseUint(input, 10, 8)
val = uint8(val.(uint64))
return val, err
},
reflect.Uint16: func(input string) (val any, err error) {
val, err = strconv.ParseUint(input, 10, 16)
val = uint16(val.(uint64))
return val, err
},
reflect.Uint32: func(input string) (val any, err error) {
val, err = strconv.ParseUint(input, 10, 32)
val = uint32(val.(uint64))
return val, err
},
reflect.Uint64: func(input string) (any, error) {
return strconv.ParseUint(input, 10, 64)
},
reflect.Float32: func(input string) (val any, err error) {
val, err = strconv.ParseFloat(input, 32)
val = float32(val.(float64))
return val, err
},
reflect.Float64: func(input string) (any, error) {
return strconv.ParseFloat(input, 64)
},
reflect.Complex64: func(input string) (val any, err error) {
val, err = strconv.ParseComplex(input, 64)
val = complex64(val.(complex128))
return val, err
},
reflect.Complex128: func(input string) (any, error) {
return strconv.ParseComplex(input, 128)
},
}

View File

@@ -0,0 +1,189 @@
package bla4
import (
"fmt"
"log/slog"
"os"
"reflect"
"github.com/manifoldco/promptui"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
var Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelDebug,
}))
type Configure func() (value any, err error)
type Configurer interface {
// Configure is called to configure the value.
// It must return the same type as itself. Otherwise [Update] will panic because it is not able to set the value.
Configure() (value any, err error)
}
func Update(v *viper.Viper, config any) func(cmd *cobra.Command, args []string) {
return func(cmd *cobra.Command, args []string) {
value := reflect.ValueOf(config)
structConfigures := structToConfigureMap(Logger, v, value)
for key, configure := range structConfigures {
result, err := configure()
if err != nil {
Logger.Error("error configuring field", slog.String("field", key), slog.Any("cause", err))
return
}
value.Elem().FieldByName(key).Set(reflect.ValueOf(result))
}
err := v.WriteConfig()
if err != nil {
Logger.Error("error writing config", slog.Any("cause", err))
os.Exit(1)
}
}
}
func structToConfigureMap(l *slog.Logger, v *viper.Viper, object reflect.Value) map[string]Configure {
if object.Kind() == reflect.Pointer {
if object.IsNil() {
l.Debug("initialize object")
object = reflect.New(object.Type().Elem())
}
return structToConfigureMap(l, v, object.Elem())
}
if object.Kind() != reflect.Struct {
panic("config must be a struct")
}
fields := make(map[string]Configure, object.NumField())
for i := range object.NumField() {
if !object.Type().Field(i).IsExported() {
continue
}
tag, err := newTag(object.Type().Field(i), object.Field(i))
if err != nil {
l.Error("failed to parse field tag", slog.Any("cause", err))
continue
}
if tag.skip {
l.Debug("skipping field", slog.String("field", object.Type().Field(i).Name))
continue
}
fields[object.Type().Field(i).Name] = fieldFunction(
l.With(slog.String("field", object.Type().Field(i).Name)),
&field{
info: object.Type().Field(i),
tag: tag,
viper: v,
},
)
}
return fields
}
func fieldFunction(l *slog.Logger, f *field) (configure Configure) {
for _, mapper := range []func(*slog.Logger, *field) Configure{
fieldFunctionByImplementation,
fieldFunctionByReflection,
} {
if configure = mapper(l, f); configure != nil {
return configure
}
}
panic(fmt.Sprintf("unsupported field type: %s", f.info.Type.String()))
}
func fieldFunctionByImplementation(l *slog.Logger, f *field) Configure {
if f.value.Type().Implements(reflect.TypeFor[Configurer]()) {
l.Debug("field is a custom implementation")
return func() (value any, err error) {
f.printStructInfo()
res, err := f.value.Interface().(Configurer).Configure()
if err != nil {
return nil, err
}
f.viper.Set(f.tag.fieldName, res)
return res, nil
}
}
return nil
}
func fieldFunctionByReflection(l *slog.Logger, f *field) Configure {
kind := f.value.Kind()
//nolint:exhaustive // only types that require special treatment are covered
switch kind {
case reflect.Pointer:
if f.value.IsNil() {
f.value.Set(reflect.New(f.value.Type().Elem()))
}
sub := f.sub()
m := structToConfigureMap(l, sub, f.value)
return func() (value any, err error) {
f.printStructInfo()
for key, configure := range m {
value, err = configure()
if err != nil {
return nil, err
}
f.value.Elem().FieldByName(key).Set(reflect.ValueOf(value))
}
f.viper.Set(f.tag.fieldName, sub.AllSettings())
return f.value.Interface(), nil
}
case reflect.Struct:
sub := f.sub()
m := structToConfigureMap(l, sub, f.value)
return func() (value any, err error) {
f.printStructInfo()
for key, configure := range m {
value, err = configure()
if err != nil {
return nil, err
}
f.value.FieldByName(key).Set(reflect.ValueOf(value))
}
f.viper.Set(f.tag.fieldName, sub.AllSettings())
return f.value.Interface(), nil
}
case reflect.Array, reflect.Slice, reflect.Map:
l.Warn("skipping because kind is unimplemented", slog.String("kind", kind.String()))
return nil
case reflect.Chan, reflect.Func, reflect.Interface, reflect.UnsafePointer, reflect.Invalid:
slog.Error("skipping because kind is unsupported", slog.String("kind", kind.String()))
return nil
}
mapper := typeMapper(f.info.Type)
if mapper == nil {
mapper = kindMapper(kind)
}
if mapper == nil {
l.Error("unsupported kind", slog.String("kind", kind.String()))
panic(fmt.Sprintf("unsupported kind: %s", kind.String()))
}
return func() (value any, err error) {
prompt := promptui.Prompt{
Label: f.label(),
Validate: func(input string) error {
value, err = mapper(input)
return err
},
Default: f.defaultValue(),
}
_, err = prompt.Run()
if err != nil {
return nil, err
}
f.viper.Set(f.tag.fieldName, value)
return value, nil
}
}

View File

@@ -6,7 +6,7 @@ import (
"github.com/zitadel/zitadel/backend/cmd/config" "github.com/zitadel/zitadel/backend/cmd/config"
"github.com/zitadel/zitadel/backend/cmd/configure" "github.com/zitadel/zitadel/backend/cmd/configure"
"github.com/zitadel/zitadel/backend/cmd/configure/bla2" "github.com/zitadel/zitadel/backend/cmd/configure/bla4"
step001 "github.com/zitadel/zitadel/backend/cmd/prepare/001" step001 "github.com/zitadel/zitadel/backend/cmd/prepare/001"
"github.com/zitadel/zitadel/backend/storage/database" "github.com/zitadel/zitadel/backend/storage/database"
"github.com/zitadel/zitadel/backend/storage/database/dialect" "github.com/zitadel/zitadel/backend/storage/database/dialect"
@@ -42,7 +42,7 @@ var (
// configuration.Fields(), // configuration.Fields(),
// ), // ),
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
bla2.Update(viper.GetViper(), &configuration)(cmd, args) bla4.Update(viper.GetViper(), &configuration)(cmd, args)
}, },
PreRun: configure.ReadConfigPreRun(viper.GetViper(), &configuration), PreRun: configure.ReadConfigPreRun(viper.GetViper(), &configuration),
} }
@@ -51,20 +51,21 @@ var (
type Config struct { type Config struct {
config.Config `mapstructure:",squash" configure:"-"` config.Config `mapstructure:",squash" configure:"-"`
Database dialect.Config `configure:"-"` Database *dialect.Config // `configure:"-"`
Step001 step001.Step001 Step001 step001.Step001
Step002 *step001.Step001
// runtime config // runtime config
Client database.Pool `mapstructure:"-" configure:"-"` Client database.Pool `mapstructure:"-" configure:"-"`
} }
func (c *Config) Hooks() (decoders []viper.DecoderConfigOption) { func (c *Config) Hooks() (decoders []viper.DecoderConfigOption) {
for _, hooks := range []configure.Unmarshaller{ // for _, hooks := range []configure.Unmarshaller{
c.Config, // c.Config,
c.Database, // c.Database,
} { // } {
decoders = append(decoders, hooks.Hooks()...) // decoders = append(decoders, hooks.Hooks()...)
} // }
return decoders return decoders
} }

View File

@@ -1,6 +1,8 @@
configuredversion: 2025.2.23
database: database:
postgres: host=local postgres: host=localhost user=zitadel password= dbname=zitadel sslmode=disable test=test
step001: step001:
databasename: zx;lvkj databasename: qwer
username: z.;nv.,mvnzx username: asdf
step002:
databasename: yuio
username: hjkl

View File

@@ -1,7 +1,12 @@
package database package database
import "context" import (
"context"
"github.com/zitadel/zitadel/backend/cmd/configure/bla4"
)
type Connector interface { type Connector interface {
Connect(ctx context.Context) (Pool, error) Connect(ctx context.Context) (Pool, error)
bla4.Configurer
} }

View File

@@ -5,23 +5,20 @@ import (
"errors" "errors"
"reflect" "reflect"
"github.com/manifoldco/promptui"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/zitadel/zitadel/backend/cmd/config" "github.com/zitadel/zitadel/backend/cmd/configure/bla4"
"github.com/zitadel/zitadel/backend/cmd/configure"
"github.com/zitadel/zitadel/backend/cmd/configure/bla"
"github.com/zitadel/zitadel/backend/storage/database" "github.com/zitadel/zitadel/backend/storage/database"
"github.com/zitadel/zitadel/backend/storage/database/dialect/gosql"
"github.com/zitadel/zitadel/backend/storage/database/dialect/postgres" "github.com/zitadel/zitadel/backend/storage/database/dialect/postgres"
) )
type Hook struct { type Hook struct {
Match func(string) bool Match func(string) bool
Decode func(name string, config any) (database.Connector, error) Decode func(config any) (database.Connector, error)
Name string Name string
Field configure.Updater Constructor func() database.Connector
Constructor func() any
} }
var hooks = []Hook{ var hooks = []Hook{
@@ -29,24 +26,65 @@ var hooks = []Hook{
Match: postgres.NameMatcher, Match: postgres.NameMatcher,
Decode: postgres.DecodeConfig, Decode: postgres.DecodeConfig,
Name: postgres.Name, Name: postgres.Name,
Field: postgres.Field, Constructor: func() database.Connector { return new(postgres.Config) },
Constructor: func() any { return new(postgres.Config) },
},
{
Match: gosql.NameMatcher,
Decode: gosql.DecodeConfig,
Name: gosql.Name,
Field: gosql.Field,
Constructor: func() any { return new(gosql.Config) },
}, },
// {
// Match: gosql.NameMatcher,
// Decode: gosql.DecodeConfig,
// Name: gosql.Name,
// Constructor: func() database.Connector { return new(gosql.Config) },
// },
} }
type Config struct { type Config struct {
Dialects dialects `mapstructure:",remain"` Dialects map[string]any `mapstructure:",remain" yaml:",inline"`
connector database.Connector connector database.Connector
} }
// Configure implements [configure.Configurer].
func (c *Config) Configure() (any, error) {
possibilities := make([]string, len(hooks))
var cursor int
for i, hook := range hooks {
if _, ok := c.Dialects[hook.Name]; ok {
cursor = i
}
possibilities[i] = hook.Name
}
prompt := promptui.Select{
Label: "Select a dialect",
Items: possibilities,
CursorPos: cursor,
}
i, _, err := prompt.Run()
if err != nil {
return nil, err
}
var config bla4.Configurer
if dialect, ok := c.Dialects[hooks[i].Name]; ok {
config, err = hooks[i].Decode(dialect)
if err != nil {
return nil, err
}
} else {
clear(c.Dialects)
config = hooks[i].Constructor()
}
if c.Dialects == nil {
c.Dialects = make(map[string]any)
}
c.Dialects[hooks[i].Name], err = config.Configure()
if err != nil {
return nil, err
}
return c, nil
}
func (c Config) Connect(ctx context.Context) (database.Pool, error) { func (c Config) Connect(ctx context.Context) (database.Pool, error) {
if len(c.Dialects) != 1 { if len(c.Dialects) != 1 {
return nil, errors.New("Exactly one dialect must be configured") return nil, errors.New("Exactly one dialect must be configured")
@@ -62,12 +100,6 @@ func (c Config) Hooks() []viper.DecoderConfigOption {
} }
} }
// var _ configure.StructUpdater = (*Config)(nil)
func (c Config) Configure(v *viper.Viper, currentVersion config.Version) Config {
return c
}
func (c *Config) decodeDialect() error { func (c *Config) decodeDialect() error {
for _, hook := range hooks { for _, hook := range hooks {
for name, config := range c.Dialects { for name, config := range c.Dialects {
@@ -75,7 +107,7 @@ func (c *Config) decodeDialect() error {
continue continue
} }
connector, err := hook.Decode(name, config) connector, err := hook.Decode(config)
if err != nil { if err != nil {
return err return err
} }
@@ -87,7 +119,7 @@ func (c *Config) decodeDialect() error {
return errors.New("no dialect found") return errors.New("no dialect found")
} }
func decodeHook(from, to reflect.Value) (_ interface{}, err error) { func decodeHook(from, to reflect.Value) (_ any, err error) {
if to.Type() != reflect.TypeOf(Config{}) { if to.Type() != reflect.TypeOf(Config{}) {
return from.Interface(), nil return from.Interface(), nil
} }
@@ -103,21 +135,3 @@ func decodeHook(from, to reflect.Value) (_ interface{}, err error) {
return config, nil return config, nil
} }
type dialects map[string]any
// ConfigForIndex implements [bla.OneOfField].
func (d dialects) ConfigForIndex(i int) any {
return hooks[i].Constructor()
}
// Possibilities implements [bla.OneOfField].
func (d dialects) Possibilities() []string {
possibilities := make([]string, len(hooks))
for i, hook := range hooks {
possibilities[i] = hook.Name
}
return possibilities
}
var _ bla.OneOfField = (dialects)(nil)

View File

@@ -6,10 +6,6 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/Masterminds/semver/v3"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/cmd/configure"
"github.com/zitadel/zitadel/backend/storage/database" "github.com/zitadel/zitadel/backend/storage/database"
) )
@@ -17,14 +13,6 @@ var (
_ database.Connector = (*Config)(nil) _ database.Connector = (*Config)(nil)
Name = "gosql" Name = "gosql"
Field = &configure.Field[string]{
Description: "Connection string",
Version: semver.MustParse("v3"),
Validate: func(s string) error {
_, err := pgxpool.ParseConfig(s)
return err
},
}
) )
type Config struct { type Config struct {

View File

@@ -6,117 +6,70 @@ import (
"slices" "slices"
"strings" "strings"
"github.com/Masterminds/semver/v3"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/manifoldco/promptui"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/zitadel/backend/cmd/configure"
"github.com/zitadel/zitadel/backend/cmd/configure/bla"
"github.com/zitadel/zitadel/backend/storage/database" "github.com/zitadel/zitadel/backend/storage/database"
) )
var ( var (
_ database.Connector = (*Config)(nil) _ database.Connector = (*Config)(nil)
Name = "postgres" Name = "postgres"
Field = &configure.OneOf{
Description: "Configuring postgres using one of the following options",
SubFields: []configure.Updater{
&configure.Field[string]{
Description: "Connection string",
Version: semver.MustParse("v3"),
Validate: func(s string) error {
_, err := pgxpool.ParseConfig(s)
return err
},
},
&configure.Struct{
Description: "Configuration for the connection",
SubFields: []configure.Updater{
&configure.Field[string]{
FieldName: "host",
Value: "localhost",
Description: "The host to connect to",
Version: semver.MustParse("3"),
},
&configure.Field[uint32]{
FieldName: "port",
Value: 5432,
Description: "The port to connect to",
Version: semver.MustParse("3"),
},
&configure.Field[string]{
FieldName: "database",
Value: "zitadel",
Description: "The database to connect to",
Version: semver.MustParse("3"),
},
&configure.Field[string]{
FieldName: "user",
Description: "The user to connect as",
Value: "zitadel",
Version: semver.MustParse("3"),
},
&configure.Field[string]{
FieldName: "password",
Description: "The password to connect with",
Version: semver.MustParse("3"),
HideInput: true,
},
&configure.OneOf{
FieldName: "sslMode",
Description: "The SSL mode to use",
SubFields: []configure.Updater{
&configure.Constant[string]{
Description: "Disable",
Constant: "disable",
Version: semver.MustParse("3"),
},
&configure.Constant[string]{
Description: "Require",
Constant: "require",
Version: semver.MustParse("3"),
},
&configure.Constant[string]{
Description: "Verify CA",
Constant: "verify-ca",
Version: semver.MustParse("3"),
},
&configure.Constant[string]{
Description: "Verify Full",
Constant: "verify-full",
Version: semver.MustParse("3"),
},
},
},
},
},
},
}
) )
type Config struct{ pgxpool.Config } type Config struct {
config *pgxpool.Config
// ConfigForIndex implements bla.OneOfField. // Host string
func (c Config) ConfigForIndex(i int) any { // Port int32
switch i { // Database string
case 0: // EventPushConnRatio float64
return new(string) // MaxOpenConns uint32
case 1: // MaxIdleConns uint32
return &c.Config // MaxConnLifetime time.Duration
} // MaxConnIdleTime time.Duration
return nil // User User
// Admin AdminUser
// // Additional options to be appended as options=<Options>
// // The value will be taken as is. Multiple options are space separated.
// Options string
} }
// Possibilities implements bla.OneOfField. // Configure implements bla3.Custom.
func (c Config) Possibilities() []string { func (c *Config) Configure() (value any, err error) {
return []string{"connection string", "fields"} typeSelect := promptui.Select{
Label: "Configure the database connection",
Items: []string{"connection string", "fields"},
}
i, _, err := typeSelect.Run()
if err != nil {
return nil, err
}
if i > 0 {
return nil, nil
} }
var _ bla.OneOfField = (*Config)(nil) if c.config == nil {
c.config, _ = pgxpool.ParseConfig("host=localhost user=zitadel password= dbname=zitadel sslmode=disable")
}
prompt := promptui.Prompt{
Label: "Connection string",
Default: c.config.ConnString(),
AllowEdit: c.config.ConnString() != "",
Validate: func(input string) error {
_, err := pgxpool.ParseConfig(input)
return err
},
}
return prompt.Run()
}
// Connect implements [database.Connector]. // Connect implements [database.Connector].
func (c *Config) Connect(ctx context.Context) (database.Pool, error) { func (c *Config) Connect(ctx context.Context) (database.Pool, error) {
pool, err := pgxpool.NewWithConfig(ctx, &c.Config) pool, err := pgxpool.NewWithConfig(ctx, c.config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -130,17 +83,29 @@ func NameMatcher(name string) bool {
return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name)) return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name))
} }
func DecodeConfig(_ string, config any) (database.Connector, error) { func DecodeConfig(input any) (database.Connector, error) {
switch c := config.(type) { switch c := input.(type) {
case string: case string:
config, err := pgxpool.ParseConfig(c) config, err := pgxpool.ParseConfig(c)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Config{Config: *config}, nil return &Config{config: config}, nil
case map[string]any: case map[string]any:
connector := new(Config)
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
WeaklyTypedInput: true,
Result: connector,
})
if err != nil {
return nil, err
}
if err = decoder.Decode(c); err != nil {
return nil, err
}
return &Config{ return &Config{
Config: pgxpool.Config{}, config: &pgxpool.Config{},
}, nil }, nil
} }
return nil, errors.New("invalid configuration") return nil, errors.New("invalid configuration")