From 013ea64e941d86b99ef534fef9da59005939f33e Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 27 Jul 2024 11:38:00 -0700 Subject: [PATCH] move more to network, hardcode less Change-Id: If1c773153f7f3fa7ea483d1b7231193ab093278a Signed-off-by: Brad Fitzpatrick --- natlab/natlabd/natlabd.go | 150 ++++++++++++++++++++++---------------- 1 file changed, 87 insertions(+), 63 deletions(-) diff --git a/natlab/natlabd/natlabd.go b/natlab/natlabd/natlabd.go index 3ae881579..f89571412 100644 --- a/natlab/natlabd/natlabd.go +++ b/natlab/natlabd/natlabd.go @@ -53,11 +53,12 @@ func main() { // Hard-coded world shape for me. net1 := &network{ + s: s, mac: MAC{0x52, 0x54, 0x00, 0x01, 0x01, 0x01}, wanIP: netip.MustParseAddr("2.1.1.1"), lanIP: netip.MustParsePrefix("192.168.2.1/24"), } - s.nodes[client1mac] = &node{ + s.nodes[MAC{0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee}] = &node{ net: net1, lanIP: netip.MustParseAddr("192.168.2.102"), } @@ -96,6 +97,11 @@ func (s *Server) checkWorld() error { if n.net.nodesByIP == nil { n.net.nodesByIP = map[netip.Addr]*node{} } + if n.net.ns == nil { + if err := n.net.initStack(); err != nil { + return fmt.Errorf("newServer: initStack: %v", err) + } + } if _, ok := n.net.nodesByIP[n.lanIP]; ok { return fmt.Errorf("node %v has duplicate LAN IP %v", mac, n.lanIP) } @@ -104,8 +110,8 @@ func (s *Server) checkWorld() error { return nil } -func (s *Server) initStack() error { - s.ns = stack.New(stack.Options{ +func (n *network) initStack() error { + n.ns = stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, arp.NewProtocol, @@ -116,20 +122,20 @@ func (s *Server) initStack() error { }, }) sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default - tcpipErr := s.ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) + tcpipErr := n.ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) if tcpipErr != nil { return fmt.Errorf("SetTransportProtocolOption SACK: %v", tcpipErr) } - s.linkEP = channel.New(512, 1500, tcpip.LinkAddress(gwMACTOREMOVE)) - if tcpipProblem := s.ns.CreateNIC(nicID, s.linkEP); tcpipProblem != nil { + n.linkEP = channel.New(512, 1500, tcpip.LinkAddress(n.mac.HWAddr())) + if tcpipProblem := n.ns.CreateNIC(nicID, n.linkEP); tcpipProblem != nil { return fmt.Errorf("CreateNIC: %v", tcpipProblem) } - s.ns.SetPromiscuousMode(nicID, true) - s.ns.SetSpoofing(nicID, true) + n.ns.SetPromiscuousMode(nicID, true) + n.ns.SetSpoofing(nicID, true) - prefix := tcpip.AddrFrom4Slice(gwIP.AsSlice()).WithPrefix() - prefix.PrefixLen = 24 - if tcpProb := s.ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ + prefix := tcpip.AddrFrom4Slice(n.lanIP.Addr().AsSlice()).WithPrefix() + prefix.PrefixLen = n.lanIP.Bits() + if tcpProb := n.ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: prefix, }, stack.AddressProperties{}); tcpProb != nil { @@ -140,7 +146,7 @@ func (s *Server) initStack() error { if err != nil { return fmt.Errorf("could not create IPv4 subnet: %v", err) } - s.ns.SetRouteTable([]tcpip.Route{ + n.ns.SetRouteTable([]tcpip.Route{ { Destination: ipv4Subnet, NIC: nicID, @@ -149,17 +155,17 @@ func (s *Server) initStack() error { const tcpReceiveBufferSize = 0 // default const maxInFlightConnectionAttempts = 8192 - tcpFwd := tcp.NewForwarder(s.ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.acceptTCP) - s.ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) { + tcpFwd := tcp.NewForwarder(n.ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, n.acceptTCP) + n.ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) { log.Printf("TCP packet: %+v", tei) return tcpFwd.HandlePacket(tei, pb) }) go func() { for { - pkt := s.linkEP.ReadContext(s.shutdownCtx) + pkt := n.linkEP.ReadContext(n.s.shutdownCtx) if pkt.IsNil() { - if s.shutdownCtx.Err() != nil { + if n.s.shutdownCtx.Err() != nil { // Return without logging. return } @@ -168,14 +174,20 @@ func (s *Server) initStack() error { } ipRaw := pkt.ToView().AsSlice() - log.Printf("Read packet from linkEP: % 02x", ipRaw) goPkt := gopacket.NewPacket( ipRaw, layers.LayerTypeIPv4, gopacket.Lazy) layerV4 := goPkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + + dstIP, _ := netip.AddrFromSlice(layerV4.DstIP) + node, ok := n.nodesByIP[dstIP] + if !ok { + log.Printf("no MAC for dest IP %v", dstIP) + continue + } eth := &layers.Ethernet{ - SrcMAC: gwMACTOREMOVE, - DstMAC: client1mac.HWAddr(), + SrcMAC: n.mac.HWAddr(), + DstMAC: node.mac.HWAddr(), EthernetType: layers.EthernetTypeIPv4, } buffer := gopacket.NewSerializeBuffer() @@ -201,11 +213,11 @@ func (s *Server) initStack() error { log.Printf("Serialize error: %v", err) continue } - if writeFunc, ok := s.writeFunc.Load(client1mac); ok { + if writeFunc, ok := n.writeFunc.Load(node.mac); ok { writeFunc(buffer.Bytes()) log.Printf("wrote packet to client: % 02x", buffer.Bytes()) } else { - log.Printf("No writeFunc for %v", client1mac) + log.Printf("No writeFunc for %v", node.mac) } } }() @@ -228,7 +240,7 @@ func stringifyTEI(tei stack.TransportEndpointID) string { return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort) } -func (s *Server) acceptTCP(r *tcp.ForwarderRequest) { +func (n *network) acceptTCP(r *tcp.ForwarderRequest) { reqDetails := r.ID() log.Printf("AcceptTCP: %v", stringifyTEI(reqDetails)) @@ -275,17 +287,10 @@ func (s *Server) acceptTCP(r *tcp.ForwarderRequest) { } var ( - // TODO: remove this and run a netstack per *network instead. - gwMACTOREMOVE = net.HardwareAddr{0x52, 0x54, 0x00, 0x01, 0x01, 0x01} - fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11}) fakeControlplaneIP = netip.AddrFrom4([4]byte{52, 52, 0, 1}) ) -var gwIP = netip.AddrFrom4([4]byte{192, 168, 1, 1}) - -var client1mac = MAC{0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee} - type MAC [6]byte func macOf(hwa net.HardwareAddr) (_ MAC, ok bool) { @@ -304,12 +309,27 @@ func (m MAC) String() string { } type network struct { - mac MAC - doesNAT bool - wanIP netip.Addr - lanIP netip.Prefix // with host bits set (e.g. 192.168.2.1/24) - + s *Server + mac MAC + doesNAT bool + wanIP netip.Addr + lanIP netip.Prefix // with host bits set (e.g. 192.168.2.1/24) nodesByIP map[netip.Addr]*node + + ns *stack.Stack + linkEP *channel.Endpoint + + // writeFunc is a map of MAC -> func to write to that MAC. + // It contains entries for connected nodes only. + writeFunc syncs.Map[MAC, func([]byte)] // MAC -> func to write to that MAC +} + +func (n *network) registerWriter(mac MAC, f func([]byte)) { + if f != nil { + n.writeFunc.Store(mac, f) + } else { + n.writeFunc.Delete(mac) + } } func (n *network) MACOfIP(ip netip.Addr) (_ MAC, ok bool) { @@ -333,11 +353,6 @@ type Server struct { shutdownCancel context.CancelFunc nodes map[MAC]*node - - writeFunc syncs.Map[MAC, func([]byte)] // MAC -> func to write to that MAC - - ns *stack.Stack - linkEP *channel.Endpoint } func newServer() (*Server, error) { @@ -347,9 +362,6 @@ func newServer() (*Server, error) { shutdownCancel: cancel, nodes: map[MAC]*node{}, } - if err := s.initStack(); err != nil { - return nil, fmt.Errorf("newServer: initStack: %v", err) - } return s, nil } @@ -371,7 +383,7 @@ func (s *Server) IPv4ForDNS(qname string) (netip.Addr, bool) { } func (s *Server) serveConn(uc net.Conn) { - log.Printf("Got conn") + log.Printf("Got conn %p", uc) defer uc.Close() bw := bufio.NewWriterSize(uc, 2<<10) @@ -395,9 +407,10 @@ func (s *Server) serveConn(uc net.Conn) { log.Printf("Flush: %v", err) } } - s.writeFunc.Store(client1mac, writePkt) buf := make([]byte, 16<<10) + var srcNode *node + var netw *network // non-nil after first packet for { if _, err := io.ReadFull(uc, buf[:4]); err != nil { log.Printf("ReadFull header: %v", err) @@ -416,9 +429,26 @@ func (s *Server) serveConn(uc net.Conn) { if !ok { continue } + srcMAC := MAC(ll.SrcMAC) + if srcNode == nil { + srcNode, ok = s.nodes[srcMAC] + if !ok { + log.Printf("[conn %p] ignoring frame from unknown MAC %v", uc, srcMAC) + continue + } + log.Printf("[conn %p] MAC %v is node %v", uc, srcMAC, srcNode.lanIP) + netw = srcNode.net + netw.registerWriter(srcMAC, writePkt) + defer netw.registerWriter(srcMAC, nil) + } else { + if srcMAC != srcNode.mac { + log.Printf("[conn %p] ignoring frame from MAC %v, expected %v", uc, srcMAC, srcNode.mac) + continue + } + } if ll.EthernetType == layers.EthernetTypeARP { - res, err := s.createARPResponse(packet) + res, err := netw.createARPResponse(packet) if err != nil { log.Printf("createARPResponse: %v", err) } else { @@ -478,7 +508,7 @@ func (s *Server) serveConn(uc net.Conn) { packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(pktCopy), }) - s.linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) + netw.linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) // var list stack.PacketBufferList // list.PushBack(packetBuf) @@ -679,7 +709,7 @@ func (s *Server) createSTUNResponse(pkt gopacket.Packet) ([]byte, error) { log.Printf("invalid STUN request: %v", err) return nil, nil } - stunRes := stun.Response(txid, netip.AddrPortFrom(gwIP, 31234)) + stunRes := stun.Response(txid, netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 12345)) eth2 := &layers.Ethernet{ SrcMAC: ethLayer.DstMAC, @@ -786,30 +816,24 @@ func (s *Server) createDNSResponse(pkt gopacket.Packet) ([]byte, error) { return nil, err } - if len(response.Answers) > 0 { - back := gopacket.NewPacket(buffer.Bytes(), layers.LayerTypeEthernet, gopacket.Lazy) - log.Printf("Generated: %v", back) - } else { - log.Printf("made empty response for %q", names) + const debugDNS = false + if debugDNS { + if len(response.Answers) > 0 { + back := gopacket.NewPacket(buffer.Bytes(), layers.LayerTypeEthernet, gopacket.Lazy) + log.Printf("Generated: %v", back) + } else { + log.Printf("made empty response for %q", names) + } } return buffer.Bytes(), nil } -func (s *Server) createARPResponse(pkt gopacket.Packet) ([]byte, error) { +func (n *network) createARPResponse(pkt gopacket.Packet) ([]byte, error) { ethLayer, ok := pkt.Layer(layers.LayerTypeEthernet).(*layers.Ethernet) if !ok { return nil, nil } - srcMAC, ok := macOf(ethLayer.SrcMAC) - if !ok { - return nil, nil - } - node, ok := s.nodes[srcMAC] - if !ok { - return nil, nil - } - arpLayer, ok := pkt.Layer(layers.LayerTypeARP).(*layers.ARP) if !ok || arpLayer.Operation != layers.ARPRequest || @@ -822,7 +846,7 @@ func (s *Server) createARPResponse(pkt gopacket.Packet) ([]byte, error) { } wantIP := netip.AddrFrom4([4]byte(arpLayer.DstProtAddress)) - foundMAC, ok := node.net.MACOfIP(wantIP) + foundMAC, ok := n.MACOfIP(wantIP) if !ok { return nil, nil }