From 67b22ef9c4b762b6a6a0daadd94fd1f8d6e27c31 Mon Sep 17 00:00:00 2001 From: adlerhurst <27845747+adlerhurst@users.noreply.github.com> Date: Thu, 6 Mar 2025 17:12:23 +0100 Subject: [PATCH] v7 --- backend/cmd/configure/bla3/prompt.go | 26 +-- backend/cmd/configure/bla4/field_tag.go | 115 +++++++++++ backend/cmd/configure/bla4/mapper.go | 121 +++++++++++ backend/cmd/configure/bla4/update_config7.go | 189 ++++++++++++++++++ backend/cmd/prepare/config.go | 19 +- backend/cmd/test.yaml | 10 +- backend/storage/database/config.go | 7 +- backend/storage/database/dialect/config.go | 100 +++++---- .../storage/database/dialect/gosql/config.go | 14 +- .../database/dialect/postgres/config.go | 161 ++++++--------- 10 files changed, 582 insertions(+), 180 deletions(-) create mode 100644 backend/cmd/configure/bla4/field_tag.go create mode 100644 backend/cmd/configure/bla4/mapper.go create mode 100644 backend/cmd/configure/bla4/update_config7.go diff --git a/backend/cmd/configure/bla3/prompt.go b/backend/cmd/configure/bla3/prompt.go index e45a3a6da5..1ba0fd8f79 100644 --- a/backend/cmd/configure/bla3/prompt.go +++ b/backend/cmd/configure/bla3/prompt.go @@ -36,20 +36,12 @@ func (o *Object) Configure(v *viper.Viper) error { continue } structField := o.value.Field(i) - tag, err := newFieldTag(o.value.Type().Field(i), structField.Interface()) - if err != nil { - return err - } - if tag.skip { - continue - } f := Field{ - tag: tag, value: structField, structField: o.value.Type().Field(i), } - err = f.Configure(v) + err := f.Configure(v) if err != nil { return err } @@ -85,9 +77,7 @@ func (f *Field) callCustom(v *viper.Viper) (ok bool, err error) { if !f.value.Type().Implements(customType) { return false, nil } - if f.value.IsNil() { - f.value.Set(reflect.New(f.value.Type().Elem())) - } + custom := f.value.Interface().(Custom) value, err := custom.Configure() if err != nil { @@ -108,6 +98,18 @@ func (f *Field) callCustom(v *viper.Viper) (ok bool, err error) { } func (f *Field) Configure(v *viper.Viper) error { + if f.value.IsNil() { + f.value.Set(reflect.New(f.value.Type().Elem())) + } + + tag, err := newFieldTag(f.structField, f.value.Interface()) + if err != nil { + return err + } + if tag.skip { + return nil + } + if ok, err := f.callCustom(v); ok || err != nil { return err } diff --git a/backend/cmd/configure/bla4/field_tag.go b/backend/cmd/configure/bla4/field_tag.go new file mode 100644 index 0000000000..1a2ccbc9a0 --- /dev/null +++ b/backend/cmd/configure/bla4/field_tag.go @@ -0,0 +1,115 @@ +package bla4 + +import ( + "fmt" + "reflect" + "strings" + + "github.com/spf13/viper" +) + +type field struct { + info reflect.StructField + viper *viper.Viper + + tag +} + +func (f *field) label() string { + var builder strings.Builder + builder.WriteString(f.info.Name) + builder.WriteString(" (") + builder.WriteString(f.info.Type.Kind().String()) + builder.WriteString(")") + if f.description != "" { + builder.WriteString(": ") + builder.WriteString(f.description) + } + return builder.String() +} + +func (f *field) sub() *viper.Viper { + if !f.viper.IsSet(f.fieldName) { + f.viper.Set(f.fieldName, map[string]any{}) + } + return f.viper.Sub(f.fieldName) +} + +func (f *field) printStructInfo() { + var builder strings.Builder + builder.WriteString("------- ") + builder.WriteString(f.info.Name) + builder.WriteString(" -------") + if f.description != "" { + builder.WriteString(": ") + builder.WriteString(f.description) + } + fmt.Println(builder.String()) +} + +type tag struct { + skip bool + + fieldName string + description string + + value reflect.Value +} + +const ( + tagName = "configure" + defaultKey = "default" + descriptionKey = "description" +) + +func newTag(field reflect.StructField, current reflect.Value) (config tag, err error) { + config.fieldName = field.Name + defer func() { + if !config.value.IsValid() { + if field.Type.Kind() == reflect.Pointer { + config.value = reflect.New(field.Type.Elem()) + return + } + config.value = reflect.New(field.Type).Elem() + } + }() + if !current.IsZero() { + config.value = current + } + + value, ok := field.Tag.Lookup(tagName) + if !ok { + return config, nil + } + if value == "-" { + config.skip = true + return config, nil + } + + fields := strings.Split(value, ",") + for _, f := range fields { + configSplit := strings.Split(f, "=") + switch strings.ToLower(configSplit[0]) { + case defaultKey: + if !config.value.IsZero() { + continue + } + value, err := kindMapper(field.Type.Kind())(configSplit[1]) + if err != nil { + return config, err + } + config.value.Set(reflect.ValueOf(value)) + case descriptionKey: + config.description = configSplit[1] + } + } + + return config, nil +} + +func (tag tag) defaultValue() string { + if tag.value.IsZero() { + return "" + } + return fmt.Sprintf("%v", tag.value.Interface()) +} diff --git a/backend/cmd/configure/bla4/mapper.go b/backend/cmd/configure/bla4/mapper.go new file mode 100644 index 0000000000..8776e6122c --- /dev/null +++ b/backend/cmd/configure/bla4/mapper.go @@ -0,0 +1,121 @@ +package bla4 + +import ( + "reflect" + "strconv" + "sync" + "time" +) + +var writeMu sync.RWMutex + +func SetTypeMapper(typ reflect.Type, mapper func(input string) (any, error)) { + writeMu.Lock() + defer writeMu.Unlock() + + typeMappers[typ] = mapper +} + +func SetTypeMapperFor[T any](mapper func(input string) (any, error)) { + writeMu.Lock() + defer writeMu.Unlock() + + typeMappers[reflect.TypeFor[T]()] = mapper +} + +func typeMapper(typ reflect.Type) func(input string) (any, error) { + writeMu.RLock() + defer writeMu.RUnlock() + + return typeMappers[typ] +} + +var typeMappers = map[reflect.Type]func(string) (any, error){ + reflect.TypeFor[time.Duration](): func(input string) (any, error) { + return time.ParseDuration(input) + }, +} + +// SetKindMapper overwrites the mapper for the given kind. +func SetKindMapper(kind reflect.Kind, mapper func(input string) (any, error)) { + writeMu.Lock() + defer writeMu.Unlock() + + kindMappers[kind] = mapper +} + +func kindMapper(kind reflect.Kind) func(input string) (any, error) { + writeMu.RLock() + defer writeMu.RUnlock() + + return kindMappers[kind] +} + +var kindMappers = map[reflect.Kind]func(input string) (any, error){ + reflect.String: func(input string) (any, error) { + return input, nil + }, + reflect.Bool: func(input string) (any, error) { + return strconv.ParseBool(input) + }, + reflect.Int: func(input string) (any, error) { + return strconv.Atoi(input) + }, + reflect.Int8: func(input string) (val any, err error) { + val, err = strconv.ParseInt(input, 10, 8) + val = int8(val.(int64)) + return val, err + }, + reflect.Int16: func(input string) (val any, err error) { + val, err = strconv.ParseInt(input, 10, 16) + val = int16(val.(int64)) + return val, err + }, + reflect.Int32: func(input string) (val any, err error) { + val, err = strconv.ParseInt(input, 10, 32) + val = int32(val.(int64)) + return val, err + }, + reflect.Int64: func(input string) (any, error) { + return strconv.ParseInt(input, 10, 64) + }, + reflect.Uint: func(input string) (val any, err error) { + val, err = strconv.ParseUint(input, 10, 0) + val = uint(val.(uint64)) + return val, err + }, + reflect.Uint8: func(input string) (val any, err error) { + val, err = strconv.ParseUint(input, 10, 8) + val = uint8(val.(uint64)) + return val, err + }, + reflect.Uint16: func(input string) (val any, err error) { + val, err = strconv.ParseUint(input, 10, 16) + val = uint16(val.(uint64)) + return val, err + }, + reflect.Uint32: func(input string) (val any, err error) { + val, err = strconv.ParseUint(input, 10, 32) + val = uint32(val.(uint64)) + return val, err + }, + reflect.Uint64: func(input string) (any, error) { + return strconv.ParseUint(input, 10, 64) + }, + reflect.Float32: func(input string) (val any, err error) { + val, err = strconv.ParseFloat(input, 32) + val = float32(val.(float64)) + return val, err + }, + reflect.Float64: func(input string) (any, error) { + return strconv.ParseFloat(input, 64) + }, + reflect.Complex64: func(input string) (val any, err error) { + val, err = strconv.ParseComplex(input, 64) + val = complex64(val.(complex128)) + return val, err + }, + reflect.Complex128: func(input string) (any, error) { + return strconv.ParseComplex(input, 128) + }, +} diff --git a/backend/cmd/configure/bla4/update_config7.go b/backend/cmd/configure/bla4/update_config7.go new file mode 100644 index 0000000000..57ca00bb2d --- /dev/null +++ b/backend/cmd/configure/bla4/update_config7.go @@ -0,0 +1,189 @@ +package bla4 + +import ( + "fmt" + "log/slog" + "os" + "reflect" + + "github.com/manifoldco/promptui" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, +})) + +type Configure func() (value any, err error) + +type Configurer interface { + // Configure is called to configure the value. + // It must return the same type as itself. Otherwise [Update] will panic because it is not able to set the value. + Configure() (value any, err error) +} + +func Update(v *viper.Viper, config any) func(cmd *cobra.Command, args []string) { + return func(cmd *cobra.Command, args []string) { + value := reflect.ValueOf(config) + structConfigures := structToConfigureMap(Logger, v, value) + for key, configure := range structConfigures { + result, err := configure() + if err != nil { + Logger.Error("error configuring field", slog.String("field", key), slog.Any("cause", err)) + return + } + value.Elem().FieldByName(key).Set(reflect.ValueOf(result)) + } + + err := v.WriteConfig() + if err != nil { + Logger.Error("error writing config", slog.Any("cause", err)) + os.Exit(1) + } + } +} + +func structToConfigureMap(l *slog.Logger, v *viper.Viper, object reflect.Value) map[string]Configure { + if object.Kind() == reflect.Pointer { + if object.IsNil() { + l.Debug("initialize object") + object = reflect.New(object.Type().Elem()) + } + return structToConfigureMap(l, v, object.Elem()) + } + if object.Kind() != reflect.Struct { + panic("config must be a struct") + } + + fields := make(map[string]Configure, object.NumField()) + + for i := range object.NumField() { + if !object.Type().Field(i).IsExported() { + continue + } + + tag, err := newTag(object.Type().Field(i), object.Field(i)) + if err != nil { + l.Error("failed to parse field tag", slog.Any("cause", err)) + continue + } + if tag.skip { + l.Debug("skipping field", slog.String("field", object.Type().Field(i).Name)) + continue + } + + fields[object.Type().Field(i).Name] = fieldFunction( + l.With(slog.String("field", object.Type().Field(i).Name)), + &field{ + info: object.Type().Field(i), + tag: tag, + viper: v, + }, + ) + } + return fields +} + +func fieldFunction(l *slog.Logger, f *field) (configure Configure) { + for _, mapper := range []func(*slog.Logger, *field) Configure{ + fieldFunctionByImplementation, + fieldFunctionByReflection, + } { + if configure = mapper(l, f); configure != nil { + return configure + } + } + panic(fmt.Sprintf("unsupported field type: %s", f.info.Type.String())) +} + +func fieldFunctionByImplementation(l *slog.Logger, f *field) Configure { + if f.value.Type().Implements(reflect.TypeFor[Configurer]()) { + l.Debug("field is a custom implementation") + return func() (value any, err error) { + f.printStructInfo() + res, err := f.value.Interface().(Configurer).Configure() + if err != nil { + return nil, err + } + f.viper.Set(f.tag.fieldName, res) + return res, nil + } + } + return nil +} + +func fieldFunctionByReflection(l *slog.Logger, f *field) Configure { + kind := f.value.Kind() + + //nolint:exhaustive // only types that require special treatment are covered + switch kind { + case reflect.Pointer: + if f.value.IsNil() { + f.value.Set(reflect.New(f.value.Type().Elem())) + } + sub := f.sub() + m := structToConfigureMap(l, sub, f.value) + return func() (value any, err error) { + f.printStructInfo() + for key, configure := range m { + value, err = configure() + if err != nil { + return nil, err + } + f.value.Elem().FieldByName(key).Set(reflect.ValueOf(value)) + } + f.viper.Set(f.tag.fieldName, sub.AllSettings()) + return f.value.Interface(), nil + } + case reflect.Struct: + sub := f.sub() + m := structToConfigureMap(l, sub, f.value) + return func() (value any, err error) { + f.printStructInfo() + for key, configure := range m { + value, err = configure() + if err != nil { + return nil, err + } + f.value.FieldByName(key).Set(reflect.ValueOf(value)) + } + f.viper.Set(f.tag.fieldName, sub.AllSettings()) + return f.value.Interface(), nil + } + + case reflect.Array, reflect.Slice, reflect.Map: + l.Warn("skipping because kind is unimplemented", slog.String("kind", kind.String())) + return nil + case reflect.Chan, reflect.Func, reflect.Interface, reflect.UnsafePointer, reflect.Invalid: + slog.Error("skipping because kind is unsupported", slog.String("kind", kind.String())) + return nil + } + + mapper := typeMapper(f.info.Type) + if mapper == nil { + mapper = kindMapper(kind) + } + if mapper == nil { + l.Error("unsupported kind", slog.String("kind", kind.String())) + panic(fmt.Sprintf("unsupported kind: %s", kind.String())) + } + + return func() (value any, err error) { + prompt := promptui.Prompt{ + Label: f.label(), + Validate: func(input string) error { + value, err = mapper(input) + return err + }, + Default: f.defaultValue(), + } + _, err = prompt.Run() + if err != nil { + return nil, err + } + f.viper.Set(f.tag.fieldName, value) + return value, nil + } +} diff --git a/backend/cmd/prepare/config.go b/backend/cmd/prepare/config.go index a7e9c34d15..3c22efea14 100644 --- a/backend/cmd/prepare/config.go +++ b/backend/cmd/prepare/config.go @@ -6,7 +6,7 @@ import ( "github.com/zitadel/zitadel/backend/cmd/config" "github.com/zitadel/zitadel/backend/cmd/configure" - "github.com/zitadel/zitadel/backend/cmd/configure/bla2" + "github.com/zitadel/zitadel/backend/cmd/configure/bla4" step001 "github.com/zitadel/zitadel/backend/cmd/prepare/001" "github.com/zitadel/zitadel/backend/storage/database" "github.com/zitadel/zitadel/backend/storage/database/dialect" @@ -42,7 +42,7 @@ var ( // configuration.Fields(), // ), Run: func(cmd *cobra.Command, args []string) { - bla2.Update(viper.GetViper(), &configuration)(cmd, args) + bla4.Update(viper.GetViper(), &configuration)(cmd, args) }, PreRun: configure.ReadConfigPreRun(viper.GetViper(), &configuration), } @@ -51,20 +51,21 @@ var ( type Config struct { config.Config `mapstructure:",squash" configure:"-"` - Database dialect.Config `configure:"-"` + Database *dialect.Config // `configure:"-"` Step001 step001.Step001 + Step002 *step001.Step001 // runtime config Client database.Pool `mapstructure:"-" configure:"-"` } func (c *Config) Hooks() (decoders []viper.DecoderConfigOption) { - for _, hooks := range []configure.Unmarshaller{ - c.Config, - c.Database, - } { - decoders = append(decoders, hooks.Hooks()...) - } + // for _, hooks := range []configure.Unmarshaller{ + // c.Config, + // c.Database, + // } { + // decoders = append(decoders, hooks.Hooks()...) + // } return decoders } diff --git a/backend/cmd/test.yaml b/backend/cmd/test.yaml index 65b25d458a..094bc2899c 100644 --- a/backend/cmd/test.yaml +++ b/backend/cmd/test.yaml @@ -1,6 +1,8 @@ -configuredversion: 2025.2.23 database: - postgres: host=local + postgres: host=localhost user=zitadel password= dbname=zitadel sslmode=disable test=test step001: - databasename: zx;lvkj - username: z.;nv.,mvnzx + databasename: qwer + username: asdf +step002: + databasename: yuio + username: hjkl diff --git a/backend/storage/database/config.go b/backend/storage/database/config.go index 56c86b2b2d..cb2b28ab94 100644 --- a/backend/storage/database/config.go +++ b/backend/storage/database/config.go @@ -1,7 +1,12 @@ package database -import "context" +import ( + "context" + + "github.com/zitadel/zitadel/backend/cmd/configure/bla4" +) type Connector interface { Connect(ctx context.Context) (Pool, error) + bla4.Configurer } diff --git a/backend/storage/database/dialect/config.go b/backend/storage/database/dialect/config.go index 26a40bfe83..4a7ac35c77 100644 --- a/backend/storage/database/dialect/config.go +++ b/backend/storage/database/dialect/config.go @@ -5,23 +5,20 @@ import ( "errors" "reflect" + "github.com/manifoldco/promptui" "github.com/mitchellh/mapstructure" "github.com/spf13/viper" - "github.com/zitadel/zitadel/backend/cmd/config" - "github.com/zitadel/zitadel/backend/cmd/configure" - "github.com/zitadel/zitadel/backend/cmd/configure/bla" + "github.com/zitadel/zitadel/backend/cmd/configure/bla4" "github.com/zitadel/zitadel/backend/storage/database" - "github.com/zitadel/zitadel/backend/storage/database/dialect/gosql" "github.com/zitadel/zitadel/backend/storage/database/dialect/postgres" ) type Hook struct { Match func(string) bool - Decode func(name string, config any) (database.Connector, error) + Decode func(config any) (database.Connector, error) Name string - Field configure.Updater - Constructor func() any + Constructor func() database.Connector } var hooks = []Hook{ @@ -29,24 +26,65 @@ var hooks = []Hook{ Match: postgres.NameMatcher, Decode: postgres.DecodeConfig, Name: postgres.Name, - Field: postgres.Field, - Constructor: func() any { return new(postgres.Config) }, - }, - { - Match: gosql.NameMatcher, - Decode: gosql.DecodeConfig, - Name: gosql.Name, - Field: gosql.Field, - Constructor: func() any { return new(gosql.Config) }, + Constructor: func() database.Connector { return new(postgres.Config) }, }, + // { + // Match: gosql.NameMatcher, + // Decode: gosql.DecodeConfig, + // Name: gosql.Name, + // Constructor: func() database.Connector { return new(gosql.Config) }, + // }, } type Config struct { - Dialects dialects `mapstructure:",remain"` + Dialects map[string]any `mapstructure:",remain" yaml:",inline"` connector database.Connector } +// Configure implements [configure.Configurer]. +func (c *Config) Configure() (any, error) { + possibilities := make([]string, len(hooks)) + var cursor int + for i, hook := range hooks { + if _, ok := c.Dialects[hook.Name]; ok { + cursor = i + } + possibilities[i] = hook.Name + } + + prompt := promptui.Select{ + Label: "Select a dialect", + Items: possibilities, + CursorPos: cursor, + } + i, _, err := prompt.Run() + if err != nil { + return nil, err + } + + var config bla4.Configurer + + if dialect, ok := c.Dialects[hooks[i].Name]; ok { + config, err = hooks[i].Decode(dialect) + if err != nil { + return nil, err + } + } else { + clear(c.Dialects) + config = hooks[i].Constructor() + } + if c.Dialects == nil { + c.Dialects = make(map[string]any) + } + c.Dialects[hooks[i].Name], err = config.Configure() + if err != nil { + return nil, err + } + + return c, nil +} + func (c Config) Connect(ctx context.Context) (database.Pool, error) { if len(c.Dialects) != 1 { return nil, errors.New("Exactly one dialect must be configured") @@ -62,12 +100,6 @@ func (c Config) Hooks() []viper.DecoderConfigOption { } } -// var _ configure.StructUpdater = (*Config)(nil) - -func (c Config) Configure(v *viper.Viper, currentVersion config.Version) Config { - return c -} - func (c *Config) decodeDialect() error { for _, hook := range hooks { for name, config := range c.Dialects { @@ -75,7 +107,7 @@ func (c *Config) decodeDialect() error { continue } - connector, err := hook.Decode(name, config) + connector, err := hook.Decode(config) if err != nil { return err } @@ -87,7 +119,7 @@ func (c *Config) decodeDialect() error { return errors.New("no dialect found") } -func decodeHook(from, to reflect.Value) (_ interface{}, err error) { +func decodeHook(from, to reflect.Value) (_ any, err error) { if to.Type() != reflect.TypeOf(Config{}) { return from.Interface(), nil } @@ -103,21 +135,3 @@ func decodeHook(from, to reflect.Value) (_ interface{}, err error) { return config, nil } - -type dialects map[string]any - -// ConfigForIndex implements [bla.OneOfField]. -func (d dialects) ConfigForIndex(i int) any { - return hooks[i].Constructor() -} - -// Possibilities implements [bla.OneOfField]. -func (d dialects) Possibilities() []string { - possibilities := make([]string, len(hooks)) - for i, hook := range hooks { - possibilities[i] = hook.Name - } - return possibilities -} - -var _ bla.OneOfField = (dialects)(nil) diff --git a/backend/storage/database/dialect/gosql/config.go b/backend/storage/database/dialect/gosql/config.go index 7e2d9b12c7..77648ee236 100644 --- a/backend/storage/database/dialect/gosql/config.go +++ b/backend/storage/database/dialect/gosql/config.go @@ -6,25 +6,13 @@ import ( "errors" "strings" - "github.com/Masterminds/semver/v3" - "github.com/jackc/pgx/v5/pgxpool" - - "github.com/zitadel/zitadel/backend/cmd/configure" "github.com/zitadel/zitadel/backend/storage/database" ) var ( _ database.Connector = (*Config)(nil) - Name = "gosql" - Field = &configure.Field[string]{ - Description: "Connection string", - Version: semver.MustParse("v3"), - Validate: func(s string) error { - _, err := pgxpool.ParseConfig(s) - return err - }, - } + Name = "gosql" ) type Config struct { diff --git a/backend/storage/database/dialect/postgres/config.go b/backend/storage/database/dialect/postgres/config.go index 30d3b7580e..e5ee319f58 100644 --- a/backend/storage/database/dialect/postgres/config.go +++ b/backend/storage/database/dialect/postgres/config.go @@ -6,117 +6,70 @@ import ( "slices" "strings" - "github.com/Masterminds/semver/v3" "github.com/jackc/pgx/v5/pgxpool" + "github.com/manifoldco/promptui" + "github.com/mitchellh/mapstructure" - "github.com/zitadel/zitadel/backend/cmd/configure" - "github.com/zitadel/zitadel/backend/cmd/configure/bla" "github.com/zitadel/zitadel/backend/storage/database" ) var ( _ database.Connector = (*Config)(nil) Name = "postgres" - - Field = &configure.OneOf{ - Description: "Configuring postgres using one of the following options", - SubFields: []configure.Updater{ - &configure.Field[string]{ - Description: "Connection string", - Version: semver.MustParse("v3"), - Validate: func(s string) error { - _, err := pgxpool.ParseConfig(s) - return err - }, - }, - &configure.Struct{ - Description: "Configuration for the connection", - SubFields: []configure.Updater{ - &configure.Field[string]{ - FieldName: "host", - Value: "localhost", - Description: "The host to connect to", - Version: semver.MustParse("3"), - }, - &configure.Field[uint32]{ - FieldName: "port", - Value: 5432, - Description: "The port to connect to", - Version: semver.MustParse("3"), - }, - &configure.Field[string]{ - FieldName: "database", - Value: "zitadel", - Description: "The database to connect to", - Version: semver.MustParse("3"), - }, - &configure.Field[string]{ - FieldName: "user", - Description: "The user to connect as", - Value: "zitadel", - Version: semver.MustParse("3"), - }, - &configure.Field[string]{ - FieldName: "password", - Description: "The password to connect with", - Version: semver.MustParse("3"), - HideInput: true, - }, - &configure.OneOf{ - FieldName: "sslMode", - Description: "The SSL mode to use", - SubFields: []configure.Updater{ - &configure.Constant[string]{ - Description: "Disable", - Constant: "disable", - Version: semver.MustParse("3"), - }, - &configure.Constant[string]{ - Description: "Require", - Constant: "require", - Version: semver.MustParse("3"), - }, - &configure.Constant[string]{ - Description: "Verify CA", - Constant: "verify-ca", - Version: semver.MustParse("3"), - }, - &configure.Constant[string]{ - Description: "Verify Full", - Constant: "verify-full", - Version: semver.MustParse("3"), - }, - }, - }, - }, - }, - }, - } ) -type Config struct{ pgxpool.Config } +type Config struct { + config *pgxpool.Config -// ConfigForIndex implements bla.OneOfField. -func (c Config) ConfigForIndex(i int) any { - switch i { - case 0: - return new(string) - case 1: - return &c.Config + // Host string + // Port int32 + // Database string + // EventPushConnRatio float64 + // MaxOpenConns uint32 + // MaxIdleConns uint32 + // MaxConnLifetime time.Duration + // MaxConnIdleTime time.Duration + // User User + // Admin AdminUser + // // Additional options to be appended as options= + // // The value will be taken as is. Multiple options are space separated. + // Options string +} + +// Configure implements bla3.Custom. +func (c *Config) Configure() (value any, err error) { + typeSelect := promptui.Select{ + Label: "Configure the database connection", + Items: []string{"connection string", "fields"}, + } + i, _, err := typeSelect.Run() + if err != nil { + return nil, err + } + if i > 0 { + return nil, nil } - return nil -} -// Possibilities implements bla.OneOfField. -func (c Config) Possibilities() []string { - return []string{"connection string", "fields"} -} + if c.config == nil { + c.config, _ = pgxpool.ParseConfig("host=localhost user=zitadel password= dbname=zitadel sslmode=disable") + } -var _ bla.OneOfField = (*Config)(nil) + prompt := promptui.Prompt{ + Label: "Connection string", + Default: c.config.ConnString(), + AllowEdit: c.config.ConnString() != "", + Validate: func(input string) error { + _, err := pgxpool.ParseConfig(input) + return err + }, + } + + return prompt.Run() +} // Connect implements [database.Connector]. func (c *Config) Connect(ctx context.Context) (database.Pool, error) { - pool, err := pgxpool.NewWithConfig(ctx, &c.Config) + pool, err := pgxpool.NewWithConfig(ctx, c.config) if err != nil { return nil, err } @@ -130,17 +83,29 @@ func NameMatcher(name string) bool { return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name)) } -func DecodeConfig(_ string, config any) (database.Connector, error) { - switch c := config.(type) { +func DecodeConfig(input any) (database.Connector, error) { + switch c := input.(type) { case string: config, err := pgxpool.ParseConfig(c) if err != nil { return nil, err } - return &Config{Config: *config}, nil + return &Config{config: config}, nil case map[string]any: + connector := new(Config) + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + WeaklyTypedInput: true, + Result: connector, + }) + if err != nil { + return nil, err + } + if err = decoder.Decode(c); err != nil { + return nil, err + } return &Config{ - Config: pgxpool.Config{}, + config: &pgxpool.Config{}, }, nil } return nil, errors.New("invalid configuration")