From 04a3118d45c993fbed896f696e7d911de2724537 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Fri, 21 Apr 2023 10:27:15 -0400 Subject: [PATCH] net/tstun: add tests for captureHook Signed-off-by: Andrew Dunham Change-Id: I630f852d9f16c951c721b34f2bc4128e68fe9475 --- net/tstun/wrap.go | 28 +++++++++---- net/tstun/wrap_test.go | 95 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 8 deletions(-) diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 2c4dd0f57..450885fbf 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -93,6 +93,9 @@ type Wrapper struct { destMACAtomic syncs.AtomicValue[[6]byte] discoKey syncs.AtomicValue[key.DiscoPublic] + // timeNow, if non-nil, will be used to obtain the current time. + timeNow func() time.Time + // natV4Config stores the current NAT configuration. natV4Config atomic.Pointer[natV4Config] @@ -258,6 +261,15 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool) *Wrapper { return w } +// now returns the current time, either by calling t.timeNow if set or time.Now +// if not. +func (t *Wrapper) now() time.Time { + if t.timeNow != nil { + return t.timeNow() + } + return time.Now() +} + // SetDestIPActivityFuncs sets a map of funcs to run per packet // destination (the map keys). // @@ -724,7 +736,7 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { } } if captHook != nil { - captHook(capture.FromLocal, time.Now(), p.Buffer(), p.CaptureMeta) + captHook(capture.FromLocal, t.now(), p.Buffer(), p.CaptureMeta) } if !t.disableFilter { response := t.filterPacketOutboundToWireGuard(p) @@ -791,7 +803,7 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, buf []byte, offset int) (int func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback) filter.Response { if captHook != nil { - captHook(capture.FromPeer, time.Now(), p.Buffer(), p.CaptureMeta) + captHook(capture.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) } if p.IPProto == ipproto.TSMP { @@ -959,7 +971,7 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt stack.PacketBufferPtr) error { p.Decode(buf[PacketStartOffset:]) captHook := t.captureHook.Load() if captHook != nil { - captHook(capture.SynthesizedToLocal, time.Now(), p.Buffer(), p.CaptureMeta) + captHook(capture.SynthesizedToLocal, t.now(), p.Buffer(), p.CaptureMeta) } t.dnatV4(p) @@ -1037,14 +1049,14 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque // It does not block, but takes ownership of the packet. // The injected packet will not pass through outbound filters. // Injecting an empty packet is a no-op. -func (t *Wrapper) InjectOutbound(packet []byte) error { - if len(packet) > MaxPacketSize { +func (t *Wrapper) InjectOutbound(pkt []byte) error { + if len(pkt) > MaxPacketSize { return errPacketTooBig } - if len(packet) == 0 { + if len(pkt) == 0 { return nil } - t.injectOutbound(tunInjectedRead{data: packet}) + t.injectOutbound(tunInjectedRead{data: pkt}) return nil } @@ -1063,7 +1075,7 @@ func (t *Wrapper) InjectOutboundPacketBuffer(pkt stack.PacketBufferPtr) error { } if capt := t.captureHook.Load(); capt != nil { b := pkt.ToBuffer() - capt(capture.SynthesizedToPeer, time.Now(), b.Flatten(), packet.CaptureMeta{}) + capt(capture.SynthesizedToPeer, t.now(), b.Flatten(), packet.CaptureMeta{}) } t.injectOutbound(tunInjectedRead{packet: pkt}) diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index baf0eff56..f2ae0a614 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -10,9 +10,11 @@ import ( "encoding/hex" "fmt" "net/netip" + "reflect" "strconv" "strings" "testing" + "time" "unicode" "unsafe" @@ -21,6 +23,8 @@ import ( "github.com/tailscale/wireguard-go/tun/tuntest" "go4.org/mem" "go4.org/netipx" + "gvisor.dev/gvisor/pkg/bufferv2" + "gvisor.dev/gvisor/pkg/tcpip/stack" "tailscale.com/disco" "tailscale.com/net/connstats" "tailscale.com/net/netaddr" @@ -33,6 +37,7 @@ import ( "tailscale.com/types/netlogtype" "tailscale.com/types/ptr" "tailscale.com/util/must" + "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/wgcfg" ) @@ -766,3 +771,93 @@ func TestNATCfg(t *testing.T) { }) } } + +// TestCaptureHook verifies that the Wrapper.captureHook callback is called +// with the correct parameters when various packet operations are performed. +func TestCaptureHook(t *testing.T) { + type captureRecord struct { + path capture.Path + now time.Time + pkt []byte + meta packet.CaptureMeta + } + + var captured []captureRecord + hook := func(path capture.Path, now time.Time, pkt []byte, meta packet.CaptureMeta) { + captured = append(captured, captureRecord{ + path: path, + now: now, + pkt: pkt, + meta: meta, + }) + } + + now := time.Unix(1682085856, 0) + + _, w := newFakeTUN(t.Logf, true) + w.timeNow = func() time.Time { + return now + } + w.InstallCaptureHook(hook) + defer w.Close() + + // Loop reading and discarding packets; this ensures that we don't have + // packets stuck in vectorOutbound + go func() { + var ( + buf [MaxPacketSize]byte + sizes = make([]int, 1) + ) + for { + _, err := w.Read([][]byte{buf[:]}, sizes, 0) + if err != nil { + return + } + } + }() + + // Do operations that should result in a packet being captured. + w.Write([][]byte{ + []byte("Write1"), + []byte("Write2"), + }, 0) + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: bufferv2.MakeWithData([]byte("InjectInboundPacketBuffer")), + }) + w.InjectInboundPacketBuffer(packetBuf) + + packetBuf = stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: bufferv2.MakeWithData([]byte("InjectOutboundPacketBuffer")), + }) + w.InjectOutboundPacketBuffer(packetBuf) + + // TODO: test Read + // TODO: determine if we want InjectOutbound to log + + // Assert that the right packets are captured. + want := []captureRecord{ + { + path: capture.FromPeer, + pkt: []byte("Write1"), + }, + { + path: capture.FromPeer, + pkt: []byte("Write2"), + }, + { + path: capture.SynthesizedToLocal, + pkt: []byte("InjectInboundPacketBuffer"), + }, + { + path: capture.SynthesizedToPeer, + pkt: []byte("InjectOutboundPacketBuffer"), + }, + } + for i := 0; i < len(want); i++ { + want[i].now = now + } + if !reflect.DeepEqual(captured, want) { + t.Errorf("mismatch between captured and expected packets\ngot: %+v\nwant: %+v", + captured, want) + } +}