diff --git a/derp/derp_server.go b/derp/derp_server.go index 87b8aa549..53d0bc262 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -32,6 +32,7 @@ type Server struct { logf logger.Logf mu sync.Mutex + closed bool netConns map[net.Conn]chan struct{} // chan is closed when conn closes clients map[key.Public]*sclient } @@ -51,6 +52,14 @@ func NewServer(privateKey key.Private, logf logger.Logf) *Server { // Close closes the server and waits for the connections to disconnect. func (s *Server) Close() error { + s.mu.Lock() + wasClosed := s.closed + s.closed = true + s.mu.Unlock() + if wasClosed { + return nil + } + var closedChs []chan struct{} s.mu.Lock() @@ -67,6 +76,12 @@ func (s *Server) Close() error { return nil } +func (s *Server) isClosed() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.closed +} + // Accept adds a new connection to the server. // The provided bufio ReadWriter must be already connected to nc. // Accept blocks until the Server is closed or the connection closes @@ -87,7 +102,7 @@ func (s *Server) Accept(nc net.Conn, brw *bufio.ReadWriter) { s.mu.Unlock() }() - if err := s.accept(nc, brw); err != nil { + if err := s.accept(nc, brw); err != nil && !s.isClosed() { s.logf("derp: %s: %v", nc.RemoteAddr(), err) } } @@ -136,10 +151,6 @@ func (s *Server) accept(nc net.Conn, brw *bufio.ReadWriter) error { // At this point we trust the client so we don't time out. nc.SetDeadline(time.Time{}) - if err := s.sendServerInfo(bw, clientKey); err != nil { - return fmt.Errorf("send server info: %v", err) - } - c := &sclient{ key: clientKey, nc: nc, @@ -150,7 +161,18 @@ func (s *Server) accept(nc net.Conn, brw *bufio.ReadWriter) error { c.info = *clientInfo } + // Once the client is registered, it can start receiving + // traffic, but we want to make sure the first thing it + // receives after its frameClientInfo is our frameServerInfo, + // so acquire the c.mu lock (which guards writing to c.bw) + // while we register. + c.mu.Lock() s.registerClient(c) + err = s.sendServerInfo(bw, clientKey) + c.mu.Unlock() + if err != nil { + return fmt.Errorf("send server info: %v", err) + } defer s.unregisterClient(c) ctx, cancel := context.WithCancel(context.Background()) diff --git a/derp/derp_test.go b/derp/derp_test.go index 0b39abb0d..6be3b4bfe 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -14,67 +14,71 @@ import ( "tailscale.com/types/key" ) -func TestSendRecv(t *testing.T) { - const numClients = 3 - var serverPrivateKey key.Private - if _, err := crand.Read(serverPrivateKey[:]); err != nil { +func newPrivateKey(t *testing.T) (k key.Private) { + if _, err := crand.Read(k[:]); err != nil { t.Fatal(err) } + return +} + +func TestSendRecv(t *testing.T) { + serverPrivateKey := newPrivateKey(t) + s := NewServer(serverPrivateKey, t.Logf) + defer s.Close() + + const numClients = 3 var clientPrivateKeys []key.Private - for i := 0; i < numClients; i++ { - var key key.Private - if _, err := crand.Read(key[:]); err != nil { - t.Fatal(err) - } - clientPrivateKeys = append(clientPrivateKeys, key) - } var clientKeys []key.Public - for _, privKey := range clientPrivateKeys { - clientKeys = append(clientKeys, privKey.Public()) + for i := 0; i < numClients; i++ { + priv := newPrivateKey(t) + clientPrivateKeys = append(clientPrivateKeys, priv) + clientKeys = append(clientKeys, priv.Public()) } ln, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } + defer ln.Close() var clientConns []net.Conn - for i := 0; i < numClients; i++ { - conn, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - clientConns = append(clientConns, conn) - } - s := NewServer(serverPrivateKey, t.Logf) - defer s.Close() - for i := 0; i < numClients; i++ { - netConn, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - conn := bufio.NewReadWriter(bufio.NewReader(netConn), bufio.NewWriter(netConn)) - go s.Accept(netConn, conn) - } - var clients []*Client var recvChs []chan []byte errCh := make(chan error, 3) + for i := 0; i < numClients; i++ { + t.Logf("Connecting client %d ...", i) + cout, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer cout.Close() + clientConns = append(clientConns, cout) + + cin, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + defer cin.Close() + go s.Accept(cin, bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin))) + key := clientPrivateKeys[i] - netConn := clientConns[i] - conn := bufio.NewReadWriter(bufio.NewReader(netConn), bufio.NewWriter(netConn)) - c, err := NewClient(key, netConn, conn, t.Logf) + brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout)) + c, err := NewClient(key, cout, brw, t.Logf) if err != nil { t.Fatalf("client %d: %v", i, err) } clients = append(clients, c) recvChs = append(recvChs, make(chan []byte)) + t.Logf("Connected client %d.", i) + } + t.Logf("Starting read loops") + for i := 0; i < numClients; i++ { go func(i int) { for { b := make([]byte, 1<<16) - n, err := c.Recv(b) + n, err := clients[i].Recv(b) if err != nil { errCh <- err return @@ -120,4 +124,7 @@ func TestSendRecv(t *testing.T) { recv(2, string(msg2)) recvNothing(0) recvNothing(1) + + t.Logf("passed") + s.Close() }