diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index f79f64573..02c032147 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -21,9 +21,10 @@ "os" "strings" "sync" + "time" - "github.com/golang-jwt/jwt" "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" "tailscale.com/client/tailscale" "tailscale.com/client/tailscale/apitype" "tailscale.com/tailcfg" @@ -81,7 +82,16 @@ type idpServer struct { mu sync.Mutex // guards the fields below - code map[string]*apitype.WhoIsResponse // code -> whois + code map[string]*authRequest + accessToken map[string]*authRequest +} + +type authRequest struct { + nonce string + redirectURI string + + who *apitype.WhoIsResponse + validTill time.Time } func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -112,28 +122,196 @@ func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + uq := r.URL.Query() code := must.Get(readHex()) + ar := &authRequest{ + nonce: uq.Get("nonce"), + who: who, + redirectURI: uq.Get("redirect_uri"), + } s.mu.Lock() - mak.Set(&s.code, code, who) + mak.Set(&s.code, code, ar) s.mu.Unlock() q := make(url.Values) q.Set("code", code) - q.Set("state", r.URL.Query().Get("state")) - u := r.URL.Query().Get("redirect_uri") + "?" + q.Encode() + q.Set("state", uq.Get("state")) + u := uq.Get("redirect_uri") + "?" + q.Encode() log.Printf("Redirecting to %q", u) http.Redirect(w, r, u, http.StatusFound) return } - if r.URL.Path == "/token" { + if r.URL.Path == "/userinfo" { + s.serveUserInfo(w, r) + return + } + if r.URL.Path == "/token" { + s.serveToken(w, r) + return } http.Error(w, "tsidp: not found", http.StatusNotFound) } +func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed) + return + } + tk, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ") + if !ok { + http.Error(w, "tsidp: invalid Authorization header", http.StatusBadRequest) + return + } + + s.mu.Lock() + ar, ok := s.accessToken[tk] + s.mu.Unlock() + if !ok { + http.Error(w, "tsidp: invalid token", http.StatusBadRequest) + return + } + if ar.validTill.Before(time.Now()) { + http.Error(w, "tsidp: token expired", http.StatusBadRequest) + s.mu.Lock() + delete(s.accessToken, tk) + s.mu.Unlock() + } + + ui := userInfo{} + if ar.who.Node.IsTagged() { + http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest) + return + } + ui.Sub = ar.who.Node.User.String() + ui.Name = ar.who.UserProfile.DisplayName + ui.Email = ar.who.UserProfile.LoginName + ui.Picture = ar.who.UserProfile.ProfilePicURL + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(ui); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +type userInfo struct { + Sub string `json:"sub"` + Name string `json:"name"` + Email string `json:"email"` + Picture string `json:"picture"` +} + +func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed) + return + } + // TODO: check who is making the request + if r.FormValue("grant_type") != "authorization_code" { + http.Error(w, "tsidp: grant_type not supported", http.StatusBadRequest) + return + } + code := r.FormValue("code") + if code == "" { + http.Error(w, "tsidp: code is required", http.StatusBadRequest) + return + } + s.mu.Lock() + ar, ok := s.code[code] + delete(s.code, code) + s.mu.Unlock() + if !ok { + http.Error(w, "tsidp: code not found", http.StatusBadRequest) + return + } + if ar.redirectURI != r.FormValue("redirect_uri") { + http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest) + return + } + signer, err := s.oidcSigner() + if err != nil { + log.Printf("Error getting signer: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + jti, err := readHex() + if err != nil { + log.Printf("Error reading hex: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + who := ar.who + n := who.Node.View() + now := time.Now() + tsClaims := tailscaleClaims{ + Claims: jwt.Claims{ + Audience: jwt.Audience{"unused"}, + Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)), + ID: jti, + IssuedAt: jwt.NewNumericDate(now), + Issuer: s.serverURL, + NotBefore: jwt.NewNumericDate(now), + }, + Nonce: ar.nonce, + Key: n.Key(), + Addresses: n.Addresses(), + NodeID: n.ID(), + NodeName: n.Name(), + } + + _, tcd, _ := strings.Cut(n.Name(), ".") + tsClaims.Tailnet = tcd + + if n.IsTagged() { + tsClaims.Subject = n.ID().String() + tsClaims.Tags = n.Tags().AsSlice() + } else { + tsClaims.Subject = n.User().String() + tsClaims.UserID = n.User() + tsClaims.User = who.UserProfile.LoginName + } + + // Create an OIDC token using this issuer's signer. + token, err := jwt.Signed(signer).Claims(tsClaims).CompactSerialize() + if err != nil { + log.Printf("Error getting token: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + at, err := readHex() + if err != nil { + log.Printf("Error reading hex: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + s.mu.Lock() + ar.validTill = now.Add(5 * time.Minute) + mak.Set(&s.accessToken, at, ar) + s.mu.Unlock() + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(oidcTokenResponse{ + AccessToken: at, + TokenType: "Bearer", + ExpiresIn: 5 * 60, + IDToken: token, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +type oidcTokenResponse struct { + IDToken string `json:"id_token"` + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` +} + const ( oidcJWKSPath = "/.well-known/jwks.json" oidcConfigPath = "/.well-known/openid-configuration" @@ -213,6 +391,7 @@ func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error { type openIDProviderMetadata struct { Issuer string `json:"issuer"` AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"` + TokenEndpoint string `json:"token_endpoint,omitempty"` UserInfoEndpoint string `json:"userinfo_endpoint,omitempty"` JWKS_URI string `json:"jwks_uri"` ScopesSupported views.Slice[string] `json:"scopes_supported"` @@ -226,11 +405,12 @@ type openIDProviderMetadata struct { type tailscaleClaims struct { jwt.Claims `json:",inline"` - Key key.NodePublic `json:"key"` // the node public key - Addresses views.Slice[netip.Prefix] `json:"addresses"` // the Tailscale IPs of the node - NodeID tailcfg.NodeID `json:"nid"` // the stable node ID - NodeName string `json:"node"` // name of the node - Tailnet string `json:"tailnet"` // tailnet (like tail-scale.ts.net) + Nonce string `json:"nonce,omitempty"` // the nonce from the request + Key key.NodePublic `json:"key"` // the node public key + Addresses views.Slice[netip.Prefix] `json:"addresses"` // the Tailscale IPs of the node + NodeID tailcfg.NodeID `json:"nid"` // the stable node ID + NodeName string `json:"node"` // name of the node + Tailnet string `json:"tailnet"` // tailnet (like tail-scale.ts.net) // Tags is the list of tags the node is tagged with prefixed with the Tailnet name. Tags []string `json:"tags,omitempty"` // the tags on the node (like alice.github:tag:foo or example.com:tag:foo) @@ -269,13 +449,14 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) er return tsweb.Error(404, "", nil) } w.Header().Set("Content-Type", "application/json") - je := json.NewEncoder(io.MultiWriter(w, os.Stderr)) + je := json.NewEncoder(w) je.SetIndent("", " ") if err := je.Encode(openIDProviderMetadata{ Issuer: s.serverURL, JWKS_URI: s.serverURL + oidcJWKSPath, UserInfoEndpoint: s.serverURL + "/userinfo", AuthorizationEndpoint: s.serverURL + "/authorize", // TODO: add / suffix + TokenEndpoint: s.serverURL + "/token", ScopesSupported: openIDSupportedScopes, ResponseTypesSupported: openIDSupportedReponseTypes, SubjectTypesSupported: openIDSupportedSubjectTypes,