move more to network, hardcode less

Change-Id: If1c773153f7f3fa7ea483d1b7231193ab093278a
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2024-07-27 11:38:00 -07:00
parent 87f777d21b
commit 013ea64e94

View File

@ -53,11 +53,12 @@ func main() {
// Hard-coded world shape for me. // Hard-coded world shape for me.
net1 := &network{ net1 := &network{
s: s,
mac: MAC{0x52, 0x54, 0x00, 0x01, 0x01, 0x01}, mac: MAC{0x52, 0x54, 0x00, 0x01, 0x01, 0x01},
wanIP: netip.MustParseAddr("2.1.1.1"), wanIP: netip.MustParseAddr("2.1.1.1"),
lanIP: netip.MustParsePrefix("192.168.2.1/24"), lanIP: netip.MustParsePrefix("192.168.2.1/24"),
} }
s.nodes[client1mac] = &node{ s.nodes[MAC{0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee}] = &node{
net: net1, net: net1,
lanIP: netip.MustParseAddr("192.168.2.102"), lanIP: netip.MustParseAddr("192.168.2.102"),
} }
@ -96,6 +97,11 @@ func (s *Server) checkWorld() error {
if n.net.nodesByIP == nil { if n.net.nodesByIP == nil {
n.net.nodesByIP = map[netip.Addr]*node{} n.net.nodesByIP = map[netip.Addr]*node{}
} }
if n.net.ns == nil {
if err := n.net.initStack(); err != nil {
return fmt.Errorf("newServer: initStack: %v", err)
}
}
if _, ok := n.net.nodesByIP[n.lanIP]; ok { if _, ok := n.net.nodesByIP[n.lanIP]; ok {
return fmt.Errorf("node %v has duplicate LAN IP %v", mac, n.lanIP) return fmt.Errorf("node %v has duplicate LAN IP %v", mac, n.lanIP)
} }
@ -104,8 +110,8 @@ func (s *Server) checkWorld() error {
return nil return nil
} }
func (s *Server) initStack() error { func (n *network) initStack() error {
s.ns = stack.New(stack.Options{ n.ns = stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol, ipv4.NewProtocol,
arp.NewProtocol, arp.NewProtocol,
@ -116,20 +122,20 @@ func (s *Server) initStack() error {
}, },
}) })
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
tcpipErr := s.ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) tcpipErr := n.ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
if tcpipErr != nil { if tcpipErr != nil {
return fmt.Errorf("SetTransportProtocolOption SACK: %v", tcpipErr) return fmt.Errorf("SetTransportProtocolOption SACK: %v", tcpipErr)
} }
s.linkEP = channel.New(512, 1500, tcpip.LinkAddress(gwMACTOREMOVE)) n.linkEP = channel.New(512, 1500, tcpip.LinkAddress(n.mac.HWAddr()))
if tcpipProblem := s.ns.CreateNIC(nicID, s.linkEP); tcpipProblem != nil { if tcpipProblem := n.ns.CreateNIC(nicID, n.linkEP); tcpipProblem != nil {
return fmt.Errorf("CreateNIC: %v", tcpipProblem) return fmt.Errorf("CreateNIC: %v", tcpipProblem)
} }
s.ns.SetPromiscuousMode(nicID, true) n.ns.SetPromiscuousMode(nicID, true)
s.ns.SetSpoofing(nicID, true) n.ns.SetSpoofing(nicID, true)
prefix := tcpip.AddrFrom4Slice(gwIP.AsSlice()).WithPrefix() prefix := tcpip.AddrFrom4Slice(n.lanIP.Addr().AsSlice()).WithPrefix()
prefix.PrefixLen = 24 prefix.PrefixLen = n.lanIP.Bits()
if tcpProb := s.ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ if tcpProb := n.ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: prefix, AddressWithPrefix: prefix,
}, stack.AddressProperties{}); tcpProb != nil { }, stack.AddressProperties{}); tcpProb != nil {
@ -140,7 +146,7 @@ func (s *Server) initStack() error {
if err != nil { if err != nil {
return fmt.Errorf("could not create IPv4 subnet: %v", err) return fmt.Errorf("could not create IPv4 subnet: %v", err)
} }
s.ns.SetRouteTable([]tcpip.Route{ n.ns.SetRouteTable([]tcpip.Route{
{ {
Destination: ipv4Subnet, Destination: ipv4Subnet,
NIC: nicID, NIC: nicID,
@ -149,17 +155,17 @@ func (s *Server) initStack() error {
const tcpReceiveBufferSize = 0 // default const tcpReceiveBufferSize = 0 // default
const maxInFlightConnectionAttempts = 8192 const maxInFlightConnectionAttempts = 8192
tcpFwd := tcp.NewForwarder(s.ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.acceptTCP) tcpFwd := tcp.NewForwarder(n.ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, n.acceptTCP)
s.ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) { n.ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) {
log.Printf("TCP packet: %+v", tei) log.Printf("TCP packet: %+v", tei)
return tcpFwd.HandlePacket(tei, pb) return tcpFwd.HandlePacket(tei, pb)
}) })
go func() { go func() {
for { for {
pkt := s.linkEP.ReadContext(s.shutdownCtx) pkt := n.linkEP.ReadContext(n.s.shutdownCtx)
if pkt.IsNil() { if pkt.IsNil() {
if s.shutdownCtx.Err() != nil { if n.s.shutdownCtx.Err() != nil {
// Return without logging. // Return without logging.
return return
} }
@ -168,14 +174,20 @@ func (s *Server) initStack() error {
} }
ipRaw := pkt.ToView().AsSlice() ipRaw := pkt.ToView().AsSlice()
log.Printf("Read packet from linkEP: % 02x", ipRaw)
goPkt := gopacket.NewPacket( goPkt := gopacket.NewPacket(
ipRaw, ipRaw,
layers.LayerTypeIPv4, gopacket.Lazy) layers.LayerTypeIPv4, gopacket.Lazy)
layerV4 := goPkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) layerV4 := goPkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
dstIP, _ := netip.AddrFromSlice(layerV4.DstIP)
node, ok := n.nodesByIP[dstIP]
if !ok {
log.Printf("no MAC for dest IP %v", dstIP)
continue
}
eth := &layers.Ethernet{ eth := &layers.Ethernet{
SrcMAC: gwMACTOREMOVE, SrcMAC: n.mac.HWAddr(),
DstMAC: client1mac.HWAddr(), DstMAC: node.mac.HWAddr(),
EthernetType: layers.EthernetTypeIPv4, EthernetType: layers.EthernetTypeIPv4,
} }
buffer := gopacket.NewSerializeBuffer() buffer := gopacket.NewSerializeBuffer()
@ -201,11 +213,11 @@ func (s *Server) initStack() error {
log.Printf("Serialize error: %v", err) log.Printf("Serialize error: %v", err)
continue continue
} }
if writeFunc, ok := s.writeFunc.Load(client1mac); ok { if writeFunc, ok := n.writeFunc.Load(node.mac); ok {
writeFunc(buffer.Bytes()) writeFunc(buffer.Bytes())
log.Printf("wrote packet to client: % 02x", buffer.Bytes()) log.Printf("wrote packet to client: % 02x", buffer.Bytes())
} else { } else {
log.Printf("No writeFunc for %v", client1mac) log.Printf("No writeFunc for %v", node.mac)
} }
} }
}() }()
@ -228,7 +240,7 @@ func stringifyTEI(tei stack.TransportEndpointID) string {
return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort) return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort)
} }
func (s *Server) acceptTCP(r *tcp.ForwarderRequest) { func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
reqDetails := r.ID() reqDetails := r.ID()
log.Printf("AcceptTCP: %v", stringifyTEI(reqDetails)) log.Printf("AcceptTCP: %v", stringifyTEI(reqDetails))
@ -275,17 +287,10 @@ func (s *Server) acceptTCP(r *tcp.ForwarderRequest) {
} }
var ( var (
// TODO: remove this and run a netstack per *network instead.
gwMACTOREMOVE = net.HardwareAddr{0x52, 0x54, 0x00, 0x01, 0x01, 0x01}
fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11}) fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11})
fakeControlplaneIP = netip.AddrFrom4([4]byte{52, 52, 0, 1}) fakeControlplaneIP = netip.AddrFrom4([4]byte{52, 52, 0, 1})
) )
var gwIP = netip.AddrFrom4([4]byte{192, 168, 1, 1})
var client1mac = MAC{0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee}
type MAC [6]byte type MAC [6]byte
func macOf(hwa net.HardwareAddr) (_ MAC, ok bool) { func macOf(hwa net.HardwareAddr) (_ MAC, ok bool) {
@ -304,12 +309,27 @@ func (m MAC) String() string {
} }
type network struct { type network struct {
mac MAC s *Server
doesNAT bool mac MAC
wanIP netip.Addr doesNAT bool
lanIP netip.Prefix // with host bits set (e.g. 192.168.2.1/24) wanIP netip.Addr
lanIP netip.Prefix // with host bits set (e.g. 192.168.2.1/24)
nodesByIP map[netip.Addr]*node nodesByIP map[netip.Addr]*node
ns *stack.Stack
linkEP *channel.Endpoint
// writeFunc is a map of MAC -> func to write to that MAC.
// It contains entries for connected nodes only.
writeFunc syncs.Map[MAC, func([]byte)] // MAC -> func to write to that MAC
}
func (n *network) registerWriter(mac MAC, f func([]byte)) {
if f != nil {
n.writeFunc.Store(mac, f)
} else {
n.writeFunc.Delete(mac)
}
} }
func (n *network) MACOfIP(ip netip.Addr) (_ MAC, ok bool) { func (n *network) MACOfIP(ip netip.Addr) (_ MAC, ok bool) {
@ -333,11 +353,6 @@ type Server struct {
shutdownCancel context.CancelFunc shutdownCancel context.CancelFunc
nodes map[MAC]*node nodes map[MAC]*node
writeFunc syncs.Map[MAC, func([]byte)] // MAC -> func to write to that MAC
ns *stack.Stack
linkEP *channel.Endpoint
} }
func newServer() (*Server, error) { func newServer() (*Server, error) {
@ -347,9 +362,6 @@ func newServer() (*Server, error) {
shutdownCancel: cancel, shutdownCancel: cancel,
nodes: map[MAC]*node{}, nodes: map[MAC]*node{},
} }
if err := s.initStack(); err != nil {
return nil, fmt.Errorf("newServer: initStack: %v", err)
}
return s, nil return s, nil
} }
@ -371,7 +383,7 @@ func (s *Server) IPv4ForDNS(qname string) (netip.Addr, bool) {
} }
func (s *Server) serveConn(uc net.Conn) { func (s *Server) serveConn(uc net.Conn) {
log.Printf("Got conn") log.Printf("Got conn %p", uc)
defer uc.Close() defer uc.Close()
bw := bufio.NewWriterSize(uc, 2<<10) bw := bufio.NewWriterSize(uc, 2<<10)
@ -395,9 +407,10 @@ func (s *Server) serveConn(uc net.Conn) {
log.Printf("Flush: %v", err) log.Printf("Flush: %v", err)
} }
} }
s.writeFunc.Store(client1mac, writePkt)
buf := make([]byte, 16<<10) buf := make([]byte, 16<<10)
var srcNode *node
var netw *network // non-nil after first packet
for { for {
if _, err := io.ReadFull(uc, buf[:4]); err != nil { if _, err := io.ReadFull(uc, buf[:4]); err != nil {
log.Printf("ReadFull header: %v", err) log.Printf("ReadFull header: %v", err)
@ -416,9 +429,26 @@ func (s *Server) serveConn(uc net.Conn) {
if !ok { if !ok {
continue continue
} }
srcMAC := MAC(ll.SrcMAC)
if srcNode == nil {
srcNode, ok = s.nodes[srcMAC]
if !ok {
log.Printf("[conn %p] ignoring frame from unknown MAC %v", uc, srcMAC)
continue
}
log.Printf("[conn %p] MAC %v is node %v", uc, srcMAC, srcNode.lanIP)
netw = srcNode.net
netw.registerWriter(srcMAC, writePkt)
defer netw.registerWriter(srcMAC, nil)
} else {
if srcMAC != srcNode.mac {
log.Printf("[conn %p] ignoring frame from MAC %v, expected %v", uc, srcMAC, srcNode.mac)
continue
}
}
if ll.EthernetType == layers.EthernetTypeARP { if ll.EthernetType == layers.EthernetTypeARP {
res, err := s.createARPResponse(packet) res, err := netw.createARPResponse(packet)
if err != nil { if err != nil {
log.Printf("createARPResponse: %v", err) log.Printf("createARPResponse: %v", err)
} else { } else {
@ -478,7 +508,7 @@ func (s *Server) serveConn(uc net.Conn) {
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(pktCopy), Payload: buffer.MakeWithData(pktCopy),
}) })
s.linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) netw.linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
// var list stack.PacketBufferList // var list stack.PacketBufferList
// list.PushBack(packetBuf) // list.PushBack(packetBuf)
@ -679,7 +709,7 @@ func (s *Server) createSTUNResponse(pkt gopacket.Packet) ([]byte, error) {
log.Printf("invalid STUN request: %v", err) log.Printf("invalid STUN request: %v", err)
return nil, nil return nil, nil
} }
stunRes := stun.Response(txid, netip.AddrPortFrom(gwIP, 31234)) stunRes := stun.Response(txid, netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 12345))
eth2 := &layers.Ethernet{ eth2 := &layers.Ethernet{
SrcMAC: ethLayer.DstMAC, SrcMAC: ethLayer.DstMAC,
@ -786,30 +816,24 @@ func (s *Server) createDNSResponse(pkt gopacket.Packet) ([]byte, error) {
return nil, err return nil, err
} }
if len(response.Answers) > 0 { const debugDNS = false
back := gopacket.NewPacket(buffer.Bytes(), layers.LayerTypeEthernet, gopacket.Lazy) if debugDNS {
log.Printf("Generated: %v", back) if len(response.Answers) > 0 {
} else { back := gopacket.NewPacket(buffer.Bytes(), layers.LayerTypeEthernet, gopacket.Lazy)
log.Printf("made empty response for %q", names) log.Printf("Generated: %v", back)
} else {
log.Printf("made empty response for %q", names)
}
} }
return buffer.Bytes(), nil return buffer.Bytes(), nil
} }
func (s *Server) createARPResponse(pkt gopacket.Packet) ([]byte, error) { func (n *network) createARPResponse(pkt gopacket.Packet) ([]byte, error) {
ethLayer, ok := pkt.Layer(layers.LayerTypeEthernet).(*layers.Ethernet) ethLayer, ok := pkt.Layer(layers.LayerTypeEthernet).(*layers.Ethernet)
if !ok { if !ok {
return nil, nil return nil, nil
} }
srcMAC, ok := macOf(ethLayer.SrcMAC)
if !ok {
return nil, nil
}
node, ok := s.nodes[srcMAC]
if !ok {
return nil, nil
}
arpLayer, ok := pkt.Layer(layers.LayerTypeARP).(*layers.ARP) arpLayer, ok := pkt.Layer(layers.LayerTypeARP).(*layers.ARP)
if !ok || if !ok ||
arpLayer.Operation != layers.ARPRequest || arpLayer.Operation != layers.ARPRequest ||
@ -822,7 +846,7 @@ func (s *Server) createARPResponse(pkt gopacket.Packet) ([]byte, error) {
} }
wantIP := netip.AddrFrom4([4]byte(arpLayer.DstProtAddress)) wantIP := netip.AddrFrom4([4]byte(arpLayer.DstProtAddress))
foundMAC, ok := node.net.MACOfIP(wantIP) foundMAC, ok := n.MACOfIP(wantIP)
if !ok { if !ok {
return nil, nil return nil, nil
} }