cmd/tsidp: persist signing key

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2023-11-13 16:09:50 -08:00
parent 63c52e48a4
commit 20e7f99570

View File

@ -8,6 +8,7 @@
"crypto/x509"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"encoding/pem"
"flag"
@ -16,6 +17,7 @@
"log"
"net/http"
"net/netip"
"net/url"
"os"
"strings"
"sync"
@ -23,12 +25,15 @@
"github.com/golang-jwt/jwt"
"gopkg.in/square/go-jose.v2"
"tailscale.com/client/tailscale"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/tailcfg"
"tailscale.com/tsnet"
"tailscale.com/tsweb"
"tailscale.com/types/key"
"tailscale.com/types/lazy"
"tailscale.com/types/logger"
"tailscale.com/types/views"
"tailscale.com/util/mak"
"tailscale.com/util/must"
)
@ -71,9 +76,12 @@ type idpServer struct {
lc *tailscale.LocalClient
serverURL string // "https://foo.bar.ts.net"
oidcSignerInitOnce sync.Once
oidcSignerLazy jose.Signer
oidcSignerError error
lazySigningKey lazy.SyncValue[*signingKey]
lazySigner lazy.SyncValue[jose.Signer]
mu sync.Mutex // guards the fields below
code map[string]*apitype.WhoIsResponse // code -> whois
}
func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -97,11 +105,32 @@ func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
if r.URL.Path == "/authorize" {
redir := r.URL.Query().Get("redirect_uri")
who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
if err != nil {
log.Printf("Error getting WhoIs: %v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, r, redir, http.StatusFound)
code := must.Get(readHex())
s.mu.Lock()
mak.Set(&s.code, code, who)
s.mu.Unlock()
q := make(url.Values)
q.Set("code", code)
q.Set("state", r.URL.Query().Get("state"))
u := r.URL.Query().Get("redirect_uri") + "?" + q.Encode()
log.Printf("Redirecting to %q", u)
http.Redirect(w, r, u, http.StatusFound)
return
}
if r.URL.Path == "/token" {
}
http.Error(w, "tsidp: not found", http.StatusNotFound)
}
@ -111,24 +140,44 @@ func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
)
func (s *idpServer) oidcSigner() (jose.Signer, error) {
s.oidcSignerInitOnce.Do(s.oidcSignerInit)
return s.oidcSignerLazy, s.oidcSignerError
return s.lazySigner.GetErr(func() (jose.Signer, error) {
sk, err := s.oidcPrivateKey()
if err != nil {
return nil, err
}
return jose.NewSigner(jose.SigningKey{
Algorithm: jose.RS256,
Key: sk.k,
}, &jose.SignerOptions{EmbedJWK: false, ExtraHeaders: map[jose.HeaderKey]interface{}{
jose.HeaderType: "JWT",
"kid": fmt.Sprint(sk.kid),
}})
})
}
func (s *idpServer) oidcSignerInit() {
id, k := s.oidcPrivateKey()
s.oidcSignerLazy, s.oidcSignerError = jose.NewSigner(jose.SigningKey{
Algorithm: jose.RS256,
Key: k,
}, &jose.SignerOptions{EmbedJWK: false, ExtraHeaders: map[jose.HeaderKey]interface{}{
jose.HeaderType: "JWT",
"kid": fmt.Sprint(id),
}})
}
func (s *idpServer) oidcPrivateKey() (id uint64, k *rsa.PrivateKey) {
id, k = mustGenRSAKey(2048)
return
func (s *idpServer) oidcPrivateKey() (*signingKey, error) {
return s.lazySigningKey.GetErr(func() (*signingKey, error) {
var sk signingKey
b, err := os.ReadFile("oidc-key.json")
if err == nil {
if err := sk.UnmarshalJSON(b); err == nil {
return &sk, nil
} else {
log.Printf("Error unmarshaling key: %v", err)
}
}
id, k := mustGenRSAKey(2048)
sk.k = k
sk.kid = id
b, err = sk.MarshalJSON()
if err != nil {
log.Fatalf("Error marshaling key: %v", err)
}
if err := os.WriteFile("oidc-key.json", b, 0600); err != nil {
log.Fatalf("Error writing key: %v", err)
}
return &sk, nil
})
}
func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error {
@ -136,16 +185,21 @@ func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) error {
return tsweb.Error(404, "", nil)
}
w.Header().Set("Content-Type", "application/json")
id, k := s.oidcPrivateKey()
sk, err := s.oidcPrivateKey()
if err != nil {
return tsweb.Error(500, err.Error(), err)
}
// TODO(maisem): maybe only marshal this once and reuse?
// TODO(maisem): implement key rotation.
if err := json.NewEncoder(w).Encode(jose.JSONWebKeySet{
je := json.NewEncoder(w)
je.SetIndent("", " ")
if err := je.Encode(jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{
Key: k.Public(),
Key: sk.k.Public(),
Algorithm: string(jose.RS256),
Use: "sig",
KeyID: fmt.Sprint(id),
KeyID: fmt.Sprint(sk.kid),
},
},
}); err != nil {
@ -189,7 +243,7 @@ type tailscaleClaims struct {
var (
openIDSupportedClaims = views.SliceOf([]string{
// Standard claims, these correspond to fields in jwt.Claims.
"sub", "aud", "exp", "iat", "iss", "jti", "nbf",
"sub", "aud", "exp", "iat", "iss", "jti", "nbf", "username", "email",
// Tailscale claims, these correspond to fields in tailscaleClaims.
"key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid",
@ -215,8 +269,10 @@ func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) er
return tsweb.Error(404, "", nil)
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(io.MultiWriter(w, os.Stderr)).Encode(openIDProviderMetadata{
Issuer: s.serverURL + "/",
je := json.NewEncoder(io.MultiWriter(w, os.Stderr))
je.SetIndent("", " ")
if err := je.Encode(openIDProviderMetadata{
Issuer: s.serverURL,
JWKS_URI: s.serverURL + oidcJWKSPath,
UserInfoEndpoint: s.serverURL + "/userinfo",
AuthorizationEndpoint: s.serverURL + "/authorize", // TODO: add /<nodeid> suffix
@ -247,6 +303,14 @@ func mustGenRSAKey(bits int) (kid uint64, k *rsa.PrivateKey) {
return
}
func readHex() (string, error) {
var proxyCred [16]byte
if _, err := crand.Read(proxyCred[:]); err != nil {
return "", err
}
return hex.EncodeToString(proxyCred[:]), nil
}
// readUint64 reads from r until 8 bytes represent a non-zero uint64.
func readUint64(r io.Reader) (uint64, error) {
for {
@ -267,31 +331,41 @@ type rsaPrivateKeyJSONWrapper struct {
ID uint64
}
func marshalKeyJSON(k *rsa.PrivateKey, kid uint64) ([]byte, error) {
type signingKey struct {
k *rsa.PrivateKey
kid uint64
}
func (sk *signingKey) MarshalJSON() ([]byte, error) {
b := pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(k),
Bytes: x509.MarshalPKCS1PrivateKey(sk.k),
}
bts := pem.EncodeToMemory(&b)
return json.Marshal(rsaPrivateKeyJSONWrapper{
Key: base64.URLEncoding.EncodeToString(bts),
ID: kid,
ID: sk.kid,
})
}
func unmarshalKeyJSON(b []byte) (*rsa.PrivateKey, uint64, error) {
func (sk *signingKey) UnmarshalJSON(b []byte) error {
var wrapper rsaPrivateKeyJSONWrapper
if err := json.Unmarshal(b, &wrapper); err != nil {
return nil, 0, err
return err
}
if len(wrapper.Key) == 0 {
return nil, 0, nil
return nil
}
b64dec, err := base64.URLEncoding.DecodeString(wrapper.Key)
if err != nil {
return nil, 0, err
return err
}
blk, _ := pem.Decode(b64dec)
k, err := x509.ParsePKCS1PrivateKey(blk.Bytes)
return k, wrapper.ID, err
if err != nil {
return err
}
sk.k = k
sk.kid = wrapper.ID
return nil
}