tstest/natlab: refactor PacketHandler into a larger interface.

The new interface lets implementors more precisely distinguish
local traffic from forwarded traffic, and applies different
forwarding logic within Machines for each type. This allows
Machines to be packet forwarders, which didn't quite work
with the implementation of Inject.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2020-07-14 21:01:52 +00:00
committed by Dave Anderson
parent 723b9eecb0
commit 45578b47f3
5 changed files with 331 additions and 221 deletions

View File

@@ -101,32 +101,57 @@ func (f *Firewall) timeNow() time.Time {
return time.Now() return time.Now()
} }
// HandlePacket implements the PacketHandler type. func (f *Firewall) init() {
func (f *Firewall) HandlePacket(p *Packet, inIf *Interface) PacketVerdict {
f.mu.Lock()
defer f.mu.Unlock()
if f.seen == nil { if f.seen == nil {
f.seen = map[fwKey]time.Time{} f.seen = map[fwKey]time.Time{}
} }
if f.SessionTimeout == 0 { }
f.SessionTimeout = 30 * time.Second
} func (f *Firewall) HandleOut(p *Packet, oif *Interface) *Packet {
f.mu.Lock()
defer f.mu.Unlock()
f.init()
if inIf == f.TrustedInterface || inIf == nil {
k := f.Type.key(p.Src, p.Dst) k := f.Type.key(p.Src, p.Dst)
f.seen[k] = f.timeNow().Add(f.SessionTimeout) f.seen[k] = f.timeNow().Add(f.sessionTimeoutLocked())
p.Trace("firewall out ok") p.Trace("firewall out ok")
return Continue return p
} else { }
// reverse src and dst because the session table is from the
// POV of outbound packets. func (f *Firewall) HandleIn(p *Packet, iif *Interface) *Packet {
f.mu.Lock()
defer f.mu.Unlock()
f.init()
// reverse src and dst because the session table is from the POV
// of outbound packets.
k := f.Type.key(p.Dst, p.Src) k := f.Type.key(p.Dst, p.Src)
now := f.timeNow() now := f.timeNow()
if now.After(f.seen[k]) { if now.After(f.seen[k]) {
p.Trace("firewall drop") p.Trace("firewall drop")
return Drop return nil
} }
p.Trace("firewall in ok") p.Trace("firewall in ok")
return Continue return p
} }
func (f *Firewall) HandleForward(p *Packet, iif *Interface, oif *Interface) *Packet {
if iif == f.TrustedInterface {
// Treat just like a locally originated packet
return f.HandleOut(p, oif)
}
if oif != f.TrustedInterface {
// Not a possible return packet from our trusted interface, drop.
p.Trace("firewall drop, unexpected oif")
return nil
}
// Otherwise, a session must exist, same as HandleIn.
return f.HandleIn(p, iif)
}
func (f *Firewall) sessionTimeoutLocked() time.Duration {
if f.SessionTimeout == 0 {
return DefaultSessionTimeout
}
return f.SessionTimeout
} }

View File

@@ -99,11 +99,6 @@ type SNAT44 struct {
// nil, time.Now is used. // nil, time.Now is used.
TimeNow func() time.Time TimeNow func() time.Time
// inject, if not nil, will be invoked instead of Machine.Inject
// to inject NATed packets into the network. It is used for tests
// only.
inject func(*Packet) error
mu sync.Mutex mu sync.Mutex
byLAN map[natKey]*mapping // lookup by outbound packet tuple byLAN map[natKey]*mapping // lookup by outbound packet tuple
byWAN map[netaddr.IPPort]*mapping // lookup by wan ip:port only byWAN map[netaddr.IPPort]*mapping // lookup by wan ip:port only
@@ -131,64 +126,72 @@ func (n *SNAT44) initLocked() {
if n.ExternalInterface.Machine() != n.Machine { if n.ExternalInterface.Machine() != n.Machine {
panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name)) panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name))
} }
if n.inject == nil {
n.inject = n.Machine.Inject
}
} }
func (n *SNAT44) HandlePacket(p *Packet, inIf *Interface) PacketVerdict { func (n *SNAT44) HandleOut(p *Packet, oif *Interface) *Packet {
// NATs don't affect locally originated packets.
if n.Firewall != nil {
return n.Firewall.HandleOut(p, oif)
}
return p
}
func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet {
if iif != n.ExternalInterface {
// NAT can't apply, defer to firewall.
if n.Firewall != nil {
return n.Firewall.HandleIn(p, iif)
}
return p
}
n.mu.Lock() n.mu.Lock()
defer n.mu.Unlock() defer n.mu.Unlock()
n.initLocked() n.initLocked()
if inIf == n.ExternalInterface {
return n.processInboundLocked(p, inIf)
} else {
return n.processOutboundLocked(p, inIf)
}
}
func (n *SNAT44) processInboundLocked(p *Packet, inIf *Interface) PacketVerdict {
// TODO: packets to local addrs should fall through to local
// socket processing.
now := n.timeNow() now := n.timeNow()
mapping := n.byWAN[p.Dst] mapping := n.byWAN[p.Dst]
if mapping == nil || now.After(mapping.deadline) { if mapping == nil || now.After(mapping.deadline) {
p.Trace("nat drop, no mapping/expired mapping") // NAT didn't hit, defer to firewall or allow in for local
return Drop // socket handling.
}
p.Dst = mapping.lanSrc
if n.Firewall != nil { if n.Firewall != nil {
if verdict := n.Firewall(p.Clone(), inIf); verdict == Drop { return n.Firewall.HandleIn(p, iif)
return Drop
} }
return p
} }
if err := n.inject(p); err != nil { p.Dst = mapping.lanSrc
p.Trace("inject failed: %v", err) p.Trace("dnat to %v", p.Dst)
} // Don't process firewall here. We mutated the packet such that
return Drop // it's no longer destined locally, so we'll get reinvoked as
// HandleForward and need to process the altered packet there.
return p
} }
func (n *SNAT44) processOutboundLocked(p *Packet, inIf *Interface) PacketVerdict { func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet {
switch {
case oif == n.ExternalInterface:
if p.Src.IP == oif.V4() {
// Packet already NATed and is just retraversing Forward,
// don't touch it again.
return p
}
if n.Firewall != nil { if n.Firewall != nil {
if verdict := n.Firewall(p, inIf); verdict == Drop { p2 := n.Firewall.HandleForward(p, iif, oif)
return Drop if p2 == nil {
// firewall dropped, done
return nil
}
if !p.Equivalent(p2) {
// firewall mutated packet? Weird, but okay.
return p2
} }
} }
if inIf == nil {
// Technically, we don't need to process the outbound firewall n.mu.Lock()
// for NATed packets, but our current packet processing API defer n.mu.Unlock()
// doesn't give us that granularity: we'll see both locally n.initLocked()
// originated PacketConn traffic and NATed traffic as inIf ==
// nil, and we need to apply the firewall to locally
// originated traffic. This may create some useless state
// entries in the firewall, but until we implement a much more
// elaborate packet processing pipeline that can distinguish
// local vs. forwarded traffic, this is the best we have.
return Continue
}
k := n.Type.key(p.Src, p.Dst) k := n.Type.key(p.Src, p.Dst)
now := n.timeNow() now := n.timeNow()
@@ -206,12 +209,22 @@ func (n *SNAT44) processOutboundLocked(p *Packet, inIf *Interface) PacketVerdict
} }
m.deadline = now.Add(n.mappingTimeout()) m.deadline = now.Add(n.mappingTimeout())
p.Src = m.wanSrc p.Src = m.wanSrc
p.Trace("snat from %v", p.Src) p.Trace("snat from %v", p.Src)
if err := n.inject(p); err != nil { return p
p.Trace("inject failed: %v", err) case iif == n.ExternalInterface:
// Packet was already un-NAT-ed, we just need to either
// firewall it or let it through.
if n.Firewall != nil {
return n.Firewall.HandleForward(p, iif, oif)
}
return p
default:
// No NAT applies, invoke firewall or drop.
if n.Firewall != nil {
return n.Firewall.HandleForward(p, iif, oif)
}
return nil
} }
return Drop
} }
func (n *SNAT44) allocateMappedPort() (net.PacketConn, netaddr.IPPort) { func (n *SNAT44) allocateMappedPort() (net.PacketConn, netaddr.IPPort) {

View File

@@ -12,6 +12,7 @@
package natlab package natlab
import ( import (
"bytes"
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
@@ -40,6 +41,12 @@ type Packet struct {
locator string locator string
} }
// Equivalent returns true if Src, Dst and Payload are the same in p
// and p2.
func (p *Packet) Equivalent(p2 *Packet) bool {
return p.Src == p2.Src && p.Dst == p2.Dst && bytes.Equal(p.Payload, p2.Payload)
}
// Clone returns a copy of p that shares nothing with p. // Clone returns a copy of p that shares nothing with p.
func (p *Packet) Clone() *Packet { func (p *Packet) Clone() *Packet {
return &Packet{ return &Packet{
@@ -266,8 +273,41 @@ func (v PacketVerdict) String() string {
} }
} }
// A PacketHandler is a function that can process packets. // A PacketHandler can look at packets arriving at, departing, and
type PacketHandler func(p *Packet, inIf *Interface) PacketVerdict // transiting a Machine, and filter or mutate them.
//
// Each method is invoked with a Packet that natlab would like to keep
// processing. Handlers can return that same Packet to allow
// processing to continue; nil to drop the Packet; or a different
// Packet that should be processed instead of the original.
//
// Packets passed to handlers share no state with anything else, and
// are therefore safe to mutate. It's safe to return the original
// packet mutated in-place, or a brand new packet initialized from
// scratch.
//
// Packets mutated by a PacketHandler are processed anew by the
// associated Machine, as if the packet had always been the mutated
// one. For example, if HandleForward is invoked with a Packet, and
// the handler changes the destination IP address to one of the
// Machine's own IPs, the Machine restarts delivery, but this time
// going to a local PacketConn (which in turn will invoke HandleIn,
// since the packet is now destined for local delivery).
type PacketHandler interface {
// HandleIn processes a packet arriving on iif, whose destination
// is an IP address owned by the attached Machine. If p is
// returned unmodified, the Machine will go on to deliver the
// Packet to the appropriate listening PacketConn, if one exists.
HandleIn(p *Packet, iif *Interface) *Packet
// HandleOut processes a packet about to depart on oif from a
// local PacketConn. If p is returned unmodified, the Machine will
// transmit the Packet on oif.
HandleOut(p *Packet, oif *Interface) *Packet
// HandleForward is called when the Machine wants to forward a
// packet from iif to oif. If p is returned unmodified, the
// Machine will transmit the packet on oif.
HandleForward(p *Packet, iif, oif *Interface) *Packet
}
// A Machine is a representation of an operating system's network // A Machine is a representation of an operating system's network
// stack. It has a network routing table and can have multiple // stack. It has a network routing table and can have multiple
@@ -278,19 +318,14 @@ type Machine struct {
// not be globally unique. // not be globally unique.
Name string Name string
// HandlePacket, if not nil, is a function that gets invoked for // PacketHandler, if not nil, is a PacketHandler implementation
// every packet this Machine receives, and every packet sent by a // that inspects all packets arriving, departing, or transiting
// local PacketConn. Returns a verdict for how the packet should // the Machine. See the definition of the PacketHandler interface
// continue to be handled (or not). // for semantics.
// //
// HandlePacket's interface parameter is the interface on which // If PacketHandler is nil, the machine allows all inbound
// the packet was received, or nil for a packet sent by a local // traffic, all outbound traffic, and drops forwarded packets.
// PacketConn or Inject call. PacketHandler PacketHandler
//
// The packet provided to HandlePacket can safely be mutated and
// Inject()ed if desired. This can be used to implement things
// like stateful firewalls and NAT boxes.
HandlePacket PacketHandler
mu sync.Mutex mu sync.Mutex
interfaces []*Interface interfaces []*Interface
@@ -300,26 +335,42 @@ type Machine struct {
conns6 map[netaddr.IPPort]*conn // conns that want IPv6 packets conns6 map[netaddr.IPPort]*conn // conns that want IPv6 packets
} }
// Inject transmits p from src to dst, without the need for a local socket. func (m *Machine) isLocalIP(ip netaddr.IP) bool {
// It's useful for implementing e.g. NAT boxes that need to mangle IPs. m.mu.Lock()
func (m *Machine) Inject(p *Packet) error { defer m.mu.Unlock()
p = p.Clone() for _, intf := range m.interfaces {
p.setLocator("mach=%s", m.Name) for _, iip := range intf.ips {
p.Trace("Machine.Inject") if ip == iip {
_, err := m.writePacket(p) return true
return err }
}
}
return false
} }
func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) { func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) {
p.setLocator("mach=%s if=%s", m.Name, iface.name) p.setLocator("mach=%s if=%s", m.Name, iface.name)
if m.isLocalIP(p.Dst.IP) {
m.deliverLocalPacket(p, iface)
} else {
m.forwardPacket(p, iface)
}
}
func (m *Machine) deliverLocalPacket(p *Packet, iface *Interface) {
// TODO: can't hold lock while handling packet. This is safe as // TODO: can't hold lock while handling packet. This is safe as
// long as you set HandlePacket before traffic starts flowing. // long as you set HandlePacket before traffic starts flowing.
if m.HandlePacket != nil { if m.PacketHandler != nil {
p.Trace("Machine.HandlePacket") p2 := m.PacketHandler.HandleIn(p.Clone(), iface)
verdict := m.HandlePacket(p.Clone(), iface) if p2 == nil {
p.Trace("Machine.HandlePacket verdict=%s", verdict) // Packet dropped, nothing left to do.
if verdict == Drop { return
// Custom packet handler ate the packet, we're done. }
if !p.Equivalent(p2) {
// Restart delivery, this packet might be a forward packet
// now.
m.deliverIncomingPacket(p2, iface)
return return
} }
} }
@@ -353,6 +404,35 @@ func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) {
p.Trace("dropped, no listening conn") p.Trace("dropped, no listening conn")
} }
func (m *Machine) forwardPacket(p *Packet, iif *Interface) {
oif, err := m.interfaceForIP(p.Dst.IP)
if err != nil {
p.Trace("%v", err)
return
}
if m.PacketHandler == nil {
// Forwarding not allowed by default
p.Trace("drop, forwarding not allowed")
return
}
p2 := m.PacketHandler.HandleForward(p.Clone(), iif, oif)
if p2 == nil {
p.Trace("drop")
// Packet dropped, done.
return
}
if !p.Equivalent(p2) {
// Packet changed, restart delivery.
p2.Trace("PacketHandler mutated packet")
m.deliverIncomingPacket(p2, iif)
return
}
p.Trace("-> net=%s oif=%s", oif.net.Name, oif)
oif.net.write(p)
}
func unspecOf(ip netaddr.IP) netaddr.IP { func unspecOf(ip netaddr.IP) netaddr.IP {
if ip.Is4() { if ip.Is4() {
return v4unspec return v4unspec
@@ -455,13 +535,17 @@ func (m *Machine) writePacket(p *Packet) (n int, err error) {
return 0, err return 0, err
} }
if m.HandlePacket != nil { if m.PacketHandler != nil {
p.Trace("Machine.HandlePacket") p2 := m.PacketHandler.HandleOut(p.Clone(), iface)
verdict := m.HandlePacket(p.Clone(), nil) if p2 == nil {
p.Trace("Machine.HandlePacket verdict=%s", verdict) // Packet dropped, done.
if verdict == Drop {
return len(p.Payload), nil return len(p.Payload), nil
} }
if !p.Equivalent(p2) {
// Restart transmission, src may have changed weirdly
m.writePacket(p2)
return
}
} }
p.Trace("-> net=%s if=%s", iface.net.Name, iface) p.Trace("-> net=%s if=%s", iface.net.Name, iface)

View File

@@ -148,6 +148,38 @@ func TestMultiNetwork(t *testing.T) {
} }
} }
type trivialNAT struct {
clientIP netaddr.IP
lanIf, wanIf *Interface
}
func (n *trivialNAT) HandleIn(p *Packet, iface *Interface) *Packet {
if iface == n.wanIf && p.Dst.IP == n.wanIf.V4() {
p.Dst.IP = n.clientIP
}
return p
}
func (n trivialNAT) HandleOut(p *Packet, iface *Interface) *Packet {
return p
}
func (n *trivialNAT) HandleForward(p *Packet, iif, oif *Interface) *Packet {
// Outbound from LAN -> apply NAT, continue
if iif == n.lanIf && oif == n.wanIf {
if p.Src.IP == n.clientIP {
p.Src.IP = n.wanIf.V4()
}
return p
}
// Return traffic to LAN, allow if right dst.
if iif == n.wanIf && oif == n.lanIf && p.Dst.IP == n.clientIP {
return p
}
// Else drop.
return nil
}
func TestPacketHandler(t *testing.T) { func TestPacketHandler(t *testing.T) {
lan := &Network{ lan := &Network{
Name: "lan", Name: "lan",
@@ -167,29 +199,10 @@ func TestPacketHandler(t *testing.T) {
lan.SetDefaultGateway(ifNATLAN) lan.SetDefaultGateway(ifNATLAN)
// This HandlePacket implements a basic (some might say "broken") nat.PacketHandler = &trivialNAT{
// 1:1 NAT, where client's IP gets replaced with the NAT's WAN IP, clientIP: ifClient.V4(),
// and vice versa. lanIf: ifNATLAN,
// wanIf: ifNATWAN,
// This NAT is not suitable for actual use, since it doesn't do
// port remappings or any other things that NATs usually to. But
// it works as a demonstrator for a single client behind the NAT,
// where the NAT box itself doesn't also make PacketConns.
nat.HandlePacket = func(p *Packet, iface *Interface) PacketVerdict {
switch {
case p.Dst.IP.Is6():
return Continue // no NAT for ipv6
case iface == ifNATLAN && p.Src.IP == ifClient.V4():
p.Src.IP = ifNATWAN.V4()
nat.Inject(p)
return Drop
case iface == ifNATWAN && p.Dst.IP == ifNATWAN.V4():
p.Dst.IP = ifClient.V4()
nat.Inject(p)
return Drop
default:
return Continue
}
} }
ctx := context.Background() ctx := context.Background()
@@ -246,17 +259,17 @@ func TestFirewall(t *testing.T) {
} }
testFirewall(t, f, []fwTest{ testFirewall(t, f, []fwTest{
// client -> A authorizes A -> client // client -> A authorizes A -> client
{trust, client, serverA, Continue}, {trust, untrust, client, serverA, true},
{untrust, serverA, client, Continue}, {untrust, trust, serverA, client, true},
{untrust, serverA, client, Continue}, {untrust, trust, serverA, client, true},
// B1 -> client fails until client -> B1 // B1 -> client fails until client -> B1
{untrust, serverB1, client, Drop}, {untrust, trust, serverB1, client, false},
{trust, client, serverB1, Continue}, {trust, untrust, client, serverB1, true},
{untrust, serverB1, client, Continue}, {untrust, trust, serverB1, client, true},
// B2 -> client still fails // B2 -> client still fails
{untrust, serverB2, client, Drop}, {untrust, trust, serverB2, client, false},
}) })
}) })
t.Run("ip_dependent", func(t *testing.T) { t.Run("ip_dependent", func(t *testing.T) {
@@ -267,17 +280,17 @@ func TestFirewall(t *testing.T) {
} }
testFirewall(t, f, []fwTest{ testFirewall(t, f, []fwTest{
// client -> A authorizes A -> client // client -> A authorizes A -> client
{trust, client, serverA, Continue}, {trust, untrust, client, serverA, true},
{untrust, serverA, client, Continue}, {untrust, trust, serverA, client, true},
{untrust, serverA, client, Continue}, {untrust, trust, serverA, client, true},
// B1 -> client fails until client -> B1 // B1 -> client fails until client -> B1
{untrust, serverB1, client, Drop}, {untrust, trust, serverB1, client, false},
{trust, client, serverB1, Continue}, {trust, untrust, client, serverB1, true},
{untrust, serverB1, client, Continue}, {untrust, trust, serverB1, client, true},
// B2 -> client also works now // B2 -> client also works now
{untrust, serverB2, client, Continue}, {untrust, trust, serverB2, client, true},
}) })
}) })
t.Run("endpoint_independent", func(t *testing.T) { t.Run("endpoint_independent", func(t *testing.T) {
@@ -288,23 +301,23 @@ func TestFirewall(t *testing.T) {
} }
testFirewall(t, f, []fwTest{ testFirewall(t, f, []fwTest{
// client -> A authorizes A -> client // client -> A authorizes A -> client
{trust, client, serverA, Continue}, {trust, untrust, client, serverA, true},
{untrust, serverA, client, Continue}, {untrust, trust, serverA, client, true},
{untrust, serverA, client, Continue}, {untrust, trust, serverA, client, true},
// B1 -> client also works // B1 -> client also works
{untrust, serverB1, client, Continue}, {untrust, trust, serverB1, client, true},
// B2 -> client also works // B2 -> client also works
{untrust, serverB2, client, Continue}, {untrust, trust, serverB2, client, true},
}) })
}) })
} }
type fwTest struct { type fwTest struct {
iface *Interface iif, oif *Interface
src, dst netaddr.IPPort src, dst netaddr.IPPort
want PacketVerdict ok bool
} }
func testFirewall(t *testing.T, f *Firewall, tests []fwTest) { func testFirewall(t *testing.T, f *Firewall, tests []fwTest) {
@@ -318,9 +331,10 @@ func testFirewall(t *testing.T, f *Firewall, tests []fwTest) {
Dst: test.dst, Dst: test.dst,
Payload: []byte{}, Payload: []byte{},
} }
got := f.HandlePacket(p, test.iface) got := f.HandleForward(p, test.iif, test.oif)
if got != test.want { gotOK := got != nil
t.Errorf("iface=%s src=%s dst=%s got %v, want %v", test.iface.name, test.src, test.dst, got, test.want) if gotOK != test.ok {
t.Errorf("iif=%s oif=%s src=%s dst=%s got ok=%v, want ok=%v", test.iif, test.oif, test.src, test.dst, gotOK, test.ok)
} }
} }
} }
@@ -344,14 +358,13 @@ func TestNAT(t *testing.T) {
lanIf := m.Attach("lan", lan) lanIf := m.Attach("lan", lan)
t.Run("endpoint_independent_mapping", func(t *testing.T) { t.Run("endpoint_independent_mapping", func(t *testing.T) {
fw := &Firewall{
TrustedInterface: lanIf,
}
n := &SNAT44{ n := &SNAT44{
Machine: m, Machine: m,
ExternalInterface: wanIf, ExternalInterface: wanIf,
Type: EndpointIndependentNAT, Type: EndpointIndependentNAT,
Firewall: fw.HandlePacket, Firewall: &Firewall{
TrustedInterface: lanIf,
},
} }
testNAT(t, n, lanIf, wanIf, []natTest{ testNAT(t, n, lanIf, wanIf, []natTest{
{ {
@@ -373,14 +386,13 @@ func TestNAT(t *testing.T) {
}) })
t.Run("address_dependent_mapping", func(t *testing.T) { t.Run("address_dependent_mapping", func(t *testing.T) {
fw := &Firewall{
TrustedInterface: lanIf,
}
n := &SNAT44{ n := &SNAT44{
Machine: m, Machine: m,
ExternalInterface: wanIf, ExternalInterface: wanIf,
Type: AddressDependentNAT, Type: AddressDependentNAT,
Firewall: fw.HandlePacket, Firewall: &Firewall{
TrustedInterface: lanIf,
},
} }
testNAT(t, n, lanIf, wanIf, []natTest{ testNAT(t, n, lanIf, wanIf, []natTest{
{ {
@@ -407,14 +419,13 @@ func TestNAT(t *testing.T) {
}) })
t.Run("address_and_port_dependent_mapping", func(t *testing.T) { t.Run("address_and_port_dependent_mapping", func(t *testing.T) {
fw := &Firewall{
TrustedInterface: lanIf,
}
n := &SNAT44{ n := &SNAT44{
Machine: m, Machine: m,
ExternalInterface: wanIf, ExternalInterface: wanIf,
Type: AddressAndPortDependentNAT, Type: AddressAndPortDependentNAT,
Firewall: fw.HandlePacket, Firewall: &Firewall{
TrustedInterface: lanIf,
},
} }
testNAT(t, n, lanIf, wanIf, []natTest{ testNAT(t, n, lanIf, wanIf, []natTest{
{ {
@@ -448,16 +459,7 @@ type natTest struct {
func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest) { func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest) {
clock := &tstest.Clock{} clock := &tstest.Clock{}
injected := make(chan *Packet, 100) // arbitrary
n.TimeNow = clock.Now n.TimeNow = clock.Now
n.inject = func(p *Packet) error {
select {
case injected <- p:
default:
panic("inject overflow")
}
return nil
}
mappings := map[netaddr.IPPort]bool{} mappings := map[netaddr.IPPort]bool{}
for _, test := range tests { for _, test := range tests {
@@ -467,25 +469,18 @@ func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest)
Dst: test.dst, Dst: test.dst,
Payload: []byte("foo"), Payload: []byte("foo"),
} }
gotVerdict := n.HandlePacket(p.Clone(), lanIf) gotPacket := n.HandleForward(p.Clone(), lanIf, wanIf)
if gotVerdict != Drop { if gotPacket == nil {
t.Errorf("p.HandlePacket(%v) = %v, want Drop", p, gotVerdict) t.Errorf("n.HandleForward(%v) dropped packet", p)
} continue
var gotPacket *Packet
select {
default:
t.Errorf("p.HandlePacket(%v) didn't inject expected packet", p)
case gotPacket = <-injected:
} }
if gotPacket.Dst != p.Dst { if gotPacket.Dst != p.Dst {
t.Errorf("p.HandlePacket(%v) mutated dest ip:port, got %v", p, gotPacket.Dst) t.Errorf("n.HandleForward(%v) mutated dest ip:port, got %v", p, gotPacket.Dst)
} }
gotNewMapping := !mappings[gotPacket.Src] gotNewMapping := !mappings[gotPacket.Src]
if gotNewMapping != test.wantNewMapping { if gotNewMapping != test.wantNewMapping {
t.Errorf("p.HandlePacket(%v) mapping was new=%v, want %v", p, gotNewMapping, test.wantNewMapping) t.Errorf("n.HandleForward(%v) mapping was new=%v, want %v", p, gotNewMapping, test.wantNewMapping)
} }
mappings[gotPacket.Src] = true mappings[gotPacket.Src] = true
@@ -497,16 +492,11 @@ func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest)
Dst: gotPacket.Src, Dst: gotPacket.Src,
Payload: []byte("bar"), Payload: []byte("bar"),
} }
gotVerdict = n.HandlePacket(p2.Clone(), wanIf) gotPacket2 := n.HandleIn(p2.Clone(), wanIf)
if gotVerdict != Drop {
t.Errorf("p.HandlePacket(%v) = %v, want Drop", p, gotVerdict)
}
var gotPacket2 *Packet if gotPacket2 == nil {
select { t.Errorf("return packet was dropped")
default: continue
t.Errorf("p.HandlePacket(%v) didn't inject expected packet", p)
case gotPacket2 = <-injected:
} }
if gotPacket2.Src != test.dst { if gotPacket2.Src != test.dst {

View File

@@ -371,15 +371,13 @@ func TestTwoDevicePing(t *testing.T) {
t.Run("facing firewalls", func(t *testing.T) { t.Run("facing firewalls", func(t *testing.T) {
mstun := &natlab.Machine{Name: "stun"} mstun := &natlab.Machine{Name: "stun"}
f1 := &natlab.Firewall{}
f2 := &natlab.Firewall{}
m1 := &natlab.Machine{ m1 := &natlab.Machine{
Name: "m1", Name: "m1",
HandlePacket: f1.HandlePacket, PacketHandler: &natlab.Firewall{},
} }
m2 := &natlab.Machine{ m2 := &natlab.Machine{
Name: "m2", Name: "m2",
HandlePacket: f2.HandlePacket, PacketHandler: &natlab.Firewall{},
} }
inet := natlab.NewInternet() inet := natlab.NewInternet()
sif := mstun.Attach("eth0", inet) sif := mstun.Attach("eth0", inet)