From 62445931b62d22a60e8bd006b2c716dd2a4931a3 Mon Sep 17 00:00:00 2001 From: Percy Wegmann Date: Tue, 22 Oct 2024 21:34:50 -0500 Subject: [PATCH] DERP JWT POC Signed-off-by: Percy Wegmann --- cmd/derper/derper.go | 2 +- derp/derp_client.go | 12 ++++++++--- derp/derp_server.go | 49 +++++++++++++++++++++++++++++++++++++++++++- derp/derp_test.go | 49 ++++++++++++++++++++++++++++++++------------ go.mod | 1 + go.sum | 2 ++ 6 files changed, 97 insertions(+), 18 deletions(-) diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 80c9dc44f..557fccf11 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -168,7 +168,7 @@ func main() { serveTLS := tsweb.IsProd443(*addr) || *certMode == "manual" - s := derp.NewServer(cfg.PrivateKey, log.Printf) + s := derp.NewServer(cfg.PrivateKey, nil, log.Printf) s.SetVerifyClient(*verifyClients) s.SetVerifyClientURL(*verifyClientURL) s.SetVerifyClientURLFailOpen(*verifyFailOpen) diff --git a/derp/derp_client.go b/derp/derp_client.go index 7a646fa51..d384e89c5 100644 --- a/derp/derp_client.go +++ b/derp/derp_client.go @@ -27,6 +27,7 @@ type Client struct { serverKey key.NodePublic // of the DERP server; not a machine or node key privateKey key.NodePrivate publicKey key.NodePublic // of privateKey + jwt string logf logger.Logf nc Conn br *bufio.Reader @@ -84,7 +85,7 @@ func CanAckPings(v bool) ClientOpt { return clientOptFunc(func(o *clientOpt) { o.CanAckPings = v }) } -func NewClient(privateKey key.NodePrivate, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opts ...ClientOpt) (*Client, error) { +func NewClient(privateKey key.NodePrivate, jwt string, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opts ...ClientOpt) (*Client, error) { var opt clientOpt for _, o := range opts { if o == nil { @@ -92,13 +93,14 @@ func NewClient(privateKey key.NodePrivate, nc Conn, brw *bufio.ReadWriter, logf } o.update(&opt) } - return newClient(privateKey, nc, brw, logf, opt) + return newClient(privateKey, jwt, nc, brw, logf, opt) } -func newClient(privateKey key.NodePrivate, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opt clientOpt) (*Client, error) { +func newClient(privateKey key.NodePrivate, jwt string, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opt clientOpt) (*Client, error) { c := &Client{ privateKey: privateKey, publicKey: privateKey.Public(), + jwt: jwt, logf: logf, nc: nc, br: brw.Reader, @@ -177,6 +179,9 @@ type clientInfo struct { // IsProber is whether this client is a prober. IsProber bool `json:",omitempty"` + + // JWT is a JSON web token with authorization grants for this client. + JWT string `json:",omitempty"` } func (c *Client) sendClientKey() error { @@ -185,6 +190,7 @@ func (c *Client) sendClientKey() error { MeshKey: c.meshKey, CanAckPings: c.canAckPings, IsProber: c.isProber, + JWT: c.jwt, }) if err != nil { return err diff --git a/derp/derp_server.go b/derp/derp_server.go index ab0ab0a90..cbb135120 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -35,6 +35,7 @@ "sync/atomic" "time" + "github.com/golang-jwt/jwt" "go4.org/mem" "golang.org/x/sync/errgroup" "tailscale.com/client/tailscale" @@ -114,6 +115,7 @@ type Server struct { privateKey key.NodePrivate publicKey key.NodePublic + jwtSigner ed25519.PublicKey logf logger.Logf memSys0 uint64 // runtime.MemStats.Sys at start (or early-ish) meshKey string @@ -342,7 +344,7 @@ type Conn interface { // NewServer returns a new DERP server. It doesn't listen on its own. // Connections are given to it via Server.Accept. -func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server { +func NewServer(privateKey key.NodePrivate, jwtSigner ed25519.PublicKey, logf logger.Logf) *Server { var ms runtime.MemStats runtime.ReadMemStats(&ms) @@ -350,6 +352,7 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server { debug: envknob.Bool("DERP_DEBUG_LOGS"), privateKey: privateKey, publicKey: privateKey.Public(), + jwtSigner: jwtSigner, logf: logf, limitedLogf: logger.RateLimitedFn(logf, 30*time.Second, 5, 100), packetsRecvByKind: metrics.LabelMap{Label: "kind"}, @@ -1450,9 +1453,53 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info if err := json.Unmarshal(msg, info); err != nil { return zpub, nil, fmt.Errorf("msg: %v", err) } + if info.JWT == "" { + fmt.Println("ZZZZ No JWT, maybe old client") + } else if err := s.authorizeJWT(info.JWT, clientKey); err != nil { + return clientKey, info, fmt.Errorf("failed to authorize JWT: %w", err) + } return clientKey, info, nil } +func (s *Server) authorizeJWT(tokenString string, clientKey key.NodePublic) error { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok { + return nil, fmt.Errorf("Unexpected signing method: %s", token.Header["alg"]) + } + return s.jwtSigner, nil + }) + if err != nil { + return fmt.Errorf("error verifying provided JWT: %w", err) + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return errors.New("invalid type of JWT claims provided") + } + _expires, ok := claims["expires"] + if !ok { + return errors.New("JWT missing expires") + } + expires, err := time.Parse(time.RFC3339, _expires.(string)) + if err != nil { + return fmt.Errorf("failed to parse expires: %w", err) + } + if expires.Before(time.Now()) { + return errors.New("JWT expired") + } + pkHex, ok := claims["publicKeyHex"] + if !ok { + return errors.New("JWT missing publicKeyHex") + } + var clientKeyFromJWT key.NodePublic + if err := clientKeyFromJWT.UnmarshalText([]byte(pkHex.(string))); err != nil { + return fmt.Errorf("Failed to unmarshal publicKeyHex: %w", err) + } + if clientKey != clientKeyFromJWT { + return fmt.Errorf("client key in JWT does not match client's key") + } + return nil +} + func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodePublic, contents []byte, err error) { if frameLen < keyLen { return zpub, nil, errors.New("short send packet frame") diff --git a/derp/derp_test.go b/derp/derp_test.go index 9185194dd..a84ce583d 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -7,6 +7,8 @@ "bufio" "bytes" "context" + "crypto/ed25519" + "crypto/rand" "crypto/x509" "encoding/asn1" "encoding/json" @@ -23,6 +25,7 @@ "testing" "time" + "github.com/golang-jwt/jwt" "go4.org/mem" "golang.org/x/time/rate" "tailscale.com/disco" @@ -49,17 +52,37 @@ func TestClientInfoUnmarshal(t *testing.T) { } func TestSendRecv(t *testing.T) { + signerPub, signerPriv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + serverPrivateKey := key.NewNode() - s := NewServer(serverPrivateKey, t.Logf) + s := NewServer(serverPrivateKey, signerPub, t.Logf) defer s.Close() const numClients = 3 var clientPrivateKeys []key.NodePrivate var clientKeys []key.NodePublic + var clientJWTs []string for range numClients { priv := key.NewNode() clientPrivateKeys = append(clientPrivateKeys, priv) clientKeys = append(clientKeys, priv.Public()) + pkHex, err := priv.Public().MarshalText() + if err != nil { + t.Fatal(err) + } + // The below would typically be done by the control server + jt := jwt.NewWithClaims(jwt.SigningMethodEdDSA, jwt.MapClaims{ + "publicKeyHex": string(pkHex), + "expires": time.Now().Add(1 * time.Hour).Format(time.RFC3339), + }) + sjt, err := jt.SignedString(signerPriv) + if err != nil { + t.Fatal(err) + } + clientJWTs = append(clientJWTs, sjt) } ln, err := net.Listen("tcp", "127.0.0.1:0") @@ -96,7 +119,7 @@ func TestSendRecv(t *testing.T) { key := clientPrivateKeys[i] brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout)) - c, err := NewClient(key, cout, brw, t.Logf) + c, err := NewClient(key, clientJWTs[i], cout, brw, t.Logf) if err != nil { t.Fatalf("client %d: %v", i, err) } @@ -269,7 +292,7 @@ func TestSendRecv(t *testing.T) { func TestSendFreeze(t *testing.T) { serverPrivateKey := key.NewNode() - s := NewServer(serverPrivateKey, t.Logf) + s := NewServer(serverPrivateKey, nil, t.Logf) defer s.Close() s.WriteTimeout = 100 * time.Millisecond @@ -287,7 +310,7 @@ func TestSendFreeze(t *testing.T) { go s.Accept(ctx, c1, bufio.NewReadWriter(bufio.NewReader(c1), bufio.NewWriter(c1)), name) brw := bufio.NewReadWriter(bufio.NewReader(c2), bufio.NewWriter(c2)) - c, err := NewClient(k, c2, brw, t.Logf) + c, err := NewClient(k, "", c2, brw, t.Logf) if err != nil { t.Fatal(err) } @@ -511,7 +534,7 @@ func (ts *testServer) close(t *testing.T) error { func newTestServer(t *testing.T, ctx context.Context) *testServer { t.Helper() logf := logger.WithPrefix(t.Logf, "derp-server: ") - s := NewServer(key.NewNode(), logf) + s := NewServer(key.NewNode(), nil, logf) s.SetMeshKey("mesh-key") ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -576,7 +599,7 @@ func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net func newRegularClient(t *testing.T, ts *testServer, name string) *testClient { return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) { brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)) - c, err := NewClient(priv, nc, brw, logf) + c, err := NewClient(priv, "", nc, brw, logf) if err != nil { return nil, err } @@ -589,7 +612,7 @@ func newRegularClient(t *testing.T, ts *testServer, name string) *testClient { func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient { return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) { brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)) - c, err := NewClient(priv, nc, brw, logf, MeshKey("mesh-key")) + c, err := NewClient(priv, "", nc, brw, logf, MeshKey("mesh-key")) if err != nil { return nil, err } @@ -918,7 +941,7 @@ func TestMultiForwarder(t *testing.T) { func TestMetaCert(t *testing.T) { priv := key.NewNode() pub := priv.Public() - s := NewServer(priv, t.Logf) + s := NewServer(priv, nil, t.Logf) certBytes := s.MetaCert() cert, err := x509.ParseCertificate(certBytes) @@ -1065,7 +1088,7 @@ func TestServerDupClients(t *testing.T) { // run starts a new test case and resets clients back to their zero values. run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) { - s = NewServer(serverPriv, t.Logf) + s = NewServer(serverPriv, nil, t.Logf) s.dupPolicy = dupPolicy c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")} c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")} @@ -1315,7 +1338,7 @@ func TestLimiter(t *testing.T) { // single Server instance with multiple concurrent client flows. func BenchmarkConcurrentStreams(b *testing.B) { serverPrivateKey := key.NewNode() - s := NewServer(serverPrivateKey, logger.Discard) + s := NewServer(serverPrivateKey, nil, logger.Discard) defer s.Close() ln, err := net.Listen("tcp", "127.0.0.1:0") @@ -1354,7 +1377,7 @@ func BenchmarkConcurrentStreams(b *testing.B) { k := key.NewNode() brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) - client, err := NewClient(k, connOut, brw, logger.Discard) + client, err := NewClient(k, "", connOut, brw, logger.Discard) if err != nil { b.Fatalf("client: %v", err) } @@ -1385,7 +1408,7 @@ func BenchmarkSendRecv(b *testing.B) { func benchmarkSendRecvSize(b *testing.B, packetSize int) { serverPrivateKey := key.NewNode() - s := NewServer(serverPrivateKey, logger.Discard) + s := NewServer(serverPrivateKey, nil, logger.Discard) defer s.Close() k := key.NewNode() @@ -1416,7 +1439,7 @@ func benchmarkSendRecvSize(b *testing.B, packetSize int) { go s.Accept(ctx, connIn, brwServer, "test-client") brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) - client, err := NewClient(k, connOut, brw, logger.Discard) + client, err := NewClient(k, "", connOut, brw, logger.Discard) if err != nil { b.Fatalf("client: %v", err) } diff --git a/go.mod b/go.mod index 464db8313..5ee11060b 100644 --- a/go.mod +++ b/go.mod @@ -146,6 +146,7 @@ require ( github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 // indirect github.com/gobuffalo/flect v1.0.2 // indirect github.com/goccy/go-yaml v1.12.0 // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golangci/plugin-module-register v0.1.1 // indirect github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 // indirect diff --git a/go.sum b/go.sum index 549f559d0..27f3f4a05 100644 --- a/go.sum +++ b/go.sum @@ -420,6 +420,8 @@ github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14j github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=