mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-07 16:17:41 +00:00
cmd/tsidp: persist signing key
Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
63c52e48a4
commit
20e7f99570
@ -8,6 +8,7 @@
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
@ -16,6 +17,7 @@
|
||||
"log"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -23,12 +25,15 @@
|
||||
"github.com/golang-jwt/jwt"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"tailscale.com/client/tailscale"
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tsnet"
|
||||
"tailscale.com/tsweb"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/lazy"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
@ -71,9 +76,12 @@ type idpServer struct {
|
||||
lc *tailscale.LocalClient
|
||||
serverURL string // "https://foo.bar.ts.net"
|
||||
|
||||
oidcSignerInitOnce sync.Once
|
||||
oidcSignerLazy jose.Signer
|
||||
oidcSignerError error
|
||||
lazySigningKey lazy.SyncValue[*signingKey]
|
||||
lazySigner lazy.SyncValue[jose.Signer]
|
||||
|
||||
mu sync.Mutex // guards the fields below
|
||||
|
||||
code map[string]*apitype.WhoIsResponse // code -> whois
|
||||
}
|
||||
|
||||
func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
@ -97,11 +105,32 @@ func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if r.URL.Path == "/authorize" {
|
||||
redir := r.URL.Query().Get("redirect_uri")
|
||||
who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
|
||||
if err != nil {
|
||||
log.Printf("Error getting WhoIs: %v", err)
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, redir, http.StatusFound)
|
||||
code := must.Get(readHex())
|
||||
|
||||
s.mu.Lock()
|
||||
mak.Set(&s.code, code, who)
|
||||
s.mu.Unlock()
|
||||
|
||||
q := make(url.Values)
|
||||
q.Set("code", code)
|
||||
q.Set("state", r.URL.Query().Get("state"))
|
||||
u := r.URL.Query().Get("redirect_uri") + "?" + q.Encode()
|
||||
log.Printf("Redirecting to %q", u)
|
||||
|
||||
http.Redirect(w, r, u, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/token" {
|
||||
|
||||
}
|
||||
http.Error(w, "tsidp: not found", http.StatusNotFound)
|
||||
}
|
||||
|
||||
@ -111,24 +140,44 @@ func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
|
||||
func (s *idpServer) oidcSigner() (jose.Signer, error) {
|
||||
s.oidcSignerInitOnce.Do(s.oidcSignerInit)
|
||||
return s.oidcSignerLazy, s.oidcSignerError
|
||||
return s.lazySigner.GetErr(func() (jose.Signer, error) {
|
||||
sk, err := s.oidcPrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.RS256,
|
||||
Key: sk.k,
|
||||
}, &jose.SignerOptions{EmbedJWK: false, ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
jose.HeaderType: "JWT",
|
||||
"kid": fmt.Sprint(sk.kid),
|
||||
}})
|
||||
})
|
||||
}
|
||||
|
||||
func (s *idpServer) oidcSignerInit() {
|
||||
id, k := s.oidcPrivateKey()
|
||||
s.oidcSignerLazy, s.oidcSignerError = jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.RS256,
|
||||
Key: k,
|
||||
}, &jose.SignerOptions{EmbedJWK: false, ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
jose.HeaderType: "JWT",
|
||||
"kid": fmt.Sprint(id),
|
||||
}})
|
||||
}
|
||||
|
||||
func (s *idpServer) oidcPrivateKey() (id uint64, k *rsa.PrivateKey) {
|
||||
id, k = mustGenRSAKey(2048)
|
||||
return
|
||||
func (s *idpServer) oidcPrivateKey() (*signingKey, error) {
|
||||
return s.lazySigningKey.GetErr(func() (*signingKey, error) {
|
||||
var sk signingKey
|
||||
b, err := os.ReadFile("oidc-key.json")
|
||||
if err == nil {
|
||||
if err := sk.UnmarshalJSON(b); err == nil {
|
||||
return &sk, nil
|
||||
} else {
|
||||
log.Printf("Error unmarshaling key: %v", err)
|
||||
}
|
||||
}
|
||||
id, k := mustGenRSAKey(2048)
|
||||
sk.k = k
|
||||
sk.kid = id
|
||||
b, err = sk.MarshalJSON()
|
||||
if err != nil {
|
||||
log.Fatalf("Error marshaling key: %v", err)
|
||||
}
|
||||
if err := os.WriteFile("oidc-key.json", b, 0600); err != nil {
|
||||
log.Fatalf("Error writing key: %v", err)
|
||||
}
|
||||
return &sk, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error {
|
||||
@ -136,16 +185,21 @@ func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error {
|
||||
return tsweb.Error(404, "", nil)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
id, k := s.oidcPrivateKey()
|
||||
sk, err := s.oidcPrivateKey()
|
||||
if err != nil {
|
||||
return tsweb.Error(500, err.Error(), err)
|
||||
}
|
||||
// TODO(maisem): maybe only marshal this once and reuse?
|
||||
// TODO(maisem): implement key rotation.
|
||||
if err := json.NewEncoder(w).Encode(jose.JSONWebKeySet{
|
||||
je := json.NewEncoder(w)
|
||||
je.SetIndent("", " ")
|
||||
if err := je.Encode(jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{
|
||||
Key: k.Public(),
|
||||
Key: sk.k.Public(),
|
||||
Algorithm: string(jose.RS256),
|
||||
Use: "sig",
|
||||
KeyID: fmt.Sprint(id),
|
||||
KeyID: fmt.Sprint(sk.kid),
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
@ -189,7 +243,7 @@ type tailscaleClaims struct {
|
||||
var (
|
||||
openIDSupportedClaims = views.SliceOf([]string{
|
||||
// Standard claims, these correspond to fields in jwt.Claims.
|
||||
"sub", "aud", "exp", "iat", "iss", "jti", "nbf",
|
||||
"sub", "aud", "exp", "iat", "iss", "jti", "nbf", "username", "email",
|
||||
|
||||
// Tailscale claims, these correspond to fields in tailscaleClaims.
|
||||
"key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid",
|
||||
@ -215,8 +269,10 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) er
|
||||
return tsweb.Error(404, "", nil)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(io.MultiWriter(w, os.Stderr)).Encode(openIDProviderMetadata{
|
||||
Issuer: s.serverURL + "/",
|
||||
je := json.NewEncoder(io.MultiWriter(w, os.Stderr))
|
||||
je.SetIndent("", " ")
|
||||
if err := je.Encode(openIDProviderMetadata{
|
||||
Issuer: s.serverURL,
|
||||
JWKS_URI: s.serverURL + oidcJWKSPath,
|
||||
UserInfoEndpoint: s.serverURL + "/userinfo",
|
||||
AuthorizationEndpoint: s.serverURL + "/authorize", // TODO: add /<nodeid> suffix
|
||||
@ -247,6 +303,14 @@ func mustGenRSAKey(bits int) (kid uint64, k *rsa.PrivateKey) {
|
||||
return
|
||||
}
|
||||
|
||||
func readHex() (string, error) {
|
||||
var proxyCred [16]byte
|
||||
if _, err := crand.Read(proxyCred[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(proxyCred[:]), nil
|
||||
}
|
||||
|
||||
// readUint64 reads from r until 8 bytes represent a non-zero uint64.
|
||||
func readUint64(r io.Reader) (uint64, error) {
|
||||
for {
|
||||
@ -267,31 +331,41 @@ type rsaPrivateKeyJSONWrapper struct {
|
||||
ID uint64
|
||||
}
|
||||
|
||||
func marshalKeyJSON(k *rsa.PrivateKey, kid uint64) ([]byte, error) {
|
||||
type signingKey struct {
|
||||
k *rsa.PrivateKey
|
||||
kid uint64
|
||||
}
|
||||
|
||||
func (sk *signingKey) MarshalJSON() ([]byte, error) {
|
||||
b := pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(k),
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(sk.k),
|
||||
}
|
||||
bts := pem.EncodeToMemory(&b)
|
||||
return json.Marshal(rsaPrivateKeyJSONWrapper{
|
||||
Key: base64.URLEncoding.EncodeToString(bts),
|
||||
ID: kid,
|
||||
ID: sk.kid,
|
||||
})
|
||||
}
|
||||
|
||||
func unmarshalKeyJSON(b []byte) (*rsa.PrivateKey, uint64, error) {
|
||||
func (sk *signingKey) UnmarshalJSON(b []byte) error {
|
||||
var wrapper rsaPrivateKeyJSONWrapper
|
||||
if err := json.Unmarshal(b, &wrapper); err != nil {
|
||||
return nil, 0, err
|
||||
return err
|
||||
}
|
||||
if len(wrapper.Key) == 0 {
|
||||
return nil, 0, nil
|
||||
return nil
|
||||
}
|
||||
b64dec, err := base64.URLEncoding.DecodeString(wrapper.Key)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return err
|
||||
}
|
||||
blk, _ := pem.Decode(b64dec)
|
||||
k, err := x509.ParsePKCS1PrivateKey(blk.Bytes)
|
||||
return k, wrapper.ID, err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sk.k = k
|
||||
sk.kid = wrapper.ID
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user