mirror of
https://github.com/juanfont/headscale.git
synced 2025-01-10 01:54:33 +00:00
100 lines
2.7 KiB
Go
100 lines
2.7 KiB
Go
|
package db
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
|
||
|
"gorm.io/gorm/schema"
|
||
|
)
|
||
|
|
||
|
// Got from https://github.com/xdg-go/strum/blob/main/types.go
|
||
|
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
|
||
|
|
||
|
func isTextUnmarshaler(rv reflect.Value) bool {
|
||
|
return rv.Type().Implements(textUnmarshalerType)
|
||
|
}
|
||
|
|
||
|
func maybeInstantiatePtr(rv reflect.Value) {
|
||
|
if rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||
|
np := reflect.New(rv.Type().Elem())
|
||
|
rv.Set(np)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func decodingError(name string, err error) error {
|
||
|
return fmt.Errorf("error decoding to %s: %w", name, err)
|
||
|
}
|
||
|
|
||
|
// TextSerialiser implements the Serialiser interface for fields that
|
||
|
// have a type that implements encoding.TextUnmarshaler.
|
||
|
type TextSerialiser struct{}
|
||
|
|
||
|
func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||
|
fieldValue := reflect.New(field.FieldType)
|
||
|
|
||
|
// If the field is a pointer, we need to dereference it to get the actual type
|
||
|
// so we do not end with a second pointer.
|
||
|
if fieldValue.Elem().Kind() == reflect.Ptr {
|
||
|
fieldValue = fieldValue.Elem()
|
||
|
}
|
||
|
|
||
|
if dbValue != nil {
|
||
|
var bytes []byte
|
||
|
switch v := dbValue.(type) {
|
||
|
case []byte:
|
||
|
bytes = v
|
||
|
case string:
|
||
|
bytes = []byte(v)
|
||
|
default:
|
||
|
return fmt.Errorf("failed to unmarshal text value: %#v", dbValue)
|
||
|
}
|
||
|
|
||
|
if isTextUnmarshaler(fieldValue) {
|
||
|
maybeInstantiatePtr(fieldValue)
|
||
|
f := fieldValue.MethodByName("UnmarshalText")
|
||
|
args := []reflect.Value{reflect.ValueOf(bytes)}
|
||
|
ret := f.Call(args)
|
||
|
if !ret[0].IsNil() {
|
||
|
return decodingError(field.Name, ret[0].Interface().(error))
|
||
|
}
|
||
|
|
||
|
// If the underlying field is to a pointer type, we need to
|
||
|
// assign the value as a pointer to it.
|
||
|
// If it is not a pointer, we need to assign the value to the
|
||
|
// field.
|
||
|
dstField := field.ReflectValueOf(ctx, dst)
|
||
|
if dstField.Kind() == reflect.Ptr {
|
||
|
dstField.Set(fieldValue)
|
||
|
} else {
|
||
|
dstField.Set(fieldValue.Elem())
|
||
|
}
|
||
|
return nil
|
||
|
} else {
|
||
|
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||
|
switch v := fieldValue.(type) {
|
||
|
case encoding.TextMarshaler:
|
||
|
// If the value is nil, we return nil, however, go nil values are not
|
||
|
// always comparable, particularly when reflection is involved:
|
||
|
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
|
||
|
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
|
||
|
return nil, nil
|
||
|
}
|
||
|
b, err := v.MarshalText()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return string(b), nil
|
||
|
default:
|
||
|
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
|
||
|
}
|
||
|
}
|