mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-23 01:11:40 +00:00
cmd/tsidp: persist signing key
Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
63c52e48a4
commit
20e7f99570
@ -8,6 +8,7 @@ import (
|
|||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"flag"
|
"flag"
|
||||||
@ -16,6 +17,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -23,12 +25,15 @@ import (
|
|||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
"tailscale.com/client/tailscale"
|
"tailscale.com/client/tailscale"
|
||||||
|
"tailscale.com/client/tailscale/apitype"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/tsnet"
|
"tailscale.com/tsnet"
|
||||||
"tailscale.com/tsweb"
|
"tailscale.com/tsweb"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
"tailscale.com/types/lazy"
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
"tailscale.com/types/views"
|
"tailscale.com/types/views"
|
||||||
|
"tailscale.com/util/mak"
|
||||||
"tailscale.com/util/must"
|
"tailscale.com/util/must"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -71,9 +76,12 @@ type idpServer struct {
|
|||||||
lc *tailscale.LocalClient
|
lc *tailscale.LocalClient
|
||||||
serverURL string // "https://foo.bar.ts.net"
|
serverURL string // "https://foo.bar.ts.net"
|
||||||
|
|
||||||
oidcSignerInitOnce sync.Once
|
lazySigningKey lazy.SyncValue[*signingKey]
|
||||||
oidcSignerLazy jose.Signer
|
lazySigner lazy.SyncValue[jose.Signer]
|
||||||
oidcSignerError error
|
|
||||||
|
mu sync.Mutex // guards the fields below
|
||||||
|
|
||||||
|
code map[string]*apitype.WhoIsResponse // code -> whois
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -97,11 +105,32 @@ func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if r.URL.Path == "/authorize" {
|
if r.URL.Path == "/authorize" {
|
||||||
redir := r.URL.Query().Get("redirect_uri")
|
who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error getting WhoIs: %v", err)
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
http.Redirect(w, r, redir, http.StatusFound)
|
code := must.Get(readHex())
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
mak.Set(&s.code, code, who)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
q := make(url.Values)
|
||||||
|
q.Set("code", code)
|
||||||
|
q.Set("state", r.URL.Query().Get("state"))
|
||||||
|
u := r.URL.Query().Get("redirect_uri") + "?" + q.Encode()
|
||||||
|
log.Printf("Redirecting to %q", u)
|
||||||
|
|
||||||
|
http.Redirect(w, r, u, http.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.URL.Path == "/token" {
|
||||||
|
|
||||||
|
}
|
||||||
http.Error(w, "tsidp: not found", http.StatusNotFound)
|
http.Error(w, "tsidp: not found", http.StatusNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,24 +140,44 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *idpServer) oidcSigner() (jose.Signer, error) {
|
func (s *idpServer) oidcSigner() (jose.Signer, error) {
|
||||||
s.oidcSignerInitOnce.Do(s.oidcSignerInit)
|
return s.lazySigner.GetErr(func() (jose.Signer, error) {
|
||||||
return s.oidcSignerLazy, s.oidcSignerError
|
sk, err := s.oidcPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return jose.NewSigner(jose.SigningKey{
|
||||||
|
Algorithm: jose.RS256,
|
||||||
|
Key: sk.k,
|
||||||
|
}, &jose.SignerOptions{EmbedJWK: false, ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
|
jose.HeaderType: "JWT",
|
||||||
|
"kid": fmt.Sprint(sk.kid),
|
||||||
|
}})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *idpServer) oidcSignerInit() {
|
func (s *idpServer) oidcPrivateKey() (*signingKey, error) {
|
||||||
id, k := s.oidcPrivateKey()
|
return s.lazySigningKey.GetErr(func() (*signingKey, error) {
|
||||||
s.oidcSignerLazy, s.oidcSignerError = jose.NewSigner(jose.SigningKey{
|
var sk signingKey
|
||||||
Algorithm: jose.RS256,
|
b, err := os.ReadFile("oidc-key.json")
|
||||||
Key: k,
|
if err == nil {
|
||||||
}, &jose.SignerOptions{EmbedJWK: false, ExtraHeaders: map[jose.HeaderKey]interface{}{
|
if err := sk.UnmarshalJSON(b); err == nil {
|
||||||
jose.HeaderType: "JWT",
|
return &sk, nil
|
||||||
"kid": fmt.Sprint(id),
|
} else {
|
||||||
}})
|
log.Printf("Error unmarshaling key: %v", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
func (s *idpServer) oidcPrivateKey() (id uint64, k *rsa.PrivateKey) {
|
id, k := mustGenRSAKey(2048)
|
||||||
id, k = mustGenRSAKey(2048)
|
sk.k = k
|
||||||
return
|
sk.kid = id
|
||||||
|
b, err = sk.MarshalJSON()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error marshaling key: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile("oidc-key.json", b, 0600); err != nil {
|
||||||
|
log.Fatalf("Error writing key: %v", err)
|
||||||
|
}
|
||||||
|
return &sk, nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error {
|
func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error {
|
||||||
@ -136,16 +185,21 @@ func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error {
|
|||||||
return tsweb.Error(404, "", nil)
|
return tsweb.Error(404, "", nil)
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
id, k := s.oidcPrivateKey()
|
sk, err := s.oidcPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
return tsweb.Error(500, err.Error(), err)
|
||||||
|
}
|
||||||
// TODO(maisem): maybe only marshal this once and reuse?
|
// TODO(maisem): maybe only marshal this once and reuse?
|
||||||
// TODO(maisem): implement key rotation.
|
// TODO(maisem): implement key rotation.
|
||||||
if err := json.NewEncoder(w).Encode(jose.JSONWebKeySet{
|
je := json.NewEncoder(w)
|
||||||
|
je.SetIndent("", " ")
|
||||||
|
if err := je.Encode(jose.JSONWebKeySet{
|
||||||
Keys: []jose.JSONWebKey{
|
Keys: []jose.JSONWebKey{
|
||||||
{
|
{
|
||||||
Key: k.Public(),
|
Key: sk.k.Public(),
|
||||||
Algorithm: string(jose.RS256),
|
Algorithm: string(jose.RS256),
|
||||||
Use: "sig",
|
Use: "sig",
|
||||||
KeyID: fmt.Sprint(id),
|
KeyID: fmt.Sprint(sk.kid),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@ -189,7 +243,7 @@ type tailscaleClaims struct {
|
|||||||
var (
|
var (
|
||||||
openIDSupportedClaims = views.SliceOf([]string{
|
openIDSupportedClaims = views.SliceOf([]string{
|
||||||
// Standard claims, these correspond to fields in jwt.Claims.
|
// Standard claims, these correspond to fields in jwt.Claims.
|
||||||
"sub", "aud", "exp", "iat", "iss", "jti", "nbf",
|
"sub", "aud", "exp", "iat", "iss", "jti", "nbf", "username", "email",
|
||||||
|
|
||||||
// Tailscale claims, these correspond to fields in tailscaleClaims.
|
// Tailscale claims, these correspond to fields in tailscaleClaims.
|
||||||
"key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid",
|
"key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid",
|
||||||
@ -215,8 +269,10 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) er
|
|||||||
return tsweb.Error(404, "", nil)
|
return tsweb.Error(404, "", nil)
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(io.MultiWriter(w, os.Stderr)).Encode(openIDProviderMetadata{
|
je := json.NewEncoder(io.MultiWriter(w, os.Stderr))
|
||||||
Issuer: s.serverURL + "/",
|
je.SetIndent("", " ")
|
||||||
|
if err := je.Encode(openIDProviderMetadata{
|
||||||
|
Issuer: s.serverURL,
|
||||||
JWKS_URI: s.serverURL + oidcJWKSPath,
|
JWKS_URI: s.serverURL + oidcJWKSPath,
|
||||||
UserInfoEndpoint: s.serverURL + "/userinfo",
|
UserInfoEndpoint: s.serverURL + "/userinfo",
|
||||||
AuthorizationEndpoint: s.serverURL + "/authorize", // TODO: add /<nodeid> suffix
|
AuthorizationEndpoint: s.serverURL + "/authorize", // TODO: add /<nodeid> suffix
|
||||||
@ -247,6 +303,14 @@ func mustGenRSAKey(bits int) (kid uint64, k *rsa.PrivateKey) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readHex() (string, error) {
|
||||||
|
var proxyCred [16]byte
|
||||||
|
if _, err := crand.Read(proxyCred[:]); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(proxyCred[:]), nil
|
||||||
|
}
|
||||||
|
|
||||||
// readUint64 reads from r until 8 bytes represent a non-zero uint64.
|
// readUint64 reads from r until 8 bytes represent a non-zero uint64.
|
||||||
func readUint64(r io.Reader) (uint64, error) {
|
func readUint64(r io.Reader) (uint64, error) {
|
||||||
for {
|
for {
|
||||||
@ -267,31 +331,41 @@ type rsaPrivateKeyJSONWrapper struct {
|
|||||||
ID uint64
|
ID uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalKeyJSON(k *rsa.PrivateKey, kid uint64) ([]byte, error) {
|
type signingKey struct {
|
||||||
|
k *rsa.PrivateKey
|
||||||
|
kid uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sk *signingKey) MarshalJSON() ([]byte, error) {
|
||||||
b := pem.Block{
|
b := pem.Block{
|
||||||
Type: "RSA PRIVATE KEY",
|
Type: "RSA PRIVATE KEY",
|
||||||
Bytes: x509.MarshalPKCS1PrivateKey(k),
|
Bytes: x509.MarshalPKCS1PrivateKey(sk.k),
|
||||||
}
|
}
|
||||||
bts := pem.EncodeToMemory(&b)
|
bts := pem.EncodeToMemory(&b)
|
||||||
return json.Marshal(rsaPrivateKeyJSONWrapper{
|
return json.Marshal(rsaPrivateKeyJSONWrapper{
|
||||||
Key: base64.URLEncoding.EncodeToString(bts),
|
Key: base64.URLEncoding.EncodeToString(bts),
|
||||||
ID: kid,
|
ID: sk.kid,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func unmarshalKeyJSON(b []byte) (*rsa.PrivateKey, uint64, error) {
|
func (sk *signingKey) UnmarshalJSON(b []byte) error {
|
||||||
var wrapper rsaPrivateKeyJSONWrapper
|
var wrapper rsaPrivateKeyJSONWrapper
|
||||||
if err := json.Unmarshal(b, &wrapper); err != nil {
|
if err := json.Unmarshal(b, &wrapper); err != nil {
|
||||||
return nil, 0, err
|
return err
|
||||||
}
|
}
|
||||||
if len(wrapper.Key) == 0 {
|
if len(wrapper.Key) == 0 {
|
||||||
return nil, 0, nil
|
return nil
|
||||||
}
|
}
|
||||||
b64dec, err := base64.URLEncoding.DecodeString(wrapper.Key)
|
b64dec, err := base64.URLEncoding.DecodeString(wrapper.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return err
|
||||||
}
|
}
|
||||||
blk, _ := pem.Decode(b64dec)
|
blk, _ := pem.Decode(b64dec)
|
||||||
k, err := x509.ParsePKCS1PrivateKey(blk.Bytes)
|
k, err := x509.ParsePKCS1PrivateKey(blk.Bytes)
|
||||||
return k, wrapper.ID, err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sk.k = k
|
||||||
|
sk.kid = wrapper.ID
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user