From 3669296cefe47df3f7c3cd88cb377b2b1210338a Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 24 Jul 2020 21:19:20 +0000 Subject: [PATCH] wgengine/magicsock: refactor twoDevicePing to make stack construction cleaner. Signed-off-by: David Anderson --- wgengine/magicsock/magicsock_test.go | 326 +++++++++++++-------------- 1 file changed, 163 insertions(+), 163 deletions(-) diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index dcad71745..306ce755a 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -60,6 +60,136 @@ func (c *Conn) WaitReady(t *testing.T) { } } +func runDERPAndStun(t *testing.T, logf logger.Logf, l nettype.PacketListener, stunIP netaddr.IP) (derpMap *tailcfg.DERPMap, cleanup func()) { + var serverPrivateKey key.Private + if _, err := crand.Read(serverPrivateKey[:]); err != nil { + t.Fatal(err) + } + d := derp.NewServer(serverPrivateKey, logf) + if l != (nettype.Std{}) { + // When using virtual networking, only allow DERP to forward + // discovery traffic, not actual packets. + d.OnlyDisco = true + } + + httpsrv := httptest.NewUnstartedServer(derphttp.Handler(d)) + httpsrv.Config.ErrorLog = logger.StdLogger(logf) + httpsrv.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) + httpsrv.StartTLS() + + stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, l) + + m := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: &tailcfg.DERPRegion{ + RegionID: 1, + RegionCode: "test", + Nodes: []*tailcfg.DERPNode{ + { + Name: "t1", + RegionID: 1, + HostName: "test-node.unused", + IPv4: "127.0.0.1", + IPv6: "none", + STUNPort: stunAddr.Port, + DERPTestPort: httpsrv.Listener.Addr().(*net.TCPAddr).Port, + STUNTestIP: stunIP.String(), + }, + }, + }, + }, + } + + cleanup = func() { + httpsrv.CloseClientConnections() + httpsrv.Close() + d.Close() + stunCleanup() + } + + return m, cleanup +} + +// magicStack is a magicsock, plus all the stuff around it that's +// necessary to send and receive packets to test e2e wireguard +// happiness. +type magicStack struct { + privateKey wgcfg.PrivateKey + epCh chan []string // endpoint updates produced by this peer + conn *Conn // the magicsock itself + tun *tuntest.ChannelTUN // tuntap device to send/receive packets + tsTun *tstun.TUN // wrapped tun that implements filtering and wgengine hooks + dev *device.Device // the wireguard-go Device that connects the previous things +} + +// newMagicStack builds and initializes an idle magicsock and +// friends. You need to call conn.SetNetworkMap and dev.Reconfig +// before anything interesting happens. +func newMagicStack(t *testing.T, logf logger.Logf, l nettype.PacketListener, derpMap *tailcfg.DERPMap) *magicStack { + t.Helper() + + privateKey, err := wgcfg.NewPrivateKey() + if err != nil { + t.Fatalf("generating private key: %v", err) + } + + epCh := make(chan []string, 100) // arbitrary + conn, err := NewConn(Options{ + Logf: logf, + PacketListener: l, + EndpointsFunc: func(eps []string) { + epCh <- eps + }, + }) + if err != nil { + t.Fatalf("constructing magicsock: %v", err) + } + conn.Start() + conn.SetDERPMap(derpMap) + if err := conn.SetPrivateKey(privateKey); err != nil { + t.Fatalf("setting private key in magicsock: %v", err) + } + + tun := tuntest.NewChannelTUN() + tsTun := tstun.WrapTUN(logf, tun.TUN()) + tsTun.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf)) + + dev := device.NewDevice(tsTun, &device.DeviceOptions{ + Logger: &device.Logger{ + Debug: logger.StdLogger(logf), + Info: logger.StdLogger(logf), + Error: logger.StdLogger(logf), + }, + CreateEndpoint: conn.CreateEndpoint, + CreateBind: conn.CreateBind, + SkipBindUpdate: true, + }) + dev.Up() + + // Wait for magicsock to connect up to DERP. + conn.WaitReady(t) + + // Wait for first endpoint update to be available + deadline := time.Now().Add(2 * time.Second) + for len(epCh) == 0 && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } + + return &magicStack{ + privateKey: privateKey, + epCh: epCh, + conn: conn, + tun: tun, + tsTun: tsTun, + dev: dev, + } +} + +func (s *magicStack) Close() { + s.dev.Close() + s.conn.Close() +} + func TestNewConn(t *testing.T) { tstest.PanicOnLog() rc := tstest.NewResourceCheck() @@ -243,45 +373,6 @@ func parseCIDR(t *testing.T, addr string) wgcfg.CIDR { return cidr } -func runDERP(t *testing.T, logf logger.Logf, onlyDisco bool) (s *derp.Server, addr *net.TCPAddr, cleanupFn func()) { - var serverPrivateKey key.Private - if _, err := crand.Read(serverPrivateKey[:]); err != nil { - t.Fatal(err) - } - - s = derp.NewServer(serverPrivateKey, logf) - s.OnlyDisco = onlyDisco - - httpsrv := httptest.NewUnstartedServer(derphttp.Handler(s)) - httpsrv.Config.ErrorLog = logger.StdLogger(logf) - httpsrv.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) - httpsrv.StartTLS() - logf("DERP server URL: %s (onlyDisco=%v)", httpsrv.URL, onlyDisco) - - cleanupFn = func() { - httpsrv.CloseClientConnections() - httpsrv.Close() - s.Close() - } - - return s, httpsrv.Listener.Addr().(*net.TCPAddr), cleanupFn -} - -// devLogger returns a wireguard-go device.Logger that writes -// wireguard logs to the test logger. -func devLogger(t *testing.T, prefix string, logfx logger.Logf) *device.Logger { - pfx := []interface{}{prefix} - logf := func(format string, args ...interface{}) { - t.Helper() - logfx("%s: "+format, append(pfx, args...)...) - } - return &device.Logger{ - Debug: logger.StdLogger(logf), - Info: logger.StdLogger(logf), - Error: logger.StdLogger(logf), - } -} - // TestDeviceStartStop exercises the startup and shutdown logic of // wireguard-go, which is intimately intertwined with magicsock's own // lifecycle. We seem to be good at generating deadlocks here, so if @@ -305,7 +396,11 @@ func TestDeviceStartStop(t *testing.T) { tun := tuntest.NewChannelTUN() dev := device.NewDevice(tun.TUN(), &device.DeviceOptions{ - Logger: devLogger(t, "dev", t.Logf), + Logger: &device.Logger{ + Debug: logger.StdLogger(t.Logf), + Info: logger.StdLogger(t.Logf), + Error: logger.StdLogger(t.Logf), + }, CreateEndpoint: conn.CreateEndpoint, CreateBind: conn.CreateBind, SkipBindUpdate: true, @@ -414,127 +509,37 @@ func testTwoDevicePing(t *testing.T, d *devices) { rc := tstest.NewResourceCheck() defer rc.Assert(t) - usingNatLab := d.m1 != (nettype.Std{}) - // This gets reassigned inside every test, so that the connections // all log using the "current" t.Logf function. Sigh. logf, setT := makeNestable(t) - derpServer, derpAddr, derpCleanupFn := runDERP(t, logf, usingNatLab) - defer derpCleanupFn() + derpMap, cleanup := runDERPAndStun(t, logf, d.stun, d.stunIP) + defer cleanup() - stunAddr, stunCleanupFn := stuntest.ServeWithPacketListener(t, d.stun) - defer stunCleanupFn() - - derpMap := &tailcfg.DERPMap{ - Regions: map[int]*tailcfg.DERPRegion{ - 1: &tailcfg.DERPRegion{ - RegionID: 1, - RegionCode: "test", - Nodes: []*tailcfg.DERPNode{ - { - Name: "t1", - RegionID: 1, - HostName: "test-node.unused", - IPv4: "127.0.0.1", - IPv6: "none", - STUNPort: stunAddr.Port, - DERPTestPort: derpAddr.Port, - STUNTestIP: d.stunIP.String(), - }, - }, - }, - }, - } - - epCh1 := make(chan []string, 16) - conn1, err := NewConn(Options{ - Logf: logger.WithPrefix(logf, "conn1: "), - PacketListener: d.m1, - EndpointsFunc: func(eps []string) { - epCh1 <- eps - }, - }) - if err != nil { - t.Fatal(err) - } - defer conn1.Close() - conn1.Start() - conn1.SetDERPMap(derpMap) - - epCh2 := make(chan []string, 16) - conn2, err := NewConn(Options{ - Logf: logger.WithPrefix(logf, "conn2: "), - PacketListener: d.m2, - EndpointsFunc: func(eps []string) { - epCh2 <- eps - }, - }) - if err != nil { - t.Fatal(err) - } - defer conn2.Close() - conn2.Start() - conn2.SetDERPMap(derpMap) + m1 := newMagicStack(t, logf, d.m1, derpMap) + defer m1.Close() + m2 := newMagicStack(t, logf, d.m2, derpMap) + defer m2.Close() addrs := []netaddr.IPPort{ - {IP: d.m1IP, Port: conn1.LocalPort()}, - {IP: d.m2IP, Port: conn2.LocalPort()}, + {IP: d.m1IP, Port: m1.conn.LocalPort()}, + {IP: d.m2IP, Port: m2.conn.LocalPort()}, } cfgs := makeConfigs(t, addrs) - if err := conn1.SetPrivateKey(cfgs[0].PrivateKey); err != nil { + if err := m1.dev.Reconfig(&cfgs[0]); err != nil { t.Fatal(err) } - if err := conn2.SetPrivateKey(cfgs[1].PrivateKey); err != nil { + if err := m2.dev.Reconfig(&cfgs[1]); err != nil { t.Fatal(err) } - //uapi1, _ := cfgs[0].ToUAPI() - //logf("cfg0: %v", uapi1) - //uapi2, _ := cfgs[1].ToUAPI() - //logf("cfg1: %v", uapi2) - - tun1 := tuntest.NewChannelTUN() - tstun1 := tstun.WrapTUN(logf, tun1.TUN()) - tstun1.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf)) - dev1 := device.NewDevice(tstun1, &device.DeviceOptions{ - Logger: devLogger(t, "dev1", logf), - CreateEndpoint: conn1.CreateEndpoint, - CreateBind: conn1.CreateBind, - SkipBindUpdate: true, - }) - dev1.Up() - if err := dev1.Reconfig(&cfgs[0]); err != nil { - t.Fatal(err) - } - defer dev1.Close() - - tun2 := tuntest.NewChannelTUN() - tstun2 := tstun.WrapTUN(logf, tun2.TUN()) - tstun2.SetFilter(filter.NewAllowAll([]filter.Net{filter.NetAny}, logf)) - dev2 := device.NewDevice(tstun2, &device.DeviceOptions{ - Logger: devLogger(t, "dev2", logf), - CreateEndpoint: conn2.CreateEndpoint, - CreateBind: conn2.CreateBind, - SkipBindUpdate: true, - }) - dev2.Up() - defer dev2.Close() - - if err := dev2.Reconfig(&cfgs[1]); err != nil { - t.Fatal(err) - } - - conn1.WaitReady(t) - conn2.WaitReady(t) - ping1 := func(t *testing.T) { msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) - tun2.Outbound <- msg2to1 + m2.tun.Outbound <- msg2to1 t.Log("ping1 sent") select { - case msgRecv := <-tun1.Inbound: + case msgRecv := <-m1.tun.Inbound: if !bytes.Equal(msg2to1, msgRecv) { t.Error("ping did not transit correctly") } @@ -544,10 +549,10 @@ func testTwoDevicePing(t *testing.T, d *devices) { } ping2 := func(t *testing.T) { msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) - tun1.Outbound <- msg1to2 + m1.tun.Outbound <- msg1to2 t.Log("ping2 sent") select { - case msgRecv := <-tun2.Inbound: + case msgRecv := <-m2.tun.Inbound: if !bytes.Equal(msg1to2, msgRecv) { t.Error("return ping did not transit correctly") } @@ -573,12 +578,12 @@ func testTwoDevicePing(t *testing.T, d *devices) { setT(t) defer setT(outerT) msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) - if err := tstun1.InjectOutbound(msg1to2); err != nil { + if err := m1.tsTun.InjectOutbound(msg1to2); err != nil { t.Fatal(err) } t.Log("SendPacket sent") select { - case msgRecv := <-tun2.Inbound: + case msgRecv := <-m2.tun.Inbound: if !bytes.Equal(msg1to2, msgRecv) { t.Error("return ping did not transit correctly") } @@ -590,7 +595,7 @@ func testTwoDevicePing(t *testing.T, d *devices) { t.Run("no-op dev1 reconfig", func(t *testing.T) { setT(t) defer setT(outerT) - if err := dev1.Reconfig(&cfgs[0]); err != nil { + if err := m1.dev.Reconfig(&cfgs[0]); err != nil { t.Fatal(err) } ping1(t) @@ -632,14 +637,14 @@ func testTwoDevicePing(t *testing.T, d *devices) { for i := 0; i < count; i++ { b := msg(i) - tun1.Outbound <- b + m1.tun.Outbound <- b time.Sleep(interPacketGap) } for i := 0; i < count; i++ { b := msg(i) select { - case msgRecv := <-tun2.Inbound: + case msgRecv := <-m2.tun.Inbound: if !bytes.Equal(b, msgRecv) { if strict { t.Errorf("return ping %d did not transit correctly: %s", i, cmp.Diff(b, msgRecv)) @@ -651,7 +656,6 @@ func testTwoDevicePing(t *testing.T, d *devices) { } } } - } t.Run("ping 1.0.0.1 x50", func(t *testing.T) { @@ -668,29 +672,26 @@ func testTwoDevicePing(t *testing.T, d *devices) { ep1 := cfgs[1].Peers[0].Endpoints ep1 = append([]wgcfg.Endpoint{derpEp}, ep1...) cfgs[1].Peers[0].Endpoints = ep1 - if err := dev1.Reconfig(&cfgs[0]); err != nil { + if err := m1.dev.Reconfig(&cfgs[0]); err != nil { t.Fatal(err) } - if err := dev2.Reconfig(&cfgs[1]); err != nil { + if err := m2.dev.Reconfig(&cfgs[1]); err != nil { t.Fatal(err) } t.Run("add DERP", func(t *testing.T) { setT(t) defer setT(outerT) - defer func() { - logf("DERP vars: %s", derpServer.ExpVar().String()) - }() pingSeq(t, 20, 0, true) }) // Disable real route. cfgs[0].Peers[0].Endpoints = []wgcfg.Endpoint{derpEp} cfgs[1].Peers[0].Endpoints = []wgcfg.Endpoint{derpEp} - if err := dev1.Reconfig(&cfgs[0]); err != nil { + if err := m1.dev.Reconfig(&cfgs[0]); err != nil { t.Fatal(err) } - if err := dev2.Reconfig(&cfgs[1]); err != nil { + if err := m2.dev.Reconfig(&cfgs[1]); err != nil { t.Fatal(err) } time.Sleep(250 * time.Millisecond) // TODO remove @@ -699,7 +700,6 @@ func testTwoDevicePing(t *testing.T, d *devices) { setT(t) defer setT(outerT) defer func() { - logf("DERP vars: %s", derpServer.ExpVar().String()) if t.Failed() || true { uapi1, _ := cfgs[0].ToUAPI() logf("cfg0: %v", uapi1) @@ -710,8 +710,8 @@ func testTwoDevicePing(t *testing.T, d *devices) { pingSeq(t, 20, 0, true) }) - dev1.RemoveAllPeers() - dev2.RemoveAllPeers() + m1.dev.RemoveAllPeers() + m2.dev.RemoveAllPeers() // Give one peer a non-DERP endpoint. We expect the other to // accept it via roamAddr. @@ -719,10 +719,10 @@ func testTwoDevicePing(t *testing.T, d *devices) { if ep2 := cfgs[1].Peers[0].Endpoints; len(ep2) != 1 { t.Errorf("unexpected peer endpoints in dev2: %v", ep2) } - if err := dev2.Reconfig(&cfgs[1]); err != nil { + if err := m2.dev.Reconfig(&cfgs[1]); err != nil { t.Fatal(err) } - if err := dev1.Reconfig(&cfgs[0]); err != nil { + if err := m1.dev.Reconfig(&cfgs[0]); err != nil { t.Fatal(err) } // Dear future human debugging a test failure here: this test is @@ -736,7 +736,7 @@ func testTwoDevicePing(t *testing.T, d *devices) { defer setT(outerT) pingSeq(t, 50, 700*time.Millisecond, false) - ep2 := dev2.Config().Peers[0].Endpoints + ep2 := m2.dev.Config().Peers[0].Endpoints if len(ep2) != 2 { t.Error("handshake spray failed to find real route") }