This commit is contained in:
adlerhurst
2025-03-06 17:12:19 +01:00
parent e31bd14a07
commit 594152581c
3 changed files with 402 additions and 0 deletions

View File

@@ -0,0 +1,181 @@
package bla3
import (
"fmt"
"log/slog"
"reflect"
"github.com/manifoldco/promptui"
"github.com/spf13/viper"
)
var customType = reflect.TypeFor[Custom]()
type Custom interface {
Configure() (value any, err error)
}
type Object struct {
value reflect.Value
tag fieldTag
}
func NewObject(value any) Object {
return Object{
value: reflect.Indirect(reflect.ValueOf(value)),
}
}
func (o *Object) Configure(v *viper.Viper) error {
if o.tag.label() != "" {
fmt.Println("\n", o.tag.label())
}
for i := range o.value.NumField() {
if !o.value.Type().Field(i).IsExported() {
continue
}
structField := o.value.Field(i)
tag, err := newFieldTag(o.value.Type().Field(i), structField.Interface())
if err != nil {
return err
}
if tag.skip {
continue
}
f := Field{
tag: tag,
value: structField,
structField: o.value.Type().Field(i),
}
err = f.Configure(v)
if err != nil {
return err
}
}
return nil
}
type Field struct {
value reflect.Value
structField reflect.StructField
tag fieldTag
}
func NewField(value any) Field {
return Field{
value: reflect.ValueOf(value),
}
}
func (f *Field) validate(input string) (err error) {
configuredValue, err := mapValue(f.value.Type(), input)
if err != nil {
return err
}
f.value.Set(configuredValue)
return nil
}
func (f *Field) callCustom(v *viper.Viper) (ok bool, err error) {
if !f.value.Type().Implements(customType) {
return false, nil
}
if f.value.IsNil() {
f.value.Set(reflect.New(f.value.Type().Elem()))
}
custom := f.value.Interface().(Custom)
value, err := custom.Configure()
if err != nil {
return true, err
}
switch val := value.(type) {
case Field:
val.tag, err = newFieldTag(f.structField, val.value.Interface())
val.structField = f.structField
if err != nil {
return false, err
}
return true, val.Configure(v)
default:
v.Set(f.tag.fieldName, val)
return true, nil
}
}
func (f *Field) Configure(v *viper.Viper) error {
if ok, err := f.callCustom(v); ok || err != nil {
return err
}
f.value = reflect.Indirect(f.value)
if ok, err := f.callCustom(v); ok || err != nil {
return err
}
kind := f.value.Kind()
switch kind {
case reflect.Bool:
prompt := promptui.Prompt{
Label: f.tag.label(),
IsConfirm: true,
}
_, err := prompt.Run()
selected := true
if err != nil {
if err.Error() != "" {
return err
}
selected = false
}
f.value.SetBool(selected)
v.Set(f.tag.fieldName, f.value.Interface())
return nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.String:
prompt := promptui.Prompt{
Label: f.tag.label(),
Validate: f.validate,
Default: f.tag.defaultValue(),
}
_, err := prompt.Run()
if err != nil {
return err
}
v.Set(f.tag.fieldName, f.value.Interface())
return nil
case reflect.Struct:
o := Object{
value: f.value,
tag: f.tag,
}
if !v.IsSet(f.tag.fieldName) {
v.Set(f.tag.fieldName, map[string]any{})
}
sub := v.Sub(f.tag.fieldName)
err := o.Configure(sub)
if err != nil {
return err
}
v.Set(f.tag.fieldName, sub.AllSettings())
return nil
case reflect.Pointer:
if f.value.IsNil() {
f.value = reflect.New(f.value.Type().Elem())
}
f.value = f.value.Elem()
return f.Configure(v)
case reflect.Array, reflect.Slice, reflect.Map:
slog.Warn("skipping because kind is unimplemented", slog.String("field", f.tag.fieldName), slog.String("kind", kind.String()))
return nil
case reflect.Chan, reflect.Func, reflect.Interface, reflect.UnsafePointer, reflect.Invalid:
slog.Error("skipping because kind is unsupported", slog.String("field", f.tag.fieldName), slog.String("kind", kind.String()))
return nil
}
return nil
}

View File

@@ -0,0 +1,81 @@
package bla3
import (
"fmt"
"reflect"
"strings"
)
type fieldTag struct {
skip bool
fieldName string
description string
currentValue any
}
const (
tagName = "configure"
defaultKey = "default"
descriptionKey = "description"
)
func newFieldTag(field reflect.StructField, current any) (config fieldTag, err error) {
config.fieldName = field.Name
if current != nil {
config.currentValue = current
}
value, ok := field.Tag.Lookup(tagName)
if !ok {
if config.currentValue == nil {
config.currentValue = reflect.New(field.Type).Elem().Interface()
}
return config, nil
}
if value == "-" {
config.skip = true
return config, nil
}
fields := strings.Split(value, ",")
for _, f := range fields {
configSplit := strings.Split(f, "=")
switch strings.ToLower(configSplit[0]) {
case defaultKey:
config.currentValue, err = mapValue(field.Type, configSplit[1])
if err != nil {
return config, err
}
case descriptionKey:
config.description = configSplit[1]
}
}
if config.currentValue == nil {
config.currentValue = reflect.New(field.Type).Elem().Interface()
}
return config, nil
}
func (tag fieldTag) label() string {
if tag.fieldName == "" {
return ""
}
valueType := reflect.TypeOf(tag.currentValue)
typ := valueType.Kind().String()
if typeParsers[valueType] != nil {
typ = valueType.Name()
}
label := fmt.Sprintf("%s (%s)", tag.fieldName, typ)
if tag.description == "" {
return label
}
return fmt.Sprintf("%s (%s)", label, tag.description)
}
func (tag fieldTag) defaultValue() string {
return fmt.Sprintf("%v", tag.currentValue)
}

View File

@@ -0,0 +1,140 @@
package bla3
import (
"fmt"
"os"
"reflect"
"strconv"
"time"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
type TestConfig struct {
API APIConfig `configure:""`
Database DatabaseOneOf `configure:"type=oneof"`
}
type APIConfig struct {
Host string `configure:""`
Port uint16 `configure:""`
}
type DatabaseOneOf struct {
ConnectionString *string `configure:""`
Config *DatabaseConfig `configure:""`
}
type DatabaseConfig struct {
Host string `configure:""`
Port uint16 `configure:""`
SSLMode string `configure:""`
}
func Update(v *viper.Viper, config any) func(cmd *cobra.Command, args []string) {
return func(cmd *cobra.Command, args []string) {
s := NewObject(config)
if err := s.Configure(v); err != nil {
fmt.Println(err)
os.Exit(1)
}
err := v.WriteConfig()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
}
}
var typeParsers = map[reflect.Type]func(string) (any, error){
reflect.TypeFor[time.Duration](): func(value string) (any, error) {
return time.ParseDuration(value)
},
reflect.TypeFor[time.Time](): func(value string) (any, error) {
if t, err := time.Parse(time.DateTime, value); err == nil {
return t, nil
}
if t, err := time.Parse(time.DateOnly, value); err == nil {
return t, nil
}
if t, err := time.Parse(time.TimeOnly, value); err == nil {
return t, nil
}
return time.Parse(time.RFC3339, value)
},
}
func SetTypeParser[T any](fn func(string) (any, error)) {
typeParsers[reflect.TypeFor[T]()] = fn
}
func mapValue(typ reflect.Type, value string) (v reflect.Value, err error) {
if fn, ok := typeParsers[typ]; ok {
mappedValue, err := fn(value)
if err != nil {
return v, err
}
res := reflect.ValueOf(mappedValue)
if !res.CanConvert(typ) {
return v, fmt.Errorf("cannot convert %T to %s", mappedValue, typ.Kind().String())
}
return res.Convert(typ), nil
}
var val any
switch typ.Kind() {
case reflect.String:
val = value
case reflect.Bool:
val, err = strconv.ParseBool(value)
case reflect.Int:
val, err = strconv.Atoi(value)
case reflect.Int8:
val, err = strconv.ParseInt(value, 10, 8)
val = int8(val.(int64))
case reflect.Int16:
val, err = strconv.ParseInt(value, 10, 16)
val = int16(val.(int64))
case reflect.Int32:
val, err = strconv.ParseInt(value, 10, 32)
val = int32(val.(int64))
case reflect.Int64:
val, err = strconv.ParseInt(value, 10, 64)
case reflect.Uint:
val, err = strconv.ParseUint(value, 10, 0)
val = uint(val.(uint64))
case reflect.Uint8:
val, err = strconv.ParseUint(value, 10, 8)
val = uint8(val.(uint64))
case reflect.Uint16:
val, err = strconv.ParseUint(value, 10, 16)
val = uint16(val.(uint64))
case reflect.Uint32:
val, err = strconv.ParseUint(value, 10, 32)
val = uint32(val.(uint64))
case reflect.Uint64:
val, err = strconv.ParseUint(value, 10, 64)
case reflect.Float32:
val, err = strconv.ParseFloat(value, 32)
val = float32(val.(float64))
case reflect.Float64:
val, err = strconv.ParseFloat(value, 64)
case reflect.Complex64:
val, err = strconv.ParseComplex(value, 64)
val = complex64(val.(complex128))
case reflect.Complex128:
val, err = strconv.ParseComplex(value, 128)
default:
return v, fmt.Errorf("unsupported type: %s", typ.Kind().String())
}
if err != nil {
return v, err
}
res := reflect.ValueOf(val)
if !res.CanConvert(typ) {
return v, fmt.Errorf("cannot convert %T to %s", val, typ.Kind().String())
}
return res.Convert(typ), nil
}