cmd/tsidp: accept any client_id, not just 'unused'

Change-Id: Ia0185d6bdf8416fd5fd64559e8f43ec94fa5b7d5
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2023-11-14 19:11:26 -08:00
parent 678a3bf88a
commit d5f7500d83

View File

@ -18,6 +18,7 @@
"net/netip" "net/netip"
"net/url" "net/url"
"os" "os"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -80,16 +81,18 @@ type idpServer struct {
lazySigningKey lazy.SyncValue[*signingKey] lazySigningKey lazy.SyncValue[*signingKey]
lazySigner lazy.SyncValue[jose.Signer] lazySigner lazy.SyncValue[jose.Signer]
mu sync.Mutex // guards the fields below mu sync.Mutex // guards the fields below
code map[string]*authRequest code map[string]*authRequest // keyed by random hex
accessToken map[string]*authRequest accessToken map[string]*authRequest // keyed by random hex
} }
type authRequest struct { type authRequest struct {
// requesterNodeID is the node who requested the auth (say synology), not the node // rpNodeID is the NodeID of the relying party (who requested the auth, such
// who is being authenticated. // as Proxmox or Synology), not the user node who is being authenticated.
// String form of tailcfg.NodeID rpNodeID tailcfg.NodeID
requesterNodeID string
// clientID is the "client_id" sent in the authorized request.
clientID string
// nonce presented in the request. // nonce presented in the request.
nonce string nonce string
@ -114,15 +117,21 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) {
return return
} }
nodeID := strings.TrimPrefix(r.URL.Path, "/authorize/") rpNodeID, ok := parseID[tailcfg.NodeID](strings.TrimPrefix(r.URL.Path, "/authorize/"))
if !ok {
http.Error(w, "tsidp: invalid node ID suffix after /authorize/", http.StatusBadRequest)
return
}
uq := r.URL.Query() uq := r.URL.Query()
code := rands.HexString(32) code := rands.HexString(32)
ar := &authRequest{ ar := &authRequest{
requesterNodeID: nodeID, rpNodeID: rpNodeID,
nonce: uq.Get("nonce"), nonce: uq.Get("nonce"),
remoteUser: who, remoteUser: who,
redirectURI: uq.Get("redirect_uri"), redirectURI: uq.Get("redirect_uri"),
clientID: uq.Get("client_id"),
} }
s.mu.Lock() s.mu.Lock()
@ -184,7 +193,7 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: invalid token", http.StatusBadRequest) http.Error(w, "tsidp: invalid token", http.StatusBadRequest)
return return
} }
if ar.requesterNodeID != who.Node.ID.String() { if ar.rpNodeID != who.Node.ID {
http.Error(w, "tsidp: token for different node", http.StatusForbidden) http.Error(w, "tsidp: token for different node", http.StatusForbidden)
return return
} }
@ -245,13 +254,15 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
} }
s.mu.Lock() s.mu.Lock()
ar, ok := s.code[code] ar, ok := s.code[code]
delete(s.code, code) if ok {
delete(s.code, code)
}
s.mu.Unlock() s.mu.Unlock()
if !ok { if !ok {
http.Error(w, "tsidp: code not found", http.StatusBadRequest) http.Error(w, "tsidp: code not found", http.StatusBadRequest)
return return
} }
if ar.requesterNodeID != caller.Node.ID.String() { if ar.rpNodeID != caller.Node.ID {
http.Error(w, "tsidp: token for different node", http.StatusForbidden) http.Error(w, "tsidp: token for different node", http.StatusForbidden)
return return
} }
@ -280,7 +291,7 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
_, tcd, _ := strings.Cut(n.Name(), ".") _, tcd, _ := strings.Cut(n.Name(), ".")
tsClaims := tailscaleClaims{ tsClaims := tailscaleClaims{
Claims: jwt.Claims{ Claims: jwt.Claims{
Audience: jwt.Audience{"unused"}, Audience: jwt.Audience{ar.clientID},
Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)), Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)),
ID: jti, ID: jti,
IssuedAt: jwt.NewNumericDate(now), IssuedAt: jwt.NewNumericDate(now),
@ -479,7 +490,7 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
authorizeEndpoint := fmt.Sprintf("%s/authorize/%s", s.serverURL, who.Node.ID.String()) authorizeEndpoint := fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
je := json.NewEncoder(w) je := json.NewEncoder(w)
@ -573,3 +584,19 @@ func (sk *signingKey) UnmarshalJSON(b []byte) error {
sk.kid = wrapper.ID sk.kid = wrapper.ID
return nil return nil
} }
// parseID takes a string input and returns a typed IntID T and true, or a zero
// value and false if the input is unhandled syntax or out of a valid range.
func parseID[T ~int64](input string) (_ T, ok bool) {
if input == "" {
return 0, false
}
i, err := strconv.ParseInt(input, 10, 64)
if err != nil {
return 0, false
}
if i < 0 {
return 0, false
}
return T(i), true
}