From 6b80bcf112d8cdc2321ac8c0e1ecf86538100c9c Mon Sep 17 00:00:00 2001
From: Brad Fitzpatrick <bradfitz@tailscale.com>
Date: Mon, 17 Aug 2020 16:14:07 -0700
Subject: [PATCH] derp: remove a client round-trip waiting on serverInfo

It just has a version number in it and it's not really needed.
Instead just return it as a normal Recv message type for those
that care (currently only tests).

Updates #150 (in that it shares the same goal: initial DERP latency)
Updates #199 (in that it removes some DERP versioning)
---
 derp/derp.go                   |  7 ----
 derp/derp_client.go            | 75 +++++++++++++++++-----------------
 derp/derp_test.go              | 23 ++++++++++-
 derp/derphttp/derphttp_test.go | 38 +++++++++--------
 4 files changed, 80 insertions(+), 63 deletions(-)

diff --git a/derp/derp.go b/derp/derp.go
index 9e537ba53..f3f73c455 100644
--- a/derp/derp.go
+++ b/derp/derp.go
@@ -209,13 +209,6 @@ func writeFrame(bw *bufio.Writer, t frameType, b []byte) error {
 	return bw.Flush()
 }
 
-func minInt(a, b int) int {
-	if a < b {
-		return a
-	}
-	return b
-}
-
 func minUint32(a, b uint32) uint32 {
 	if a < b {
 		return a
diff --git a/derp/derp_client.go b/derp/derp_client.go
index d4cbb583d..7291efbec 100644
--- a/derp/derp_client.go
+++ b/derp/derp_client.go
@@ -21,14 +21,13 @@ import (
 
 // Client is a DERP client.
 type Client struct {
-	serverKey    key.Public // of the DERP server; not a machine or node key
-	privateKey   key.Private
-	publicKey    key.Public // of privateKey
-	protoVersion int        // min of server+client
-	logf         logger.Logf
-	nc           Conn
-	br           *bufio.Reader
-	meshKey      string
+	serverKey  key.Public // of the DERP server; not a machine or node key
+	privateKey key.Private
+	publicKey  key.Public // of privateKey
+	logf       logger.Logf
+	nc         Conn
+	br         *bufio.Reader
+	meshKey    string
 
 	wmu sync.Mutex // hold while writing to bw
 	bw  *bufio.Writer
@@ -85,11 +84,6 @@ func newClient(privateKey key.Private, nc Conn, brw *bufio.ReadWriter, logf logg
 	if err := c.sendClientKey(); err != nil {
 		return nil, fmt.Errorf("derp.Client: failed to send client key: %v", err)
 	}
-	info, err := c.recvServerInfo()
-	if err != nil {
-		return nil, fmt.Errorf("derp.Client: failed to receive server info: %v", err)
-	}
-	c.protoVersion = minInt(protocolVersion, info.Version)
 	return c, nil
 }
 
@@ -110,12 +104,9 @@ func (c *Client) recvServerKey() error {
 	return nil
 }
 
-func (c *Client) recvServerInfo() (*serverInfo, error) {
-	fl, err := readFrameTypeHeader(c.br, frameServerInfo)
-	if err != nil {
-		return nil, err
-	}
+func (c *Client) parseServerInfo(b []byte) (*serverInfo, error) {
 	const maxLength = nonceLen + maxInfoLen
+	fl := len(b)
 	if fl < nonceLen {
 		return nil, fmt.Errorf("short serverInfo frame")
 	}
@@ -124,21 +115,15 @@ func (c *Client) recvServerInfo() (*serverInfo, error) {
 	}
 	// TODO: add a read-nonce-and-box helper
 	var nonce [nonceLen]byte
-	if _, err := io.ReadFull(c.br, nonce[:]); err != nil {
-		return nil, fmt.Errorf("nonce: %v", err)
-	}
-	msgLen := fl - nonceLen
-	msgbox := make([]byte, msgLen)
-	if _, err := io.ReadFull(c.br, msgbox); err != nil {
-		return nil, fmt.Errorf("msgbox: %v", err)
-	}
+	copy(nonce[:], b)
+	msgbox := b[nonceLen:]
 	msg, ok := box.Open(nil, msgbox, &nonce, c.serverKey.B32(), c.privateKey.B32())
 	if !ok {
-		return nil, fmt.Errorf("msgbox: cannot open len=%d with server key %x", msgLen, c.serverKey[:])
+		return nil, fmt.Errorf("failed to open naclbox from server key %x", c.serverKey[:])
 	}
 	info := new(serverInfo)
 	if err := json.Unmarshal(msg, info); err != nil {
-		return nil, fmt.Errorf("msg: %v", err)
+		return nil, fmt.Errorf("invalid JSON: %v", err)
 	}
 	return info, nil
 }
@@ -318,6 +303,11 @@ type PeerPresentMessage key.Public
 
 func (PeerPresentMessage) msg() {}
 
+// ServerInfoMessage is sent by the server upon first connect.
+type ServerInfoMessage struct{}
+
+func (ServerInfoMessage) msg() {}
+
 // Recv reads a message from the DERP server.
 //
 // The returned message may alias memory owned by the Client; it
@@ -364,7 +354,7 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro
 		// If the frame fits in our bufio.Reader buffer, just use it.
 		// In practice it's 4KB (from derphttp.Client's bufio.NewReader(httpConn)) and
 		// in practive, WireGuard packets (and thus DERP frames) are under 1.5KB.
-		// So This is the common path.
+		// So this is the common path.
 		if int(n) <= c.br.Size() {
 			b, err = c.br.Peek(int(n))
 			c.peeked = int(n)
@@ -382,6 +372,19 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro
 		switch t {
 		default:
 			continue
+		case frameServerInfo:
+			// Server sends this at start-up. Currently unused.
+			// Just has a JSON message saying "version: 2",
+			// but the protocol seems extensible enough as-is without
+			// needing to wait an RTT to discover the version at startup.
+			// We'd prefer to give the connection to the client (magicsock)
+			// to start writing as soon as possible.
+			_, err := c.parseServerInfo(b)
+			if err != nil {
+				return nil, fmt.Errorf("invalid server info frame: %v", err)
+			}
+			// TODO: add the results of parseServerInfo to ServerInfoMessage if we ever need it.
+			return ServerInfoMessage{}, nil
 		case frameKeepAlive:
 			// TODO: eventually we'll have server->client pings that
 			// require ack pongs.
@@ -406,16 +409,12 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro
 
 		case frameRecvPacket:
 			var rp ReceivedPacket
-			if c.protoVersion < protocolSrcAddrs {
-				rp.Data = b[:n]
-			} else {
-				if n < keyLen {
-					c.logf("[unexpected] dropping short packet from DERP server")
-					continue
-				}
-				copy(rp.Source[:], b[:keyLen])
-				rp.Data = b[keyLen:n]
+			if n < keyLen {
+				c.logf("[unexpected] dropping short packet from DERP server")
+				continue
 			}
+			copy(rp.Source[:], b[:keyLen])
+			rp.Data = b[keyLen:n]
 			return rp, nil
 		}
 	}
diff --git a/derp/derp_test.go b/derp/derp_test.go
index 203c5138b..1ca32128d 100644
--- a/derp/derp_test.go
+++ b/derp/derp_test.go
@@ -80,6 +80,8 @@ func TestSendRecv(t *testing.T) {
 		if err != nil {
 			t.Fatalf("client %d: %v", i, err)
 		}
+		waitConnect(t, c)
+
 		clients = append(clients, c)
 		recvChs = append(recvChs, make(chan []byte))
 		t.Logf("Connected client %d.", i)
@@ -119,7 +121,7 @@ func TestSendRecv(t *testing.T) {
 			if got := string(b); got != want {
 				t.Errorf("client1.Recv=%q, want %q", got, want)
 			}
-		case <-time.After(1 * time.Second):
+		case <-time.After(5 * time.Second):
 			t.Errorf("client%d.Recv, got nothing, want %q", i, want)
 		}
 	}
@@ -225,6 +227,7 @@ func TestSendFreeze(t *testing.T) {
 		if err != nil {
 			t.Fatal(err)
 		}
+		waitConnect(t, c)
 		return c, c2
 	}
 
@@ -503,7 +506,13 @@ func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net
 func newRegularClient(t *testing.T, ts *testServer, name string) *testClient {
 	return newTestClient(t, ts, name, func(nc net.Conn, priv key.Private, logf logger.Logf) (*Client, error) {
 		brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc))
-		return NewClient(priv, nc, brw, logf)
+		c, err := NewClient(priv, nc, brw, logf)
+		if err != nil {
+			return nil, err
+		}
+		waitConnect(t, c)
+		return c, nil
+
 	})
 }
 
@@ -514,6 +523,7 @@ func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient {
 		if err != nil {
 			return nil, err
 		}
+		waitConnect(t, c)
 		if err := c.WatchConnectionChanges(); err != nil {
 			return nil, err
 		}
@@ -834,3 +844,12 @@ func BenchmarkReadUint32(b *testing.B) {
 		}
 	}
 }
+
+func waitConnect(t testing.TB, c *Client) {
+	t.Helper()
+	if m, err := c.Recv(); err != nil {
+		t.Fatalf("client first Recv: %v", err)
+	} else if v, ok := m.(ServerInfoMessage); !ok {
+		t.Fatalf("client first Recv was unexpected type %T", v)
+	}
+}
diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go
index 3356c890c..d5290b8a9 100644
--- a/derp/derphttp/derphttp_test.go
+++ b/derp/derphttp/derphttp_test.go
@@ -6,7 +6,6 @@ package derphttp
 
 import (
 	"context"
-	crand "crypto/rand"
 	"crypto/tls"
 	"net"
 	"net/http"
@@ -19,22 +18,15 @@ import (
 )
 
 func TestSendRecv(t *testing.T) {
+	serverPrivateKey := key.NewPrivate()
+
 	const numClients = 3
-	var serverPrivateKey key.Private
-	if _, err := crand.Read(serverPrivateKey[:]); err != nil {
-		t.Fatal(err)
-	}
 	var clientPrivateKeys []key.Private
-	for i := 0; i < numClients; i++ {
-		var key key.Private
-		if _, err := crand.Read(key[:]); err != nil {
-			t.Fatal(err)
-		}
-		clientPrivateKeys = append(clientPrivateKeys, key)
-	}
 	var clientKeys []key.Public
-	for _, privKey := range clientPrivateKeys {
-		clientKeys = append(clientKeys, privKey.Public())
+	for i := 0; i < numClients; i++ {
+		priv := key.NewPrivate()
+		clientPrivateKeys = append(clientPrivateKeys, priv)
+		clientKeys = append(clientKeys, priv.Public())
 	}
 
 	s := derp.NewServer(serverPrivateKey, t.Logf)
@@ -81,6 +73,7 @@ func TestSendRecv(t *testing.T) {
 		if err := c.Connect(context.Background()); err != nil {
 			t.Fatalf("client %d Connect: %v", i, err)
 		}
+		waitConnect(t, c)
 		clients = append(clients, c)
 		recvChs = append(recvChs, make(chan []byte))
 
@@ -95,6 +88,11 @@ func TestSendRecv(t *testing.T) {
 				}
 				m, err := c.Recv()
 				if err != nil {
+					select {
+					case <-done:
+						return
+					default:
+					}
 					t.Logf("client%d: %v", i, err)
 					break
 				}
@@ -118,7 +116,7 @@ func TestSendRecv(t *testing.T) {
 			if got := string(b); got != want {
 				t.Errorf("client1.Recv=%q, want %q", got, want)
 			}
-		case <-time.After(1 * time.Second):
+		case <-time.After(5 * time.Second):
 			t.Errorf("client%d.Recv, got nothing, want %q", i, want)
 		}
 	}
@@ -146,5 +144,13 @@ func TestSendRecv(t *testing.T) {
 	recv(2, string(msg2))
 	recvNothing(0)
 	recvNothing(1)
-
+}
+
+func waitConnect(t testing.TB, c *Client) {
+	t.Helper()
+	if m, err := c.Recv(); err != nil {
+		t.Fatalf("client first Recv: %v", err)
+	} else if v, ok := m.(derp.ServerInfoMessage); !ok {
+		t.Fatalf("client first Recv was unexpected type %T", v)
+	}
 }