diff --git a/tstest/natlab/firewall.go b/tstest/natlab/firewall.go index 1030e0e27..584d466df 100644 --- a/tstest/natlab/firewall.go +++ b/tstest/natlab/firewall.go @@ -49,8 +49,11 @@ func (f *Firewall) HandlePacket(p *Packet, inIf *Interface) PacketVerdict { if f.seen == nil { f.seen = map[session]time.Time{} } + if f.SessionTimeout == 0 { + f.SessionTimeout = 30 * time.Second + } - if inIf == f.TrustedInterface { + if inIf == f.TrustedInterface || inIf == nil { sess := session{ src: p.Src, dst: p.Dst, diff --git a/tstest/natlab/natlab.go b/tstest/natlab/natlab.go index b4ea783ed..664ec1913 100644 --- a/tstest/natlab/natlab.go +++ b/tstest/natlab/natlab.go @@ -279,8 +279,13 @@ type Machine struct { Name string // HandlePacket, if not nil, is a function that gets invoked for - // every packet this Machine receives. Returns a verdict for how - // the packet should continue to be handled (or not). + // every packet this Machine receives, and every packet sent by a + // local PacketConn. Returns a verdict for how the packet should + // continue to be handled (or not). + // + // HandlePacket's interface parameter is the interface on which + // the packet was received, or nil for a packet sent by a local + // PacketConn or Inject call. // // The packet provided to HandlePacket can safely be mutated and // Inject()ed if desired. This can be used to implement things @@ -450,6 +455,15 @@ func (m *Machine) writePacket(p *Packet) (n int, err error) { return 0, err } + if m.HandlePacket != nil { + p.Trace("Machine.HandlePacket") + verdict := m.HandlePacket(p.Clone(), nil) + p.Trace("Machine.HandlePacket verdict=%s", verdict) + if verdict == Drop { + return len(p.Payload), nil + } + } + p.Trace("-> net=%s if=%s", iface.net.Name, iface) return iface.net.write(p) } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 9fb23b368..1f7814cbf 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -368,6 +368,34 @@ func TestTwoDevicePing(t *testing.T) { } testTwoDevicePing(t, n) }) + + t.Run("facing firewalls", func(t *testing.T) { + mstun := &natlab.Machine{Name: "stun"} + f1 := &natlab.Firewall{} + f2 := &natlab.Firewall{} + m1 := &natlab.Machine{ + Name: "m1", + HandlePacket: f1.HandlePacket, + } + m2 := &natlab.Machine{ + Name: "m2", + HandlePacket: f2.HandlePacket, + } + inet := natlab.NewInternet() + sif := mstun.Attach("eth0", inet) + m1if := m1.Attach("eth0", inet) + m2if := m2.Attach("eth0", inet) + + n := &devices{ + m1: m1, + m1IP: m1if.V4(), + m2: m2, + m2IP: m2if.V4(), + stun: mstun, + stunIP: sif.V4(), + } + testTwoDevicePing(t, n) + }) }) }