diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go
index 7e666b18e..fe54a65a6 100644
--- a/cmd/tsidp/tsidp.go
+++ b/cmd/tsidp/tsidp.go
@@ -28,7 +28,6 @@ import (
"tailscale.com/client/tailscale/apitype"
"tailscale.com/tailcfg"
"tailscale.com/tsnet"
- "tailscale.com/tsweb"
"tailscale.com/types/key"
"tailscale.com/types/lazy"
"tailscale.com/types/logger"
@@ -70,7 +69,10 @@ func main() {
if err != nil {
log.Fatal(err)
}
- log.Fatal(http.Serve(ln, srv))
+ mux := http.NewServeMux()
+ srv.Register(mux)
+
+ log.Fatal(http.Serve(ln, mux))
}
type idpServer struct {
@@ -87,72 +89,63 @@ type idpServer struct {
}
type authRequest struct {
+ forNodeID string // string form nodeid:abcd
nonce string
redirectURI string
- who *apitype.WhoIsResponse
- validTill time.Time
+ remoteUser *apitype.WhoIsResponse
+ validTill time.Time
+}
+
+func (s *idpServer) Register(mux *http.ServeMux) {
+ mux.Handle(oidcJWKSPath, http.HandlerFunc(s.serveJWKS))
+ mux.Handle(oidcConfigPath, http.HandlerFunc(s.serveOpenIDConfig))
+ mux.Handle("/authorize/", http.HandlerFunc(s.authorize))
+ mux.Handle("/userinfo", http.HandlerFunc(s.serveUserInfo))
+ mux.Handle("/token", http.HandlerFunc(s.serveToken))
+ mux.Handle("/", s)
+}
+
+func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) {
+ who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
+ if err != nil {
+ log.Printf("Error getting WhoIs: %v", err)
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ nodeID := strings.TrimPrefix(r.URL.Path, "/authorize/")
+
+ uq := r.URL.Query()
+ code := rands.HexString(32)
+ ar := &authRequest{
+ forNodeID: nodeID,
+ nonce: uq.Get("nonce"),
+ remoteUser: who,
+ redirectURI: uq.Get("redirect_uri"),
+ }
+
+ s.mu.Lock()
+ mak.Set(&s.code, code, ar)
+ s.mu.Unlock()
+
+ q := make(url.Values)
+ q.Set("code", code)
+ q.Set("state", uq.Get("state"))
+ u := uq.Get("redirect_uri") + "?" + q.Encode()
+ log.Printf("Redirecting to %q", u)
+
+ http.Redirect(w, r, u, http.StatusFound)
}
func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Printf("%v %v", r.Method, r.URL)
- if r.URL.Path == oidcJWKSPath {
- if err := s.serveJWKS(w, r); err != nil {
- log.Printf("Error serving JWKS: %v", err)
- }
- return
- }
- if r.URL.Path == oidcConfigPath {
- if err := s.serveOpenIDConfig(w, r); err != nil {
- log.Printf("Error serving OpenID config: %v", err)
- }
- return
- }
if r.URL.Path == "/" {
io.WriteString(w, "
Tailscale OIDC IdP
")
return
}
- if r.URL.Path == "/authorize" {
- who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
- if err != nil {
- log.Printf("Error getting WhoIs: %v", err)
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- uq := r.URL.Query()
- code := rands.HexString(32)
- ar := &authRequest{
- nonce: uq.Get("nonce"),
- who: who,
- redirectURI: uq.Get("redirect_uri"),
- }
-
- s.mu.Lock()
- mak.Set(&s.code, code, ar)
- s.mu.Unlock()
-
- q := make(url.Values)
- q.Set("code", code)
- q.Set("state", uq.Get("state"))
- u := uq.Get("redirect_uri") + "?" + q.Encode()
- log.Printf("Redirecting to %q", u)
-
- http.Redirect(w, r, u, http.StatusFound)
- return
- }
-
- if r.URL.Path == "/userinfo" {
- s.serveUserInfo(w, r)
- return
- }
-
- if r.URL.Path == "/token" {
- s.serveToken(w, r)
- return
- }
http.Error(w, "tsidp: not found", http.StatusNotFound)
}
@@ -166,6 +159,12 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: invalid Authorization header", http.StatusBadRequest)
return
}
+ who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
+ if err != nil {
+ log.Printf("Error getting WhoIs: %v", err)
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
s.mu.Lock()
ar, ok := s.accessToken[tk]
@@ -174,6 +173,10 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: invalid token", http.StatusBadRequest)
return
}
+ if ar.forNodeID != who.Node.ID.String() {
+ http.Error(w, "tsidp: token for different node", http.StatusForbidden)
+ return
+ }
if ar.validTill.Before(time.Now()) {
http.Error(w, "tsidp: token expired", http.StatusBadRequest)
s.mu.Lock()
@@ -182,17 +185,17 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
}
ui := userInfo{}
- if ar.who.Node.IsTagged() {
+ if ar.remoteUser.Node.IsTagged() {
http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest)
return
}
- ui.Sub = ar.who.Node.User.String()
- ui.Name = ar.who.UserProfile.DisplayName
- ui.Email = ar.who.UserProfile.LoginName
- ui.Picture = ar.who.UserProfile.ProfilePicURL
+ ui.Sub = ar.remoteUser.Node.User.String()
+ ui.Name = ar.remoteUser.UserProfile.DisplayName
+ ui.Email = ar.remoteUser.UserProfile.LoginName
+ ui.Picture = ar.remoteUser.UserProfile.ProfilePicURL
// TODO(maisem): not sure if this is the right thing to do
- ui.UserName, _, _ = strings.Cut(ar.who.UserProfile.LoginName, "@")
+ ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@")
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(ui); err != nil {
@@ -213,7 +216,13 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed)
return
}
- // TODO: check who is making the request
+ caller, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
+ if err != nil {
+ log.Printf("Error getting WhoIs: %v", err)
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
if r.FormValue("grant_type") != "authorization_code" {
http.Error(w, "tsidp: grant_type not supported", http.StatusBadRequest)
return
@@ -231,6 +240,10 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
http.Error(w, "tsidp: code not found", http.StatusBadRequest)
return
}
+ if ar.forNodeID != caller.Node.ID.String() {
+ http.Error(w, "tsidp: token for different node", http.StatusForbidden)
+ return
+ }
if ar.redirectURI != r.FormValue("redirect_uri") {
http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest)
return
@@ -242,10 +255,10 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
return
}
jti := rands.HexString(32)
- who := ar.who
+ who := ar.remoteUser
// TODO(maisem): not sure if this is the right thing to do
- userName, _, _ := strings.Cut(ar.who.UserProfile.LoginName, "@")
+ userName, _, _ := strings.Cut(ar.remoteUser.UserProfile.LoginName, "@")
n := who.Node.View()
if n.IsTagged() {
http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest)
@@ -354,14 +367,16 @@ func (s *idpServer) oidcPrivateKey() (*signingKey, error) {
})
}
-func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error {
+func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != oidcJWKSPath {
- return tsweb.Error(404, "", nil)
+ http.Error(w, "tsidp: not found", http.StatusNotFound)
+ return
}
w.Header().Set("Content-Type", "application/json")
sk, err := s.oidcPrivateKey()
if err != nil {
- return tsweb.Error(500, err.Error(), err)
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
}
// TODO(maisem): maybe only marshal this once and reuse?
// TODO(maisem): implement key rotation.
@@ -377,9 +392,9 @@ func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error {
},
},
}); err != nil {
- return tsweb.Error(500, err.Error(), err)
+ http.Error(w, err.Error(), http.StatusInternalServerError)
}
- return nil
+ return
}
// openIDProviderMetadata is a partial representation of
@@ -442,10 +457,19 @@ var (
openIDSupportedSigningAlgos = views.SliceOf([]string{string(jose.RS256)})
)
-func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) error {
+func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != oidcConfigPath {
- return tsweb.Error(404, "", nil)
+ http.Error(w, "tsidp: not found", http.StatusNotFound)
+ return
}
+ who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
+ if err != nil {
+ log.Printf("Error getting WhoIs: %v", err)
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ authorizeEndpoint := fmt.Sprintf("%s/authorize/%s", s.serverURL, who.Node.ID.String())
+
w.Header().Set("Content-Type", "application/json")
je := json.NewEncoder(w)
je.SetIndent("", " ")
@@ -453,7 +477,7 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) er
Issuer: s.serverURL,
JWKS_URI: s.serverURL + oidcJWKSPath,
UserInfoEndpoint: s.serverURL + "/userinfo",
- AuthorizationEndpoint: s.serverURL + "/authorize", // TODO: add / suffix
+ AuthorizationEndpoint: authorizeEndpoint,
TokenEndpoint: s.serverURL + "/token",
ScopesSupported: openIDSupportedScopes,
ResponseTypesSupported: openIDSupportedReponseTypes,
@@ -461,10 +485,8 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) er
ClaimsSupported: openIDSupportedClaims,
IDTokenSigningAlgValuesSupported: openIDSupportedSigningAlgos,
}); err != nil {
- log.Printf("Error encoding JSON: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
- return nil
}
const (