From b07e04e9e687b7f4ce74d2c4132f43cc70e56556 Mon Sep 17 00:00:00 2001 From: Cedric Kienzler Date: Tue, 25 Feb 2025 13:20:07 +0100 Subject: [PATCH 1/3] cmd/tsidp: add groups claim to tsidp This feature adds support for a `groups` claim in tsidp using the grants syntax: ```json { "grants": [ { "src": ["group:admins"], "dst": ["*"], "ip": ["*"], "app": { "tailscale.com/cap/tsidp": [ { "groups": ["admin"] } ] } }, { "src": ["group:reader"], "dst": ["*"], "ip": ["*"], "app": { "tailscale.com/cap/tsidp": [ { "groups": ["reader"] } ] } } ] } ``` For #10263 Signed-off-by: Cedric Kienzler --- cmd/tsidp/tsidp.go | 45 +++++++++++++++++++++++++++++++++++++++------ tailcfg/tailcfg.go | 5 +++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 3eabef245..2b7bc136f 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -494,10 +494,23 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest) return } + + rules, err := tailcfg.UnmarshalCapJSON[capRule](ar.remoteUser.CapMap, tailcfg.PeerCapabilityTsIDP) + if err != nil { + http.Error(w, "tsidp: failed to unmarshal capability: %v", http.StatusBadRequest) + return + } + + groups := make([]string, 0) + for _, rule := range rules { + groups = append(groups, rule.Groups...) + } + ui.Sub = ar.remoteUser.Node.User.String() ui.Name = ar.remoteUser.UserProfile.DisplayName ui.Email = ar.remoteUser.UserProfile.LoginName ui.Picture = ar.remoteUser.UserProfile.ProfilePicURL + ui.Groups = groups // TODO(maisem): not sure if this is the right thing to do ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@") @@ -509,11 +522,16 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { } type userInfo struct { - Sub string `json:"sub"` - Name string `json:"name"` - Email string `json:"email"` - Picture string `json:"picture"` - UserName string `json:"username"` + Sub string `json:"sub"` + Name string `json:"name"` + Email string `json:"email"` + Picture string `json:"picture"` + UserName string `json:"username"` + Groups []string `json:"groups,omitempty"` +} + +type capRule struct { + Groups []string `json:"groups,omitempty"` // list of features peer is allowed to edit } func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { @@ -566,6 +584,17 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { return } + rules, err := tailcfg.UnmarshalCapJSON[capRule](who.CapMap, tailcfg.PeerCapabilityTsIDP) + if err != nil { + http.Error(w, "tsidp: failed to unmarshal capability: %v", http.StatusBadRequest) + return + } + + groups := make([]string, 0) + for _, rule := range rules { + groups = append(groups, rule.Groups...) + } + now := time.Now() _, tcd, _ := strings.Cut(n.Name(), ".") tsClaims := tailscaleClaims{ @@ -587,6 +616,7 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { UserID: n.User(), Email: who.UserProfile.LoginName, UserName: userName, + Groups: groups, } if ar.localRP { tsClaims.Issuer = s.loopbackURL @@ -735,6 +765,9 @@ type tailscaleClaims struct { // It is a temporary (2023-11-15) hack during development. // We should probably let this be configured via grants. UserName string `json:"username,omitempty"` + + // Groups are group memberships controlled via grants + Groups []string `json:"groups,omitempty"` } var ( @@ -743,7 +776,7 @@ var ( "sub", "aud", "exp", "iat", "iss", "jti", "nbf", "username", "email", // Tailscale claims, these correspond to fields in tailscaleClaims. - "key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid", + "key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid", "groups", }) // As defined in the OpenID spec this should be "openid". diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index f82c6eb81..4562929a7 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -1461,6 +1461,11 @@ const ( // user groups as Kubernetes user groups. This capability is read by // peers that are Tailscale Kubernetes operator instances. PeerCapabilityKubernetes PeerCapability = "tailscale.com/cap/kubernetes" + + // PeerCapabilityTsIDP grants a peer tsidp-specific + // capabilities, such as the ability to add user groups to the OIDC + // claim + PeerCapabilityTsIDP PeerCapability = "tailscale.com/cap/tsidp" ) // NodeCapMap is a map of capabilities to their optional values. It is valid for From d861d708af42d1663ab4a335a9d2468d60928473 Mon Sep 17 00:00:00 2001 From: Cedric Kienzler Date: Fri, 21 Mar 2025 14:19:56 +0100 Subject: [PATCH 2/3] cmd/tsidp: refactor cap/tsidp to allow extraClaims This commit refactors the `capRule` struct to allow specifying arbitrary extra claims: ```json { "src": ["group:reader"], "dst": ["*"], "ip": ["*"], "app": { "tailscale.com/cap/tsidp": [ { "extraClaims": { "groups": ["reader"], "entitlements": ["read-stuff"], }, } ] } } ``` Overwriting pre-existing claims cannot be modified/overwritten. Also adding more unit-testing Signed-off-by: Cedric Kienzler --- cmd/tsidp/tsidp.go | 205 ++++++++-- cmd/tsidp/tsidp_test.go | 826 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 995 insertions(+), 36 deletions(-) create mode 100644 cmd/tsidp/tsidp_test.go diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 2b7bc136f..e17507a62 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -495,43 +495,178 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { return } + ui.Sub = ar.remoteUser.Node.User.String() + ui.Name = ar.remoteUser.UserProfile.DisplayName + ui.Email = ar.remoteUser.UserProfile.LoginName + ui.Picture = ar.remoteUser.UserProfile.ProfilePicURL + + // TODO(maisem): not sure if this is the right thing to do + ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@") + rules, err := tailcfg.UnmarshalCapJSON[capRule](ar.remoteUser.CapMap, tailcfg.PeerCapabilityTsIDP) if err != nil { http.Error(w, "tsidp: failed to unmarshal capability: %v", http.StatusBadRequest) return } - groups := make([]string, 0) - for _, rule := range rules { - groups = append(groups, rule.Groups...) + // Only keep rules where IncludeInUserInfo is true + var filtered []capRule + for _, r := range rules { + if r.IncludeInUserInfo { + filtered = append(filtered, r) + } } - ui.Sub = ar.remoteUser.Node.User.String() - ui.Name = ar.remoteUser.UserProfile.DisplayName - ui.Email = ar.remoteUser.UserProfile.LoginName - ui.Picture = ar.remoteUser.UserProfile.ProfilePicURL - ui.Groups = groups - - // TODO(maisem): not sure if this is the right thing to do - ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@") + userInfo, err := withExtraClaims(ui, filtered) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + // Write the final result w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(ui); err != nil { + if err := json.NewEncoder(w).Encode(userInfo); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } type userInfo struct { - Sub string `json:"sub"` - Name string `json:"name"` - Email string `json:"email"` - Picture string `json:"picture"` - UserName string `json:"username"` - Groups []string `json:"groups,omitempty"` + Sub string `json:"sub"` + Name string `json:"name"` + Email string `json:"email"` + Picture string `json:"picture"` + UserName string `json:"username"` } type capRule struct { - Groups []string `json:"groups,omitempty"` // list of features peer is allowed to edit + IncludeInUserInfo bool `json:"includeInUserInfo"` + ExtraClaims map[string]interface{} `json:"extraClaims,omitempty"` // list of features peer is allowed to edit +} + +// flattenExtraClaims merges all ExtraClaims from a slice of capRule into a single map. +// It deduplicates values for each claim and preserves the original input type: +// scalar values remain scalars, and slices are returned as deduplicated []interface{} slices. +func flattenExtraClaims(rules []capRule) map[string]interface{} { + // sets stores deduplicated stringified values for each claim key. + sets := make(map[string]map[string]struct{}) + + // isSlice tracks whether each claim was originally provided as a slice. + isSlice := make(map[string]bool) + + for _, rule := range rules { + for claim, raw := range rule.ExtraClaims { + // Track whether the claim was provided as a slice + switch raw.(type) { + case []string, []interface{}: + isSlice[claim] = true + default: + // Only mark as scalar if this is the first time we've seen this claim + if _, seen := isSlice[claim]; !seen { + isSlice[claim] = false + } + } + + // Add the claim value(s) into the deduplication set + addClaimValue(sets, claim, raw) + } + } + + // Build final result: either scalar or slice depending on original type + result := make(map[string]interface{}) + for claim, valSet := range sets { + if isSlice[claim] { + // Claim was provided as a slice: output as []interface{} + var vals []interface{} + for val := range valSet { + vals = append(vals, val) + } + result[claim] = vals + } else { + // Claim was a scalar: return a single value + for val := range valSet { + result[claim] = val + break // only one value is expected + } + } + } + + return result +} + +// addClaimValue adds a claim value to the deduplication set for a given claim key. +// It accepts scalars (string, int, float64), slices of strings or interfaces, +// and recursively handles nested slices. Unsupported types are ignored with a log message. +func addClaimValue(sets map[string]map[string]struct{}, claim string, val interface{}) { + switch v := val.(type) { + case string, float64, int, int64: + // Ensure the claim set is initialized + if sets[claim] == nil { + sets[claim] = make(map[string]struct{}) + } + // Add the stringified scalar to the set + sets[claim][fmt.Sprintf("%v", v)] = struct{}{} + + case []string: + // Ensure the claim set is initialized + if sets[claim] == nil { + sets[claim] = make(map[string]struct{}) + } + // Add each string value to the set + for _, s := range v { + sets[claim][s] = struct{}{} + } + + case []interface{}: + // Recursively handle each item in the slice + for _, item := range v { + addClaimValue(sets, claim, item) + } + + default: + // Log unsupported types for visibility and debugging + log.Printf("Unsupported claim type for %q: %#v (type %T)", claim, val, val) + } +} + +// withExtraClaims merges flattened extra claims from a list of capRule into the provided struct v, +// returning a map[string]interface{} that combines both sources. +// +// The input struct v is first marshaled to JSON, then unmarshalled into a generic map. +// Claims defined in openIDSupportedClaims are considered protected and cannot be overwritten. +// If an extra claim attempts to overwrite a protected claim, an error is returned. +// +// Returns the merged claims map or an error if any protected claim is violated or JSON (un)marshaling fails. +func withExtraClaims(v any, rules []capRule) (map[string]interface{}, error) { + // Marshal the static struct + data, err := json.Marshal(v) + if err != nil { + return nil, err + } + + // Unmarshal into a generic map + var claimMap map[string]interface{} + if err := json.Unmarshal(data, &claimMap); err != nil { + return nil, err + } + + // Convert views.Slice to a map[string]struct{} for efficient lookup + protected := make(map[string]struct{}, len(openIDSupportedClaims.AsSlice())) + for _, claim := range openIDSupportedClaims.AsSlice() { + protected[claim] = struct{}{} + } + + // Merge extra claims + extra := flattenExtraClaims(rules) + for k, v := range extra { + if _, isProtected := protected[k]; isProtected { + log.Printf("Skip overwriting of existing claim %q", k) + return nil, fmt.Errorf("extra claim %q overwriting existing claim", k) + } + + claimMap[k] = v + } + + return claimMap, nil } func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { @@ -584,17 +719,6 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { return } - rules, err := tailcfg.UnmarshalCapJSON[capRule](who.CapMap, tailcfg.PeerCapabilityTsIDP) - if err != nil { - http.Error(w, "tsidp: failed to unmarshal capability: %v", http.StatusBadRequest) - return - } - - groups := make([]string, 0) - for _, rule := range rules { - groups = append(groups, rule.Groups...) - } - now := time.Now() _, tcd, _ := strings.Cut(n.Name(), ".") tsClaims := tailscaleClaims{ @@ -616,14 +740,26 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { UserID: n.User(), Email: who.UserProfile.LoginName, UserName: userName, - Groups: groups, } if ar.localRP { tsClaims.Issuer = s.loopbackURL } + rules, err := tailcfg.UnmarshalCapJSON[capRule](who.CapMap, tailcfg.PeerCapabilityTsIDP) + if err != nil { + http.Error(w, "tsidp: failed to unmarshal capability: %v", http.StatusBadRequest) + return + } + + tsClaimsWithExtra, err := withExtraClaims(tsClaims, rules) + if err != nil { + log.Printf("tsidp: failed to merge extra claims: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + // Create an OIDC token using this issuer's signer. - token, err := jwt.Signed(signer).Claims(tsClaims).CompactSerialize() + token, err := jwt.Signed(signer).Claims(tsClaimsWithExtra).CompactSerialize() if err != nil { log.Printf("Error getting token: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -765,9 +901,6 @@ type tailscaleClaims struct { // It is a temporary (2023-11-15) hack during development. // We should probably let this be configured via grants. UserName string `json:"username,omitempty"` - - // Groups are group memberships controlled via grants - Groups []string `json:"groups,omitempty"` } var ( @@ -776,7 +909,7 @@ var ( "sub", "aud", "exp", "iat", "iss", "jti", "nbf", "username", "email", // Tailscale claims, these correspond to fields in tailscaleClaims. - "key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid", "groups", + "key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid", }) // As defined in the OpenID spec this should be "openid". diff --git a/cmd/tsidp/tsidp_test.go b/cmd/tsidp/tsidp_test.go new file mode 100644 index 000000000..752a8b96f --- /dev/null +++ b/cmd/tsidp/tsidp_test.go @@ -0,0 +1,826 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "fmt" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "reflect" + "sort" + "strings" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/views" + "testing" + "time" +) + +// normalizeMap recursively sorts []interface{} values in a map[string]interface{} +func normalizeMap(t *testing.T, m map[string]interface{}) map[string]interface{} { + t.Helper() + normalized := make(map[string]interface{}, len(m)) + for k, v := range m { + switch val := v.(type) { + case []interface{}: + sorted := make([]string, len(val)) + for i, item := range val { + sorted[i] = fmt.Sprintf("%v", item) // convert everything to string for sorting + } + sort.Strings(sorted) + + // convert back to []interface{} + sortedIface := make([]interface{}, len(sorted)) + for i, s := range sorted { + sortedIface[i] = s + } + normalized[k] = sortedIface + + default: + normalized[k] = v + } + } + return normalized +} + +func mustMarshalJSON(t *testing.T, v any) tailcfg.RawMessage { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + return tailcfg.RawMessage(b) +} + +var privateKey *rsa.PrivateKey = nil + +func oidcTestingSigner(t *testing.T) jose.Signer { + t.Helper() + privKey := mustGeneratePrivateKey(t) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privKey}, nil) + if err != nil { + t.Fatalf("failed to create signer: %v", err) + } + return sig +} + +func oidcTestingPublicKey(t *testing.T) *rsa.PublicKey { + t.Helper() + privKey := mustGeneratePrivateKey(t) + return &privKey.PublicKey +} + +func mustGeneratePrivateKey(t *testing.T) *rsa.PrivateKey { + t.Helper() + if privateKey != nil { + return privateKey + } + + var err error + privateKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + return privateKey +} + +func TestFlattenExtraClaims(t *testing.T) { + log.SetOutput(io.Discard) // suppress log output during tests + + tests := []struct { + name string + input []capRule + expected map[string]interface{} + }{ + { + name: "empty extra claims", + input: []capRule{ + {ExtraClaims: map[string]interface{}{}}, + }, + expected: map[string]interface{}{}, + }, + { + name: "string and number values", + input: []capRule{ + { + ExtraClaims: map[string]interface{}{ + "featureA": "read", + "featureB": 42, + }, + }, + }, + expected: map[string]interface{}{ + "featureA": "read", + "featureB": "42", + }, + }, + { + name: "slice of strings and ints", + input: []capRule{ + { + ExtraClaims: map[string]interface{}{ + "roles": []interface{}{"admin", "user", 1}, + }, + }, + }, + expected: map[string]interface{}{ + "roles": []interface{}{"admin", "user", "1"}, + }, + }, + { + name: "duplicate values deduplicated (slice input)", + input: []capRule{ + { + ExtraClaims: map[string]interface{}{ + "foo": []string{"bar", "baz"}, + }, + }, + { + ExtraClaims: map[string]interface{}{ + "foo": []interface{}{"bar", "qux"}, + }, + }, + }, + expected: map[string]interface{}{ + "foo": []interface{}{"bar", "baz", "qux"}, + }, + }, + { + name: "ignore unsupported map type, keep valid scalar", + input: []capRule{ + { + ExtraClaims: map[string]interface{}{ + "invalid": map[string]interface{}{"bad": "yes"}, + "valid": "ok", + }, + }, + }, + expected: map[string]interface{}{ + "valid": "ok", + }, + }, + { + name: "scalar first, slice second", + input: []capRule{ + {ExtraClaims: map[string]interface{}{"foo": "bar"}}, + {ExtraClaims: map[string]interface{}{"foo": []interface{}{"baz"}}}, + }, + expected: map[string]interface{}{ + "foo": []interface{}{"bar", "baz"}, // since first was scalar, second being a slice forces slice output + }, + }, + { + name: "conflicting scalar and unsupported map", + input: []capRule{ + {ExtraClaims: map[string]interface{}{"foo": "bar"}}, + {ExtraClaims: map[string]interface{}{"foo": map[string]interface{}{"bad": "entry"}}}, + }, + expected: map[string]interface{}{ + "foo": "bar", // map should be ignored + }, + }, + { + name: "multiple slices with overlap", + input: []capRule{ + {ExtraClaims: map[string]interface{}{"roles": []interface{}{"admin", "user"}}}, + {ExtraClaims: map[string]interface{}{"roles": []interface{}{"admin", "guest"}}}, + }, + expected: map[string]interface{}{ + "roles": []interface{}{"admin", "user", "guest"}, + }, + }, + { + name: "slice with unsupported values", + input: []capRule{ + {ExtraClaims: map[string]interface{}{ + "mixed": []interface{}{"ok", 42, map[string]string{"oops": "fail"}}, + }}, + }, + expected: map[string]interface{}{ + "mixed": []interface{}{"ok", "42"}, // map is ignored + }, + }, + { + name: "duplicate scalar value", + input: []capRule{ + {ExtraClaims: map[string]interface{}{"env": "prod"}}, + {ExtraClaims: map[string]interface{}{"env": "prod"}}, + }, + expected: map[string]interface{}{ + "env": "prod", // not converted to slice + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := flattenExtraClaims(tt.input) + + gotNormalized := normalizeMap(t, got) + expectedNormalized := normalizeMap(t, tt.expected) + + if !reflect.DeepEqual(gotNormalized, expectedNormalized) { + t.Errorf("mismatch\nGot:\n%s\nWant:\n%s", gotNormalized, expectedNormalized) + } + }) + } +} + +func TestExtraClaims(t *testing.T) { + tests := []struct { + name string + claim tailscaleClaims + extraClaims []capRule + expected map[string]interface{} + expectError bool + }{ + { + name: "extra claim", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]interface{}{ + "foo": []string{"bar"}, + }, + }, + }, + expected: map[string]interface{}{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + "foo": []interface{}{"bar"}, + }, + }, + { + name: "duplicate claim distinct values", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]interface{}{ + "foo": []string{"bar"}, + }, + }, + { + ExtraClaims: map[string]interface{}{ + "foo": []string{"foobar"}, + }, + }, + }, + expected: map[string]interface{}{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + "foo": []interface{}{"foobar", "bar"}, + }, + }, + { + name: "multiple extra claims", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]interface{}{ + "foo": []string{"bar"}, + }, + }, + { + ExtraClaims: map[string]interface{}{ + "bar": []string{"foo"}, + }, + }, + }, + expected: map[string]interface{}{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + "foo": []interface{}{"bar"}, + "bar": []interface{}{"foo"}, + }, + }, + { + name: "overwrite claim", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]interface{}{ + "username": "foobar", + }, + }, + }, + expected: map[string]interface{}{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "foobar", + }, + expectError: true, + }, + { + name: "empty extra claims", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{{ExtraClaims: map[string]interface{}{}}}, + expected: map[string]interface{}{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims, err := withExtraClaims(tt.claim, tt.extraClaims) + if err != nil { + if tt.expectError { + if err == nil { + t.Fatalf("expected error, got nil") + } + return + } else { + t.Errorf("claim.withExtraClaims() unexpected error = %v", err) + } + } + + // Marshal to JSON then unmarshal back to map[string]interface{} + gotClaims, err := json.Marshal(claims) + if err != nil { + t.Errorf("json.Marshal(claims) error = %v", err) + } + + var gotClaimsMap map[string]interface{} + if err := json.Unmarshal(gotClaims, &gotClaimsMap); err != nil { + t.Fatalf("json.Unmarshal(gotClaims) error = %v", err) + } + + gotNormalized := normalizeMap(t, gotClaimsMap) + expectedNormalized := normalizeMap(t, tt.expected) + + if !reflect.DeepEqual(gotNormalized, expectedNormalized) { + t.Errorf("claims mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized) + } + }) + } +} + +func TestServeToken(t *testing.T) { + tests := []struct { + name string + caps tailcfg.PeerCapMap + method string + grantType string + code string + omitCode bool + redirectURI string + remoteAddr string + expectError bool + expected map[string]interface{} + }{ + { + name: "GET not allowed", + method: "GET", + grantType: "authorization_code", + expectError: true, + }, + { + name: "unsupported grant type", + method: "POST", + grantType: "pkcs", + expectError: true, + }, + { + name: "invalid code", + method: "POST", + grantType: "authorization_code", + code: "invalid-code", + expectError: true, + }, + { + name: "omit code from form", + method: "POST", + grantType: "authorization_code", + omitCode: true, + expectError: true, + }, + { + name: "invalid redirect uri", + method: "POST", + grantType: "authorization_code", + code: "valid-code", + redirectURI: "https://invalid.example.com/callback", + remoteAddr: "127.0.0.1:12345", + expectError: true, + }, + { + name: "invalid remoteAddr", + method: "POST", + grantType: "authorization_code", + redirectURI: "https://rp.example.com/callback", + code: "valid-code", + remoteAddr: "192.168.0.1:12345", + expectError: true, + }, + { + name: "extra claim included", + method: "POST", + grantType: "authorization_code", + redirectURI: "https://rp.example.com/callback", + code: "valid-code", + remoteAddr: "127.0.0.1:12345", + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]interface{}{ + "foo": "bar", + }, + }), + }, + }, + expected: map[string]interface{}{ + "foo": "bar", + }, + }, + { + name: "attempt to overwrite protected claim", + method: "POST", + grantType: "authorization_code", + redirectURI: "https://rp.example.com/callback", + code: "valid-code", + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]interface{}{ + "sub": "should-not-overwrite", + }, + }), + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + now := time.Now() + + // Fake user/node + profile := &tailcfg.UserProfile{ + LoginName: "alice@example.com", + DisplayName: "Alice Example", + ProfilePicURL: "https://example.com/alice.jpg", + } + node := &tailcfg.Node{ + ID: 123, + Name: "test-node.test.ts.net.", + User: 456, + Key: key.NodePublic{}, + Cap: 1, + DiscoKey: key.DiscoPublic{}, + } + + remoteUser := &apitype.WhoIsResponse{ + Node: node, + UserProfile: profile, + CapMap: tt.caps, + } + + s := &idpServer{ + code: map[string]*authRequest{ + "valid-code": { + clientID: "client-id", + nonce: "nonce123", + redirectURI: "https://rp.example.com/callback", + validTill: now.Add(5 * time.Minute), + remoteUser: remoteUser, + localRP: true, + }, + }, + } + // Inject a working signer + s.lazySigner.Set(oidcTestingSigner(t)) + + form := url.Values{} + form.Set("grant_type", tt.grantType) + form.Set("redirect_uri", tt.redirectURI) + if !tt.omitCode { + form.Set("code", tt.code) + } + + req := httptest.NewRequest(tt.method, "/token", strings.NewReader(form.Encode())) + req.RemoteAddr = tt.remoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + + s.serveToken(rr, req) + + if tt.expectError { + if rr.Code == http.StatusOK { + t.Fatalf("expected error, got 200 OK: %s", rr.Body.String()) + } + return + } + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String()) + } + + var resp struct { + IDToken string `json:"id_token"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + tok, err := jwt.ParseSigned(resp.IDToken) + if err != nil { + t.Fatalf("failed to parse ID token: %v", err) + } + + out := make(map[string]interface{}) + if err := tok.Claims(oidcTestingPublicKey(t), &out); err != nil { + t.Fatalf("failed to extract claims: %v", err) + } + + for k, want := range tt.expected { + got, ok := out[k] + if !ok { + t.Errorf("missing expected claim %q", k) + continue + } + if !reflect.DeepEqual(got, want) { + t.Errorf("claim %q: got %v, want %v", k, got, want) + } + } + }) + } +} + +func TestExtraUserInfo(t *testing.T) { + tests := []struct { + name string + caps tailcfg.PeerCapMap + tokenValidTill time.Time + expected map[string]interface{} + expectError bool + }{ + { + name: "extra claim", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]interface{}{ + "foo": []string{"bar"}, + }, + }), + }, + }, + expected: map[string]interface{}{ + "foo": []interface{}{"bar"}, + }, + }, + { + name: "duplicate claim distinct values", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]interface{}{ + "foo": []string{"bar", "foobar"}, + }, + }), + }, + }, + expected: map[string]interface{}{ + "foo": []interface{}{"bar", "foobar"}, + }, + }, + { + name: "multiple extra claims", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]interface{}{ + "foo": "bar", + "bar": "foo", + }, + }), + }, + }, + expected: map[string]interface{}{ + "foo": "bar", + "bar": "foo", + }, + }, + { + name: "empty extra claims", + caps: tailcfg.PeerCapMap{}, + tokenValidTill: time.Now().Add(1 * time.Minute), + expected: map[string]interface{}{}, + }, + { + name: "attempt to overwrite protected claim", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]interface{}{ + "sub": "should-not-overwrite", + "foo": "ok", + }, + }), + }, + }, + expectError: true, + }, + { + name: "extra claim omitted", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: false, + ExtraClaims: map[string]interface{}{ + "foo": "ok", + }, + }), + }, + }, + expected: map[string]interface{}{}, + }, + { + name: "expired token", + caps: tailcfg.PeerCapMap{}, + tokenValidTill: time.Now().Add(-1 * time.Minute), + expected: map[string]interface{}{}, + expectError: true, + }, + } + token := "valid-token" + + // Create a fake tailscale Node + node := &tailcfg.Node{ + ID: 123, + Name: "test-node.test.ts.net.", + User: 456, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // Construct the remote user + profile := tailcfg.UserProfile{ + LoginName: "alice@example.com", + DisplayName: "Alice Example", + ProfilePicURL: "https://example.com/alice.jpg", + } + + remoteUser := &apitype.WhoIsResponse{ + Node: node, + UserProfile: &profile, + CapMap: tt.caps, + } + + // Insert a valid token into the idpServer + s := &idpServer{ + accessToken: map[string]*authRequest{ + token: { + validTill: tt.tokenValidTill, + remoteUser: remoteUser, + }, + }, + } + + // Construct request + req := httptest.NewRequest("GET", "/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + + // Call the method under test + s.serveUserInfo(rr, req) + + if tt.expectError { + if rr.Code == http.StatusOK { + t.Fatalf("expected error, got %d: %s", rr.Code, rr.Body.String()) + } + return + } + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse JSON response: %v", err) + } + + // Construct expected + tt.expected["sub"] = remoteUser.Node.User.String() + tt.expected["name"] = profile.DisplayName + tt.expected["email"] = profile.LoginName + tt.expected["picture"] = profile.ProfilePicURL + tt.expected["username"], _, _ = strings.Cut(profile.LoginName, "@") + + gotNormalized := normalizeMap(t, resp) + expectedNormalized := normalizeMap(t, tt.expected) + + if !reflect.DeepEqual(gotNormalized, expectedNormalized) { + t.Errorf("UserInfo mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized) + } + }) + } +} From fb281269187c9a81260d3da93dc4c6e58a4dbeb7 Mon Sep 17 00:00:00 2001 From: cedi Date: Fri, 21 Mar 2025 20:57:05 +0100 Subject: [PATCH 3/3] Update cmd/tsidp/tsidp.go Signed-off-by: cedi --- cmd/tsidp/tsidp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index e17507a62..bec26ca14 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -754,7 +754,7 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { tsClaimsWithExtra, err := withExtraClaims(tsClaims, rules) if err != nil { log.Printf("tsidp: failed to merge extra claims: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, err.Error(), http.StatusBadRequest) return }