From 1f029180c74c63b922858026b17bc0a3b8c2ee70 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Wed, 23 Apr 2025 11:08:45 -0700 Subject: [PATCH] types/jsonx: add package for json/v2 helpers (#15756) The typical way to implement union types in Go is to use an interface where the set of types is limited. However, there historically has been poor support in v1 "encoding/json" with interface types where you can marshal such values, but fail to unmarshal them since type information about the concrete type is lost. The MakeInterfaceCoders function constructs custom marshal/unmarshal functions such that the type name is encoded in the JSON representation. The set of valid concrete types for an interface must be statically specified for this to function. Updates tailscale/corp#22024 Signed-off-by: Joe Tsai --- types/jsonx/json.go | 171 +++++++++++++++++++++++++++++++++++++++ types/jsonx/json_test.go | 140 ++++++++++++++++++++++++++++++++ 2 files changed, 311 insertions(+) create mode 100644 types/jsonx/json.go create mode 100644 types/jsonx/json_test.go diff --git a/types/jsonx/json.go b/types/jsonx/json.go new file mode 100644 index 000000000..3f01ea358 --- /dev/null +++ b/types/jsonx/json.go @@ -0,0 +1,171 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsonx contains helper types and functionality to use with +// [github.com/go-json-experiment/json], which is positioned to be +// merged into the Go standard library as [encoding/json/v2]. +// +// See https://go.dev/issues/71497 +package jsonx + +import ( + "errors" + "fmt" + "reflect" + + "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" +) + +var ( + errUnknownTypeName = errors.New("unknown type name") + errNonSingularValue = errors.New("dynamic value must only have exactly one member") +) + +// MakeInterfaceCoders constructs a pair of marshal and unmarshal functions +// to serialize a Go interface type T. A bijective mapping for the set +// of concrete types that implement T is provided, +// where the key is a stable type name to use in the JSON representation, +// while the value is any value of a concrete type that implements T. +// By convention, only the zero value of concrete types is passed. +// +// The JSON representation for a dynamic value is a JSON object +// with a single member, where the member name is the type name, +// and the value is the JSON representation for the Go value. +// For example, the JSON serialization for a concrete type named Foo +// would be {"Foo": ...}, where ... is the JSON representation +// of the concrete value of the Foo type. +// +// Example instantiation: +// +// // Interface is a union type implemented by [FooType] and [BarType]. +// type Interface interface { ... } +// +// var interfaceCoders = MakeInterfaceCoders(map[string]Interface{ +// "FooType": FooType{}, +// "BarType": (*BarType)(nil), +// }) +// +// The pair of Marshal and Unmarshal functions can be used with the [json] +// package with either type-specified or caller-specified serialization. +// The result of this constructor is usually stored into a global variable. +// +// Example usage with type-specified serialization: +// +// // InterfaceWrapper is a concrete type that wraps [Interface]. +// // It extends [Interface] to implement +// // [json.MarshalerTo] and [json.UnmarshalerFrom]. +// type InterfaceWrapper struct{ Interface } +// +// func (w InterfaceWrapper) MarshalJSONTo(enc *jsontext.Encoder) error { +// return interfaceCoders.Marshal(enc, &w.Interface) +// } +// +// func (w *InterfaceWrapper) UnmarshalJSONFrom(dec *jsontext.Decoder) error { +// return interfaceCoders.Unmarshal(dec, &w.Interface) +// } +// +// Example usage with caller-specified serialization: +// +// var opts json.Options = json.JoinOptions( +// json.WithMarshalers(json.MarshalToFunc(interfaceCoders.Marshal)), +// json.WithUnmarshalers(json.UnmarshalFromFunc(interfaceCoders.Unmarshal)), +// ) +// +// var v Interface +// ... := json.Marshal(v, opts) +// ... := json.Unmarshal(&v, opts) +// +// The function panics if T is not a named interface kind, +// or if valuesByName contains distinct entries with the same concrete type. +func MakeInterfaceCoders[T any](valuesByName map[string]T) (c struct { + Marshal func(*jsontext.Encoder, *T) error + Unmarshal func(*jsontext.Decoder, *T) error +}) { + // Verify that T is a named interface. + switch t := reflect.TypeFor[T](); { + case t.Kind() != reflect.Interface: + panic(fmt.Sprintf("%v must be an interface kind", t)) + case t.Name() == "": + panic(fmt.Sprintf("%v must be a named type", t)) + } + + // Construct a bijective mapping of names to types. + typesByName := make(map[string]reflect.Type) + namesByType := make(map[reflect.Type]string) + for name, value := range valuesByName { + t := reflect.TypeOf(value) + if t == nil { + panic(fmt.Sprintf("nil value for %s", name)) + } + if name2, ok := namesByType[t]; ok { + panic(fmt.Sprintf("type %v cannot have multiple names %s and %v", t, name, name2)) + } + typesByName[name] = t + namesByType[t] = name + } + + // Construct the marshal and unmarshal functions. + c.Marshal = func(enc *jsontext.Encoder, val *T) error { + t := reflect.TypeOf(*val) + if t == nil { + return enc.WriteToken(jsontext.Null) + } + name := namesByType[t] + if name == "" { + return fmt.Errorf("Go type %v: %w", t, errUnknownTypeName) + } + + if err := enc.WriteToken(jsontext.BeginObject); err != nil { + return err + } + if err := enc.WriteToken(jsontext.String(name)); err != nil { + return err + } + if err := json.MarshalEncode(enc, *val); err != nil { + return err + } + if err := enc.WriteToken(jsontext.EndObject); err != nil { + return err + } + return nil + } + c.Unmarshal = func(dec *jsontext.Decoder, val *T) error { + switch tok, err := dec.ReadToken(); { + case err != nil: + return err + case tok.Kind() == 'n': + var zero T + *val = zero // store nil interface value for JSON null + return nil + case tok.Kind() != '{': + return &json.SemanticError{JSONKind: tok.Kind(), GoType: reflect.TypeFor[T]()} + } + var v reflect.Value + switch tok, err := dec.ReadToken(); { + case err != nil: + return err + case tok.Kind() != '"': + return errNonSingularValue + default: + t := typesByName[tok.String()] + if t == nil { + return errUnknownTypeName + } + v = reflect.New(t) + } + if err := json.UnmarshalDecode(dec, v.Interface()); err != nil { + return err + } + *val = v.Elem().Interface().(T) + switch tok, err := dec.ReadToken(); { + case err != nil: + return err + case tok.Kind() != '}': + return errNonSingularValue + } + return nil + } + + return c +} diff --git a/types/jsonx/json_test.go b/types/jsonx/json_test.go new file mode 100644 index 000000000..0f2a646c4 --- /dev/null +++ b/types/jsonx/json_test.go @@ -0,0 +1,140 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsonx + +import ( + "errors" + "testing" + + "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "github.com/google/go-cmp/cmp" + "tailscale.com/types/ptr" +) + +type Interface interface { + implementsInterface() +} + +type Foo string + +func (Foo) implementsInterface() {} + +type Bar int + +func (Bar) implementsInterface() {} + +type Baz struct{ Fizz, Buzz string } + +func (*Baz) implementsInterface() {} + +var interfaceCoders = MakeInterfaceCoders(map[string]Interface{ + "Foo": Foo(""), + "Bar": (*Bar)(nil), + "Baz": (*Baz)(nil), +}) + +type InterfaceWrapper struct{ Interface } + +func (w InterfaceWrapper) MarshalJSONTo(enc *jsontext.Encoder) error { + return interfaceCoders.Marshal(enc, &w.Interface) +} + +func (w *InterfaceWrapper) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + return interfaceCoders.Unmarshal(dec, &w.Interface) +} + +func TestInterfaceCoders(t *testing.T) { + var opts json.Options = json.JoinOptions( + json.WithMarshalers(json.MarshalToFunc(interfaceCoders.Marshal)), + json.WithUnmarshalers(json.UnmarshalFromFunc(interfaceCoders.Unmarshal)), + ) + + errSkipMarshal := errors.New("skip marshal") + makeFiller := func() InterfaceWrapper { + return InterfaceWrapper{&Baz{"fizz", "buzz"}} + } + + for _, tt := range []struct { + label string + wantVal InterfaceWrapper + wantJSON string + wantMarshalError error + wantUnmarshalError error + }{{ + label: "Null", + wantVal: InterfaceWrapper{}, + wantJSON: `null`, + }, { + label: "Foo", + wantVal: InterfaceWrapper{Foo("hello")}, + wantJSON: `{"Foo":"hello"}`, + }, { + label: "BarPointer", + wantVal: InterfaceWrapper{ptr.To(Bar(5))}, + wantJSON: `{"Bar":5}`, + }, { + label: "BarValue", + wantVal: InterfaceWrapper{Bar(5)}, + // NOTE: We could handle BarValue just like BarPointer, + // but round-trip marshal/unmarshal would not be identical. + wantMarshalError: errUnknownTypeName, + }, { + label: "Baz", + wantVal: InterfaceWrapper{&Baz{"alpha", "omega"}}, + wantJSON: `{"Baz":{"Fizz":"alpha","Buzz":"omega"}}`, + }, { + label: "Unknown", + wantVal: makeFiller(), + wantJSON: `{"Unknown":[1,2,3]}`, + wantMarshalError: errSkipMarshal, + wantUnmarshalError: errUnknownTypeName, + }, { + label: "Empty", + wantVal: makeFiller(), + wantJSON: `{}`, + wantMarshalError: errSkipMarshal, + wantUnmarshalError: errNonSingularValue, + }, { + label: "Duplicate", + wantVal: InterfaceWrapper{Foo("hello")}, // first entry wins + wantJSON: `{"Foo":"hello","Bar":5}`, + wantMarshalError: errSkipMarshal, + wantUnmarshalError: errNonSingularValue, + }} { + t.Run(tt.label, func(t *testing.T) { + if tt.wantMarshalError != errSkipMarshal { + switch gotJSON, err := json.Marshal(&tt.wantVal); { + case !errors.Is(err, tt.wantMarshalError): + t.Fatalf("json.Marshal(%v) error = %v, want %v", tt.wantVal, err, tt.wantMarshalError) + case string(gotJSON) != tt.wantJSON: + t.Fatalf("json.Marshal(%v) = %s, want %s", tt.wantVal, gotJSON, tt.wantJSON) + } + switch gotJSON, err := json.Marshal(&tt.wantVal.Interface, opts); { + case !errors.Is(err, tt.wantMarshalError): + t.Fatalf("json.Marshal(%v) error = %v, want %v", tt.wantVal, err, tt.wantMarshalError) + case string(gotJSON) != tt.wantJSON: + t.Fatalf("json.Marshal(%v) = %s, want %s", tt.wantVal, gotJSON, tt.wantJSON) + } + } + + if tt.wantJSON != "" { + gotVal := makeFiller() + if err := json.Unmarshal([]byte(tt.wantJSON), &gotVal); !errors.Is(err, tt.wantUnmarshalError) { + t.Fatalf("json.Unmarshal(%v) error = %v, want %v", tt.wantJSON, err, tt.wantUnmarshalError) + } + if d := cmp.Diff(gotVal, tt.wantVal); d != "" { + t.Fatalf("json.Unmarshal(%v):\n%s", tt.wantJSON, d) + } + gotVal = makeFiller() + if err := json.Unmarshal([]byte(tt.wantJSON), &gotVal.Interface, opts); !errors.Is(err, tt.wantUnmarshalError) { + t.Fatalf("json.Unmarshal(%v) error = %v, want %v", tt.wantJSON, err, tt.wantUnmarshalError) + } + if d := cmp.Diff(gotVal, tt.wantVal); d != "" { + t.Fatalf("json.Unmarshal(%v):\n%s", tt.wantJSON, d) + } + } + }) + } +}