From 017a2ed3498c50554d65698e4fc2d6b7cae7a664 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Tue, 14 Nov 2023 11:36:37 -0800 Subject: [PATCH] cmd/tsidp: use mux, add node id Signed-off-by: Maisem Ali --- cmd/tsidp/tsidp.go | 170 +++++++++++++++++++++++++-------------------- 1 file changed, 96 insertions(+), 74 deletions(-) diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 7e666b18e..fe54a65a6 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -28,7 +28,6 @@ import ( "tailscale.com/client/tailscale/apitype" "tailscale.com/tailcfg" "tailscale.com/tsnet" - "tailscale.com/tsweb" "tailscale.com/types/key" "tailscale.com/types/lazy" "tailscale.com/types/logger" @@ -70,7 +69,10 @@ func main() { if err != nil { log.Fatal(err) } - log.Fatal(http.Serve(ln, srv)) + mux := http.NewServeMux() + srv.Register(mux) + + log.Fatal(http.Serve(ln, mux)) } type idpServer struct { @@ -87,72 +89,63 @@ type idpServer struct { } type authRequest struct { + forNodeID string // string form nodeid:abcd nonce string redirectURI string - who *apitype.WhoIsResponse - validTill time.Time + remoteUser *apitype.WhoIsResponse + validTill time.Time +} + +func (s *idpServer) Register(mux *http.ServeMux) { + mux.Handle(oidcJWKSPath, http.HandlerFunc(s.serveJWKS)) + mux.Handle(oidcConfigPath, http.HandlerFunc(s.serveOpenIDConfig)) + mux.Handle("/authorize/", http.HandlerFunc(s.authorize)) + mux.Handle("/userinfo", http.HandlerFunc(s.serveUserInfo)) + mux.Handle("/token", http.HandlerFunc(s.serveToken)) + mux.Handle("/", s) +} + +func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { + who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + log.Printf("Error getting WhoIs: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + nodeID := strings.TrimPrefix(r.URL.Path, "/authorize/") + + uq := r.URL.Query() + code := rands.HexString(32) + ar := &authRequest{ + forNodeID: nodeID, + nonce: uq.Get("nonce"), + remoteUser: who, + redirectURI: uq.Get("redirect_uri"), + } + + s.mu.Lock() + mak.Set(&s.code, code, ar) + s.mu.Unlock() + + q := make(url.Values) + q.Set("code", code) + q.Set("state", uq.Get("state")) + u := uq.Get("redirect_uri") + "?" + q.Encode() + log.Printf("Redirecting to %q", u) + + http.Redirect(w, r, u, http.StatusFound) } func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("%v %v", r.Method, r.URL) - if r.URL.Path == oidcJWKSPath { - if err := s.serveJWKS(w, r); err != nil { - log.Printf("Error serving JWKS: %v", err) - } - return - } - if r.URL.Path == oidcConfigPath { - if err := s.serveOpenIDConfig(w, r); err != nil { - log.Printf("Error serving OpenID config: %v", err) - } - return - } if r.URL.Path == "/" { io.WriteString(w, "

Tailscale OIDC IdP

") return } - if r.URL.Path == "/authorize" { - who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) - if err != nil { - log.Printf("Error getting WhoIs: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - uq := r.URL.Query() - code := rands.HexString(32) - ar := &authRequest{ - nonce: uq.Get("nonce"), - who: who, - redirectURI: uq.Get("redirect_uri"), - } - - s.mu.Lock() - mak.Set(&s.code, code, ar) - s.mu.Unlock() - - q := make(url.Values) - q.Set("code", code) - q.Set("state", uq.Get("state")) - u := uq.Get("redirect_uri") + "?" + q.Encode() - log.Printf("Redirecting to %q", u) - - http.Redirect(w, r, u, http.StatusFound) - return - } - - if r.URL.Path == "/userinfo" { - s.serveUserInfo(w, r) - return - } - - if r.URL.Path == "/token" { - s.serveToken(w, r) - return - } http.Error(w, "tsidp: not found", http.StatusNotFound) } @@ -166,6 +159,12 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: invalid Authorization header", http.StatusBadRequest) return } + who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + log.Printf("Error getting WhoIs: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } s.mu.Lock() ar, ok := s.accessToken[tk] @@ -174,6 +173,10 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: invalid token", http.StatusBadRequest) return } + if ar.forNodeID != who.Node.ID.String() { + http.Error(w, "tsidp: token for different node", http.StatusForbidden) + return + } if ar.validTill.Before(time.Now()) { http.Error(w, "tsidp: token expired", http.StatusBadRequest) s.mu.Lock() @@ -182,17 +185,17 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { } ui := userInfo{} - if ar.who.Node.IsTagged() { + if ar.remoteUser.Node.IsTagged() { http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest) return } - ui.Sub = ar.who.Node.User.String() - ui.Name = ar.who.UserProfile.DisplayName - ui.Email = ar.who.UserProfile.LoginName - ui.Picture = ar.who.UserProfile.ProfilePicURL + ui.Sub = ar.remoteUser.Node.User.String() + ui.Name = ar.remoteUser.UserProfile.DisplayName + ui.Email = ar.remoteUser.UserProfile.LoginName + ui.Picture = ar.remoteUser.UserProfile.ProfilePicURL // TODO(maisem): not sure if this is the right thing to do - ui.UserName, _, _ = strings.Cut(ar.who.UserProfile.LoginName, "@") + ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@") w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(ui); err != nil { @@ -213,7 +216,13 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed) return } - // TODO: check who is making the request + caller, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + log.Printf("Error getting WhoIs: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if r.FormValue("grant_type") != "authorization_code" { http.Error(w, "tsidp: grant_type not supported", http.StatusBadRequest) return @@ -231,6 +240,10 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: code not found", http.StatusBadRequest) return } + if ar.forNodeID != caller.Node.ID.String() { + http.Error(w, "tsidp: token for different node", http.StatusForbidden) + return + } if ar.redirectURI != r.FormValue("redirect_uri") { http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest) return @@ -242,10 +255,10 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { return } jti := rands.HexString(32) - who := ar.who + who := ar.remoteUser // TODO(maisem): not sure if this is the right thing to do - userName, _, _ := strings.Cut(ar.who.UserProfile.LoginName, "@") + userName, _, _ := strings.Cut(ar.remoteUser.UserProfile.LoginName, "@") n := who.Node.View() if n.IsTagged() { http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest) @@ -354,14 +367,16 @@ func (s *idpServer) oidcPrivateKey() (*signingKey, error) { }) } -func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error { +func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) { if r.URL.Path != oidcJWKSPath { - return tsweb.Error(404, "", nil) + http.Error(w, "tsidp: not found", http.StatusNotFound) + return } w.Header().Set("Content-Type", "application/json") sk, err := s.oidcPrivateKey() if err != nil { - return tsweb.Error(500, err.Error(), err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return } // TODO(maisem): maybe only marshal this once and reuse? // TODO(maisem): implement key rotation. @@ -377,9 +392,9 @@ func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error { }, }, }); err != nil { - return tsweb.Error(500, err.Error(), err) + http.Error(w, err.Error(), http.StatusInternalServerError) } - return nil + return } // openIDProviderMetadata is a partial representation of @@ -442,10 +457,19 @@ var ( openIDSupportedSigningAlgos = views.SliceOf([]string{string(jose.RS256)}) ) -func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) error { +func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) { if r.URL.Path != oidcConfigPath { - return tsweb.Error(404, "", nil) + http.Error(w, "tsidp: not found", http.StatusNotFound) + return } + who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + log.Printf("Error getting WhoIs: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + authorizeEndpoint := fmt.Sprintf("%s/authorize/%s", s.serverURL, who.Node.ID.String()) + w.Header().Set("Content-Type", "application/json") je := json.NewEncoder(w) je.SetIndent("", " ") @@ -453,7 +477,7 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) er Issuer: s.serverURL, JWKS_URI: s.serverURL + oidcJWKSPath, UserInfoEndpoint: s.serverURL + "/userinfo", - AuthorizationEndpoint: s.serverURL + "/authorize", // TODO: add / suffix + AuthorizationEndpoint: authorizeEndpoint, TokenEndpoint: s.serverURL + "/token", ScopesSupported: openIDSupportedScopes, ResponseTypesSupported: openIDSupportedReponseTypes, @@ -461,10 +485,8 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) er ClaimsSupported: openIDSupportedClaims, IDTokenSigningAlgValuesSupported: openIDSupportedSigningAlgos, }); err != nil { - log.Printf("Error encoding JSON: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) } - return nil } const (