cmd/tsidp: add flag to run with tailscaled

This allows it to run using the local tailscaled instead
of tsnet.

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2023-11-14 20:54:11 -08:00
parent d5f7500d83
commit 93664ac8dc

View File

@ -5,6 +5,7 @@ import (
"context" "context"
crand "crypto/rand" crand "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
@ -14,6 +15,7 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"net/netip" "net/netip"
"net/url" "net/url"
@ -27,6 +29,7 @@ import (
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
"tailscale.com/client/tailscale" "tailscale.com/client/tailscale"
"tailscale.com/client/tailscale/apitype" "tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn/ipnstate"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tsnet" "tailscale.com/tsnet"
"tailscale.com/types/key" "tailscale.com/types/key"
@ -39,43 +42,76 @@ import (
) )
var ( var (
flagVerbose = flag.Bool("verbose", false, "be verbose") flagVerbose = flag.Bool("verbose", false, "be verbose")
flagPort = flag.Int("port", 443, "port to listen on")
flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost")
flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet")
) )
func main() { func main() {
flag.Parse() flag.Parse()
ctx := context.Background() ctx := context.Background()
ts := &tsnet.Server{
Hostname: "idp", var (
} lc *tailscale.LocalClient
if !*flagVerbose { st *ipnstate.Status
ts.Logf = logger.Discard err error
} )
st, err := ts.Up(ctx) if *flagUseLocalTailscaled {
if err != nil { lc = &tailscale.LocalClient{}
log.Fatal(err) st, err = lc.StatusWithoutPeers(ctx)
} if err != nil {
lc, err := ts.LocalClient() log.Fatalf("getting status: %v", err)
if err != nil { }
log.Fatalf("getting local client: %v", err) } else {
ts := &tsnet.Server{
Hostname: "idp",
}
if !*flagVerbose {
ts.Logf = logger.Discard
}
st, err = ts.Up(ctx)
if err != nil {
log.Fatal(err)
}
lc, err = ts.LocalClient()
if err != nil {
log.Fatalf("getting local client: %v", err)
}
} }
srv := &idpServer{ srv := &idpServer{
lc: lc, lc: lc,
serverURL: "https://" + strings.TrimSuffix(st.Self.DNSName, "."), serverURL: fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), *flagPort),
} }
log.Printf("Running tsidp at %s ...", srv.serverURL) log.Printf("Running tsidp at %s ...", srv.serverURL)
ln, err := ts.ListenTLS("tcp", ":443") if *flagLocalPort != -1 {
srv.loopbackURL = fmt.Sprintf("http://localhost:%d", *flagLocalPort)
go func() {
ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *flagLocalPort))
if err != nil {
log.Fatal(err)
}
log.Printf("Also running tsidp at %s ...", srv.loopbackURL)
http.Serve(ln, srv)
}()
}
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", st.TailscaleIPs[0], *flagPort))
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
ln = tls.NewListener(ln, &tls.Config{
GetCertificate: lc.GetCertificate,
})
log.Fatal(http.Serve(ln, srv)) log.Fatal(http.Serve(ln, srv))
} }
type idpServer struct { type idpServer struct {
lc *tailscale.LocalClient lc *tailscale.LocalClient
serverURL string // "https://foo.bar.ts.net" loopbackURL string
serverURL string // "https://foo.bar.ts.net"
lazyMux lazy.SyncValue[*http.ServeMux] lazyMux lazy.SyncValue[*http.ServeMux]
lazySigningKey lazy.SyncValue[*signingKey] lazySigningKey lazy.SyncValue[*signingKey]
@ -87,8 +123,13 @@ type idpServer struct {
} }
type authRequest struct { type authRequest struct {
// localRP is true if the request is from a relying party running on the
// same machine as the idp server. It is mutually exclusive with rpNodeID.
localRP bool
// rpNodeID is the NodeID of the relying party (who requested the auth, such // rpNodeID is the NodeID of the relying party (who requested the auth, such
// as Proxmox or Synology), not the user node who is being authenticated. // as Proxmox or Synology), not the user node who is being authenticated. It
// is mutually exclusive with localRP.
rpNodeID tailcfg.NodeID rpNodeID tailcfg.NodeID
// clientID is the "client_id" sent in the authorized request. // clientID is the "client_id" sent in the authorized request.
@ -109,6 +150,27 @@ type authRequest struct {
validTill time.Time validTill time.Time
} }
func (ar *authRequest) allowRelyingParty(ctx context.Context, remoteAddr string, lc *tailscale.LocalClient) error {
if ar.localRP {
ra, err := netip.ParseAddrPort(remoteAddr)
if err != nil {
return err
}
if !ra.Addr().IsLoopback() {
return fmt.Errorf("tsidp: request from non-loopback address")
}
return nil
}
who, err := lc.WhoIs(ctx, remoteAddr)
if err != nil {
return fmt.Errorf("tsidp: error getting WhoIs: %w", err)
}
if ar.rpNodeID != who.Node.ID {
return fmt.Errorf("tsidp: token for different node")
}
return nil
}
func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) {
who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
if err != nil { if err != nil {
@ -117,23 +179,27 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) {
return return
} }
rpNodeID, ok := parseID[tailcfg.NodeID](strings.TrimPrefix(r.URL.Path, "/authorize/"))
if !ok {
http.Error(w, "tsidp: invalid node ID suffix after /authorize/", http.StatusBadRequest)
return
}
uq := r.URL.Query() uq := r.URL.Query()
code := rands.HexString(32) code := rands.HexString(32)
ar := &authRequest{ ar := &authRequest{
rpNodeID: rpNodeID,
nonce: uq.Get("nonce"), nonce: uq.Get("nonce"),
remoteUser: who, remoteUser: who,
redirectURI: uq.Get("redirect_uri"), redirectURI: uq.Get("redirect_uri"),
clientID: uq.Get("client_id"), clientID: uq.Get("client_id"),
} }
if r.URL.Path == "/authorize/localhost" {
ar.localRP = true
} else {
var ok bool
ar.rpNodeID, ok = parseID[tailcfg.NodeID](strings.TrimPrefix(r.URL.Path, "/authorize/"))
if !ok {
http.Error(w, "tsidp: invalid node ID suffix after /authorize/", http.StatusBadRequest)
return
}
}
s.mu.Lock() s.mu.Lock()
mak.Set(&s.code, code, ar) mak.Set(&s.code, code, ar)
s.mu.Unlock() s.mu.Unlock()
@ -179,12 +245,6 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: invalid Authorization header", http.StatusBadRequest) http.Error(w, "tsidp: invalid Authorization header", http.StatusBadRequest)
return return
} }
who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
if err != nil {
log.Printf("Error getting WhoIs: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
s.mu.Lock() s.mu.Lock()
ar, ok := s.accessToken[tk] ar, ok := s.accessToken[tk]
@ -193,10 +253,12 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: invalid token", http.StatusBadRequest) http.Error(w, "tsidp: invalid token", http.StatusBadRequest)
return return
} }
if ar.rpNodeID != who.Node.ID { if err := ar.allowRelyingParty(r.Context(), r.RemoteAddr, s.lc); err != nil {
http.Error(w, "tsidp: token for different node", http.StatusForbidden) log.Printf("Error allowing relying party: %v", err)
http.Error(w, err.Error(), http.StatusForbidden)
return return
} }
if ar.validTill.Before(time.Now()) { if ar.validTill.Before(time.Now()) {
http.Error(w, "tsidp: token expired", http.StatusBadRequest) http.Error(w, "tsidp: token expired", http.StatusBadRequest)
s.mu.Lock() s.mu.Lock()
@ -236,13 +298,6 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed) http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed)
return return
} }
caller, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
if err != nil {
log.Printf("Error getting WhoIs: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if r.FormValue("grant_type") != "authorization_code" { if r.FormValue("grant_type") != "authorization_code" {
http.Error(w, "tsidp: grant_type not supported", http.StatusBadRequest) http.Error(w, "tsidp: grant_type not supported", http.StatusBadRequest)
return return
@ -262,8 +317,9 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: code not found", http.StatusBadRequest) http.Error(w, "tsidp: code not found", http.StatusBadRequest)
return return
} }
if ar.rpNodeID != caller.Node.ID { if err := ar.allowRelyingParty(r.Context(), r.RemoteAddr, s.lc); err != nil {
http.Error(w, "tsidp: token for different node", http.StatusForbidden) log.Printf("Error allowing relying party: %v", err)
http.Error(w, err.Error(), http.StatusForbidden)
return return
} }
if ar.redirectURI != r.FormValue("redirect_uri") { if ar.redirectURI != r.FormValue("redirect_uri") {
@ -309,6 +365,9 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
Email: who.UserProfile.LoginName, Email: who.UserProfile.LoginName,
UserName: userName, UserName: userName,
} }
if ar.localRP {
tsClaims.Issuer = s.loopbackURL
}
// Create an OIDC token using this issuer's signer. // Create an OIDC token using this issuer's signer.
token, err := jwt.Signed(signer).Claims(tsClaims).CompactSerialize() token, err := jwt.Signed(signer).Claims(tsClaims).CompactSerialize()
@ -484,23 +543,33 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: not found", http.StatusNotFound) http.Error(w, "tsidp: not found", http.StatusNotFound)
return return
} }
who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) ap, err := netip.ParseAddrPort(r.RemoteAddr)
if err != nil { if err != nil {
log.Printf("Error parsing remote addr: %v", err)
return
}
var authorizeEndpoint string
rpEndpoint := s.serverURL
if who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr); err == nil {
authorizeEndpoint = fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID)
} else if ap.Addr().IsLoopback() {
rpEndpoint = s.loopbackURL
authorizeEndpoint = fmt.Sprintf("%s/authorize/localhost", s.serverURL)
} else {
log.Printf("Error getting WhoIs: %v", err) log.Printf("Error getting WhoIs: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
authorizeEndpoint := fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
je := json.NewEncoder(w) je := json.NewEncoder(w)
je.SetIndent("", " ") je.SetIndent("", " ")
if err := je.Encode(openIDProviderMetadata{ if err := je.Encode(openIDProviderMetadata{
Issuer: s.serverURL,
JWKS_URI: s.serverURL + oidcJWKSPath,
UserInfoEndpoint: s.serverURL + "/userinfo",
AuthorizationEndpoint: authorizeEndpoint, AuthorizationEndpoint: authorizeEndpoint,
TokenEndpoint: s.serverURL + "/token", Issuer: rpEndpoint,
JWKS_URI: rpEndpoint + oidcJWKSPath,
UserInfoEndpoint: rpEndpoint + "/userinfo",
TokenEndpoint: rpEndpoint + "/token",
ScopesSupported: openIDSupportedScopes, ScopesSupported: openIDSupportedScopes,
ResponseTypesSupported: openIDSupportedReponseTypes, ResponseTypesSupported: openIDSupportedReponseTypes,
SubjectTypesSupported: openIDSupportedSubjectTypes, SubjectTypesSupported: openIDSupportedSubjectTypes,