mirror of
https://github.com/tailscale/tailscale.git
synced 2025-03-28 12:02:23 +00:00
Merge fb281269187c9a81260d3da93dc4c6e58a4dbeb7 into b3455fa99a5e8d07133d5140017ec7c49f032a07
This commit is contained in:
commit
deca3dbfb3
@ -497,6 +497,7 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ui.Sub = ar.remoteUser.Node.User.String()
|
||||
ui.Name = ar.remoteUser.UserProfile.DisplayName
|
||||
ui.Email = ar.remoteUser.UserProfile.LoginName
|
||||
@ -505,8 +506,29 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO(maisem): not sure if this is the right thing to do
|
||||
ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@")
|
||||
|
||||
rules, err := tailcfg.UnmarshalCapJSON[capRule](ar.remoteUser.CapMap, tailcfg.PeerCapabilityTsIDP)
|
||||
if err != nil {
|
||||
http.Error(w, "tsidp: failed to unmarshal capability: %v", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Only keep rules where IncludeInUserInfo is true
|
||||
var filtered []capRule
|
||||
for _, r := range rules {
|
||||
if r.IncludeInUserInfo {
|
||||
filtered = append(filtered, r)
|
||||
}
|
||||
}
|
||||
|
||||
userInfo, err := withExtraClaims(ui, filtered)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Write the final result
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(ui); err != nil {
|
||||
if err := json.NewEncoder(w).Encode(userInfo); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
@ -519,6 +541,137 @@ type userInfo struct {
|
||||
UserName string `json:"username"`
|
||||
}
|
||||
|
||||
type capRule struct {
|
||||
IncludeInUserInfo bool `json:"includeInUserInfo"`
|
||||
ExtraClaims map[string]interface{} `json:"extraClaims,omitempty"` // list of features peer is allowed to edit
|
||||
}
|
||||
|
||||
// flattenExtraClaims merges all ExtraClaims from a slice of capRule into a single map.
|
||||
// It deduplicates values for each claim and preserves the original input type:
|
||||
// scalar values remain scalars, and slices are returned as deduplicated []interface{} slices.
|
||||
func flattenExtraClaims(rules []capRule) map[string]interface{} {
|
||||
// sets stores deduplicated stringified values for each claim key.
|
||||
sets := make(map[string]map[string]struct{})
|
||||
|
||||
// isSlice tracks whether each claim was originally provided as a slice.
|
||||
isSlice := make(map[string]bool)
|
||||
|
||||
for _, rule := range rules {
|
||||
for claim, raw := range rule.ExtraClaims {
|
||||
// Track whether the claim was provided as a slice
|
||||
switch raw.(type) {
|
||||
case []string, []interface{}:
|
||||
isSlice[claim] = true
|
||||
default:
|
||||
// Only mark as scalar if this is the first time we've seen this claim
|
||||
if _, seen := isSlice[claim]; !seen {
|
||||
isSlice[claim] = false
|
||||
}
|
||||
}
|
||||
|
||||
// Add the claim value(s) into the deduplication set
|
||||
addClaimValue(sets, claim, raw)
|
||||
}
|
||||
}
|
||||
|
||||
// Build final result: either scalar or slice depending on original type
|
||||
result := make(map[string]interface{})
|
||||
for claim, valSet := range sets {
|
||||
if isSlice[claim] {
|
||||
// Claim was provided as a slice: output as []interface{}
|
||||
var vals []interface{}
|
||||
for val := range valSet {
|
||||
vals = append(vals, val)
|
||||
}
|
||||
result[claim] = vals
|
||||
} else {
|
||||
// Claim was a scalar: return a single value
|
||||
for val := range valSet {
|
||||
result[claim] = val
|
||||
break // only one value is expected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// addClaimValue adds a claim value to the deduplication set for a given claim key.
|
||||
// It accepts scalars (string, int, float64), slices of strings or interfaces,
|
||||
// and recursively handles nested slices. Unsupported types are ignored with a log message.
|
||||
func addClaimValue(sets map[string]map[string]struct{}, claim string, val interface{}) {
|
||||
switch v := val.(type) {
|
||||
case string, float64, int, int64:
|
||||
// Ensure the claim set is initialized
|
||||
if sets[claim] == nil {
|
||||
sets[claim] = make(map[string]struct{})
|
||||
}
|
||||
// Add the stringified scalar to the set
|
||||
sets[claim][fmt.Sprintf("%v", v)] = struct{}{}
|
||||
|
||||
case []string:
|
||||
// Ensure the claim set is initialized
|
||||
if sets[claim] == nil {
|
||||
sets[claim] = make(map[string]struct{})
|
||||
}
|
||||
// Add each string value to the set
|
||||
for _, s := range v {
|
||||
sets[claim][s] = struct{}{}
|
||||
}
|
||||
|
||||
case []interface{}:
|
||||
// Recursively handle each item in the slice
|
||||
for _, item := range v {
|
||||
addClaimValue(sets, claim, item)
|
||||
}
|
||||
|
||||
default:
|
||||
// Log unsupported types for visibility and debugging
|
||||
log.Printf("Unsupported claim type for %q: %#v (type %T)", claim, val, val)
|
||||
}
|
||||
}
|
||||
|
||||
// withExtraClaims merges flattened extra claims from a list of capRule into the provided struct v,
|
||||
// returning a map[string]interface{} that combines both sources.
|
||||
//
|
||||
// The input struct v is first marshaled to JSON, then unmarshalled into a generic map.
|
||||
// Claims defined in openIDSupportedClaims are considered protected and cannot be overwritten.
|
||||
// If an extra claim attempts to overwrite a protected claim, an error is returned.
|
||||
//
|
||||
// Returns the merged claims map or an error if any protected claim is violated or JSON (un)marshaling fails.
|
||||
func withExtraClaims(v any, rules []capRule) (map[string]interface{}, error) {
|
||||
// Marshal the static struct
|
||||
data, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unmarshal into a generic map
|
||||
var claimMap map[string]interface{}
|
||||
if err := json.Unmarshal(data, &claimMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert views.Slice to a map[string]struct{} for efficient lookup
|
||||
protected := make(map[string]struct{}, len(openIDSupportedClaims.AsSlice()))
|
||||
for _, claim := range openIDSupportedClaims.AsSlice() {
|
||||
protected[claim] = struct{}{}
|
||||
}
|
||||
|
||||
// Merge extra claims
|
||||
extra := flattenExtraClaims(rules)
|
||||
for k, v := range extra {
|
||||
if _, isProtected := protected[k]; isProtected {
|
||||
log.Printf("Skip overwriting of existing claim %q", k)
|
||||
return nil, fmt.Errorf("extra claim %q overwriting existing claim", k)
|
||||
}
|
||||
|
||||
claimMap[k] = v
|
||||
}
|
||||
|
||||
return claimMap, nil
|
||||
}
|
||||
|
||||
func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed)
|
||||
@ -595,8 +748,21 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
|
||||
tsClaims.Issuer = s.loopbackURL
|
||||
}
|
||||
|
||||
rules, err := tailcfg.UnmarshalCapJSON[capRule](who.CapMap, tailcfg.PeerCapabilityTsIDP)
|
||||
if err != nil {
|
||||
http.Error(w, "tsidp: failed to unmarshal capability: %v", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tsClaimsWithExtra, err := withExtraClaims(tsClaims, rules)
|
||||
if err != nil {
|
||||
log.Printf("tsidp: failed to merge extra claims: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Create an OIDC token using this issuer's signer.
|
||||
token, err := jwt.Signed(signer).Claims(tsClaims).CompactSerialize()
|
||||
token, err := jwt.Signed(signer).Claims(tsClaimsWithExtra).CompactSerialize()
|
||||
if err != nil {
|
||||
log.Printf("Error getting token: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
826
cmd/tsidp/tsidp_test.go
Normal file
826
cmd/tsidp/tsidp_test.go
Normal file
@ -0,0 +1,826 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/views"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// normalizeMap recursively sorts []interface{} values in a map[string]interface{}
|
||||
func normalizeMap(t *testing.T, m map[string]interface{}) map[string]interface{} {
|
||||
t.Helper()
|
||||
normalized := make(map[string]interface{}, len(m))
|
||||
for k, v := range m {
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
sorted := make([]string, len(val))
|
||||
for i, item := range val {
|
||||
sorted[i] = fmt.Sprintf("%v", item) // convert everything to string for sorting
|
||||
}
|
||||
sort.Strings(sorted)
|
||||
|
||||
// convert back to []interface{}
|
||||
sortedIface := make([]interface{}, len(sorted))
|
||||
for i, s := range sorted {
|
||||
sortedIface[i] = s
|
||||
}
|
||||
normalized[k] = sortedIface
|
||||
|
||||
default:
|
||||
normalized[k] = v
|
||||
}
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func mustMarshalJSON(t *testing.T, v any) tailcfg.RawMessage {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return tailcfg.RawMessage(b)
|
||||
}
|
||||
|
||||
var privateKey *rsa.PrivateKey = nil
|
||||
|
||||
func oidcTestingSigner(t *testing.T) jose.Signer {
|
||||
t.Helper()
|
||||
privKey := mustGeneratePrivateKey(t)
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privKey}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create signer: %v", err)
|
||||
}
|
||||
return sig
|
||||
}
|
||||
|
||||
func oidcTestingPublicKey(t *testing.T) *rsa.PublicKey {
|
||||
t.Helper()
|
||||
privKey := mustGeneratePrivateKey(t)
|
||||
return &privKey.PublicKey
|
||||
}
|
||||
|
||||
func mustGeneratePrivateKey(t *testing.T) *rsa.PrivateKey {
|
||||
t.Helper()
|
||||
if privateKey != nil {
|
||||
return privateKey
|
||||
}
|
||||
|
||||
var err error
|
||||
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
return privateKey
|
||||
}
|
||||
|
||||
func TestFlattenExtraClaims(t *testing.T) {
|
||||
log.SetOutput(io.Discard) // suppress log output during tests
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []capRule
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "empty extra claims",
|
||||
input: []capRule{
|
||||
{ExtraClaims: map[string]interface{}{}},
|
||||
},
|
||||
expected: map[string]interface{}{},
|
||||
},
|
||||
{
|
||||
name: "string and number values",
|
||||
input: []capRule{
|
||||
{
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"featureA": "read",
|
||||
"featureB": 42,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"featureA": "read",
|
||||
"featureB": "42",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "slice of strings and ints",
|
||||
input: []capRule{
|
||||
{
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"roles": []interface{}{"admin", "user", 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"roles": []interface{}{"admin", "user", "1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "duplicate values deduplicated (slice input)",
|
||||
input: []capRule{
|
||||
{
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"foo": []string{"bar", "baz"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"foo": []interface{}{"bar", "qux"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"foo": []interface{}{"bar", "baz", "qux"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ignore unsupported map type, keep valid scalar",
|
||||
input: []capRule{
|
||||
{
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"invalid": map[string]interface{}{"bad": "yes"},
|
||||
"valid": "ok",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"valid": "ok",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "scalar first, slice second",
|
||||
input: []capRule{
|
||||
{ExtraClaims: map[string]interface{}{"foo": "bar"}},
|
||||
{ExtraClaims: map[string]interface{}{"foo": []interface{}{"baz"}}},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"foo": []interface{}{"bar", "baz"}, // since first was scalar, second being a slice forces slice output
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "conflicting scalar and unsupported map",
|
||||
input: []capRule{
|
||||
{ExtraClaims: map[string]interface{}{"foo": "bar"}},
|
||||
{ExtraClaims: map[string]interface{}{"foo": map[string]interface{}{"bad": "entry"}}},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"foo": "bar", // map should be ignored
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple slices with overlap",
|
||||
input: []capRule{
|
||||
{ExtraClaims: map[string]interface{}{"roles": []interface{}{"admin", "user"}}},
|
||||
{ExtraClaims: map[string]interface{}{"roles": []interface{}{"admin", "guest"}}},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"roles": []interface{}{"admin", "user", "guest"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "slice with unsupported values",
|
||||
input: []capRule{
|
||||
{ExtraClaims: map[string]interface{}{
|
||||
"mixed": []interface{}{"ok", 42, map[string]string{"oops": "fail"}},
|
||||
}},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"mixed": []interface{}{"ok", "42"}, // map is ignored
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "duplicate scalar value",
|
||||
input: []capRule{
|
||||
{ExtraClaims: map[string]interface{}{"env": "prod"}},
|
||||
{ExtraClaims: map[string]interface{}{"env": "prod"}},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"env": "prod", // not converted to slice
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := flattenExtraClaims(tt.input)
|
||||
|
||||
gotNormalized := normalizeMap(t, got)
|
||||
expectedNormalized := normalizeMap(t, tt.expected)
|
||||
|
||||
if !reflect.DeepEqual(gotNormalized, expectedNormalized) {
|
||||
t.Errorf("mismatch\nGot:\n%s\nWant:\n%s", gotNormalized, expectedNormalized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraClaims(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
claim tailscaleClaims
|
||||
extraClaims []capRule
|
||||
expected map[string]interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "extra claim",
|
||||
claim: tailscaleClaims{
|
||||
Claims: jwt.Claims{},
|
||||
Nonce: "foobar",
|
||||
Key: key.NodePublic{},
|
||||
Addresses: views.Slice[netip.Prefix]{},
|
||||
NodeID: 0,
|
||||
NodeName: "test-node",
|
||||
Tailnet: "test.ts.net",
|
||||
Email: "test@example.com",
|
||||
UserID: 0,
|
||||
UserName: "test",
|
||||
},
|
||||
extraClaims: []capRule{
|
||||
{
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"foo": []string{"bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"nonce": "foobar",
|
||||
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
||||
"addresses": nil,
|
||||
"nid": float64(0),
|
||||
"node": "test-node",
|
||||
"tailnet": "test.ts.net",
|
||||
"email": "test@example.com",
|
||||
"username": "test",
|
||||
"foo": []interface{}{"bar"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "duplicate claim distinct values",
|
||||
claim: tailscaleClaims{
|
||||
Claims: jwt.Claims{},
|
||||
Nonce: "foobar",
|
||||
Key: key.NodePublic{},
|
||||
Addresses: views.Slice[netip.Prefix]{},
|
||||
NodeID: 0,
|
||||
NodeName: "test-node",
|
||||
Tailnet: "test.ts.net",
|
||||
Email: "test@example.com",
|
||||
UserID: 0,
|
||||
UserName: "test",
|
||||
},
|
||||
extraClaims: []capRule{
|
||||
{
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"foo": []string{"bar"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"foo": []string{"foobar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"nonce": "foobar",
|
||||
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
||||
"addresses": nil,
|
||||
"nid": float64(0),
|
||||
"node": "test-node",
|
||||
"tailnet": "test.ts.net",
|
||||
"email": "test@example.com",
|
||||
"username": "test",
|
||||
"foo": []interface{}{"foobar", "bar"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple extra claims",
|
||||
claim: tailscaleClaims{
|
||||
Claims: jwt.Claims{},
|
||||
Nonce: "foobar",
|
||||
Key: key.NodePublic{},
|
||||
Addresses: views.Slice[netip.Prefix]{},
|
||||
NodeID: 0,
|
||||
NodeName: "test-node",
|
||||
Tailnet: "test.ts.net",
|
||||
Email: "test@example.com",
|
||||
UserID: 0,
|
||||
UserName: "test",
|
||||
},
|
||||
extraClaims: []capRule{
|
||||
{
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"foo": []string{"bar"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"bar": []string{"foo"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"nonce": "foobar",
|
||||
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
||||
"addresses": nil,
|
||||
"nid": float64(0),
|
||||
"node": "test-node",
|
||||
"tailnet": "test.ts.net",
|
||||
"email": "test@example.com",
|
||||
"username": "test",
|
||||
"foo": []interface{}{"bar"},
|
||||
"bar": []interface{}{"foo"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "overwrite claim",
|
||||
claim: tailscaleClaims{
|
||||
Claims: jwt.Claims{},
|
||||
Nonce: "foobar",
|
||||
Key: key.NodePublic{},
|
||||
Addresses: views.Slice[netip.Prefix]{},
|
||||
NodeID: 0,
|
||||
NodeName: "test-node",
|
||||
Tailnet: "test.ts.net",
|
||||
Email: "test@example.com",
|
||||
UserID: 0,
|
||||
UserName: "test",
|
||||
},
|
||||
extraClaims: []capRule{
|
||||
{
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"username": "foobar",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"nonce": "foobar",
|
||||
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
||||
"addresses": nil,
|
||||
"nid": float64(0),
|
||||
"node": "test-node",
|
||||
"tailnet": "test.ts.net",
|
||||
"email": "test@example.com",
|
||||
"username": "foobar",
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "empty extra claims",
|
||||
claim: tailscaleClaims{
|
||||
Claims: jwt.Claims{},
|
||||
Nonce: "foobar",
|
||||
Key: key.NodePublic{},
|
||||
Addresses: views.Slice[netip.Prefix]{},
|
||||
NodeID: 0,
|
||||
NodeName: "test-node",
|
||||
Tailnet: "test.ts.net",
|
||||
Email: "test@example.com",
|
||||
UserID: 0,
|
||||
UserName: "test",
|
||||
},
|
||||
extraClaims: []capRule{{ExtraClaims: map[string]interface{}{}}},
|
||||
expected: map[string]interface{}{
|
||||
"nonce": "foobar",
|
||||
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
||||
"addresses": nil,
|
||||
"nid": float64(0),
|
||||
"node": "test-node",
|
||||
"tailnet": "test.ts.net",
|
||||
"email": "test@example.com",
|
||||
"username": "test",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claims, err := withExtraClaims(tt.claim, tt.extraClaims)
|
||||
if err != nil {
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error, got nil")
|
||||
}
|
||||
return
|
||||
} else {
|
||||
t.Errorf("claim.withExtraClaims() unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal to JSON then unmarshal back to map[string]interface{}
|
||||
gotClaims, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
t.Errorf("json.Marshal(claims) error = %v", err)
|
||||
}
|
||||
|
||||
var gotClaimsMap map[string]interface{}
|
||||
if err := json.Unmarshal(gotClaims, &gotClaimsMap); err != nil {
|
||||
t.Fatalf("json.Unmarshal(gotClaims) error = %v", err)
|
||||
}
|
||||
|
||||
gotNormalized := normalizeMap(t, gotClaimsMap)
|
||||
expectedNormalized := normalizeMap(t, tt.expected)
|
||||
|
||||
if !reflect.DeepEqual(gotNormalized, expectedNormalized) {
|
||||
t.Errorf("claims mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
caps tailcfg.PeerCapMap
|
||||
method string
|
||||
grantType string
|
||||
code string
|
||||
omitCode bool
|
||||
redirectURI string
|
||||
remoteAddr string
|
||||
expectError bool
|
||||
expected map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "GET not allowed",
|
||||
method: "GET",
|
||||
grantType: "authorization_code",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "unsupported grant type",
|
||||
method: "POST",
|
||||
grantType: "pkcs",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid code",
|
||||
method: "POST",
|
||||
grantType: "authorization_code",
|
||||
code: "invalid-code",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "omit code from form",
|
||||
method: "POST",
|
||||
grantType: "authorization_code",
|
||||
omitCode: true,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid redirect uri",
|
||||
method: "POST",
|
||||
grantType: "authorization_code",
|
||||
code: "valid-code",
|
||||
redirectURI: "https://invalid.example.com/callback",
|
||||
remoteAddr: "127.0.0.1:12345",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid remoteAddr",
|
||||
method: "POST",
|
||||
grantType: "authorization_code",
|
||||
redirectURI: "https://rp.example.com/callback",
|
||||
code: "valid-code",
|
||||
remoteAddr: "192.168.0.1:12345",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "extra claim included",
|
||||
method: "POST",
|
||||
grantType: "authorization_code",
|
||||
redirectURI: "https://rp.example.com/callback",
|
||||
code: "valid-code",
|
||||
remoteAddr: "127.0.0.1:12345",
|
||||
caps: tailcfg.PeerCapMap{
|
||||
tailcfg.PeerCapabilityTsIDP: {
|
||||
mustMarshalJSON(t, capRule{
|
||||
IncludeInUserInfo: true,
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "attempt to overwrite protected claim",
|
||||
method: "POST",
|
||||
grantType: "authorization_code",
|
||||
redirectURI: "https://rp.example.com/callback",
|
||||
code: "valid-code",
|
||||
caps: tailcfg.PeerCapMap{
|
||||
tailcfg.PeerCapabilityTsIDP: {
|
||||
mustMarshalJSON(t, capRule{
|
||||
IncludeInUserInfo: true,
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"sub": "should-not-overwrite",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
// Fake user/node
|
||||
profile := &tailcfg.UserProfile{
|
||||
LoginName: "alice@example.com",
|
||||
DisplayName: "Alice Example",
|
||||
ProfilePicURL: "https://example.com/alice.jpg",
|
||||
}
|
||||
node := &tailcfg.Node{
|
||||
ID: 123,
|
||||
Name: "test-node.test.ts.net.",
|
||||
User: 456,
|
||||
Key: key.NodePublic{},
|
||||
Cap: 1,
|
||||
DiscoKey: key.DiscoPublic{},
|
||||
}
|
||||
|
||||
remoteUser := &apitype.WhoIsResponse{
|
||||
Node: node,
|
||||
UserProfile: profile,
|
||||
CapMap: tt.caps,
|
||||
}
|
||||
|
||||
s := &idpServer{
|
||||
code: map[string]*authRequest{
|
||||
"valid-code": {
|
||||
clientID: "client-id",
|
||||
nonce: "nonce123",
|
||||
redirectURI: "https://rp.example.com/callback",
|
||||
validTill: now.Add(5 * time.Minute),
|
||||
remoteUser: remoteUser,
|
||||
localRP: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
// Inject a working signer
|
||||
s.lazySigner.Set(oidcTestingSigner(t))
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("grant_type", tt.grantType)
|
||||
form.Set("redirect_uri", tt.redirectURI)
|
||||
if !tt.omitCode {
|
||||
form.Set("code", tt.code)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(tt.method, "/token", strings.NewReader(form.Encode()))
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
s.serveToken(rr, req)
|
||||
|
||||
if tt.expectError {
|
||||
if rr.Code == http.StatusOK {
|
||||
t.Fatalf("expected error, got 200 OK: %s", rr.Body.String())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
tok, err := jwt.ParseSigned(resp.IDToken)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse ID token: %v", err)
|
||||
}
|
||||
|
||||
out := make(map[string]interface{})
|
||||
if err := tok.Claims(oidcTestingPublicKey(t), &out); err != nil {
|
||||
t.Fatalf("failed to extract claims: %v", err)
|
||||
}
|
||||
|
||||
for k, want := range tt.expected {
|
||||
got, ok := out[k]
|
||||
if !ok {
|
||||
t.Errorf("missing expected claim %q", k)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("claim %q: got %v, want %v", k, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtraUserInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
caps tailcfg.PeerCapMap
|
||||
tokenValidTill time.Time
|
||||
expected map[string]interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "extra claim",
|
||||
tokenValidTill: time.Now().Add(1 * time.Minute),
|
||||
caps: tailcfg.PeerCapMap{
|
||||
tailcfg.PeerCapabilityTsIDP: {
|
||||
mustMarshalJSON(t, capRule{
|
||||
IncludeInUserInfo: true,
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"foo": []string{"bar"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"foo": []interface{}{"bar"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "duplicate claim distinct values",
|
||||
tokenValidTill: time.Now().Add(1 * time.Minute),
|
||||
caps: tailcfg.PeerCapMap{
|
||||
tailcfg.PeerCapabilityTsIDP: {
|
||||
mustMarshalJSON(t, capRule{
|
||||
IncludeInUserInfo: true,
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"foo": []string{"bar", "foobar"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"foo": []interface{}{"bar", "foobar"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple extra claims",
|
||||
tokenValidTill: time.Now().Add(1 * time.Minute),
|
||||
caps: tailcfg.PeerCapMap{
|
||||
tailcfg.PeerCapabilityTsIDP: {
|
||||
mustMarshalJSON(t, capRule{
|
||||
IncludeInUserInfo: true,
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"bar": "foo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"bar": "foo",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty extra claims",
|
||||
caps: tailcfg.PeerCapMap{},
|
||||
tokenValidTill: time.Now().Add(1 * time.Minute),
|
||||
expected: map[string]interface{}{},
|
||||
},
|
||||
{
|
||||
name: "attempt to overwrite protected claim",
|
||||
tokenValidTill: time.Now().Add(1 * time.Minute),
|
||||
caps: tailcfg.PeerCapMap{
|
||||
tailcfg.PeerCapabilityTsIDP: {
|
||||
mustMarshalJSON(t, capRule{
|
||||
IncludeInUserInfo: true,
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"sub": "should-not-overwrite",
|
||||
"foo": "ok",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "extra claim omitted",
|
||||
tokenValidTill: time.Now().Add(1 * time.Minute),
|
||||
caps: tailcfg.PeerCapMap{
|
||||
tailcfg.PeerCapabilityTsIDP: {
|
||||
mustMarshalJSON(t, capRule{
|
||||
IncludeInUserInfo: false,
|
||||
ExtraClaims: map[string]interface{}{
|
||||
"foo": "ok",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expected: map[string]interface{}{},
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
caps: tailcfg.PeerCapMap{},
|
||||
tokenValidTill: time.Now().Add(-1 * time.Minute),
|
||||
expected: map[string]interface{}{},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
token := "valid-token"
|
||||
|
||||
// Create a fake tailscale Node
|
||||
node := &tailcfg.Node{
|
||||
ID: 123,
|
||||
Name: "test-node.test.ts.net.",
|
||||
User: 456,
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
// Construct the remote user
|
||||
profile := tailcfg.UserProfile{
|
||||
LoginName: "alice@example.com",
|
||||
DisplayName: "Alice Example",
|
||||
ProfilePicURL: "https://example.com/alice.jpg",
|
||||
}
|
||||
|
||||
remoteUser := &apitype.WhoIsResponse{
|
||||
Node: node,
|
||||
UserProfile: &profile,
|
||||
CapMap: tt.caps,
|
||||
}
|
||||
|
||||
// Insert a valid token into the idpServer
|
||||
s := &idpServer{
|
||||
accessToken: map[string]*authRequest{
|
||||
token: {
|
||||
validTill: tt.tokenValidTill,
|
||||
remoteUser: remoteUser,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Construct request
|
||||
req := httptest.NewRequest("GET", "/userinfo", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Call the method under test
|
||||
s.serveUserInfo(rr, req)
|
||||
|
||||
if tt.expectError {
|
||||
if rr.Code == http.StatusOK {
|
||||
t.Fatalf("expected error, got %d: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Construct expected
|
||||
tt.expected["sub"] = remoteUser.Node.User.String()
|
||||
tt.expected["name"] = profile.DisplayName
|
||||
tt.expected["email"] = profile.LoginName
|
||||
tt.expected["picture"] = profile.ProfilePicURL
|
||||
tt.expected["username"], _, _ = strings.Cut(profile.LoginName, "@")
|
||||
|
||||
gotNormalized := normalizeMap(t, resp)
|
||||
expectedNormalized := normalizeMap(t, tt.expected)
|
||||
|
||||
if !reflect.DeepEqual(gotNormalized, expectedNormalized) {
|
||||
t.Errorf("UserInfo mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1462,6 +1462,11 @@ const (
|
||||
// user groups as Kubernetes user groups. This capability is read by
|
||||
// peers that are Tailscale Kubernetes operator instances.
|
||||
PeerCapabilityKubernetes PeerCapability = "tailscale.com/cap/kubernetes"
|
||||
|
||||
// PeerCapabilityTsIDP grants a peer tsidp-specific
|
||||
// capabilities, such as the ability to add user groups to the OIDC
|
||||
// claim
|
||||
PeerCapabilityTsIDP PeerCapability = "tailscale.com/cap/tsidp"
|
||||
)
|
||||
|
||||
// NodeCapMap is a map of capabilities to their optional values. It is valid for
|
||||
|
Loading…
x
Reference in New Issue
Block a user