tailscale/derp/derp_test.go
David Crawshaw 72dbf26f63 derp: test that client a->b and a->c relaying do not interfere
Without the recent write deadline introduction, this test fails.

They still do interfere, but the interference is now bound by
the write deadline. Many improvements are possible.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
2020-03-12 14:42:48 -04:00

346 lines
7.9 KiB
Go

// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package derp
import (
"bufio"
"context"
crand "crypto/rand"
"errors"
"expvar"
"fmt"
"io"
"net"
"testing"
"time"
"tailscale.com/net/nettest"
"tailscale.com/types/key"
)
func newPrivateKey(t *testing.T) (k key.Private) {
t.Helper()
if _, err := crand.Read(k[:]); err != nil {
t.Fatal(err)
}
return
}
func TestSendRecv(t *testing.T) {
serverPrivateKey := newPrivateKey(t)
s := NewServer(serverPrivateKey, t.Logf)
defer s.Close()
const numClients = 3
var clientPrivateKeys []key.Private
var clientKeys []key.Public
for i := 0; i < numClients; i++ {
priv := newPrivateKey(t)
clientPrivateKeys = append(clientPrivateKeys, priv)
clientKeys = append(clientKeys, priv.Public())
}
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
var clients []*Client
var connsOut []Conn
var recvChs []chan []byte
errCh := make(chan error, 3)
for i := 0; i < numClients; i++ {
t.Logf("Connecting client %d ...", i)
cout, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer cout.Close()
connsOut = append(connsOut, cout)
cin, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
defer cin.Close()
brwServer := bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin))
go s.Accept(cin, brwServer, fmt.Sprintf("test-client-%d", i))
key := clientPrivateKeys[i]
brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout))
c, err := NewClient(key, cout, brw, t.Logf)
if err != nil {
t.Fatalf("client %d: %v", i, err)
}
clients = append(clients, c)
recvChs = append(recvChs, make(chan []byte))
t.Logf("Connected client %d.", i)
}
t.Logf("Starting read loops")
for i := 0; i < numClients; i++ {
go func(i int) {
for {
b := make([]byte, 1<<16)
m, err := clients[i].Recv(b)
if err != nil {
errCh <- err
return
}
switch m := m.(type) {
default:
t.Errorf("unexpected message type %T", m)
continue
case ReceivedPacket:
if m.Source.IsZero() {
t.Errorf("zero Source address in ReceivedPacket")
}
recvChs[i] <- m.Data
}
}
}(i)
}
recv := func(i int, want string) {
t.Helper()
select {
case b := <-recvChs[i]:
if got := string(b); got != want {
t.Errorf("client1.Recv=%q, want %q", got, want)
}
case <-time.After(1 * time.Second):
t.Errorf("client%d.Recv, got nothing, want %q", i, want)
}
}
recvNothing := func(i int) {
t.Helper()
select {
case b := <-recvChs[0]:
t.Errorf("client%d.Recv=%q, want nothing", i, string(b))
default:
}
}
wantActive := func(total, home int64) {
t.Helper()
dl := time.Now().Add(5 * time.Second)
var gotTotal, gotHome int64
for time.Now().Before(dl) {
gotTotal, gotHome = s.curClients.Value(), s.curHomeClients.Value()
if gotTotal == total && gotHome == home {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Errorf("total/home=%v/%v; want %v/%v", gotTotal, gotHome, total, home)
}
msg1 := []byte("hello 0->1\n")
if err := clients[0].Send(clientKeys[1], msg1); err != nil {
t.Fatal(err)
}
recv(1, string(msg1))
recvNothing(0)
recvNothing(2)
msg2 := []byte("hello 1->2\n")
if err := clients[1].Send(clientKeys[2], msg2); err != nil {
t.Fatal(err)
}
recv(2, string(msg2))
recvNothing(0)
recvNothing(1)
wantActive(3, 0)
clients[0].NotePreferred(true)
wantActive(3, 1)
clients[0].NotePreferred(true)
wantActive(3, 1)
clients[0].NotePreferred(false)
wantActive(3, 0)
clients[0].NotePreferred(false)
wantActive(3, 0)
clients[1].NotePreferred(true)
wantActive(3, 1)
connsOut[1].Close()
wantActive(2, 0)
clients[2].NotePreferred(true)
wantActive(2, 1)
clients[2].NotePreferred(false)
wantActive(2, 0)
connsOut[2].Close()
wantActive(1, 0)
t.Logf("passed")
s.Close()
}
func TestSendFreeze(t *testing.T) {
serverPrivateKey := newPrivateKey(t)
s := NewServer(serverPrivateKey, t.Logf)
defer s.Close()
s.WriteTimeout = 100 * time.Millisecond
// We send two streams of messages:
//
// alice --> bob
// alice --> cathy
//
// Then cathy stops processing messsages.
// That should not interfere with alice talking to bob.
newClient := func(name string, k key.Private) (c *Client, clientConn nettest.Conn) {
t.Helper()
c1, c2 := nettest.NewConn(name, 1024)
go s.Accept(c1, bufio.NewReadWriter(bufio.NewReader(c1), bufio.NewWriter(c1)), name)
brw := bufio.NewReadWriter(bufio.NewReader(c2), bufio.NewWriter(c2))
c, err := NewClient(k, c2, brw, t.Logf)
if err != nil {
t.Fatal(err)
}
return c, c2
}
aliceKey := newPrivateKey(t)
aliceClient, aliceConn := newClient("alice", aliceKey)
bobKey := newPrivateKey(t)
bobClient, bobConn := newClient("bob", bobKey)
cathyKey := newPrivateKey(t)
cathyClient, cathyConn := newClient("cathy", cathyKey)
var aliceCount, bobCount, cathyCount expvar.Int
errCh := make(chan error, 4)
recvAndCount := func(count *expvar.Int, name string, client *Client) {
for {
b := make([]byte, 1<<9)
m, err := client.Recv(b)
if err != nil {
errCh <- fmt.Errorf("%s: %w", name, err)
return
}
switch m := m.(type) {
default:
errCh <- fmt.Errorf("%s: unexpected message type %T", name, m)
return
case ReceivedPacket:
if m.Source.IsZero() {
errCh <- fmt.Errorf("%s: zero Source address in ReceivedPacket", name)
return
}
count.Add(1)
}
}
}
go recvAndCount(&aliceCount, "alice", aliceClient)
go recvAndCount(&bobCount, "bob", bobClient)
go recvAndCount(&cathyCount, "cathy", cathyClient)
var cancel func()
go func() {
t := time.NewTicker(2 * time.Millisecond)
defer t.Stop()
var ctx context.Context
ctx, cancel = context.WithCancel(context.Background())
for {
select {
case <-t.C:
case <-ctx.Done():
errCh <- nil
return
}
msg1 := []byte("hello alice->bob\n")
if err := aliceClient.Send(bobKey.Public(), msg1); err != nil {
errCh <- fmt.Errorf("alice send to bob: %w", err)
return
}
msg2 := []byte("hello alice->cathy\n")
// TODO: an error is expected here.
// We ignore it, maybe we should log it somehow?
aliceClient.Send(cathyKey.Public(), msg2)
}
}()
var countSnapshot [3]int64
loadCounts := func() (adiff, bdiff, cdiff int64) {
t.Helper()
atotal := aliceCount.Value()
btotal := bobCount.Value()
ctotal := cathyCount.Value()
adiff = atotal - countSnapshot[0]
bdiff = btotal - countSnapshot[1]
cdiff = ctotal - countSnapshot[2]
countSnapshot[0] = atotal
countSnapshot[1] = btotal
countSnapshot[2] = ctotal
t.Logf("count diffs: alice=%d, bob=%d, cathy=%d", adiff, bdiff, cdiff)
return adiff, bdiff, cdiff
}
t.Run("initial send", func(t *testing.T) {
time.Sleep(10 * time.Millisecond)
a, b, c := loadCounts()
if a != 0 {
t.Errorf("alice diff=%d, want 0", a)
}
if b == 0 {
t.Errorf("no bob diff, want positive value")
}
if c == 0 {
t.Errorf("no cathy diff, want positive value")
}
})
t.Run("block cathy", func(t *testing.T) {
// Block cathy. Now the cathyConn buffer will fill up quickly,
// and the derp server will back up.
cathyConn.SetReadBlock(true)
time.Sleep(2 * s.WriteTimeout)
a, b, _ := loadCounts()
if a != 0 {
t.Errorf("alice diff=%d, want 0", a)
}
if b == 0 {
t.Errorf("no bob diff, want positive value")
}
// Now wait a little longer, and ensure packets still flow to bob
time.Sleep(10 * time.Millisecond)
if _, b, _ := loadCounts(); b == 0 {
t.Errorf("connection alice->bob frozen by alice->cathy")
}
})
// Cleanup, make sure we process all errors.
t.Logf("TEST COMPLETE, cancelling sender")
cancel()
t.Logf("closing connections")
aliceConn.Close()
bobConn.Close()
cathyConn.Close()
for i := 0; i < cap(errCh); i++ {
err := <-errCh
if err != nil {
if errors.Is(err, io.EOF) {
continue
}
t.Error(err)
}
}
}