diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go
new file mode 100644
index 000000000..3599ab141
--- /dev/null
+++ b/cmd/tsidp/tsidp.go
@@ -0,0 +1,703 @@
+// Copyright (c) Tailscale Inc & AUTHORS
+// SPDX-License-Identifier: BSD-3-Clause
+
+// The tsidp command is an OpenID Connect Identity Provider server.
+//
+// See https://github.com/tailscale/tailscale/issues/10263 for background.
+package main
+
+import (
+ "context"
+ crand "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "encoding/pem"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "net/netip"
+ "net/url"
+ "os"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "gopkg.in/square/go-jose.v2"
+ "gopkg.in/square/go-jose.v2/jwt"
+ "tailscale.com/client/tailscale"
+ "tailscale.com/client/tailscale/apitype"
+ "tailscale.com/envknob"
+ "tailscale.com/ipn/ipnstate"
+ "tailscale.com/tailcfg"
+ "tailscale.com/tsnet"
+ "tailscale.com/types/key"
+ "tailscale.com/types/lazy"
+ "tailscale.com/types/logger"
+ "tailscale.com/types/views"
+ "tailscale.com/util/mak"
+ "tailscale.com/util/must"
+ "tailscale.com/util/rands"
+)
+
+var (
+ flagVerbose = flag.Bool("verbose", false, "be verbose")
+ flagPort = flag.Int("port", 443, "port to listen on")
+ flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost")
+ flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet")
+)
+
+func main() {
+ flag.Parse()
+ ctx := context.Background()
+ if !envknob.UseWIPCode() {
+ log.Fatal("cmd/tsidp is a work in progress and has not been security reviewed;\nits use requires TAILSCALE_USE_WIP_CODE=1 be set in the environment for now.")
+ }
+
+ var (
+ lc *tailscale.LocalClient
+ st *ipnstate.Status
+ err error
+
+ lns []net.Listener
+ )
+ if *flagUseLocalTailscaled {
+ lc = &tailscale.LocalClient{}
+ st, err = lc.StatusWithoutPeers(ctx)
+ if err != nil {
+ log.Fatalf("getting status: %v", err)
+ }
+ portStr := fmt.Sprint(*flagPort)
+ anySuccess := false
+ for _, ip := range st.TailscaleIPs {
+ ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), portStr))
+ if err != nil {
+ log.Printf("failed to listen on %v: %v", ip, err)
+ continue
+ }
+ anySuccess = true
+ ln = tls.NewListener(ln, &tls.Config{
+ GetCertificate: lc.GetCertificate,
+ })
+ lns = append(lns, ln)
+ }
+ if !anySuccess {
+ log.Fatalf("failed to listen on any of %v", st.TailscaleIPs)
+ }
+ } else {
+ ts := &tsnet.Server{
+ Hostname: "idp",
+ }
+ if !*flagVerbose {
+ ts.Logf = logger.Discard
+ }
+ st, err = ts.Up(ctx)
+ if err != nil {
+ log.Fatal(err)
+ }
+ lc, err = ts.LocalClient()
+ if err != nil {
+ log.Fatalf("getting local client: %v", err)
+ }
+ ln, err := ts.ListenTLS("tcp", fmt.Sprintf(":%d", *flagPort))
+ if err != nil {
+ log.Fatal(err)
+ }
+ lns = append(lns, ln)
+ }
+
+ srv := &idpServer{
+ lc: lc,
+ }
+ if *flagPort != 443 {
+ srv.serverURL = fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), *flagPort)
+ } else {
+ srv.serverURL = fmt.Sprintf("https://%s", strings.TrimSuffix(st.Self.DNSName, "."))
+ }
+
+ log.Printf("Running tsidp at %s ...", srv.serverURL)
+
+ if *flagLocalPort != -1 {
+ log.Printf("Also running tsidp at %s ...", srv.loopbackURL)
+ srv.loopbackURL = fmt.Sprintf("http://localhost:%d", *flagLocalPort)
+ ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *flagLocalPort))
+ if err != nil {
+ log.Fatal(err)
+ }
+ lns = append(lns, ln)
+ }
+
+ for _, ln := range lns {
+ go http.Serve(ln, srv)
+ }
+ select {}
+}
+
+type idpServer struct {
+ lc *tailscale.LocalClient
+ loopbackURL string
+ serverURL string // "https://foo.bar.ts.net"
+
+ lazyMux lazy.SyncValue[*http.ServeMux]
+ lazySigningKey lazy.SyncValue[*signingKey]
+ lazySigner lazy.SyncValue[jose.Signer]
+
+ mu sync.Mutex // guards the fields below
+ code map[string]*authRequest // keyed by random hex
+ accessToken map[string]*authRequest // keyed by random hex
+}
+
+type authRequest struct {
+ // localRP is true if the request is from a relying party running on the
+ // same machine as the idp server. It is mutually exclusive with rpNodeID.
+ localRP bool
+
+ // rpNodeID is the NodeID of the relying party (who requested the auth, such
+ // as Proxmox or Synology), not the user node who is being authenticated. It
+ // is mutually exclusive with localRP.
+ rpNodeID tailcfg.NodeID
+
+ // clientID is the "client_id" sent in the authorized request.
+ clientID string
+
+ // nonce presented in the request.
+ nonce string
+
+ // redirectURI is the redirect_uri presented in the request.
+ redirectURI string
+
+ // remoteUser is the user who is being authenticated.
+ remoteUser *apitype.WhoIsResponse
+
+ // validTill is the time until which the token is valid.
+ // As of 2023-11-14, it is 5 minutes.
+ // TODO: add routine to delete expired tokens.
+ validTill time.Time
+}
+
+func (ar *authRequest) allowRelyingParty(ctx context.Context, remoteAddr string, lc *tailscale.LocalClient) error {
+ if ar.localRP {
+ ra, err := netip.ParseAddrPort(remoteAddr)
+ if err != nil {
+ return err
+ }
+ if !ra.Addr().IsLoopback() {
+ return fmt.Errorf("tsidp: request from non-loopback address")
+ }
+ return nil
+ }
+ who, err := lc.WhoIs(ctx, remoteAddr)
+ if err != nil {
+ return fmt.Errorf("tsidp: error getting WhoIs: %w", err)
+ }
+ if ar.rpNodeID != who.Node.ID {
+ return fmt.Errorf("tsidp: token for different node")
+ }
+ return nil
+}
+
+func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) {
+ who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr)
+ if err != nil {
+ log.Printf("Error getting WhoIs: %v", err)
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ uq := r.URL.Query()
+
+ code := rands.HexString(32)
+ ar := &authRequest{
+ nonce: uq.Get("nonce"),
+ remoteUser: who,
+ redirectURI: uq.Get("redirect_uri"),
+ clientID: uq.Get("client_id"),
+ }
+
+ if r.URL.Path == "/authorize/localhost" {
+ ar.localRP = true
+ } else {
+ var ok bool
+ ar.rpNodeID, ok = parseID[tailcfg.NodeID](strings.TrimPrefix(r.URL.Path, "/authorize/"))
+ if !ok {
+ http.Error(w, "tsidp: invalid node ID suffix after /authorize/", http.StatusBadRequest)
+ return
+ }
+ }
+
+ s.mu.Lock()
+ mak.Set(&s.code, code, ar)
+ s.mu.Unlock()
+
+ q := make(url.Values)
+ q.Set("code", code)
+ q.Set("state", uq.Get("state"))
+ u := uq.Get("redirect_uri") + "?" + q.Encode()
+ log.Printf("Redirecting to %q", u)
+
+ http.Redirect(w, r, u, http.StatusFound)
+}
+
+func (s *idpServer) newMux() *http.ServeMux {
+ mux := http.NewServeMux()
+ mux.HandleFunc(oidcJWKSPath, s.serveJWKS)
+ mux.HandleFunc(oidcConfigPath, s.serveOpenIDConfig)
+ mux.HandleFunc("/authorize/", s.authorize)
+ mux.HandleFunc("/userinfo", s.serveUserInfo)
+ mux.HandleFunc("/token", s.serveToken)
+ mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/" {
+ io.WriteString(w, "
Tailscale OIDC IdP
")
+ return
+ }
+ http.Error(w, "tsidp: not found", http.StatusNotFound)
+ })
+ return mux
+}
+
+func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ log.Printf("%v %v", r.Method, r.URL)
+ s.lazyMux.Get(s.newMux).ServeHTTP(w, r)
+}
+
+func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "GET" {
+ http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+ tk, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ")
+ if !ok {
+ http.Error(w, "tsidp: invalid Authorization header", http.StatusBadRequest)
+ return
+ }
+
+ s.mu.Lock()
+ ar, ok := s.accessToken[tk]
+ s.mu.Unlock()
+ if !ok {
+ http.Error(w, "tsidp: invalid token", http.StatusBadRequest)
+ return
+ }
+ if err := ar.allowRelyingParty(r.Context(), r.RemoteAddr, s.lc); err != nil {
+ log.Printf("Error allowing relying party: %v", err)
+ http.Error(w, err.Error(), http.StatusForbidden)
+ return
+ }
+
+ if ar.validTill.Before(time.Now()) {
+ http.Error(w, "tsidp: token expired", http.StatusBadRequest)
+ s.mu.Lock()
+ delete(s.accessToken, tk)
+ s.mu.Unlock()
+ }
+
+ ui := userInfo{}
+ if ar.remoteUser.Node.IsTagged() {
+ http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest)
+ return
+ }
+ ui.Sub = ar.remoteUser.Node.User.String()
+ ui.Name = ar.remoteUser.UserProfile.DisplayName
+ ui.Email = ar.remoteUser.UserProfile.LoginName
+ ui.Picture = ar.remoteUser.UserProfile.ProfilePicURL
+
+ // TODO(maisem): not sure if this is the right thing to do
+ ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@")
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(ui); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+}
+
+type userInfo struct {
+ Sub string `json:"sub"`
+ Name string `json:"name"`
+ Email string `json:"email"`
+ Picture string `json:"picture"`
+ UserName string `json:"username"`
+}
+
+func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "POST" {
+ http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+ if r.FormValue("grant_type") != "authorization_code" {
+ http.Error(w, "tsidp: grant_type not supported", http.StatusBadRequest)
+ return
+ }
+ code := r.FormValue("code")
+ if code == "" {
+ http.Error(w, "tsidp: code is required", http.StatusBadRequest)
+ return
+ }
+ s.mu.Lock()
+ ar, ok := s.code[code]
+ if ok {
+ delete(s.code, code)
+ }
+ s.mu.Unlock()
+ if !ok {
+ http.Error(w, "tsidp: code not found", http.StatusBadRequest)
+ return
+ }
+ if err := ar.allowRelyingParty(r.Context(), r.RemoteAddr, s.lc); err != nil {
+ log.Printf("Error allowing relying party: %v", err)
+ http.Error(w, err.Error(), http.StatusForbidden)
+ return
+ }
+ if ar.redirectURI != r.FormValue("redirect_uri") {
+ http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest)
+ return
+ }
+ signer, err := s.oidcSigner()
+ if err != nil {
+ log.Printf("Error getting signer: %v", err)
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ jti := rands.HexString(32)
+ who := ar.remoteUser
+
+ // TODO(maisem): not sure if this is the right thing to do
+ userName, _, _ := strings.Cut(ar.remoteUser.UserProfile.LoginName, "@")
+ n := who.Node.View()
+ if n.IsTagged() {
+ http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest)
+ return
+ }
+
+ now := time.Now()
+ _, tcd, _ := strings.Cut(n.Name(), ".")
+ tsClaims := tailscaleClaims{
+ Claims: jwt.Claims{
+ Audience: jwt.Audience{ar.clientID},
+ Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)),
+ ID: jti,
+ IssuedAt: jwt.NewNumericDate(now),
+ Issuer: s.serverURL,
+ NotBefore: jwt.NewNumericDate(now),
+ Subject: n.User().String(),
+ },
+ Nonce: ar.nonce,
+ Key: n.Key(),
+ Addresses: n.Addresses(),
+ NodeID: n.ID(),
+ NodeName: n.Name(),
+ Tailnet: tcd,
+ UserID: n.User(),
+ Email: who.UserProfile.LoginName,
+ UserName: userName,
+ }
+ if ar.localRP {
+ tsClaims.Issuer = s.loopbackURL
+ }
+
+ // Create an OIDC token using this issuer's signer.
+ token, err := jwt.Signed(signer).Claims(tsClaims).CompactSerialize()
+ if err != nil {
+ log.Printf("Error getting token: %v", err)
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ at := rands.HexString(32)
+ s.mu.Lock()
+ ar.validTill = now.Add(5 * time.Minute)
+ mak.Set(&s.accessToken, at, ar)
+ s.mu.Unlock()
+
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(oidcTokenResponse{
+ AccessToken: at,
+ TokenType: "Bearer",
+ ExpiresIn: 5 * 60,
+ IDToken: token,
+ }); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+}
+
+type oidcTokenResponse struct {
+ IDToken string `json:"id_token"`
+ TokenType string `json:"token_type"`
+ AccessToken string `json:"access_token"`
+ RefreshToken string `json:"refresh_token"`
+ ExpiresIn int `json:"expires_in"`
+}
+
+const (
+ oidcJWKSPath = "/.well-known/jwks.json"
+ oidcConfigPath = "/.well-known/openid-configuration"
+)
+
+func (s *idpServer) oidcSigner() (jose.Signer, error) {
+ return s.lazySigner.GetErr(func() (jose.Signer, error) {
+ sk, err := s.oidcPrivateKey()
+ if err != nil {
+ return nil, err
+ }
+ return jose.NewSigner(jose.SigningKey{
+ Algorithm: jose.RS256,
+ Key: sk.k,
+ }, &jose.SignerOptions{EmbedJWK: false, ExtraHeaders: map[jose.HeaderKey]any{
+ jose.HeaderType: "JWT",
+ "kid": fmt.Sprint(sk.kid),
+ }})
+ })
+}
+
+func (s *idpServer) oidcPrivateKey() (*signingKey, error) {
+ return s.lazySigningKey.GetErr(func() (*signingKey, error) {
+ var sk signingKey
+ b, err := os.ReadFile("oidc-key.json")
+ if err == nil {
+ if err := sk.UnmarshalJSON(b); err == nil {
+ return &sk, nil
+ } else {
+ log.Printf("Error unmarshaling key: %v", err)
+ }
+ }
+ id, k := mustGenRSAKey(2048)
+ sk.k = k
+ sk.kid = id
+ b, err = sk.MarshalJSON()
+ if err != nil {
+ log.Fatalf("Error marshaling key: %v", err)
+ }
+ if err := os.WriteFile("oidc-key.json", b, 0600); err != nil {
+ log.Fatalf("Error writing key: %v", err)
+ }
+ return &sk, nil
+ })
+}
+
+func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != oidcJWKSPath {
+ http.Error(w, "tsidp: not found", http.StatusNotFound)
+ return
+ }
+ w.Header().Set("Content-Type", "application/json")
+ sk, err := s.oidcPrivateKey()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ // TODO(maisem): maybe only marshal this once and reuse?
+ // TODO(maisem): implement key rotation.
+ je := json.NewEncoder(w)
+ je.SetIndent("", " ")
+ if err := je.Encode(jose.JSONWebKeySet{
+ Keys: []jose.JSONWebKey{
+ {
+ Key: sk.k.Public(),
+ Algorithm: string(jose.RS256),
+ Use: "sig",
+ KeyID: fmt.Sprint(sk.kid),
+ },
+ },
+ }); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+ return
+}
+
+// openIDProviderMetadata is a partial representation of
+// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata.
+type openIDProviderMetadata struct {
+ Issuer string `json:"issuer"`
+ AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"`
+ TokenEndpoint string `json:"token_endpoint,omitempty"`
+ UserInfoEndpoint string `json:"userinfo_endpoint,omitempty"`
+ JWKS_URI string `json:"jwks_uri"`
+ ScopesSupported views.Slice[string] `json:"scopes_supported"`
+ ResponseTypesSupported views.Slice[string] `json:"response_types_supported"`
+ SubjectTypesSupported views.Slice[string] `json:"subject_types_supported"`
+ ClaimsSupported views.Slice[string] `json:"claims_supported"`
+ IDTokenSigningAlgValuesSupported views.Slice[string] `json:"id_token_signing_alg_values_supported"`
+ // TODO(maisem): maybe add other fields?
+ // Currently we fill out the REQUIRED fields, scopes_supported and claims_supported.
+}
+
+type tailscaleClaims struct {
+ jwt.Claims `json:",inline"`
+ Nonce string `json:"nonce,omitempty"` // the nonce from the request
+ Key key.NodePublic `json:"key"` // the node public key
+ Addresses views.Slice[netip.Prefix] `json:"addresses"` // the Tailscale IPs of the node
+ NodeID tailcfg.NodeID `json:"nid"` // the stable node ID
+ NodeName string `json:"node"` // name of the node
+ Tailnet string `json:"tailnet"` // tailnet (like tail-scale.ts.net)
+
+ // Email is the "emailish" value with an '@' sign. It might not be a valid email.
+ Email string `json:"email,omitempty"` // user emailish (like "alice@github" or "bob@example.com")
+ UserID tailcfg.UserID `json:"uid,omitempty"`
+
+ // UserName is the local part of Email (without '@' and domain).
+ // It is a temporary (2023-11-15) hack during development.
+ // We should probably let this be configured via grants.
+ UserName string `json:"username,omitempty"`
+}
+
+var (
+ openIDSupportedClaims = views.SliceOf([]string{
+ // Standard claims, these correspond to fields in jwt.Claims.
+ "sub", "aud", "exp", "iat", "iss", "jti", "nbf", "username", "email",
+
+ // Tailscale claims, these correspond to fields in tailscaleClaims.
+ "key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid",
+ })
+
+ // As defined in the OpenID spec this should be "openid".
+ openIDSupportedScopes = views.SliceOf([]string{"openid", "email", "profile"})
+
+ // We only support getting the id_token.
+ openIDSupportedReponseTypes = views.SliceOf([]string{"id_token", "code"})
+
+ // The type of the "sub" field in the JWT, which means it is globally unique identifier.
+ // The other option is "pairwise", which means the identifier is different per receiving 3p.
+ openIDSupportedSubjectTypes = views.SliceOf([]string{"public"})
+
+ // The algo used for signing. The OpenID spec says "The algorithm RS256 MUST be included."
+ // https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
+ openIDSupportedSigningAlgos = views.SliceOf([]string{string(jose.RS256)})
+)
+
+func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != oidcConfigPath {
+ http.Error(w, "tsidp: not found", http.StatusNotFound)
+ return
+ }
+ ap, err := netip.ParseAddrPort(r.RemoteAddr)
+ if err != nil {
+ log.Printf("Error parsing remote addr: %v", err)
+ return
+ }
+ var authorizeEndpoint string
+ rpEndpoint := s.serverURL
+ if who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr); err == nil {
+ authorizeEndpoint = fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID)
+ } else if ap.Addr().IsLoopback() {
+ rpEndpoint = s.loopbackURL
+ authorizeEndpoint = fmt.Sprintf("%s/authorize/localhost", s.serverURL)
+ } else {
+ log.Printf("Error getting WhoIs: %v", err)
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ je := json.NewEncoder(w)
+ je.SetIndent("", " ")
+ if err := je.Encode(openIDProviderMetadata{
+ AuthorizationEndpoint: authorizeEndpoint,
+ Issuer: rpEndpoint,
+ JWKS_URI: rpEndpoint + oidcJWKSPath,
+ UserInfoEndpoint: rpEndpoint + "/userinfo",
+ TokenEndpoint: rpEndpoint + "/token",
+ ScopesSupported: openIDSupportedScopes,
+ ResponseTypesSupported: openIDSupportedReponseTypes,
+ SubjectTypesSupported: openIDSupportedSubjectTypes,
+ ClaimsSupported: openIDSupportedClaims,
+ IDTokenSigningAlgValuesSupported: openIDSupportedSigningAlgos,
+ }); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+}
+
+const (
+ minimumRSAKeySize = 2048
+)
+
+// mustGenRSAKey generates a new RSA key with the provided number of bits. It
+// panics on failure. bits must be at least minimumRSAKeySizeBytes * 8.
+func mustGenRSAKey(bits int) (kid uint64, k *rsa.PrivateKey) {
+ if bits < minimumRSAKeySize {
+ panic("request to generate a too-small RSA key")
+ }
+ kid = must.Get(readUint64(crand.Reader))
+ k = must.Get(rsa.GenerateKey(crand.Reader, bits))
+ return
+}
+
+// readUint64 reads from r until 8 bytes represent a non-zero uint64.
+func readUint64(r io.Reader) (uint64, error) {
+ for {
+ var b [8]byte
+ if _, err := io.ReadFull(r, b[:]); err != nil {
+ return 0, err
+ }
+ if v := binary.BigEndian.Uint64(b[:]); v != 0 {
+ return v, nil
+ }
+ }
+}
+
+// rsaPrivateKeyJSONWrapper is the the JSON serialization
+// format used by RSAPrivateKey.
+type rsaPrivateKeyJSONWrapper struct {
+ Key string
+ ID uint64
+}
+
+type signingKey struct {
+ k *rsa.PrivateKey
+ kid uint64
+}
+
+func (sk *signingKey) MarshalJSON() ([]byte, error) {
+ b := pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: x509.MarshalPKCS1PrivateKey(sk.k),
+ }
+ bts := pem.EncodeToMemory(&b)
+ return json.Marshal(rsaPrivateKeyJSONWrapper{
+ Key: base64.URLEncoding.EncodeToString(bts),
+ ID: sk.kid,
+ })
+}
+
+func (sk *signingKey) UnmarshalJSON(b []byte) error {
+ var wrapper rsaPrivateKeyJSONWrapper
+ if err := json.Unmarshal(b, &wrapper); err != nil {
+ return err
+ }
+ if len(wrapper.Key) == 0 {
+ return nil
+ }
+ b64dec, err := base64.URLEncoding.DecodeString(wrapper.Key)
+ if err != nil {
+ return err
+ }
+ blk, _ := pem.Decode(b64dec)
+ k, err := x509.ParsePKCS1PrivateKey(blk.Bytes)
+ if err != nil {
+ return err
+ }
+ sk.k = k
+ sk.kid = wrapper.ID
+ return nil
+}
+
+// parseID takes a string input and returns a typed IntID T and true, or a zero
+// value and false if the input is unhandled syntax or out of a valid range.
+func parseID[T ~int64](input string) (_ T, ok bool) {
+ if input == "" {
+ return 0, false
+ }
+ i, err := strconv.ParseInt(input, 10, 64)
+ if err != nil {
+ return 0, false
+ }
+ if i < 0 {
+ return 0, false
+ }
+ return T(i), true
+}
diff --git a/go.mod b/go.mod
index cafe4eb24..1108df7b9 100644
--- a/go.mod
+++ b/go.mod
@@ -89,6 +89,7 @@ require (
golang.org/x/tools v0.13.0
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
golang.zx2c4.com/wireguard/windows v0.5.3
+ gopkg.in/square/go-jose.v2 v2.6.0
gvisor.dev/gvisor v0.0.0-20230928000133-4fe30062272c
honnef.co/go/tools v0.4.6
inet.af/peercred v0.0.0-20210906144145-0893ea02156a
diff --git a/go.sum b/go.sum
index b1b664559..fc025873f 100644
--- a/go.sum
+++ b/go.sum
@@ -1406,6 +1406,8 @@ gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
+gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI=
+gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME=
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg=