diff --git a/derp/derp.go b/derp/derp.go index 9e537ba53..f3f73c455 100644 --- a/derp/derp.go +++ b/derp/derp.go @@ -209,13 +209,6 @@ func writeFrame(bw *bufio.Writer, t frameType, b []byte) error { return bw.Flush() } -func minInt(a, b int) int { - if a < b { - return a - } - return b -} - func minUint32(a, b uint32) uint32 { if a < b { return a diff --git a/derp/derp_client.go b/derp/derp_client.go index d4cbb583d..7291efbec 100644 --- a/derp/derp_client.go +++ b/derp/derp_client.go @@ -21,14 +21,13 @@ import ( // Client is a DERP client. type Client struct { - serverKey key.Public // of the DERP server; not a machine or node key - privateKey key.Private - publicKey key.Public // of privateKey - protoVersion int // min of server+client - logf logger.Logf - nc Conn - br *bufio.Reader - meshKey string + serverKey key.Public // of the DERP server; not a machine or node key + privateKey key.Private + publicKey key.Public // of privateKey + logf logger.Logf + nc Conn + br *bufio.Reader + meshKey string wmu sync.Mutex // hold while writing to bw bw *bufio.Writer @@ -85,11 +84,6 @@ func newClient(privateKey key.Private, nc Conn, brw *bufio.ReadWriter, logf logg if err := c.sendClientKey(); err != nil { return nil, fmt.Errorf("derp.Client: failed to send client key: %v", err) } - info, err := c.recvServerInfo() - if err != nil { - return nil, fmt.Errorf("derp.Client: failed to receive server info: %v", err) - } - c.protoVersion = minInt(protocolVersion, info.Version) return c, nil } @@ -110,12 +104,9 @@ func (c *Client) recvServerKey() error { return nil } -func (c *Client) recvServerInfo() (*serverInfo, error) { - fl, err := readFrameTypeHeader(c.br, frameServerInfo) - if err != nil { - return nil, err - } +func (c *Client) parseServerInfo(b []byte) (*serverInfo, error) { const maxLength = nonceLen + maxInfoLen + fl := len(b) if fl < nonceLen { return nil, fmt.Errorf("short serverInfo frame") } @@ -124,21 +115,15 @@ func (c *Client) recvServerInfo() (*serverInfo, error) { } // TODO: add a read-nonce-and-box helper var nonce [nonceLen]byte - if _, err := io.ReadFull(c.br, nonce[:]); err != nil { - return nil, fmt.Errorf("nonce: %v", err) - } - msgLen := fl - nonceLen - msgbox := make([]byte, msgLen) - if _, err := io.ReadFull(c.br, msgbox); err != nil { - return nil, fmt.Errorf("msgbox: %v", err) - } + copy(nonce[:], b) + msgbox := b[nonceLen:] msg, ok := box.Open(nil, msgbox, &nonce, c.serverKey.B32(), c.privateKey.B32()) if !ok { - return nil, fmt.Errorf("msgbox: cannot open len=%d with server key %x", msgLen, c.serverKey[:]) + return nil, fmt.Errorf("failed to open naclbox from server key %x", c.serverKey[:]) } info := new(serverInfo) if err := json.Unmarshal(msg, info); err != nil { - return nil, fmt.Errorf("msg: %v", err) + return nil, fmt.Errorf("invalid JSON: %v", err) } return info, nil } @@ -318,6 +303,11 @@ type PeerPresentMessage key.Public func (PeerPresentMessage) msg() {} +// ServerInfoMessage is sent by the server upon first connect. +type ServerInfoMessage struct{} + +func (ServerInfoMessage) msg() {} + // Recv reads a message from the DERP server. // // The returned message may alias memory owned by the Client; it @@ -364,7 +354,7 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro // If the frame fits in our bufio.Reader buffer, just use it. // In practice it's 4KB (from derphttp.Client's bufio.NewReader(httpConn)) and // in practive, WireGuard packets (and thus DERP frames) are under 1.5KB. - // So This is the common path. + // So this is the common path. if int(n) <= c.br.Size() { b, err = c.br.Peek(int(n)) c.peeked = int(n) @@ -382,6 +372,19 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro switch t { default: continue + case frameServerInfo: + // Server sends this at start-up. Currently unused. + // Just has a JSON message saying "version: 2", + // but the protocol seems extensible enough as-is without + // needing to wait an RTT to discover the version at startup. + // We'd prefer to give the connection to the client (magicsock) + // to start writing as soon as possible. + _, err := c.parseServerInfo(b) + if err != nil { + return nil, fmt.Errorf("invalid server info frame: %v", err) + } + // TODO: add the results of parseServerInfo to ServerInfoMessage if we ever need it. + return ServerInfoMessage{}, nil case frameKeepAlive: // TODO: eventually we'll have server->client pings that // require ack pongs. @@ -406,16 +409,12 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro case frameRecvPacket: var rp ReceivedPacket - if c.protoVersion < protocolSrcAddrs { - rp.Data = b[:n] - } else { - if n < keyLen { - c.logf("[unexpected] dropping short packet from DERP server") - continue - } - copy(rp.Source[:], b[:keyLen]) - rp.Data = b[keyLen:n] + if n < keyLen { + c.logf("[unexpected] dropping short packet from DERP server") + continue } + copy(rp.Source[:], b[:keyLen]) + rp.Data = b[keyLen:n] return rp, nil } } diff --git a/derp/derp_test.go b/derp/derp_test.go index 203c5138b..1ca32128d 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -80,6 +80,8 @@ func TestSendRecv(t *testing.T) { if err != nil { t.Fatalf("client %d: %v", i, err) } + waitConnect(t, c) + clients = append(clients, c) recvChs = append(recvChs, make(chan []byte)) t.Logf("Connected client %d.", i) @@ -119,7 +121,7 @@ func TestSendRecv(t *testing.T) { if got := string(b); got != want { t.Errorf("client1.Recv=%q, want %q", got, want) } - case <-time.After(1 * time.Second): + case <-time.After(5 * time.Second): t.Errorf("client%d.Recv, got nothing, want %q", i, want) } } @@ -225,6 +227,7 @@ func TestSendFreeze(t *testing.T) { if err != nil { t.Fatal(err) } + waitConnect(t, c) return c, c2 } @@ -503,7 +506,13 @@ func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net func newRegularClient(t *testing.T, ts *testServer, name string) *testClient { return newTestClient(t, ts, name, func(nc net.Conn, priv key.Private, logf logger.Logf) (*Client, error) { brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)) - return NewClient(priv, nc, brw, logf) + c, err := NewClient(priv, nc, brw, logf) + if err != nil { + return nil, err + } + waitConnect(t, c) + return c, nil + }) } @@ -514,6 +523,7 @@ func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient { if err != nil { return nil, err } + waitConnect(t, c) if err := c.WatchConnectionChanges(); err != nil { return nil, err } @@ -834,3 +844,12 @@ func BenchmarkReadUint32(b *testing.B) { } } } + +func waitConnect(t testing.TB, c *Client) { + t.Helper() + if m, err := c.Recv(); err != nil { + t.Fatalf("client first Recv: %v", err) + } else if v, ok := m.(ServerInfoMessage); !ok { + t.Fatalf("client first Recv was unexpected type %T", v) + } +} diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index 3356c890c..d5290b8a9 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -6,7 +6,6 @@ package derphttp import ( "context" - crand "crypto/rand" "crypto/tls" "net" "net/http" @@ -19,22 +18,15 @@ import ( ) func TestSendRecv(t *testing.T) { + serverPrivateKey := key.NewPrivate() + const numClients = 3 - var serverPrivateKey key.Private - if _, err := crand.Read(serverPrivateKey[:]); err != nil { - t.Fatal(err) - } var clientPrivateKeys []key.Private - for i := 0; i < numClients; i++ { - var key key.Private - if _, err := crand.Read(key[:]); err != nil { - t.Fatal(err) - } - clientPrivateKeys = append(clientPrivateKeys, key) - } var clientKeys []key.Public - for _, privKey := range clientPrivateKeys { - clientKeys = append(clientKeys, privKey.Public()) + for i := 0; i < numClients; i++ { + priv := key.NewPrivate() + clientPrivateKeys = append(clientPrivateKeys, priv) + clientKeys = append(clientKeys, priv.Public()) } s := derp.NewServer(serverPrivateKey, t.Logf) @@ -81,6 +73,7 @@ func TestSendRecv(t *testing.T) { if err := c.Connect(context.Background()); err != nil { t.Fatalf("client %d Connect: %v", i, err) } + waitConnect(t, c) clients = append(clients, c) recvChs = append(recvChs, make(chan []byte)) @@ -95,6 +88,11 @@ func TestSendRecv(t *testing.T) { } m, err := c.Recv() if err != nil { + select { + case <-done: + return + default: + } t.Logf("client%d: %v", i, err) break } @@ -118,7 +116,7 @@ func TestSendRecv(t *testing.T) { if got := string(b); got != want { t.Errorf("client1.Recv=%q, want %q", got, want) } - case <-time.After(1 * time.Second): + case <-time.After(5 * time.Second): t.Errorf("client%d.Recv, got nothing, want %q", i, want) } } @@ -146,5 +144,13 @@ func TestSendRecv(t *testing.T) { recv(2, string(msg2)) recvNothing(0) recvNothing(1) - +} + +func waitConnect(t testing.TB, c *Client) { + t.Helper() + if m, err := c.Recv(); err != nil { + t.Fatalf("client first Recv: %v", err) + } else if v, ok := m.(derp.ServerInfoMessage); !ok { + t.Fatalf("client first Recv was unexpected type %T", v) + } }