diff --git a/backend/cmd/configure/bla3/prompt.go b/backend/cmd/configure/bla3/prompt.go new file mode 100644 index 0000000000..e45a3a6da5 --- /dev/null +++ b/backend/cmd/configure/bla3/prompt.go @@ -0,0 +1,181 @@ +package bla3 + +import ( + "fmt" + "log/slog" + "reflect" + + "github.com/manifoldco/promptui" + "github.com/spf13/viper" +) + +var customType = reflect.TypeFor[Custom]() + +type Custom interface { + Configure() (value any, err error) +} + +type Object struct { + value reflect.Value + + tag fieldTag +} + +func NewObject(value any) Object { + return Object{ + value: reflect.Indirect(reflect.ValueOf(value)), + } +} + +func (o *Object) Configure(v *viper.Viper) error { + if o.tag.label() != "" { + fmt.Println("\n", o.tag.label()) + } + for i := range o.value.NumField() { + if !o.value.Type().Field(i).IsExported() { + continue + } + structField := o.value.Field(i) + tag, err := newFieldTag(o.value.Type().Field(i), structField.Interface()) + if err != nil { + return err + } + if tag.skip { + continue + } + + f := Field{ + tag: tag, + value: structField, + structField: o.value.Type().Field(i), + } + err = f.Configure(v) + if err != nil { + return err + } + } + + return nil +} + +type Field struct { + value reflect.Value + + structField reflect.StructField + + tag fieldTag +} + +func NewField(value any) Field { + return Field{ + value: reflect.ValueOf(value), + } +} + +func (f *Field) validate(input string) (err error) { + configuredValue, err := mapValue(f.value.Type(), input) + if err != nil { + return err + } + f.value.Set(configuredValue) + return nil +} + +func (f *Field) callCustom(v *viper.Viper) (ok bool, err error) { + if !f.value.Type().Implements(customType) { + return false, nil + } + if f.value.IsNil() { + f.value.Set(reflect.New(f.value.Type().Elem())) + } + custom := f.value.Interface().(Custom) + value, err := custom.Configure() + if err != nil { + return true, err + } + switch val := value.(type) { + case Field: + val.tag, err = newFieldTag(f.structField, val.value.Interface()) + val.structField = f.structField + if err != nil { + return false, err + } + return true, val.Configure(v) + default: + v.Set(f.tag.fieldName, val) + return true, nil + } +} + +func (f *Field) Configure(v *viper.Viper) error { + if ok, err := f.callCustom(v); ok || err != nil { + return err + } + f.value = reflect.Indirect(f.value) + if ok, err := f.callCustom(v); ok || err != nil { + return err + } + + kind := f.value.Kind() + switch kind { + case reflect.Bool: + prompt := promptui.Prompt{ + Label: f.tag.label(), + IsConfirm: true, + } + _, err := prompt.Run() + selected := true + if err != nil { + if err.Error() != "" { + return err + } + selected = false + } + f.value.SetBool(selected) + v.Set(f.tag.fieldName, f.value.Interface()) + return nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, + reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.String: + + prompt := promptui.Prompt{ + Label: f.tag.label(), + Validate: f.validate, + Default: f.tag.defaultValue(), + } + _, err := prompt.Run() + if err != nil { + return err + } + v.Set(f.tag.fieldName, f.value.Interface()) + return nil + case reflect.Struct: + o := Object{ + value: f.value, + tag: f.tag, + } + if !v.IsSet(f.tag.fieldName) { + v.Set(f.tag.fieldName, map[string]any{}) + } + sub := v.Sub(f.tag.fieldName) + err := o.Configure(sub) + if err != nil { + return err + } + v.Set(f.tag.fieldName, sub.AllSettings()) + return nil + case reflect.Pointer: + if f.value.IsNil() { + f.value = reflect.New(f.value.Type().Elem()) + } + f.value = f.value.Elem() + return f.Configure(v) + case reflect.Array, reflect.Slice, reflect.Map: + slog.Warn("skipping because kind is unimplemented", slog.String("field", f.tag.fieldName), slog.String("kind", kind.String())) + return nil + case reflect.Chan, reflect.Func, reflect.Interface, reflect.UnsafePointer, reflect.Invalid: + slog.Error("skipping because kind is unsupported", slog.String("field", f.tag.fieldName), slog.String("kind", kind.String())) + return nil + } + return nil +} diff --git a/backend/cmd/configure/bla3/tag.go b/backend/cmd/configure/bla3/tag.go new file mode 100644 index 0000000000..08fa04565f --- /dev/null +++ b/backend/cmd/configure/bla3/tag.go @@ -0,0 +1,81 @@ +package bla3 + +import ( + "fmt" + "reflect" + "strings" +) + +type fieldTag struct { + skip bool + + fieldName string + description string + + currentValue any +} + +const ( + tagName = "configure" + defaultKey = "default" + descriptionKey = "description" +) + +func newFieldTag(field reflect.StructField, current any) (config fieldTag, err error) { + config.fieldName = field.Name + if current != nil { + config.currentValue = current + } + + value, ok := field.Tag.Lookup(tagName) + if !ok { + if config.currentValue == nil { + config.currentValue = reflect.New(field.Type).Elem().Interface() + } + return config, nil + } + if value == "-" { + config.skip = true + return config, nil + } + + fields := strings.Split(value, ",") + for _, f := range fields { + configSplit := strings.Split(f, "=") + switch strings.ToLower(configSplit[0]) { + case defaultKey: + config.currentValue, err = mapValue(field.Type, configSplit[1]) + if err != nil { + return config, err + } + case descriptionKey: + config.description = configSplit[1] + } + } + if config.currentValue == nil { + config.currentValue = reflect.New(field.Type).Elem().Interface() + } + + return config, nil +} + +func (tag fieldTag) label() string { + if tag.fieldName == "" { + return "" + } + valueType := reflect.TypeOf(tag.currentValue) + typ := valueType.Kind().String() + if typeParsers[valueType] != nil { + typ = valueType.Name() + } + + label := fmt.Sprintf("%s (%s)", tag.fieldName, typ) + if tag.description == "" { + return label + } + return fmt.Sprintf("%s (%s)", label, tag.description) +} + +func (tag fieldTag) defaultValue() string { + return fmt.Sprintf("%v", tag.currentValue) +} diff --git a/backend/cmd/configure/bla3/update_config6.go b/backend/cmd/configure/bla3/update_config6.go new file mode 100644 index 0000000000..126501bb0e --- /dev/null +++ b/backend/cmd/configure/bla3/update_config6.go @@ -0,0 +1,140 @@ +package bla3 + +import ( + "fmt" + "os" + "reflect" + "strconv" + "time" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +type TestConfig struct { + API APIConfig `configure:""` + Database DatabaseOneOf `configure:"type=oneof"` +} + +type APIConfig struct { + Host string `configure:""` + Port uint16 `configure:""` +} + +type DatabaseOneOf struct { + ConnectionString *string `configure:""` + Config *DatabaseConfig `configure:""` +} + +type DatabaseConfig struct { + Host string `configure:""` + Port uint16 `configure:""` + SSLMode string `configure:""` +} + +func Update(v *viper.Viper, config any) func(cmd *cobra.Command, args []string) { + return func(cmd *cobra.Command, args []string) { + s := NewObject(config) + if err := s.Configure(v); err != nil { + fmt.Println(err) + os.Exit(1) + } + err := v.WriteConfig() + if err != nil { + fmt.Println(err) + os.Exit(1) + } + } +} + +var typeParsers = map[reflect.Type]func(string) (any, error){ + reflect.TypeFor[time.Duration](): func(value string) (any, error) { + return time.ParseDuration(value) + }, + reflect.TypeFor[time.Time](): func(value string) (any, error) { + if t, err := time.Parse(time.DateTime, value); err == nil { + return t, nil + } + if t, err := time.Parse(time.DateOnly, value); err == nil { + return t, nil + } + if t, err := time.Parse(time.TimeOnly, value); err == nil { + return t, nil + } + return time.Parse(time.RFC3339, value) + }, +} + +func SetTypeParser[T any](fn func(string) (any, error)) { + typeParsers[reflect.TypeFor[T]()] = fn +} + +func mapValue(typ reflect.Type, value string) (v reflect.Value, err error) { + if fn, ok := typeParsers[typ]; ok { + mappedValue, err := fn(value) + if err != nil { + return v, err + } + res := reflect.ValueOf(mappedValue) + if !res.CanConvert(typ) { + return v, fmt.Errorf("cannot convert %T to %s", mappedValue, typ.Kind().String()) + } + return res.Convert(typ), nil + } + + var val any + switch typ.Kind() { + case reflect.String: + val = value + case reflect.Bool: + val, err = strconv.ParseBool(value) + case reflect.Int: + val, err = strconv.Atoi(value) + case reflect.Int8: + val, err = strconv.ParseInt(value, 10, 8) + val = int8(val.(int64)) + case reflect.Int16: + val, err = strconv.ParseInt(value, 10, 16) + val = int16(val.(int64)) + case reflect.Int32: + val, err = strconv.ParseInt(value, 10, 32) + val = int32(val.(int64)) + case reflect.Int64: + val, err = strconv.ParseInt(value, 10, 64) + case reflect.Uint: + val, err = strconv.ParseUint(value, 10, 0) + val = uint(val.(uint64)) + case reflect.Uint8: + val, err = strconv.ParseUint(value, 10, 8) + val = uint8(val.(uint64)) + case reflect.Uint16: + val, err = strconv.ParseUint(value, 10, 16) + val = uint16(val.(uint64)) + case reflect.Uint32: + val, err = strconv.ParseUint(value, 10, 32) + val = uint32(val.(uint64)) + case reflect.Uint64: + val, err = strconv.ParseUint(value, 10, 64) + case reflect.Float32: + val, err = strconv.ParseFloat(value, 32) + val = float32(val.(float64)) + case reflect.Float64: + val, err = strconv.ParseFloat(value, 64) + case reflect.Complex64: + val, err = strconv.ParseComplex(value, 64) + val = complex64(val.(complex128)) + case reflect.Complex128: + val, err = strconv.ParseComplex(value, 128) + default: + return v, fmt.Errorf("unsupported type: %s", typ.Kind().String()) + } + if err != nil { + return v, err + } + + res := reflect.ValueOf(val) + if !res.CanConvert(typ) { + return v, fmt.Errorf("cannot convert %T to %s", val, typ.Kind().String()) + } + return res.Convert(typ), nil +}