diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index a3e867265..5c62eea4d 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -5,6 +5,7 @@ import ( "context" crand "crypto/rand" "crypto/rsa" + "crypto/tls" "crypto/x509" "encoding/base64" "encoding/binary" @@ -14,6 +15,7 @@ import ( "fmt" "io" "log" + "net" "net/http" "net/netip" "net/url" @@ -27,6 +29,7 @@ import ( "gopkg.in/square/go-jose.v2/jwt" "tailscale.com/client/tailscale" "tailscale.com/client/tailscale/apitype" + "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/types/key" @@ -39,43 +42,76 @@ import ( ) var ( - flagVerbose = flag.Bool("verbose", false, "be verbose") + flagVerbose = flag.Bool("verbose", false, "be verbose") + flagPort = flag.Int("port", 443, "port to listen on") + flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost") + flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet") ) func main() { flag.Parse() ctx := context.Background() - ts := &tsnet.Server{ - Hostname: "idp", - } - if !*flagVerbose { - ts.Logf = logger.Discard - } - st, err := ts.Up(ctx) - if err != nil { - log.Fatal(err) - } - lc, err := ts.LocalClient() - if err != nil { - log.Fatalf("getting local client: %v", err) + + var ( + lc *tailscale.LocalClient + st *ipnstate.Status + err error + ) + if *flagUseLocalTailscaled { + lc = &tailscale.LocalClient{} + st, err = lc.StatusWithoutPeers(ctx) + if err != nil { + log.Fatalf("getting status: %v", err) + } + } else { + ts := &tsnet.Server{ + Hostname: "idp", + } + if !*flagVerbose { + ts.Logf = logger.Discard + } + st, err = ts.Up(ctx) + if err != nil { + log.Fatal(err) + } + lc, err = ts.LocalClient() + if err != nil { + log.Fatalf("getting local client: %v", err) + } } srv := &idpServer{ lc: lc, - serverURL: "https://" + strings.TrimSuffix(st.Self.DNSName, "."), + serverURL: fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), *flagPort), } log.Printf("Running tsidp at %s ...", srv.serverURL) - ln, err := ts.ListenTLS("tcp", ":443") + if *flagLocalPort != -1 { + srv.loopbackURL = fmt.Sprintf("http://localhost:%d", *flagLocalPort) + go func() { + ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *flagLocalPort)) + if err != nil { + log.Fatal(err) + } + log.Printf("Also running tsidp at %s ...", srv.loopbackURL) + http.Serve(ln, srv) + }() + } + + ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", st.TailscaleIPs[0], *flagPort)) if err != nil { log.Fatal(err) } + ln = tls.NewListener(ln, &tls.Config{ + GetCertificate: lc.GetCertificate, + }) log.Fatal(http.Serve(ln, srv)) } type idpServer struct { - lc *tailscale.LocalClient - serverURL string // "https://foo.bar.ts.net" + lc *tailscale.LocalClient + loopbackURL string + serverURL string // "https://foo.bar.ts.net" lazyMux lazy.SyncValue[*http.ServeMux] lazySigningKey lazy.SyncValue[*signingKey] @@ -87,8 +123,13 @@ type idpServer struct { } type authRequest struct { + // localRP is true if the request is from a relying party running on the + // same machine as the idp server. It is mutually exclusive with rpNodeID. + localRP bool + // rpNodeID is the NodeID of the relying party (who requested the auth, such - // as Proxmox or Synology), not the user node who is being authenticated. + // as Proxmox or Synology), not the user node who is being authenticated. It + // is mutually exclusive with localRP. rpNodeID tailcfg.NodeID // clientID is the "client_id" sent in the authorized request. @@ -109,6 +150,27 @@ type authRequest struct { validTill time.Time } +func (ar *authRequest) allowRelyingParty(ctx context.Context, remoteAddr string, lc *tailscale.LocalClient) error { + if ar.localRP { + ra, err := netip.ParseAddrPort(remoteAddr) + if err != nil { + return err + } + if !ra.Addr().IsLoopback() { + return fmt.Errorf("tsidp: request from non-loopback address") + } + return nil + } + who, err := lc.WhoIs(ctx, remoteAddr) + if err != nil { + return fmt.Errorf("tsidp: error getting WhoIs: %w", err) + } + if ar.rpNodeID != who.Node.ID { + return fmt.Errorf("tsidp: token for different node") + } + return nil +} + func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) if err != nil { @@ -117,23 +179,27 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { return } - rpNodeID, ok := parseID[tailcfg.NodeID](strings.TrimPrefix(r.URL.Path, "/authorize/")) - if !ok { - http.Error(w, "tsidp: invalid node ID suffix after /authorize/", http.StatusBadRequest) - return - } - uq := r.URL.Query() code := rands.HexString(32) ar := &authRequest{ - rpNodeID: rpNodeID, nonce: uq.Get("nonce"), remoteUser: who, redirectURI: uq.Get("redirect_uri"), clientID: uq.Get("client_id"), } + if r.URL.Path == "/authorize/localhost" { + ar.localRP = true + } else { + var ok bool + ar.rpNodeID, ok = parseID[tailcfg.NodeID](strings.TrimPrefix(r.URL.Path, "/authorize/")) + if !ok { + http.Error(w, "tsidp: invalid node ID suffix after /authorize/", http.StatusBadRequest) + return + } + } + s.mu.Lock() mak.Set(&s.code, code, ar) s.mu.Unlock() @@ -179,12 +245,6 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: invalid Authorization header", http.StatusBadRequest) return } - who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) - if err != nil { - log.Printf("Error getting WhoIs: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } s.mu.Lock() ar, ok := s.accessToken[tk] @@ -193,10 +253,12 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: invalid token", http.StatusBadRequest) return } - if ar.rpNodeID != who.Node.ID { - http.Error(w, "tsidp: token for different node", http.StatusForbidden) + if err := ar.allowRelyingParty(r.Context(), r.RemoteAddr, s.lc); err != nil { + log.Printf("Error allowing relying party: %v", err) + http.Error(w, err.Error(), http.StatusForbidden) return } + if ar.validTill.Before(time.Now()) { http.Error(w, "tsidp: token expired", http.StatusBadRequest) s.mu.Lock() @@ -236,13 +298,6 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed) return } - caller, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) - if err != nil { - log.Printf("Error getting WhoIs: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if r.FormValue("grant_type") != "authorization_code" { http.Error(w, "tsidp: grant_type not supported", http.StatusBadRequest) return @@ -262,8 +317,9 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: code not found", http.StatusBadRequest) return } - if ar.rpNodeID != caller.Node.ID { - http.Error(w, "tsidp: token for different node", http.StatusForbidden) + if err := ar.allowRelyingParty(r.Context(), r.RemoteAddr, s.lc); err != nil { + log.Printf("Error allowing relying party: %v", err) + http.Error(w, err.Error(), http.StatusForbidden) return } if ar.redirectURI != r.FormValue("redirect_uri") { @@ -309,6 +365,9 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { Email: who.UserProfile.LoginName, UserName: userName, } + if ar.localRP { + tsClaims.Issuer = s.loopbackURL + } // Create an OIDC token using this issuer's signer. token, err := jwt.Signed(signer).Claims(tsClaims).CompactSerialize() @@ -484,23 +543,33 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: not found", http.StatusNotFound) return } - who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) + ap, err := netip.ParseAddrPort(r.RemoteAddr) if err != nil { + log.Printf("Error parsing remote addr: %v", err) + return + } + var authorizeEndpoint string + rpEndpoint := s.serverURL + if who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr); err == nil { + authorizeEndpoint = fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID) + } else if ap.Addr().IsLoopback() { + rpEndpoint = s.loopbackURL + authorizeEndpoint = fmt.Sprintf("%s/authorize/localhost", s.serverURL) + } else { log.Printf("Error getting WhoIs: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } - authorizeEndpoint := fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID) w.Header().Set("Content-Type", "application/json") je := json.NewEncoder(w) je.SetIndent("", " ") if err := je.Encode(openIDProviderMetadata{ - Issuer: s.serverURL, - JWKS_URI: s.serverURL + oidcJWKSPath, - UserInfoEndpoint: s.serverURL + "/userinfo", AuthorizationEndpoint: authorizeEndpoint, - TokenEndpoint: s.serverURL + "/token", + Issuer: rpEndpoint, + JWKS_URI: rpEndpoint + oidcJWKSPath, + UserInfoEndpoint: rpEndpoint + "/userinfo", + TokenEndpoint: rpEndpoint + "/token", ScopesSupported: openIDSupportedScopes, ResponseTypesSupported: openIDSupportedReponseTypes, SubjectTypesSupported: openIDSupportedSubjectTypes,