cmd/tsidp: use mux, add node id

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2023-11-14 11:36:37 -08:00
parent 26e925fc48
commit 017a2ed349

View File

@ -28,7 +28,6 @@ import (
"tailscale.com/client/tailscale/apitype" "tailscale.com/client/tailscale/apitype"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tsnet" "tailscale.com/tsnet"
"tailscale.com/tsweb"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/lazy" "tailscale.com/types/lazy"
"tailscale.com/types/logger" "tailscale.com/types/logger"
@ -70,7 +69,10 @@ func main() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
log.Fatal(http.Serve(ln, srv)) mux := http.NewServeMux()
srv.Register(mux)
log.Fatal(http.Serve(ln, mux))
} }
type idpServer struct { type idpServer struct {
@ -87,34 +89,24 @@ type idpServer struct {
} }
type authRequest struct { type authRequest struct {
forNodeID string // string form nodeid:abcd
nonce string nonce string
redirectURI string redirectURI string
who *apitype.WhoIsResponse remoteUser *apitype.WhoIsResponse
validTill time.Time validTill time.Time
} }
func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *idpServer) Register(mux *http.ServeMux) {
log.Printf("%v %v", r.Method, r.URL) mux.Handle(oidcJWKSPath, http.HandlerFunc(s.serveJWKS))
mux.Handle(oidcConfigPath, http.HandlerFunc(s.serveOpenIDConfig))
if r.URL.Path == oidcJWKSPath { mux.Handle("/authorize/", http.HandlerFunc(s.authorize))
if err := s.serveJWKS(w, r); err != nil { mux.Handle("/userinfo", http.HandlerFunc(s.serveUserInfo))
log.Printf("Error serving JWKS: %v", err) mux.Handle("/token", http.HandlerFunc(s.serveToken))
} mux.Handle("/", s)
return
}
if r.URL.Path == oidcConfigPath {
if err := s.serveOpenIDConfig(w, r); err != nil {
log.Printf("Error serving OpenID config: %v", err)
}
return
}
if r.URL.Path == "/" {
io.WriteString(w, "<html><body><h1>Tailscale OIDC IdP</h1>")
return
} }
if r.URL.Path == "/authorize" { func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) {
who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
if err != nil { if err != nil {
log.Printf("Error getting WhoIs: %v", err) log.Printf("Error getting WhoIs: %v", err)
@ -122,11 +114,14 @@ func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
nodeID := strings.TrimPrefix(r.URL.Path, "/authorize/")
uq := r.URL.Query() uq := r.URL.Query()
code := rands.HexString(32) code := rands.HexString(32)
ar := &authRequest{ ar := &authRequest{
forNodeID: nodeID,
nonce: uq.Get("nonce"), nonce: uq.Get("nonce"),
who: who, remoteUser: who,
redirectURI: uq.Get("redirect_uri"), redirectURI: uq.Get("redirect_uri"),
} }
@ -141,18 +136,16 @@ func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Printf("Redirecting to %q", u) log.Printf("Redirecting to %q", u)
http.Redirect(w, r, u, http.StatusFound) http.Redirect(w, r, u, http.StatusFound)
}
func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Printf("%v %v", r.Method, r.URL)
if r.URL.Path == "/" {
io.WriteString(w, "<html><body><h1>Tailscale OIDC IdP</h1>")
return return
} }
if r.URL.Path == "/userinfo" {
s.serveUserInfo(w, r)
return
}
if r.URL.Path == "/token" {
s.serveToken(w, r)
return
}
http.Error(w, "tsidp: not found", http.StatusNotFound) http.Error(w, "tsidp: not found", http.StatusNotFound)
} }
@ -166,6 +159,12 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: invalid Authorization header", http.StatusBadRequest) http.Error(w, "tsidp: invalid Authorization header", http.StatusBadRequest)
return return
} }
who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
if err != nil {
log.Printf("Error getting WhoIs: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
s.mu.Lock() s.mu.Lock()
ar, ok := s.accessToken[tk] ar, ok := s.accessToken[tk]
@ -174,6 +173,10 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: invalid token", http.StatusBadRequest) http.Error(w, "tsidp: invalid token", http.StatusBadRequest)
return return
} }
if ar.forNodeID != who.Node.ID.String() {
http.Error(w, "tsidp: token for different node", http.StatusForbidden)
return
}
if ar.validTill.Before(time.Now()) { if ar.validTill.Before(time.Now()) {
http.Error(w, "tsidp: token expired", http.StatusBadRequest) http.Error(w, "tsidp: token expired", http.StatusBadRequest)
s.mu.Lock() s.mu.Lock()
@ -182,17 +185,17 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
} }
ui := userInfo{} ui := userInfo{}
if ar.who.Node.IsTagged() { if ar.remoteUser.Node.IsTagged() {
http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest) http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest)
return return
} }
ui.Sub = ar.who.Node.User.String() ui.Sub = ar.remoteUser.Node.User.String()
ui.Name = ar.who.UserProfile.DisplayName ui.Name = ar.remoteUser.UserProfile.DisplayName
ui.Email = ar.who.UserProfile.LoginName ui.Email = ar.remoteUser.UserProfile.LoginName
ui.Picture = ar.who.UserProfile.ProfilePicURL ui.Picture = ar.remoteUser.UserProfile.ProfilePicURL
// TODO(maisem): not sure if this is the right thing to do // TODO(maisem): not sure if this is the right thing to do
ui.UserName, _, _ = strings.Cut(ar.who.UserProfile.LoginName, "@") ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@")
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ui); err != nil { if err := json.NewEncoder(w).Encode(ui); err != nil {
@ -213,7 +216,13 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed) http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed)
return return
} }
// TODO: check who is making the request caller, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
if err != nil {
log.Printf("Error getting WhoIs: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if r.FormValue("grant_type") != "authorization_code" { if r.FormValue("grant_type") != "authorization_code" {
http.Error(w, "tsidp: grant_type not supported", http.StatusBadRequest) http.Error(w, "tsidp: grant_type not supported", http.StatusBadRequest)
return return
@ -231,6 +240,10 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: code not found", http.StatusBadRequest) http.Error(w, "tsidp: code not found", http.StatusBadRequest)
return return
} }
if ar.forNodeID != caller.Node.ID.String() {
http.Error(w, "tsidp: token for different node", http.StatusForbidden)
return
}
if ar.redirectURI != r.FormValue("redirect_uri") { if ar.redirectURI != r.FormValue("redirect_uri") {
http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest) http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest)
return return
@ -242,10 +255,10 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
return return
} }
jti := rands.HexString(32) jti := rands.HexString(32)
who := ar.who who := ar.remoteUser
// TODO(maisem): not sure if this is the right thing to do // TODO(maisem): not sure if this is the right thing to do
userName, _, _ := strings.Cut(ar.who.UserProfile.LoginName, "@") userName, _, _ := strings.Cut(ar.remoteUser.UserProfile.LoginName, "@")
n := who.Node.View() n := who.Node.View()
if n.IsTagged() { if n.IsTagged() {
http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest) http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest)
@ -354,14 +367,16 @@ func (s *idpServer) oidcPrivateKey() (*signingKey, error) {
}) })
} }
func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error { func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != oidcJWKSPath { if r.URL.Path != oidcJWKSPath {
return tsweb.Error(404, "", nil) http.Error(w, "tsidp: not found", http.StatusNotFound)
return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
sk, err := s.oidcPrivateKey() sk, err := s.oidcPrivateKey()
if err != nil { if err != nil {
return tsweb.Error(500, err.Error(), err) http.Error(w, err.Error(), http.StatusInternalServerError)
return
} }
// TODO(maisem): maybe only marshal this once and reuse? // TODO(maisem): maybe only marshal this once and reuse?
// TODO(maisem): implement key rotation. // TODO(maisem): implement key rotation.
@ -377,9 +392,9 @@ func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error {
}, },
}, },
}); err != nil { }); err != nil {
return tsweb.Error(500, err.Error(), err) http.Error(w, err.Error(), http.StatusInternalServerError)
} }
return nil return
} }
// openIDProviderMetadata is a partial representation of // openIDProviderMetadata is a partial representation of
@ -442,10 +457,19 @@ var (
openIDSupportedSigningAlgos = views.SliceOf([]string{string(jose.RS256)}) openIDSupportedSigningAlgos = views.SliceOf([]string{string(jose.RS256)})
) )
func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) error { func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != oidcConfigPath { if r.URL.Path != oidcConfigPath {
return tsweb.Error(404, "", nil) http.Error(w, "tsidp: not found", http.StatusNotFound)
return
} }
who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
if err != nil {
log.Printf("Error getting WhoIs: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
authorizeEndpoint := fmt.Sprintf("%s/authorize/%s", s.serverURL, who.Node.ID.String())
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
je := json.NewEncoder(w) je := json.NewEncoder(w)
je.SetIndent("", " ") je.SetIndent("", " ")
@ -453,7 +477,7 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) er
Issuer: s.serverURL, Issuer: s.serverURL,
JWKS_URI: s.serverURL + oidcJWKSPath, JWKS_URI: s.serverURL + oidcJWKSPath,
UserInfoEndpoint: s.serverURL + "/userinfo", UserInfoEndpoint: s.serverURL + "/userinfo",
AuthorizationEndpoint: s.serverURL + "/authorize", // TODO: add /<nodeid> suffix AuthorizationEndpoint: authorizeEndpoint,
TokenEndpoint: s.serverURL + "/token", TokenEndpoint: s.serverURL + "/token",
ScopesSupported: openIDSupportedScopes, ScopesSupported: openIDSupportedScopes,
ResponseTypesSupported: openIDSupportedReponseTypes, ResponseTypesSupported: openIDSupportedReponseTypes,
@ -461,10 +485,8 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) er
ClaimsSupported: openIDSupportedClaims, ClaimsSupported: openIDSupportedClaims,
IDTokenSigningAlgValuesSupported: openIDSupportedSigningAlgos, IDTokenSigningAlgValuesSupported: openIDSupportedSigningAlgos,
}); err != nil { }); err != nil {
log.Printf("Error encoding JSON: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }
return nil
} }
const ( const (