diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index d2b86ff4..55cd8fb1 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -3,6 +3,8 @@ package types import ( "cmp" "database/sql" + "encoding/json" + "fmt" "net/mail" "strconv" @@ -119,18 +121,49 @@ func (u *User) Proto() *v1.User { } } +// JumpCloud returns a JSON where email_verified is returned as a +// string "true" or "false" instead of a boolean. +// This maps bool to a specific type with a custom unmarshaler to +// ensure we can decode it from a string. +// https://github.com/juanfont/headscale/issues/2293 +type FlexibleBoolean bool + +func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { + var val interface{} + err := json.Unmarshal(data, &val) + if err != nil { + return fmt.Errorf("could not unmarshal data: %w", err) + } + + switch v := val.(type) { + case bool: + *bit = FlexibleBoolean(v) + case string: + pv, err := strconv.ParseBool(v) + if err != nil { + return fmt.Errorf("could not parse %s as boolean: %w", v, err) + } + *bit = FlexibleBoolean(pv) + + default: + return fmt.Errorf("could not parse %v as boolean", v) + } + + return nil +} + type OIDCClaims struct { // Sub is the user's unique identifier at the provider. Sub string `json:"sub"` Iss string `json:"iss"` // Name is the user's full name. - Name string `json:"name,omitempty"` - Groups []string `json:"groups,omitempty"` - Email string `json:"email,omitempty"` - EmailVerified bool `json:"email_verified,omitempty"` - ProfilePictureURL string `json:"picture,omitempty"` - Username string `json:"preferred_username,omitempty"` + Name string `json:"name,omitempty"` + Groups []string `json:"groups,omitempty"` + Email string `json:"email,omitempty"` + EmailVerified FlexibleBoolean `json:"email_verified,omitempty"` + ProfilePictureURL string `json:"picture,omitempty"` + Username string `json:"preferred_username,omitempty"` } func (c *OIDCClaims) Identifier() string { diff --git a/hscontrol/types/users_test.go b/hscontrol/types/users_test.go new file mode 100644 index 00000000..dad1d814 --- /dev/null +++ b/hscontrol/types/users_test.go @@ -0,0 +1,75 @@ +package types + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestUnmarshallOIDCClaims(t *testing.T) { + tests := []struct { + name string + jsonstr string + want OIDCClaims + }{ + { + name: "normal-bool", + jsonstr: ` +{ + "sub": "test", + "email": "test@test.no", + "email_verified": true +} + `, + want: OIDCClaims{ + Sub: "test", + Email: "test@test.no", + EmailVerified: true, + }, + }, + { + name: "string-bool-true", + jsonstr: ` +{ + "sub": "test2", + "email": "test2@test.no", + "email_verified": "true" +} + `, + want: OIDCClaims{ + Sub: "test2", + Email: "test2@test.no", + EmailVerified: true, + }, + }, + { + name: "string-bool-false", + jsonstr: ` +{ + "sub": "test3", + "email": "test3@test.no", + "email_verified": "false" +} + `, + want: OIDCClaims{ + Sub: "test3", + Email: "test3@test.no", + EmailVerified: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got OIDCClaims + if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil { + t.Errorf("UnmarshallOIDCClaims() error = %v", err) + return + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("UnmarshallOIDCClaims() mismatch (-want +got):\n%s", diff) + } + }) + } +}