diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index b1f87c7db..534385dbb 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -5,6 +5,7 @@ package magicsock import ( + "bytes" "fmt" "log" "net" @@ -13,6 +14,9 @@ import ( "testing" "time" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun/tuntest" + "github.com/tailscale/wireguard-go/wgcfg" "tailscale.com/stun" ) @@ -179,3 +183,174 @@ func runSTUN(pc net.PacketConn, stats *stunStats) { } } } + +func makeConfigs(t *testing.T, ports []uint16) []wgcfg.Config { + t.Helper() + + var privKeys []wgcfg.PrivateKey + var addresses [][]wgcfg.CIDR + + for i := range ports { + privKey, err := wgcfg.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + privKeys = append(privKeys, privKey) + + addresses = append(addresses, []wgcfg.CIDR{ + parseCIDR(t, fmt.Sprintf("1.0.0.%d/32", i+1)), + }) + } + + var cfgs []wgcfg.Config + for i, port := range ports { + cfg := wgcfg.Config{ + Name: fmt.Sprintf("peer%d", i+1), + PrivateKey: privKeys[i], + Addresses: addresses[i], + ListenPort: port, + } + for peerNum, port := range ports { + if peerNum == i { + continue + } + peer := wgcfg.Peer{ + PublicKey: privKeys[peerNum].Public(), + AllowedIPs: addresses[peerNum], + Endpoints: []wgcfg.Endpoint{{ + Host: "127.0.0.1", + Port: port, + }}, + } + cfg.Peers = append(cfg.Peers, peer) + } + cfgs = append(cfgs, cfg) + } + return cfgs +} + +func parseCIDR(t *testing.T, addr string) wgcfg.CIDR { + t.Helper() + cidr, err := wgcfg.ParseCIDR(addr) + if err != nil { + t.Fatal(err) + } + return *cidr +} + +func TestTwoDevicePing(t *testing.T) { + stunAddr, stunCleanupFn := serveSTUN(t) + defer stunCleanupFn() + + epCh1 := make(chan []string, 16) + conn1, err := Listen(Options{ + STUN: []string{stunAddr.String()}, + EndpointsFunc: func(eps []string) { + epCh1 <- eps + }, + }) + if err != nil { + t.Fatal(err) + } + defer conn1.Close() + + epCh2 := make(chan []string, 16) + conn2, err := Listen(Options{ + STUN: []string{stunAddr.String()}, + EndpointsFunc: func(eps []string) { + epCh2 <- eps + }, + }) + if err != nil { + t.Fatal(err) + } + defer conn2.Close() + + ports := []uint16{conn1.LocalPort(), conn2.LocalPort()} + cfgs := makeConfigs(t, ports) + + uapi1, _ := cfgs[0].ToUAPI() + t.Logf("cfg0: %v", uapi1) + uapi2, _ := cfgs[1].ToUAPI() + t.Logf("cfg1: %v", uapi2) + + tun1 := tuntest.NewChannelTUN() + dev1 := device.NewDevice(tun1.TUN(), &device.DeviceOptions{ + Logger: device.NewLogger(device.LogLevelDebug, "dev1: "), + CreateEndpoint: conn1.CreateEndpoint, + CreateBind: conn1.CreateBind, + SkipBindUpdate: true, + }) + dev1.Up() + //defer dev1.Close() TODO(crawshaw): this hangs + if err := dev1.Reconfig(&cfgs[0]); err != nil { + t.Fatal(err) + } + + tun2 := tuntest.NewChannelTUN() + dev2 := device.NewDevice(tun2.TUN(), &device.DeviceOptions{ + Logger: device.NewLogger(device.LogLevelDebug, "dev2: "), + CreateEndpoint: conn2.CreateEndpoint, + CreateBind: conn2.CreateBind, + SkipBindUpdate: true, + }) + dev2.Up() + //defer dev2.Close() TODO(crawshaw): this hangs + if err := dev2.Reconfig(&cfgs[1]); err != nil { + t.Fatal(err) + } + + ping1 := func(t *testing.T) { + t.Helper() + + msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) + tun2.Outbound <- msg2to1 + select { + case msgRecv := <-tun1.Inbound: + if !bytes.Equal(msg2to1, msgRecv) { + t.Error("ping did not transit correctly") + } + case <-time.After(1 * time.Second): + t.Error("ping did not transit") + } + } + ping2 := func(t *testing.T) { + t.Helper() + + msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) + tun1.Outbound <- msg1to2 + select { + case msgRecv := <-tun2.Inbound: + if !bytes.Equal(msg1to2, msgRecv) { + t.Error("return ping did not transit correctly") + } + case <-time.After(1 * time.Second): + t.Error("return ping did not transit") + } + } + + t.Run("ping 1.0.0.1", func(t *testing.T) { ping1(t) }) + t.Run("ping 1.0.0.2", func(t *testing.T) { ping2(t) }) + t.Run("ping 1.0.0.2 via SendPacket", func(t *testing.T) { + msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) + if err := dev1.SendPacket(msg1to2); err != nil { + t.Fatal(err) + } + select { + case msgRecv := <-tun2.Inbound: + if !bytes.Equal(msg1to2, msgRecv) { + t.Error("return ping did not transit correctly") + } + case <-time.After(1 * time.Second): + t.Error("return ping did not transit") + } + }) + + t.Run("no-op dev1 reconfig", func(t *testing.T) { + if err := dev1.Reconfig(&cfgs[0]); err != nil { + t.Fatal(err) + } + ping1(t) + ping2(t) + }) +}