DERP JWT POC

Signed-off-by: Percy Wegmann <percy@tailscale.com>
This commit is contained in:
Percy Wegmann 2024-10-22 21:34:50 -05:00
parent b2665d9b89
commit 62445931b6
No known key found for this signature in database
GPG Key ID: 29D8CDEB4C13D48B
6 changed files with 97 additions and 18 deletions

View File

@ -168,7 +168,7 @@ func main() {
serveTLS := tsweb.IsProd443(*addr) || *certMode == "manual" serveTLS := tsweb.IsProd443(*addr) || *certMode == "manual"
s := derp.NewServer(cfg.PrivateKey, log.Printf) s := derp.NewServer(cfg.PrivateKey, nil, log.Printf)
s.SetVerifyClient(*verifyClients) s.SetVerifyClient(*verifyClients)
s.SetVerifyClientURL(*verifyClientURL) s.SetVerifyClientURL(*verifyClientURL)
s.SetVerifyClientURLFailOpen(*verifyFailOpen) s.SetVerifyClientURLFailOpen(*verifyFailOpen)

View File

@ -27,6 +27,7 @@ type Client struct {
serverKey key.NodePublic // of the DERP server; not a machine or node key serverKey key.NodePublic // of the DERP server; not a machine or node key
privateKey key.NodePrivate privateKey key.NodePrivate
publicKey key.NodePublic // of privateKey publicKey key.NodePublic // of privateKey
jwt string
logf logger.Logf logf logger.Logf
nc Conn nc Conn
br *bufio.Reader br *bufio.Reader
@ -84,7 +85,7 @@ func CanAckPings(v bool) ClientOpt {
return clientOptFunc(func(o *clientOpt) { o.CanAckPings = v }) return clientOptFunc(func(o *clientOpt) { o.CanAckPings = v })
} }
func NewClient(privateKey key.NodePrivate, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opts ...ClientOpt) (*Client, error) { func NewClient(privateKey key.NodePrivate, jwt string, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opts ...ClientOpt) (*Client, error) {
var opt clientOpt var opt clientOpt
for _, o := range opts { for _, o := range opts {
if o == nil { if o == nil {
@ -92,13 +93,14 @@ func NewClient(privateKey key.NodePrivate, nc Conn, brw *bufio.ReadWriter, logf
} }
o.update(&opt) o.update(&opt)
} }
return newClient(privateKey, nc, brw, logf, opt) return newClient(privateKey, jwt, nc, brw, logf, opt)
} }
func newClient(privateKey key.NodePrivate, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opt clientOpt) (*Client, error) { func newClient(privateKey key.NodePrivate, jwt string, nc Conn, brw *bufio.ReadWriter, logf logger.Logf, opt clientOpt) (*Client, error) {
c := &Client{ c := &Client{
privateKey: privateKey, privateKey: privateKey,
publicKey: privateKey.Public(), publicKey: privateKey.Public(),
jwt: jwt,
logf: logf, logf: logf,
nc: nc, nc: nc,
br: brw.Reader, br: brw.Reader,
@ -177,6 +179,9 @@ type clientInfo struct {
// IsProber is whether this client is a prober. // IsProber is whether this client is a prober.
IsProber bool `json:",omitempty"` IsProber bool `json:",omitempty"`
// JWT is a JSON web token with authorization grants for this client.
JWT string `json:",omitempty"`
} }
func (c *Client) sendClientKey() error { func (c *Client) sendClientKey() error {
@ -185,6 +190,7 @@ func (c *Client) sendClientKey() error {
MeshKey: c.meshKey, MeshKey: c.meshKey,
CanAckPings: c.canAckPings, CanAckPings: c.canAckPings,
IsProber: c.isProber, IsProber: c.isProber,
JWT: c.jwt,
}) })
if err != nil { if err != nil {
return err return err

View File

@ -35,6 +35,7 @@
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/golang-jwt/jwt"
"go4.org/mem" "go4.org/mem"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"tailscale.com/client/tailscale" "tailscale.com/client/tailscale"
@ -114,6 +115,7 @@ type Server struct {
privateKey key.NodePrivate privateKey key.NodePrivate
publicKey key.NodePublic publicKey key.NodePublic
jwtSigner ed25519.PublicKey
logf logger.Logf logf logger.Logf
memSys0 uint64 // runtime.MemStats.Sys at start (or early-ish) memSys0 uint64 // runtime.MemStats.Sys at start (or early-ish)
meshKey string meshKey string
@ -342,7 +344,7 @@ type Conn interface {
// NewServer returns a new DERP server. It doesn't listen on its own. // NewServer returns a new DERP server. It doesn't listen on its own.
// Connections are given to it via Server.Accept. // Connections are given to it via Server.Accept.
func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server { func NewServer(privateKey key.NodePrivate, jwtSigner ed25519.PublicKey, logf logger.Logf) *Server {
var ms runtime.MemStats var ms runtime.MemStats
runtime.ReadMemStats(&ms) runtime.ReadMemStats(&ms)
@ -350,6 +352,7 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server {
debug: envknob.Bool("DERP_DEBUG_LOGS"), debug: envknob.Bool("DERP_DEBUG_LOGS"),
privateKey: privateKey, privateKey: privateKey,
publicKey: privateKey.Public(), publicKey: privateKey.Public(),
jwtSigner: jwtSigner,
logf: logf, logf: logf,
limitedLogf: logger.RateLimitedFn(logf, 30*time.Second, 5, 100), limitedLogf: logger.RateLimitedFn(logf, 30*time.Second, 5, 100),
packetsRecvByKind: metrics.LabelMap{Label: "kind"}, packetsRecvByKind: metrics.LabelMap{Label: "kind"},
@ -1450,9 +1453,53 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info
if err := json.Unmarshal(msg, info); err != nil { if err := json.Unmarshal(msg, info); err != nil {
return zpub, nil, fmt.Errorf("msg: %v", err) return zpub, nil, fmt.Errorf("msg: %v", err)
} }
if info.JWT == "" {
fmt.Println("ZZZZ No JWT, maybe old client")
} else if err := s.authorizeJWT(info.JWT, clientKey); err != nil {
return clientKey, info, fmt.Errorf("failed to authorize JWT: %w", err)
}
return clientKey, info, nil return clientKey, info, nil
} }
func (s *Server) authorizeJWT(tokenString string, clientKey key.NodePublic) error {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok {
return nil, fmt.Errorf("Unexpected signing method: %s", token.Header["alg"])
}
return s.jwtSigner, nil
})
if err != nil {
return fmt.Errorf("error verifying provided JWT: %w", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return errors.New("invalid type of JWT claims provided")
}
_expires, ok := claims["expires"]
if !ok {
return errors.New("JWT missing expires")
}
expires, err := time.Parse(time.RFC3339, _expires.(string))
if err != nil {
return fmt.Errorf("failed to parse expires: %w", err)
}
if expires.Before(time.Now()) {
return errors.New("JWT expired")
}
pkHex, ok := claims["publicKeyHex"]
if !ok {
return errors.New("JWT missing publicKeyHex")
}
var clientKeyFromJWT key.NodePublic
if err := clientKeyFromJWT.UnmarshalText([]byte(pkHex.(string))); err != nil {
return fmt.Errorf("Failed to unmarshal publicKeyHex: %w", err)
}
if clientKey != clientKeyFromJWT {
return fmt.Errorf("client key in JWT does not match client's key")
}
return nil
}
func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodePublic, contents []byte, err error) { func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodePublic, contents []byte, err error) {
if frameLen < keyLen { if frameLen < keyLen {
return zpub, nil, errors.New("short send packet frame") return zpub, nil, errors.New("short send packet frame")

View File

@ -7,6 +7,8 @@
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"crypto/ed25519"
"crypto/rand"
"crypto/x509" "crypto/x509"
"encoding/asn1" "encoding/asn1"
"encoding/json" "encoding/json"
@ -23,6 +25,7 @@
"testing" "testing"
"time" "time"
"github.com/golang-jwt/jwt"
"go4.org/mem" "go4.org/mem"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"tailscale.com/disco" "tailscale.com/disco"
@ -49,17 +52,37 @@ func TestClientInfoUnmarshal(t *testing.T) {
} }
func TestSendRecv(t *testing.T) { func TestSendRecv(t *testing.T) {
signerPub, signerPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
serverPrivateKey := key.NewNode() serverPrivateKey := key.NewNode()
s := NewServer(serverPrivateKey, t.Logf) s := NewServer(serverPrivateKey, signerPub, t.Logf)
defer s.Close() defer s.Close()
const numClients = 3 const numClients = 3
var clientPrivateKeys []key.NodePrivate var clientPrivateKeys []key.NodePrivate
var clientKeys []key.NodePublic var clientKeys []key.NodePublic
var clientJWTs []string
for range numClients { for range numClients {
priv := key.NewNode() priv := key.NewNode()
clientPrivateKeys = append(clientPrivateKeys, priv) clientPrivateKeys = append(clientPrivateKeys, priv)
clientKeys = append(clientKeys, priv.Public()) clientKeys = append(clientKeys, priv.Public())
pkHex, err := priv.Public().MarshalText()
if err != nil {
t.Fatal(err)
}
// The below would typically be done by the control server
jt := jwt.NewWithClaims(jwt.SigningMethodEdDSA, jwt.MapClaims{
"publicKeyHex": string(pkHex),
"expires": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
})
sjt, err := jt.SignedString(signerPriv)
if err != nil {
t.Fatal(err)
}
clientJWTs = append(clientJWTs, sjt)
} }
ln, err := net.Listen("tcp", "127.0.0.1:0") ln, err := net.Listen("tcp", "127.0.0.1:0")
@ -96,7 +119,7 @@ func TestSendRecv(t *testing.T) {
key := clientPrivateKeys[i] key := clientPrivateKeys[i]
brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout)) brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout))
c, err := NewClient(key, cout, brw, t.Logf) c, err := NewClient(key, clientJWTs[i], cout, brw, t.Logf)
if err != nil { if err != nil {
t.Fatalf("client %d: %v", i, err) t.Fatalf("client %d: %v", i, err)
} }
@ -269,7 +292,7 @@ func TestSendRecv(t *testing.T) {
func TestSendFreeze(t *testing.T) { func TestSendFreeze(t *testing.T) {
serverPrivateKey := key.NewNode() serverPrivateKey := key.NewNode()
s := NewServer(serverPrivateKey, t.Logf) s := NewServer(serverPrivateKey, nil, t.Logf)
defer s.Close() defer s.Close()
s.WriteTimeout = 100 * time.Millisecond s.WriteTimeout = 100 * time.Millisecond
@ -287,7 +310,7 @@ func TestSendFreeze(t *testing.T) {
go s.Accept(ctx, c1, bufio.NewReadWriter(bufio.NewReader(c1), bufio.NewWriter(c1)), name) go s.Accept(ctx, c1, bufio.NewReadWriter(bufio.NewReader(c1), bufio.NewWriter(c1)), name)
brw := bufio.NewReadWriter(bufio.NewReader(c2), bufio.NewWriter(c2)) brw := bufio.NewReadWriter(bufio.NewReader(c2), bufio.NewWriter(c2))
c, err := NewClient(k, c2, brw, t.Logf) c, err := NewClient(k, "", c2, brw, t.Logf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -511,7 +534,7 @@ func (ts *testServer) close(t *testing.T) error {
func newTestServer(t *testing.T, ctx context.Context) *testServer { func newTestServer(t *testing.T, ctx context.Context) *testServer {
t.Helper() t.Helper()
logf := logger.WithPrefix(t.Logf, "derp-server: ") logf := logger.WithPrefix(t.Logf, "derp-server: ")
s := NewServer(key.NewNode(), logf) s := NewServer(key.NewNode(), nil, logf)
s.SetMeshKey("mesh-key") s.SetMeshKey("mesh-key")
ln, err := net.Listen("tcp", "127.0.0.1:0") ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
@ -576,7 +599,7 @@ func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net
func newRegularClient(t *testing.T, ts *testServer, name string) *testClient { func newRegularClient(t *testing.T, ts *testServer, name string) *testClient {
return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) { return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) {
brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)) brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc))
c, err := NewClient(priv, nc, brw, logf) c, err := NewClient(priv, "", nc, brw, logf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -589,7 +612,7 @@ func newRegularClient(t *testing.T, ts *testServer, name string) *testClient {
func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient { func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient {
return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) { return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) {
brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)) brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc))
c, err := NewClient(priv, nc, brw, logf, MeshKey("mesh-key")) c, err := NewClient(priv, "", nc, brw, logf, MeshKey("mesh-key"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -918,7 +941,7 @@ func TestMultiForwarder(t *testing.T) {
func TestMetaCert(t *testing.T) { func TestMetaCert(t *testing.T) {
priv := key.NewNode() priv := key.NewNode()
pub := priv.Public() pub := priv.Public()
s := NewServer(priv, t.Logf) s := NewServer(priv, nil, t.Logf)
certBytes := s.MetaCert() certBytes := s.MetaCert()
cert, err := x509.ParseCertificate(certBytes) cert, err := x509.ParseCertificate(certBytes)
@ -1065,7 +1088,7 @@ func TestServerDupClients(t *testing.T) {
// run starts a new test case and resets clients back to their zero values. // run starts a new test case and resets clients back to their zero values.
run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) { run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) {
s = NewServer(serverPriv, t.Logf) s = NewServer(serverPriv, nil, t.Logf)
s.dupPolicy = dupPolicy s.dupPolicy = dupPolicy
c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")} c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")}
c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")} c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")}
@ -1315,7 +1338,7 @@ func TestLimiter(t *testing.T) {
// single Server instance with multiple concurrent client flows. // single Server instance with multiple concurrent client flows.
func BenchmarkConcurrentStreams(b *testing.B) { func BenchmarkConcurrentStreams(b *testing.B) {
serverPrivateKey := key.NewNode() serverPrivateKey := key.NewNode()
s := NewServer(serverPrivateKey, logger.Discard) s := NewServer(serverPrivateKey, nil, logger.Discard)
defer s.Close() defer s.Close()
ln, err := net.Listen("tcp", "127.0.0.1:0") ln, err := net.Listen("tcp", "127.0.0.1:0")
@ -1354,7 +1377,7 @@ func BenchmarkConcurrentStreams(b *testing.B) {
k := key.NewNode() k := key.NewNode()
brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut))
client, err := NewClient(k, connOut, brw, logger.Discard) client, err := NewClient(k, "", connOut, brw, logger.Discard)
if err != nil { if err != nil {
b.Fatalf("client: %v", err) b.Fatalf("client: %v", err)
} }
@ -1385,7 +1408,7 @@ func BenchmarkSendRecv(b *testing.B) {
func benchmarkSendRecvSize(b *testing.B, packetSize int) { func benchmarkSendRecvSize(b *testing.B, packetSize int) {
serverPrivateKey := key.NewNode() serverPrivateKey := key.NewNode()
s := NewServer(serverPrivateKey, logger.Discard) s := NewServer(serverPrivateKey, nil, logger.Discard)
defer s.Close() defer s.Close()
k := key.NewNode() k := key.NewNode()
@ -1416,7 +1439,7 @@ func benchmarkSendRecvSize(b *testing.B, packetSize int) {
go s.Accept(ctx, connIn, brwServer, "test-client") go s.Accept(ctx, connIn, brwServer, "test-client")
brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut))
client, err := NewClient(k, connOut, brw, logger.Discard) client, err := NewClient(k, "", connOut, brw, logger.Discard)
if err != nil { if err != nil {
b.Fatalf("client: %v", err) b.Fatalf("client: %v", err)
} }

1
go.mod
View File

@ -146,6 +146,7 @@ require (
github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 // indirect github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 // indirect
github.com/gobuffalo/flect v1.0.2 // indirect github.com/gobuffalo/flect v1.0.2 // indirect
github.com/goccy/go-yaml v1.12.0 // indirect github.com/goccy/go-yaml v1.12.0 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golangci/plugin-module-register v0.1.1 // indirect github.com/golangci/plugin-module-register v0.1.1 // indirect
github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 // indirect github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 // indirect

2
go.sum
View File

@ -420,6 +420,8 @@ github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14j
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=