mirror of
https://github.com/tailscale/tailscale.git
synced 2025-07-10 23:58:44 +00:00
tstest/natlab: refactor, expose a Packet type.
HandlePacket and Inject now receive/take Packets. This is a handy container for the packet, and the attached Trace method can be used to print traces from custom packet handlers that integrate nicely with natlab's internal traces. Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
parent
5eedbcedd1
commit
b3d65ba943
@ -43,7 +43,7 @@ func (f *Firewall) timeNow() time.Time {
|
|||||||
return time.Now()
|
return time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Firewall) HandlePacket(p []byte, inIf *Interface, dst, src netaddr.IPPort) PacketVerdict {
|
func (f *Firewall) HandlePacket(p *Packet, inIf *Interface) PacketVerdict {
|
||||||
f.mu.Lock()
|
f.mu.Lock()
|
||||||
defer f.mu.Unlock()
|
defer f.mu.Unlock()
|
||||||
if f.seen == nil {
|
if f.seen == nil {
|
||||||
@ -52,25 +52,25 @@ func (f *Firewall) HandlePacket(p []byte, inIf *Interface, dst, src netaddr.IPPo
|
|||||||
|
|
||||||
if inIf == f.TrustedInterface {
|
if inIf == f.TrustedInterface {
|
||||||
sess := session{
|
sess := session{
|
||||||
src: src,
|
src: p.Src,
|
||||||
dst: dst,
|
dst: p.Dst,
|
||||||
}
|
}
|
||||||
f.seen[sess] = f.timeNow().Add(f.SessionTimeout)
|
f.seen[sess] = f.timeNow().Add(f.SessionTimeout)
|
||||||
trace(p, "mach=%s iface=%s src=%s dst=%s firewall out ok", inIf.Machine().Name, inIf.name, src, dst)
|
p.Trace("firewall out ok")
|
||||||
return Continue
|
return Continue
|
||||||
} else {
|
} else {
|
||||||
// reverse src and dst because the session table is from the
|
// reverse src and dst because the session table is from the
|
||||||
// POV of outbound packets.
|
// POV of outbound packets.
|
||||||
sess := session{
|
sess := session{
|
||||||
src: dst,
|
src: p.Dst,
|
||||||
dst: src,
|
dst: p.Src,
|
||||||
}
|
}
|
||||||
now := f.timeNow()
|
now := f.timeNow()
|
||||||
if now.After(f.seen[sess]) {
|
if now.After(f.seen[sess]) {
|
||||||
trace(p, "mach=%s iface=%s src=%s dst=%s firewall drop", inIf.Machine().Name, inIf.name, src, dst)
|
p.Trace("firewall drop")
|
||||||
return Drop
|
return Drop
|
||||||
}
|
}
|
||||||
trace(p, "mach=%s iface=%s src=%s dst=%s firewall in ok", inIf.Machine().Name, inIf.name, src, dst)
|
p.Trace("firewall in ok")
|
||||||
return Continue
|
return Continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -30,21 +30,49 @@ import (
|
|||||||
|
|
||||||
var traceOn, _ = strconv.ParseBool(os.Getenv("NATLAB_TRACE"))
|
var traceOn, _ = strconv.ParseBool(os.Getenv("NATLAB_TRACE"))
|
||||||
|
|
||||||
func trace(p []byte, msg string, args ...interface{}) {
|
// Packet represents a UDP packet flowing through the virtual network.
|
||||||
|
type Packet struct {
|
||||||
|
Src, Dst netaddr.IPPort
|
||||||
|
Payload []byte
|
||||||
|
|
||||||
|
// Prefix set by various internal methods of natlab, to locate
|
||||||
|
// where in the network a trace occured.
|
||||||
|
locator string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone returns a copy of p that shares nothing with p.
|
||||||
|
func (p *Packet) Clone() *Packet {
|
||||||
|
return &Packet{
|
||||||
|
Src: p.Src,
|
||||||
|
Dst: p.Dst,
|
||||||
|
Payload: append([]byte(nil), p.Payload...),
|
||||||
|
locator: p.locator,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// short returns a short identifier for a packet payload,
|
||||||
|
// suitable for printing trace information.
|
||||||
|
func (p *Packet) short() string {
|
||||||
|
s := sha256.Sum256(p.Payload)
|
||||||
|
payload := base64.RawStdEncoding.EncodeToString(s[:])[:2]
|
||||||
|
|
||||||
|
s = sha256.Sum256([]byte(p.Src.String() + "_" + p.Dst.String()))
|
||||||
|
tuple := base64.RawStdEncoding.EncodeToString(s[:])[:2]
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s/%s", payload, tuple)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Packet) Trace(msg string, args ...interface{}) {
|
||||||
if !traceOn {
|
if !traceOn {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
id := packetShort(p)
|
allArgs := []interface{}{p.short(), p.locator, p.Src, p.Dst}
|
||||||
as := []interface{}{id}
|
allArgs = append(allArgs, args...)
|
||||||
as = append(as, args...)
|
fmt.Fprintf(os.Stderr, "[%s]%s src=%s dst=%s "+msg+"\n", allArgs...)
|
||||||
fmt.Fprintf(os.Stderr, "[%s] "+msg+"\n", as...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// packetShort returns a short identifier for a packet payload,
|
func (p *Packet) setLocator(msg string, args ...interface{}) {
|
||||||
// suitable for pritning trace information.
|
p.locator = fmt.Sprintf(" "+msg, args...)
|
||||||
func packetShort(p []byte) string {
|
|
||||||
s := sha256.Sum256(p)
|
|
||||||
return base64.RawStdEncoding.EncodeToString(s[:])[:4]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustPrefix(s string) netaddr.IPPrefix {
|
func mustPrefix(s string) netaddr.IPPrefix {
|
||||||
@ -79,6 +107,9 @@ type Network struct {
|
|||||||
func (n *Network) SetDefaultGateway(gwIf *Interface) {
|
func (n *Network) SetDefaultGateway(gwIf *Interface) {
|
||||||
n.mu.Lock()
|
n.mu.Lock()
|
||||||
defer n.mu.Unlock()
|
defer n.mu.Unlock()
|
||||||
|
if gwIf.net != n {
|
||||||
|
panic(fmt.Sprintf("can't set if=%s as net=%s's default gw, if not connected to net", gwIf.name, gwIf.net.Name))
|
||||||
|
}
|
||||||
n.defaultGW = gwIf
|
n.defaultGW = gwIf
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,24 +170,25 @@ func addOne(a *[16]byte, index int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Network) write(p []byte, dst, src netaddr.IPPort) (num int, err error) {
|
func (n *Network) write(p *Packet) (num int, err error) {
|
||||||
|
p.setLocator("net=%s", n.Name)
|
||||||
|
|
||||||
n.mu.Lock()
|
n.mu.Lock()
|
||||||
defer n.mu.Unlock()
|
defer n.mu.Unlock()
|
||||||
iface, ok := n.machine[dst.IP]
|
iface, ok := n.machine[p.Dst.IP]
|
||||||
if !ok {
|
if !ok {
|
||||||
if n.defaultGW == nil {
|
if n.defaultGW == nil {
|
||||||
trace(p, "net=%s dropped, no route to %v", n.Name, dst.IP)
|
p.Trace("no route to %v", p.Dst.IP)
|
||||||
return len(p), nil
|
return len(p.Payload), nil
|
||||||
}
|
}
|
||||||
iface = n.defaultGW
|
iface = n.defaultGW
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pretend it went across the network. Make a copy so nobody
|
// Pretend it went across the network. Make a copy so nobody
|
||||||
// can later mess with caller's memory.
|
// can later mess with caller's memory.
|
||||||
trace(p, "net=%s src=%v dst=%v -> mach=%s iface=%s", n.Name, src, dst, iface.machine.Name, iface.name)
|
p.Trace("-> mach=%s if=%s", iface.machine.Name, iface.name)
|
||||||
pcopy := append([]byte(nil), p...)
|
go iface.machine.deliverIncomingPacket(p, iface)
|
||||||
go iface.machine.deliverIncomingPacket(pcopy, iface, dst, src)
|
return len(p.Payload), nil
|
||||||
return len(p), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Interface struct {
|
type Interface struct {
|
||||||
@ -235,7 +267,7 @@ func (v PacketVerdict) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A PacketHandler is a function that can process packets.
|
// A PacketHandler is a function that can process packets.
|
||||||
type PacketHandler func(p []byte, inIf *Interface, dst, src netaddr.IPPort) PacketVerdict
|
type PacketHandler func(p *Packet, inIf *Interface) PacketVerdict
|
||||||
|
|
||||||
// A Machine is a representation of an operating system's network
|
// A Machine is a representation of an operating system's network
|
||||||
// stack. It has a network routing table and can have multiple
|
// stack. It has a network routing table and can have multiple
|
||||||
@ -250,8 +282,9 @@ type Machine struct {
|
|||||||
// every packet this Machine receives. Returns a verdict for how
|
// every packet this Machine receives. Returns a verdict for how
|
||||||
// the packet should continue to be handled (or not).
|
// the packet should continue to be handled (or not).
|
||||||
//
|
//
|
||||||
// This can be used to implement things like stateful firewalls
|
// The packet provided to HandlePacket can safely be mutated and
|
||||||
// and NAT boxes.
|
// Inject()ed if desired. This can be used to implement things
|
||||||
|
// like stateful firewalls and NAT boxes.
|
||||||
HandlePacket PacketHandler
|
HandlePacket PacketHandler
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@ -264,18 +297,22 @@ type Machine struct {
|
|||||||
|
|
||||||
// Inject transmits p from src to dst, without the need for a local socket.
|
// Inject transmits p from src to dst, without the need for a local socket.
|
||||||
// It's useful for implementing e.g. NAT boxes that need to mangle IPs.
|
// It's useful for implementing e.g. NAT boxes that need to mangle IPs.
|
||||||
func (m *Machine) Inject(p []byte, dst, src netaddr.IPPort) error {
|
func (m *Machine) Inject(p *Packet) error {
|
||||||
trace(p, "mach=%s src=%s dst=%s packet injected", m.Name, src, dst)
|
p = p.Clone()
|
||||||
_, err := m.writePacket(p, dst, src)
|
p.setLocator("mach=%s", m.Name)
|
||||||
|
p.Trace("Machine.Inject")
|
||||||
|
_, err := m.writePacket(p)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src netaddr.IPPort) {
|
func (m *Machine) deliverIncomingPacket(p *Packet, iface *Interface) {
|
||||||
|
p.setLocator("mach=%s if=%s", m.Name, iface.name)
|
||||||
// TODO: can't hold lock while handling packet. This is safe as
|
// TODO: can't hold lock while handling packet. This is safe as
|
||||||
// long as you set HandlePacket before traffic starts flowing.
|
// long as you set HandlePacket before traffic starts flowing.
|
||||||
if m.HandlePacket != nil {
|
if m.HandlePacket != nil {
|
||||||
verdict := m.HandlePacket(p, iface, dst, src)
|
p.Trace("Machine.HandlePacket")
|
||||||
trace(p, "mach=%s src=%v dst=%v packethandler verdict=%s", m.Name, src, dst, verdict)
|
verdict := m.HandlePacket(p.Clone(), iface)
|
||||||
|
p.Trace("Machine.HandlePacket verdict=%s", verdict)
|
||||||
if verdict == Drop {
|
if verdict == Drop {
|
||||||
// Custom packet handler ate the packet, we're done.
|
// Custom packet handler ate the packet, we're done.
|
||||||
return
|
return
|
||||||
@ -286,13 +323,13 @@ func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src net
|
|||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
conns := m.conns4
|
conns := m.conns4
|
||||||
if dst.IP.Is6() {
|
if p.Dst.IP.Is6() {
|
||||||
conns = m.conns6
|
conns = m.conns6
|
||||||
}
|
}
|
||||||
possibleDsts := []netaddr.IPPort{
|
possibleDsts := []netaddr.IPPort{
|
||||||
dst,
|
p.Dst,
|
||||||
netaddr.IPPort{IP: v6unspec, Port: dst.Port},
|
netaddr.IPPort{IP: v6unspec, Port: p.Dst.Port},
|
||||||
netaddr.IPPort{IP: v4unspec, Port: dst.Port},
|
netaddr.IPPort{IP: v4unspec, Port: p.Dst.Port},
|
||||||
}
|
}
|
||||||
for _, dest := range possibleDsts {
|
for _, dest := range possibleDsts {
|
||||||
c, ok := conns[dest]
|
c, ok := conns[dest]
|
||||||
@ -300,15 +337,15 @@ func (m *Machine) deliverIncomingPacket(p []byte, iface *Interface, dst, src net
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case c.in <- incomingPacket{src: src, p: p}:
|
case c.in <- p:
|
||||||
trace(p, "mach=%s src=%v dst=%v queued to conn", m.Name, src, dst)
|
p.Trace("queued to conn")
|
||||||
default:
|
default:
|
||||||
trace(p, "mach=%s src=%v dst=%v dropped, queue overflow", m.Name, src, dst)
|
p.Trace("dropped, queue overflow")
|
||||||
// Queue overflow. Just drop it.
|
// Queue overflow. Just drop it.
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
trace(p, "mach=%s src=%v dst=%v dropped, no listening conn", m.Name, src, dst)
|
p.Trace("dropped, no listening conn")
|
||||||
}
|
}
|
||||||
|
|
||||||
func unspecOf(ip netaddr.IP) netaddr.IP {
|
func unspecOf(ip netaddr.IP) netaddr.IP {
|
||||||
@ -378,38 +415,43 @@ var (
|
|||||||
v6unspec = netaddr.IPv6Unspecified()
|
v6unspec = netaddr.IPv6Unspecified()
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *Machine) writePacket(p []byte, dst, src netaddr.IPPort) (n int, err error) {
|
func (m *Machine) writePacket(p *Packet) (n int, err error) {
|
||||||
iface, err := m.interfaceForIP(dst.IP)
|
p.setLocator("mach=%s", m.Name)
|
||||||
|
|
||||||
|
iface, err := m.interfaceForIP(p.Dst.IP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
trace(p, "%v", err)
|
p.Trace("%v", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
origSrcIP := src.IP
|
origSrcIP := p.Src.IP
|
||||||
switch {
|
switch {
|
||||||
case src.IP == v4unspec:
|
case p.Src.IP == v4unspec:
|
||||||
src.IP = iface.V4()
|
p.Trace("assigning srcIP=%s", iface.V4())
|
||||||
case src.IP == v6unspec:
|
p.Src.IP = iface.V4()
|
||||||
|
case p.Src.IP == v6unspec:
|
||||||
// v6unspec in Go means "any src, but match address families"
|
// v6unspec in Go means "any src, but match address families"
|
||||||
if dst.IP.Is6() {
|
if p.Dst.IP.Is6() {
|
||||||
src.IP = iface.V6()
|
p.Trace("assigning srcIP=%s", iface.V6())
|
||||||
} else if dst.IP.Is4() {
|
p.Src.IP = iface.V6()
|
||||||
src.IP = iface.V4()
|
} else if p.Dst.IP.Is4() {
|
||||||
|
p.Trace("assigning srcIP=%s", iface.V4())
|
||||||
|
p.Src.IP = iface.V4()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
if !iface.Contains(src.IP) {
|
if !iface.Contains(p.Src.IP) {
|
||||||
err := fmt.Errorf("can't send to %v with src %v on interface %v", dst.IP, src.IP, iface)
|
err := fmt.Errorf("can't send to %v with src %v on interface %v", p.Dst.IP, p.Src.IP, iface)
|
||||||
trace(p, "%v", err)
|
p.Trace("%v", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if src.IP.IsZero() {
|
if p.Src.IP.IsZero() {
|
||||||
err := fmt.Errorf("no matching address for address family for %v", origSrcIP)
|
err := fmt.Errorf("no matching address for address family for %v", origSrcIP)
|
||||||
trace(p, "%v", err)
|
p.Trace("%v", err)
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
trace(p, "mach=%s src=%s dst=%s -> net=%s", m.Name, src, dst, iface.net.Name)
|
p.Trace("-> net=%s if=%s", iface.net.Name, iface)
|
||||||
return iface.net.write(p, dst, src)
|
return iface.net.write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Machine) interfaceForIP(ip netaddr.IP) (*Interface, error) {
|
func (m *Machine) interfaceForIP(ip netaddr.IP) (*Interface, error) {
|
||||||
@ -552,7 +594,7 @@ func (m *Machine) ListenPacket(ctx context.Context, network, address string) (ne
|
|||||||
m: m,
|
m: m,
|
||||||
fam: fam,
|
fam: fam,
|
||||||
ipp: ipp,
|
ipp: ipp,
|
||||||
in: make(chan incomingPacket, 100), // arbitrary
|
in: make(chan *Packet, 100), // arbitrary
|
||||||
}
|
}
|
||||||
switch c.fam {
|
switch c.fam {
|
||||||
case 0:
|
case 0:
|
||||||
@ -585,12 +627,7 @@ type conn struct {
|
|||||||
closed bool
|
closed bool
|
||||||
readDeadline time.Time
|
readDeadline time.Time
|
||||||
activeReads map[*activeRead]bool
|
activeReads map[*activeRead]bool
|
||||||
in chan incomingPacket
|
in chan *Packet
|
||||||
}
|
|
||||||
|
|
||||||
type incomingPacket struct {
|
|
||||||
p []byte
|
|
||||||
src netaddr.IPPort
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type activeRead struct {
|
type activeRead struct {
|
||||||
@ -669,9 +706,9 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case pkt := <-c.in:
|
case pkt := <-c.in:
|
||||||
n = copy(p, pkt.p)
|
n = copy(p, pkt.Payload)
|
||||||
trace(pkt.p, "mach=%s src=%s PacketConn.ReadFrom", c.m.Name, pkt.src)
|
pkt.Trace("PacketConn.ReadFrom")
|
||||||
return n, pkt.src.UDPAddr(), nil
|
return n, pkt.Src.UDPAddr(), nil
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return 0, nil, context.DeadlineExceeded
|
return 0, nil, context.DeadlineExceeded
|
||||||
}
|
}
|
||||||
@ -682,7 +719,14 @@ func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("bogus addr %T %q", addr, addr.String())
|
return 0, fmt.Errorf("bogus addr %T %q", addr, addr.String())
|
||||||
}
|
}
|
||||||
return c.m.writePacket(p, ipp, c.ipp)
|
pkt := &Packet{
|
||||||
|
Src: c.ipp,
|
||||||
|
Dst: ipp,
|
||||||
|
Payload: append([]byte(nil), p...),
|
||||||
|
}
|
||||||
|
pkt.setLocator("mach=%s", c.m.Name)
|
||||||
|
pkt.Trace("PacketConn.WriteTo")
|
||||||
|
return c.m.writePacket(pkt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) SetDeadline(t time.Time) error {
|
func (c *conn) SetDeadline(t time.Time) error {
|
||||||
|
@ -175,15 +175,17 @@ func TestPacketHandler(t *testing.T) {
|
|||||||
// port remappings or any other things that NATs usually to. But
|
// port remappings or any other things that NATs usually to. But
|
||||||
// it works as a demonstrator for a single client behind the NAT,
|
// it works as a demonstrator for a single client behind the NAT,
|
||||||
// where the NAT box itself doesn't also make PacketConns.
|
// where the NAT box itself doesn't also make PacketConns.
|
||||||
nat.HandlePacket = func(p []byte, iface *Interface, dst, src netaddr.IPPort) PacketVerdict {
|
nat.HandlePacket = func(p *Packet, iface *Interface) PacketVerdict {
|
||||||
switch {
|
switch {
|
||||||
case dst.IP.Is6():
|
case p.Dst.IP.Is6():
|
||||||
return Continue // no NAT for ipv6
|
return Continue // no NAT for ipv6
|
||||||
case iface == ifNATLAN && src.IP == ifClient.V4():
|
case iface == ifNATLAN && p.Src.IP == ifClient.V4():
|
||||||
nat.Inject(p, dst, netaddr.IPPort{IP: ifNATWAN.V4(), Port: src.Port})
|
p.Src.IP = ifNATWAN.V4()
|
||||||
|
nat.Inject(p)
|
||||||
return Drop
|
return Drop
|
||||||
case iface == ifNATWAN && dst.IP == ifNATWAN.V4():
|
case iface == ifNATWAN && p.Dst.IP == ifNATWAN.V4():
|
||||||
nat.Inject(p, netaddr.IPPort{IP: ifClient.V4(), Port: dst.Port}, src)
|
p.Dst.IP = ifClient.V4()
|
||||||
|
nat.Inject(p)
|
||||||
return Drop
|
return Drop
|
||||||
default:
|
default:
|
||||||
return Continue
|
return Continue
|
||||||
@ -257,7 +259,12 @@ func TestFirewall(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
clock.Advance(time.Second)
|
clock.Advance(time.Second)
|
||||||
got := f.HandlePacket(nil, test.iface, test.dst, test.src)
|
p := &Packet{
|
||||||
|
Src: test.src,
|
||||||
|
Dst: test.dst,
|
||||||
|
Payload: []byte{},
|
||||||
|
}
|
||||||
|
got := f.HandlePacket(p, test.iface)
|
||||||
if got != test.want {
|
if got != test.want {
|
||||||
t.Errorf("iface=%s src=%s dst=%s got %v, want %v", test.iface.name, test.src, test.dst, got, test.want)
|
t.Errorf("iface=%s src=%s dst=%s got %v, want %v", test.iface.name, test.src, test.dst, got, test.want)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user