mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-25 10:09:17 +00:00 
			
		
		
		
	tstest/natlab: refactor PacketHandler into a larger interface.
The new interface lets implementors more precisely distinguish local traffic from forwarded traffic, and applies different forwarding logic within Machines for each type. This allows Machines to be packet forwarders, which didn't quite work with the implementation of Inject. Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
		 David Anderson
					David Anderson
				
			
				
					committed by
					
						 Dave Anderson
						Dave Anderson
					
				
			
			
				
	
			
			
			 Dave Anderson
						Dave Anderson
					
				
			
						parent
						
							723b9eecb0
						
					
				
				
					commit
					45578b47f3
				
			| @@ -101,32 +101,57 @@ func (f *Firewall) timeNow() time.Time { | |||||||
| 	return time.Now() | 	return time.Now() | ||||||
| } | } | ||||||
|  |  | ||||||
| // HandlePacket implements the PacketHandler type. | func (f *Firewall) init() { | ||||||
| func (f *Firewall) HandlePacket(p *Packet, inIf *Interface) PacketVerdict { |  | ||||||
| 	f.mu.Lock() |  | ||||||
| 	defer f.mu.Unlock() |  | ||||||
| 	if f.seen == nil { | 	if f.seen == nil { | ||||||
| 		f.seen = map[fwKey]time.Time{} | 		f.seen = map[fwKey]time.Time{} | ||||||
| 	} | 	} | ||||||
| 	if f.SessionTimeout == 0 { | } | ||||||
| 		f.SessionTimeout = 30 * time.Second |  | ||||||
| 	} | func (f *Firewall) HandleOut(p *Packet, oif *Interface) *Packet { | ||||||
|  | 	f.mu.Lock() | ||||||
|  | 	defer f.mu.Unlock() | ||||||
|  | 	f.init() | ||||||
|  |  | ||||||
| 	if inIf == f.TrustedInterface || inIf == nil { |  | ||||||
| 	k := f.Type.key(p.Src, p.Dst) | 	k := f.Type.key(p.Src, p.Dst) | ||||||
| 		f.seen[k] = f.timeNow().Add(f.SessionTimeout) | 	f.seen[k] = f.timeNow().Add(f.sessionTimeoutLocked()) | ||||||
| 	p.Trace("firewall out ok") | 	p.Trace("firewall out ok") | ||||||
| 		return Continue | 	return p | ||||||
| 	} else { | } | ||||||
| 		// reverse src and dst because the session table is from the |  | ||||||
| 		// POV of outbound packets. | func (f *Firewall) HandleIn(p *Packet, iif *Interface) *Packet { | ||||||
|  | 	f.mu.Lock() | ||||||
|  | 	defer f.mu.Unlock() | ||||||
|  | 	f.init() | ||||||
|  |  | ||||||
|  | 	// reverse src and dst because the session table is from the POV | ||||||
|  | 	// of outbound packets. | ||||||
| 	k := f.Type.key(p.Dst, p.Src) | 	k := f.Type.key(p.Dst, p.Src) | ||||||
| 	now := f.timeNow() | 	now := f.timeNow() | ||||||
| 	if now.After(f.seen[k]) { | 	if now.After(f.seen[k]) { | ||||||
| 		p.Trace("firewall drop") | 		p.Trace("firewall drop") | ||||||
| 			return Drop | 		return nil | ||||||
| 	} | 	} | ||||||
| 	p.Trace("firewall in ok") | 	p.Trace("firewall in ok") | ||||||
| 		return Continue | 	return p | ||||||
| 	} | } | ||||||
|  |  | ||||||
|  | func (f *Firewall) HandleForward(p *Packet, iif *Interface, oif *Interface) *Packet { | ||||||
|  | 	if iif == f.TrustedInterface { | ||||||
|  | 		// Treat just like a locally originated packet | ||||||
|  | 		return f.HandleOut(p, oif) | ||||||
|  | 	} | ||||||
|  | 	if oif != f.TrustedInterface { | ||||||
|  | 		// Not a possible return packet from our trusted interface, drop. | ||||||
|  | 		p.Trace("firewall drop, unexpected oif") | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	// Otherwise, a session must exist, same as HandleIn. | ||||||
|  | 	return f.HandleIn(p, iif) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (f *Firewall) sessionTimeoutLocked() time.Duration { | ||||||
|  | 	if f.SessionTimeout == 0 { | ||||||
|  | 		return DefaultSessionTimeout | ||||||
|  | 	} | ||||||
|  | 	return f.SessionTimeout | ||||||
| } | } | ||||||
|   | |||||||
| @@ -99,11 +99,6 @@ type SNAT44 struct { | |||||||
| 	// nil, time.Now is used. | 	// nil, time.Now is used. | ||||||
| 	TimeNow func() time.Time | 	TimeNow func() time.Time | ||||||
|  |  | ||||||
| 	// inject, if not nil, will be invoked instead of Machine.Inject |  | ||||||
| 	// to inject NATed packets into the network. It is used for tests |  | ||||||
| 	// only. |  | ||||||
| 	inject func(*Packet) error |  | ||||||
|  |  | ||||||
| 	mu    sync.Mutex | 	mu    sync.Mutex | ||||||
| 	byLAN map[natKey]*mapping         // lookup by outbound packet tuple | 	byLAN map[natKey]*mapping         // lookup by outbound packet tuple | ||||||
| 	byWAN map[netaddr.IPPort]*mapping // lookup by wan ip:port only | 	byWAN map[netaddr.IPPort]*mapping // lookup by wan ip:port only | ||||||
| @@ -131,64 +126,72 @@ func (n *SNAT44) initLocked() { | |||||||
| 	if n.ExternalInterface.Machine() != n.Machine { | 	if n.ExternalInterface.Machine() != n.Machine { | ||||||
| 		panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name)) | 		panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name)) | ||||||
| 	} | 	} | ||||||
| 	if n.inject == nil { |  | ||||||
| 		n.inject = n.Machine.Inject |  | ||||||
| 	} |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *SNAT44) HandlePacket(p *Packet, inIf *Interface) PacketVerdict { | func (n *SNAT44) HandleOut(p *Packet, oif *Interface) *Packet { | ||||||
|  | 	// NATs don't affect locally originated packets. | ||||||
|  | 	if n.Firewall != nil { | ||||||
|  | 		return n.Firewall.HandleOut(p, oif) | ||||||
|  | 	} | ||||||
|  | 	return p | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet { | ||||||
|  | 	if iif != n.ExternalInterface { | ||||||
|  | 		// NAT can't apply, defer to firewall. | ||||||
|  | 		if n.Firewall != nil { | ||||||
|  | 			return n.Firewall.HandleIn(p, iif) | ||||||
|  | 		} | ||||||
|  | 		return p | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	n.mu.Lock() | 	n.mu.Lock() | ||||||
| 	defer n.mu.Unlock() | 	defer n.mu.Unlock() | ||||||
| 	n.initLocked() | 	n.initLocked() | ||||||
|  |  | ||||||
| 	if inIf == n.ExternalInterface { |  | ||||||
| 		return n.processInboundLocked(p, inIf) |  | ||||||
| 	} else { |  | ||||||
| 		return n.processOutboundLocked(p, inIf) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (n *SNAT44) processInboundLocked(p *Packet, inIf *Interface) PacketVerdict { |  | ||||||
| 	// TODO: packets to local addrs should fall through to local |  | ||||||
| 	// socket processing. |  | ||||||
| 	now := n.timeNow() | 	now := n.timeNow() | ||||||
| 	mapping := n.byWAN[p.Dst] | 	mapping := n.byWAN[p.Dst] | ||||||
| 	if mapping == nil || now.After(mapping.deadline) { | 	if mapping == nil || now.After(mapping.deadline) { | ||||||
| 		p.Trace("nat drop, no mapping/expired mapping") | 		// NAT didn't hit, defer to firewall or allow in for local | ||||||
| 		return Drop | 		// socket handling. | ||||||
| 	} |  | ||||||
| 	p.Dst = mapping.lanSrc |  | ||||||
|  |  | ||||||
| 		if n.Firewall != nil { | 		if n.Firewall != nil { | ||||||
| 		if verdict := n.Firewall(p.Clone(), inIf); verdict == Drop { | 			return n.Firewall.HandleIn(p, iif) | ||||||
| 			return Drop |  | ||||||
| 		} | 		} | ||||||
|  | 		return p | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if err := n.inject(p); err != nil { | 	p.Dst = mapping.lanSrc | ||||||
| 		p.Trace("inject failed: %v", err) | 	p.Trace("dnat to %v", p.Dst) | ||||||
| 	} | 	// Don't process firewall here. We mutated the packet such that | ||||||
| 	return Drop | 	// it's no longer destined locally, so we'll get reinvoked as | ||||||
|  | 	// HandleForward and need to process the altered packet there. | ||||||
|  | 	return p | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *SNAT44) processOutboundLocked(p *Packet, inIf *Interface) PacketVerdict { | func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet { | ||||||
|  | 	switch { | ||||||
|  | 	case oif == n.ExternalInterface: | ||||||
|  | 		if p.Src.IP == oif.V4() { | ||||||
|  | 			// Packet already NATed and is just retraversing Forward, | ||||||
|  | 			// don't touch it again. | ||||||
|  | 			return p | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		if n.Firewall != nil { | 		if n.Firewall != nil { | ||||||
| 		if verdict := n.Firewall(p, inIf); verdict == Drop { | 			p2 := n.Firewall.HandleForward(p, iif, oif) | ||||||
| 			return Drop | 			if p2 == nil { | ||||||
|  | 				// firewall dropped, done | ||||||
|  | 				return nil | ||||||
|  | 			} | ||||||
|  | 			if !p.Equivalent(p2) { | ||||||
|  | 				// firewall mutated packet? Weird, but okay. | ||||||
|  | 				return p2 | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	if inIf == nil { |  | ||||||
| 		// Technically, we don't need to process the outbound firewall | 		n.mu.Lock() | ||||||
| 		// for NATed packets, but our current packet processing API | 		defer n.mu.Unlock() | ||||||
| 		// doesn't give us that granularity: we'll see both locally | 		n.initLocked() | ||||||
| 		// originated PacketConn traffic and NATed traffic as inIf == |  | ||||||
| 		// nil, and we need to apply the firewall to locally |  | ||||||
| 		// originated traffic. This may create some useless state |  | ||||||
| 		// entries in the firewall, but until we implement a much more |  | ||||||
| 		// elaborate packet processing pipeline that can distinguish |  | ||||||
| 		// local vs. forwarded traffic, this is the best we have. |  | ||||||
| 		return Continue |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 		k := n.Type.key(p.Src, p.Dst) | 		k := n.Type.key(p.Src, p.Dst) | ||||||
| 		now := n.timeNow() | 		now := n.timeNow() | ||||||
| @@ -206,12 +209,22 @@ func (n *SNAT44) processOutboundLocked(p *Packet, inIf *Interface) PacketVerdict | |||||||
| 		} | 		} | ||||||
| 		m.deadline = now.Add(n.mappingTimeout()) | 		m.deadline = now.Add(n.mappingTimeout()) | ||||||
| 		p.Src = m.wanSrc | 		p.Src = m.wanSrc | ||||||
|  |  | ||||||
| 		p.Trace("snat from %v", p.Src) | 		p.Trace("snat from %v", p.Src) | ||||||
| 	if err := n.inject(p); err != nil { | 		return p | ||||||
| 		p.Trace("inject failed: %v", err) | 	case iif == n.ExternalInterface: | ||||||
|  | 		// Packet was already un-NAT-ed, we just need to either | ||||||
|  | 		// firewall it or let it through. | ||||||
|  | 		if n.Firewall != nil { | ||||||
|  | 			return n.Firewall.HandleForward(p, iif, oif) | ||||||
|  | 		} | ||||||
|  | 		return p | ||||||
|  | 	default: | ||||||
|  | 		// No NAT applies, invoke firewall or drop. | ||||||
|  | 		if n.Firewall != nil { | ||||||
|  | 			return n.Firewall.HandleForward(p, iif, oif) | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
| 	} | 	} | ||||||
| 	return Drop |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *SNAT44) allocateMappedPort() (net.PacketConn, netaddr.IPPort) { | func (n *SNAT44) allocateMappedPort() (net.PacketConn, netaddr.IPPort) { | ||||||
|   | |||||||
| @@ -12,6 +12,7 @@ | |||||||
| package natlab | package natlab | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/sha256" | 	"crypto/sha256" | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| @@ -40,6 +41,12 @@ type Packet struct { | |||||||
| 	locator string | 	locator string | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Equivalent returns true if Src, Dst and Payload are the same in p | ||||||
|  | // and p2. | ||||||
|  | func (p *Packet) Equivalent(p2 *Packet) bool { | ||||||
|  | 	return p.Src == p2.Src && p.Dst == p2.Dst && bytes.Equal(p.Payload, p2.Payload) | ||||||
|  | } | ||||||
|  |  | ||||||
| // Clone returns a copy of p that shares nothing with p. | // Clone returns a copy of p that shares nothing with p. | ||||||
| func (p *Packet) Clone() *Packet { | func (p *Packet) Clone() *Packet { | ||||||
| 	return &Packet{ | 	return &Packet{ | ||||||
| @@ -266,8 +273,41 @@ func (v PacketVerdict) String() string { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| // A PacketHandler is a function that can process packets. | // A PacketHandler can look at packets arriving at, departing, and | ||||||
| type PacketHandler func(p *Packet, inIf *Interface) PacketVerdict | // transiting a Machine, and filter or mutate them. | ||||||
|  | // | ||||||
|  | // Each method is invoked with a Packet that natlab would like to keep | ||||||
|  | // processing. Handlers can return that same Packet to allow | ||||||
|  | // processing to continue; nil to drop the Packet; or a different | ||||||
|  | // Packet that should be processed instead of the original. | ||||||
|  | // | ||||||
|  | // Packets passed to handlers share no state with anything else, and | ||||||
|  | // are therefore safe to mutate. It's safe to return the original | ||||||
|  | // packet mutated in-place, or a brand new packet initialized from | ||||||
|  | // scratch. | ||||||
|  | // | ||||||
|  | // Packets mutated by a PacketHandler are processed anew by the | ||||||
|  | // associated Machine, as if the packet had always been the mutated | ||||||
|  | // one. For example, if HandleForward is invoked with a Packet, and | ||||||
|  | // the handler changes the destination IP address to one of the | ||||||
|  | // Machine's own IPs, the Machine restarts delivery, but this time | ||||||
|  | // going to a local PacketConn (which in turn will invoke HandleIn, | ||||||
|  | // since the packet is now destined for local delivery). | ||||||
|  | type PacketHandler interface { | ||||||
|  | 	// HandleIn processes a packet arriving on iif, whose destination | ||||||
|  | 	// is an IP address owned by the attached Machine. If p is | ||||||
|  | 	// returned unmodified, the Machine will go on to deliver the | ||||||
|  | 	// Packet to the appropriate listening PacketConn, if one exists. | ||||||
|  | 	HandleIn(p *Packet, iif *Interface) *Packet | ||||||
|  | 	// HandleOut processes a packet about to depart on oif from a | ||||||
|  | 	// local PacketConn. If p is returned unmodified, the Machine will | ||||||
|  | 	// transmit the Packet on oif. | ||||||
|  | 	HandleOut(p *Packet, oif *Interface) *Packet | ||||||
|  | 	// HandleForward is called when the Machine wants to forward a | ||||||
|  | 	// packet from iif to oif. If p is returned unmodified, the | ||||||
|  | 	// Machine will transmit the packet on oif. | ||||||
|  | 	HandleForward(p *Packet, iif, oif *Interface) *Packet | ||||||
|  | } | ||||||
|  |  | ||||||
| // A Machine is a representation of an operating system's network | // A Machine is a representation of an operating system's network | ||||||
| // stack. It has a network routing table and can have multiple | // stack. It has a network routing table and can have multiple | ||||||
| @@ -278,19 +318,14 @@ type Machine struct { | |||||||
| 	// not be globally unique. | 	// not be globally unique. | ||||||
| 	Name string | 	Name string | ||||||
|  |  | ||||||
| 	// HandlePacket, if not nil, is a function that gets invoked for | 	// PacketHandler, if not nil, is a PacketHandler implementation | ||||||
| 	// every packet this Machine receives, and every packet sent by a | 	// that inspects all packets arriving, departing, or transiting | ||||||
| 	// local PacketConn. Returns a verdict for how the packet should | 	// the Machine. See the definition of the PacketHandler interface | ||||||
| 	// continue to be handled (or not). | 	// for semantics. | ||||||
| 	// | 	// | ||||||
| 	// HandlePacket's interface parameter is the interface on which | 	// If PacketHandler is nil, the machine allows all inbound | ||||||
| 	// the packet was received, or nil for a packet sent by a local | 	// traffic, all outbound traffic, and drops forwarded packets. | ||||||
| 	// PacketConn or Inject call. | 	PacketHandler PacketHandler | ||||||
| 	// |  | ||||||
| 	// The packet provided to HandlePacket can safely be mutated and |  | ||||||
| 	// Inject()ed if desired. This can be used to implement things |  | ||||||
| 	// like stateful firewalls and NAT boxes. |  | ||||||
| 	HandlePacket PacketHandler |  | ||||||
|  |  | ||||||
| 	mu         sync.Mutex | 	mu         sync.Mutex | ||||||
| 	interfaces []*Interface | 	interfaces []*Interface | ||||||
| @@ -300,26 +335,42 @@ type Machine struct { | |||||||
| 	conns6 map[netaddr.IPPort]*conn // conns that want IPv6 packets | 	conns6 map[netaddr.IPPort]*conn // conns that want IPv6 packets | ||||||
| } | } | ||||||
|  |  | ||||||
| // Inject transmits p from src to dst, without the need for a local socket. | func (m *Machine) isLocalIP(ip netaddr.IP) bool { | ||||||
| // It's useful for implementing e.g. NAT boxes that need to mangle IPs. | 	m.mu.Lock() | ||||||
| func (m *Machine) Inject(p *Packet) error { | 	defer m.mu.Unlock() | ||||||
| 	p = p.Clone() | 	for _, intf := range m.interfaces { | ||||||
| 	p.setLocator("mach=%s", m.Name) | 		for _, iip := range intf.ips { | ||||||
| 	p.Trace("Machine.Inject") | 			if ip == iip { | ||||||
| 	_, err := m.writePacket(p) | 				return true | ||||||
| 	return err | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
| } | } | ||||||
|  |  | ||||||
| func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) { | func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) { | ||||||
| 	p.setLocator("mach=%s if=%s", m.Name, iface.name) | 	p.setLocator("mach=%s if=%s", m.Name, iface.name) | ||||||
|  |  | ||||||
|  | 	if m.isLocalIP(p.Dst.IP) { | ||||||
|  | 		m.deliverLocalPacket(p, iface) | ||||||
|  | 	} else { | ||||||
|  | 		m.forwardPacket(p, iface) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *Machine) deliverLocalPacket(p *Packet, iface *Interface) { | ||||||
| 	// TODO: can't hold lock while handling packet. This is safe as | 	// TODO: can't hold lock while handling packet. This is safe as | ||||||
| 	// long as you set HandlePacket before traffic starts flowing. | 	// long as you set HandlePacket before traffic starts flowing. | ||||||
| 	if m.HandlePacket != nil { | 	if m.PacketHandler != nil { | ||||||
| 		p.Trace("Machine.HandlePacket") | 		p2 := m.PacketHandler.HandleIn(p.Clone(), iface) | ||||||
| 		verdict := m.HandlePacket(p.Clone(), iface) | 		if p2 == nil { | ||||||
| 		p.Trace("Machine.HandlePacket verdict=%s", verdict) | 			// Packet dropped, nothing left to do. | ||||||
| 		if verdict == Drop { | 			return | ||||||
| 			// Custom packet handler ate the packet, we're done. | 		} | ||||||
|  | 		if !p.Equivalent(p2) { | ||||||
|  | 			// Restart delivery, this packet might be a forward packet | ||||||
|  | 			// now. | ||||||
|  | 			m.deliverIncomingPacket(p2, iface) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| @@ -353,6 +404,35 @@ func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) { | |||||||
| 	p.Trace("dropped, no listening conn") | 	p.Trace("dropped, no listening conn") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (m *Machine) forwardPacket(p *Packet, iif *Interface) { | ||||||
|  | 	oif, err := m.interfaceForIP(p.Dst.IP) | ||||||
|  | 	if err != nil { | ||||||
|  | 		p.Trace("%v", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if m.PacketHandler == nil { | ||||||
|  | 		// Forwarding not allowed by default | ||||||
|  | 		p.Trace("drop, forwarding not allowed") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	p2 := m.PacketHandler.HandleForward(p.Clone(), iif, oif) | ||||||
|  | 	if p2 == nil { | ||||||
|  | 		p.Trace("drop") | ||||||
|  | 		// Packet dropped, done. | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if !p.Equivalent(p2) { | ||||||
|  | 		// Packet changed, restart delivery. | ||||||
|  | 		p2.Trace("PacketHandler mutated packet") | ||||||
|  | 		m.deliverIncomingPacket(p2, iif) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	p.Trace("-> net=%s oif=%s", oif.net.Name, oif) | ||||||
|  | 	oif.net.write(p) | ||||||
|  | } | ||||||
|  |  | ||||||
| func unspecOf(ip netaddr.IP) netaddr.IP { | func unspecOf(ip netaddr.IP) netaddr.IP { | ||||||
| 	if ip.Is4() { | 	if ip.Is4() { | ||||||
| 		return v4unspec | 		return v4unspec | ||||||
| @@ -455,13 +535,17 @@ func (m *Machine) writePacket(p *Packet) (n int, err error) { | |||||||
| 		return 0, err | 		return 0, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if m.HandlePacket != nil { | 	if m.PacketHandler != nil { | ||||||
| 		p.Trace("Machine.HandlePacket") | 		p2 := m.PacketHandler.HandleOut(p.Clone(), iface) | ||||||
| 		verdict := m.HandlePacket(p.Clone(), nil) | 		if p2 == nil { | ||||||
| 		p.Trace("Machine.HandlePacket verdict=%s", verdict) | 			// Packet dropped, done. | ||||||
| 		if verdict == Drop { |  | ||||||
| 			return len(p.Payload), nil | 			return len(p.Payload), nil | ||||||
| 		} | 		} | ||||||
|  | 		if !p.Equivalent(p2) { | ||||||
|  | 			// Restart transmission, src may have changed weirdly | ||||||
|  | 			m.writePacket(p2) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	p.Trace("-> net=%s if=%s", iface.net.Name, iface) | 	p.Trace("-> net=%s if=%s", iface.net.Name, iface) | ||||||
|   | |||||||
| @@ -148,6 +148,38 @@ func TestMultiNetwork(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type trivialNAT struct { | ||||||
|  | 	clientIP     netaddr.IP | ||||||
|  | 	lanIf, wanIf *Interface | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *trivialNAT) HandleIn(p *Packet, iface *Interface) *Packet { | ||||||
|  | 	if iface == n.wanIf && p.Dst.IP == n.wanIf.V4() { | ||||||
|  | 		p.Dst.IP = n.clientIP | ||||||
|  | 	} | ||||||
|  | 	return p | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n trivialNAT) HandleOut(p *Packet, iface *Interface) *Packet { | ||||||
|  | 	return p | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *trivialNAT) HandleForward(p *Packet, iif, oif *Interface) *Packet { | ||||||
|  | 	// Outbound from LAN -> apply NAT, continue | ||||||
|  | 	if iif == n.lanIf && oif == n.wanIf { | ||||||
|  | 		if p.Src.IP == n.clientIP { | ||||||
|  | 			p.Src.IP = n.wanIf.V4() | ||||||
|  | 		} | ||||||
|  | 		return p | ||||||
|  | 	} | ||||||
|  | 	// Return traffic to LAN, allow if right dst. | ||||||
|  | 	if iif == n.wanIf && oif == n.lanIf && p.Dst.IP == n.clientIP { | ||||||
|  | 		return p | ||||||
|  | 	} | ||||||
|  | 	// Else drop. | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestPacketHandler(t *testing.T) { | func TestPacketHandler(t *testing.T) { | ||||||
| 	lan := &Network{ | 	lan := &Network{ | ||||||
| 		Name:    "lan", | 		Name:    "lan", | ||||||
| @@ -167,29 +199,10 @@ func TestPacketHandler(t *testing.T) { | |||||||
|  |  | ||||||
| 	lan.SetDefaultGateway(ifNATLAN) | 	lan.SetDefaultGateway(ifNATLAN) | ||||||
|  |  | ||||||
| 	// This HandlePacket implements a basic (some might say "broken") | 	nat.PacketHandler = &trivialNAT{ | ||||||
| 	// 1:1 NAT, where client's IP gets replaced with the NAT's WAN IP, | 		clientIP: ifClient.V4(), | ||||||
| 	// and vice versa. | 		lanIf:    ifNATLAN, | ||||||
| 	// | 		wanIf:    ifNATWAN, | ||||||
| 	// This NAT is not suitable for actual use, since it doesn't do |  | ||||||
| 	// port remappings or any other things that NATs usually to. But |  | ||||||
| 	// it works as a demonstrator for a single client behind the NAT, |  | ||||||
| 	// where the NAT box itself doesn't also make PacketConns. |  | ||||||
| 	nat.HandlePacket = func(p *Packet, iface *Interface) PacketVerdict { |  | ||||||
| 		switch { |  | ||||||
| 		case p.Dst.IP.Is6(): |  | ||||||
| 			return Continue // no NAT for ipv6 |  | ||||||
| 		case iface == ifNATLAN && p.Src.IP == ifClient.V4(): |  | ||||||
| 			p.Src.IP = ifNATWAN.V4() |  | ||||||
| 			nat.Inject(p) |  | ||||||
| 			return Drop |  | ||||||
| 		case iface == ifNATWAN && p.Dst.IP == ifNATWAN.V4(): |  | ||||||
| 			p.Dst.IP = ifClient.V4() |  | ||||||
| 			nat.Inject(p) |  | ||||||
| 			return Drop |  | ||||||
| 		default: |  | ||||||
| 			return Continue |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ctx := context.Background() | 	ctx := context.Background() | ||||||
| @@ -246,17 +259,17 @@ func TestFirewall(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 		testFirewall(t, f, []fwTest{ | 		testFirewall(t, f, []fwTest{ | ||||||
| 			// client -> A authorizes A -> client | 			// client -> A authorizes A -> client | ||||||
| 			{trust, client, serverA, Continue}, | 			{trust, untrust, client, serverA, true}, | ||||||
| 			{untrust, serverA, client, Continue}, | 			{untrust, trust, serverA, client, true}, | ||||||
| 			{untrust, serverA, client, Continue}, | 			{untrust, trust, serverA, client, true}, | ||||||
|  |  | ||||||
| 			// B1 -> client fails until client -> B1 | 			// B1 -> client fails until client -> B1 | ||||||
| 			{untrust, serverB1, client, Drop}, | 			{untrust, trust, serverB1, client, false}, | ||||||
| 			{trust, client, serverB1, Continue}, | 			{trust, untrust, client, serverB1, true}, | ||||||
| 			{untrust, serverB1, client, Continue}, | 			{untrust, trust, serverB1, client, true}, | ||||||
|  |  | ||||||
| 			// B2 -> client still fails | 			// B2 -> client still fails | ||||||
| 			{untrust, serverB2, client, Drop}, | 			{untrust, trust, serverB2, client, false}, | ||||||
| 		}) | 		}) | ||||||
| 	}) | 	}) | ||||||
| 	t.Run("ip_dependent", func(t *testing.T) { | 	t.Run("ip_dependent", func(t *testing.T) { | ||||||
| @@ -267,17 +280,17 @@ func TestFirewall(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 		testFirewall(t, f, []fwTest{ | 		testFirewall(t, f, []fwTest{ | ||||||
| 			// client -> A authorizes A -> client | 			// client -> A authorizes A -> client | ||||||
| 			{trust, client, serverA, Continue}, | 			{trust, untrust, client, serverA, true}, | ||||||
| 			{untrust, serverA, client, Continue}, | 			{untrust, trust, serverA, client, true}, | ||||||
| 			{untrust, serverA, client, Continue}, | 			{untrust, trust, serverA, client, true}, | ||||||
|  |  | ||||||
| 			// B1 -> client fails until client -> B1 | 			// B1 -> client fails until client -> B1 | ||||||
| 			{untrust, serverB1, client, Drop}, | 			{untrust, trust, serverB1, client, false}, | ||||||
| 			{trust, client, serverB1, Continue}, | 			{trust, untrust, client, serverB1, true}, | ||||||
| 			{untrust, serverB1, client, Continue}, | 			{untrust, trust, serverB1, client, true}, | ||||||
|  |  | ||||||
| 			// B2 -> client also works now | 			// B2 -> client also works now | ||||||
| 			{untrust, serverB2, client, Continue}, | 			{untrust, trust, serverB2, client, true}, | ||||||
| 		}) | 		}) | ||||||
| 	}) | 	}) | ||||||
| 	t.Run("endpoint_independent", func(t *testing.T) { | 	t.Run("endpoint_independent", func(t *testing.T) { | ||||||
| @@ -288,23 +301,23 @@ func TestFirewall(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 		testFirewall(t, f, []fwTest{ | 		testFirewall(t, f, []fwTest{ | ||||||
| 			// client -> A authorizes A -> client | 			// client -> A authorizes A -> client | ||||||
| 			{trust, client, serverA, Continue}, | 			{trust, untrust, client, serverA, true}, | ||||||
| 			{untrust, serverA, client, Continue}, | 			{untrust, trust, serverA, client, true}, | ||||||
| 			{untrust, serverA, client, Continue}, | 			{untrust, trust, serverA, client, true}, | ||||||
|  |  | ||||||
| 			// B1 -> client also works | 			// B1 -> client also works | ||||||
| 			{untrust, serverB1, client, Continue}, | 			{untrust, trust, serverB1, client, true}, | ||||||
|  |  | ||||||
| 			// B2 -> client also works | 			// B2 -> client also works | ||||||
| 			{untrust, serverB2, client, Continue}, | 			{untrust, trust, serverB2, client, true}, | ||||||
| 		}) | 		}) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| type fwTest struct { | type fwTest struct { | ||||||
| 	iface    *Interface | 	iif, oif *Interface | ||||||
| 	src, dst netaddr.IPPort | 	src, dst netaddr.IPPort | ||||||
| 	want     PacketVerdict | 	ok       bool | ||||||
| } | } | ||||||
|  |  | ||||||
| func testFirewall(t *testing.T, f *Firewall, tests []fwTest) { | func testFirewall(t *testing.T, f *Firewall, tests []fwTest) { | ||||||
| @@ -318,9 +331,10 @@ func testFirewall(t *testing.T, f *Firewall, tests []fwTest) { | |||||||
| 			Dst:     test.dst, | 			Dst:     test.dst, | ||||||
| 			Payload: []byte{}, | 			Payload: []byte{}, | ||||||
| 		} | 		} | ||||||
| 		got := f.HandlePacket(p, test.iface) | 		got := f.HandleForward(p, test.iif, test.oif) | ||||||
| 		if got != test.want { | 		gotOK := got != nil | ||||||
| 			t.Errorf("iface=%s src=%s dst=%s got %v, want %v", test.iface.name, test.src, test.dst, got, test.want) | 		if gotOK != test.ok { | ||||||
|  | 			t.Errorf("iif=%s oif=%s src=%s dst=%s got ok=%v, want ok=%v", test.iif, test.oif, test.src, test.dst, gotOK, test.ok) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @@ -344,14 +358,13 @@ func TestNAT(t *testing.T) { | |||||||
| 	lanIf := m.Attach("lan", lan) | 	lanIf := m.Attach("lan", lan) | ||||||
|  |  | ||||||
| 	t.Run("endpoint_independent_mapping", func(t *testing.T) { | 	t.Run("endpoint_independent_mapping", func(t *testing.T) { | ||||||
| 		fw := &Firewall{ |  | ||||||
| 			TrustedInterface: lanIf, |  | ||||||
| 		} |  | ||||||
| 		n := &SNAT44{ | 		n := &SNAT44{ | ||||||
| 			Machine:           m, | 			Machine:           m, | ||||||
| 			ExternalInterface: wanIf, | 			ExternalInterface: wanIf, | ||||||
| 			Type:              EndpointIndependentNAT, | 			Type:              EndpointIndependentNAT, | ||||||
| 			Firewall:          fw.HandlePacket, | 			Firewall: &Firewall{ | ||||||
|  | 				TrustedInterface: lanIf, | ||||||
|  | 			}, | ||||||
| 		} | 		} | ||||||
| 		testNAT(t, n, lanIf, wanIf, []natTest{ | 		testNAT(t, n, lanIf, wanIf, []natTest{ | ||||||
| 			{ | 			{ | ||||||
| @@ -373,14 +386,13 @@ func TestNAT(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("address_dependent_mapping", func(t *testing.T) { | 	t.Run("address_dependent_mapping", func(t *testing.T) { | ||||||
| 		fw := &Firewall{ |  | ||||||
| 			TrustedInterface: lanIf, |  | ||||||
| 		} |  | ||||||
| 		n := &SNAT44{ | 		n := &SNAT44{ | ||||||
| 			Machine:           m, | 			Machine:           m, | ||||||
| 			ExternalInterface: wanIf, | 			ExternalInterface: wanIf, | ||||||
| 			Type:              AddressDependentNAT, | 			Type:              AddressDependentNAT, | ||||||
| 			Firewall:          fw.HandlePacket, | 			Firewall: &Firewall{ | ||||||
|  | 				TrustedInterface: lanIf, | ||||||
|  | 			}, | ||||||
| 		} | 		} | ||||||
| 		testNAT(t, n, lanIf, wanIf, []natTest{ | 		testNAT(t, n, lanIf, wanIf, []natTest{ | ||||||
| 			{ | 			{ | ||||||
| @@ -407,14 +419,13 @@ func TestNAT(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("address_and_port_dependent_mapping", func(t *testing.T) { | 	t.Run("address_and_port_dependent_mapping", func(t *testing.T) { | ||||||
| 		fw := &Firewall{ |  | ||||||
| 			TrustedInterface: lanIf, |  | ||||||
| 		} |  | ||||||
| 		n := &SNAT44{ | 		n := &SNAT44{ | ||||||
| 			Machine:           m, | 			Machine:           m, | ||||||
| 			ExternalInterface: wanIf, | 			ExternalInterface: wanIf, | ||||||
| 			Type:              AddressAndPortDependentNAT, | 			Type:              AddressAndPortDependentNAT, | ||||||
| 			Firewall:          fw.HandlePacket, | 			Firewall: &Firewall{ | ||||||
|  | 				TrustedInterface: lanIf, | ||||||
|  | 			}, | ||||||
| 		} | 		} | ||||||
| 		testNAT(t, n, lanIf, wanIf, []natTest{ | 		testNAT(t, n, lanIf, wanIf, []natTest{ | ||||||
| 			{ | 			{ | ||||||
| @@ -448,16 +459,7 @@ type natTest struct { | |||||||
|  |  | ||||||
| func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest) { | func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest) { | ||||||
| 	clock := &tstest.Clock{} | 	clock := &tstest.Clock{} | ||||||
| 	injected := make(chan *Packet, 100) // arbitrary |  | ||||||
| 	n.TimeNow = clock.Now | 	n.TimeNow = clock.Now | ||||||
| 	n.inject = func(p *Packet) error { |  | ||||||
| 		select { |  | ||||||
| 		case injected <- p: |  | ||||||
| 		default: |  | ||||||
| 			panic("inject overflow") |  | ||||||
| 		} |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	mappings := map[netaddr.IPPort]bool{} | 	mappings := map[netaddr.IPPort]bool{} | ||||||
| 	for _, test := range tests { | 	for _, test := range tests { | ||||||
| @@ -467,25 +469,18 @@ func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest) | |||||||
| 			Dst:     test.dst, | 			Dst:     test.dst, | ||||||
| 			Payload: []byte("foo"), | 			Payload: []byte("foo"), | ||||||
| 		} | 		} | ||||||
| 		gotVerdict := n.HandlePacket(p.Clone(), lanIf) | 		gotPacket := n.HandleForward(p.Clone(), lanIf, wanIf) | ||||||
| 		if gotVerdict != Drop { | 		if gotPacket == nil { | ||||||
| 			t.Errorf("p.HandlePacket(%v) = %v, want Drop", p, gotVerdict) | 			t.Errorf("n.HandleForward(%v) dropped packet", p) | ||||||
| 		} | 			continue | ||||||
|  |  | ||||||
| 		var gotPacket *Packet |  | ||||||
|  |  | ||||||
| 		select { |  | ||||||
| 		default: |  | ||||||
| 			t.Errorf("p.HandlePacket(%v) didn't inject expected packet", p) |  | ||||||
| 		case gotPacket = <-injected: |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if gotPacket.Dst != p.Dst { | 		if gotPacket.Dst != p.Dst { | ||||||
| 			t.Errorf("p.HandlePacket(%v) mutated dest ip:port, got %v", p, gotPacket.Dst) | 			t.Errorf("n.HandleForward(%v) mutated dest ip:port, got %v", p, gotPacket.Dst) | ||||||
| 		} | 		} | ||||||
| 		gotNewMapping := !mappings[gotPacket.Src] | 		gotNewMapping := !mappings[gotPacket.Src] | ||||||
| 		if gotNewMapping != test.wantNewMapping { | 		if gotNewMapping != test.wantNewMapping { | ||||||
| 			t.Errorf("p.HandlePacket(%v) mapping was new=%v, want %v", p, gotNewMapping, test.wantNewMapping) | 			t.Errorf("n.HandleForward(%v) mapping was new=%v, want %v", p, gotNewMapping, test.wantNewMapping) | ||||||
| 		} | 		} | ||||||
| 		mappings[gotPacket.Src] = true | 		mappings[gotPacket.Src] = true | ||||||
|  |  | ||||||
| @@ -497,16 +492,11 @@ func testNAT(t *testing.T, n *SNAT44, lanIf, wanIf *Interface, tests []natTest) | |||||||
| 			Dst:     gotPacket.Src, | 			Dst:     gotPacket.Src, | ||||||
| 			Payload: []byte("bar"), | 			Payload: []byte("bar"), | ||||||
| 		} | 		} | ||||||
| 		gotVerdict = n.HandlePacket(p2.Clone(), wanIf) | 		gotPacket2 := n.HandleIn(p2.Clone(), wanIf) | ||||||
| 		if gotVerdict != Drop { |  | ||||||
| 			t.Errorf("p.HandlePacket(%v) = %v, want Drop", p, gotVerdict) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		var gotPacket2 *Packet | 		if gotPacket2 == nil { | ||||||
| 		select { | 			t.Errorf("return packet was dropped") | ||||||
| 		default: | 			continue | ||||||
| 			t.Errorf("p.HandlePacket(%v) didn't inject expected packet", p) |  | ||||||
| 		case gotPacket2 = <-injected: |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if gotPacket2.Src != test.dst { | 		if gotPacket2.Src != test.dst { | ||||||
|   | |||||||
| @@ -371,15 +371,13 @@ func TestTwoDevicePing(t *testing.T) { | |||||||
|  |  | ||||||
| 		t.Run("facing firewalls", func(t *testing.T) { | 		t.Run("facing firewalls", func(t *testing.T) { | ||||||
| 			mstun := &natlab.Machine{Name: "stun"} | 			mstun := &natlab.Machine{Name: "stun"} | ||||||
| 			f1 := &natlab.Firewall{} |  | ||||||
| 			f2 := &natlab.Firewall{} |  | ||||||
| 			m1 := &natlab.Machine{ | 			m1 := &natlab.Machine{ | ||||||
| 				Name:          "m1", | 				Name:          "m1", | ||||||
| 				HandlePacket: f1.HandlePacket, | 				PacketHandler: &natlab.Firewall{}, | ||||||
| 			} | 			} | ||||||
| 			m2 := &natlab.Machine{ | 			m2 := &natlab.Machine{ | ||||||
| 				Name:          "m2", | 				Name:          "m2", | ||||||
| 				HandlePacket: f2.HandlePacket, | 				PacketHandler: &natlab.Firewall{}, | ||||||
| 			} | 			} | ||||||
| 			inet := natlab.NewInternet() | 			inet := natlab.NewInternet() | ||||||
| 			sif := mstun.Attach("eth0", inet) | 			sif := mstun.Attach("eth0", inet) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user