move more to network, hardcode less

Change-Id: If1c773153f7f3fa7ea483d1b7231193ab093278a
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2024-07-27 11:38:00 -07:00
parent 87f777d21b
commit 013ea64e94

View File

@ -53,11 +53,12 @@ func main() {
// Hard-coded world shape for me.
net1 := &network{
s: s,
mac: MAC{0x52, 0x54, 0x00, 0x01, 0x01, 0x01},
wanIP: netip.MustParseAddr("2.1.1.1"),
lanIP: netip.MustParsePrefix("192.168.2.1/24"),
}
s.nodes[client1mac] = &node{
s.nodes[MAC{0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee}] = &node{
net: net1,
lanIP: netip.MustParseAddr("192.168.2.102"),
}
@ -96,6 +97,11 @@ func (s *Server) checkWorld() error {
if n.net.nodesByIP == nil {
n.net.nodesByIP = map[netip.Addr]*node{}
}
if n.net.ns == nil {
if err := n.net.initStack(); err != nil {
return fmt.Errorf("newServer: initStack: %v", err)
}
}
if _, ok := n.net.nodesByIP[n.lanIP]; ok {
return fmt.Errorf("node %v has duplicate LAN IP %v", mac, n.lanIP)
}
@ -104,8 +110,8 @@ func (s *Server) checkWorld() error {
return nil
}
func (s *Server) initStack() error {
s.ns = stack.New(stack.Options{
func (n *network) initStack() error {
n.ns = stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
arp.NewProtocol,
@ -116,20 +122,20 @@ func (s *Server) initStack() error {
},
})
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
tcpipErr := s.ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
tcpipErr := n.ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
if tcpipErr != nil {
return fmt.Errorf("SetTransportProtocolOption SACK: %v", tcpipErr)
}
s.linkEP = channel.New(512, 1500, tcpip.LinkAddress(gwMACTOREMOVE))
if tcpipProblem := s.ns.CreateNIC(nicID, s.linkEP); tcpipProblem != nil {
n.linkEP = channel.New(512, 1500, tcpip.LinkAddress(n.mac.HWAddr()))
if tcpipProblem := n.ns.CreateNIC(nicID, n.linkEP); tcpipProblem != nil {
return fmt.Errorf("CreateNIC: %v", tcpipProblem)
}
s.ns.SetPromiscuousMode(nicID, true)
s.ns.SetSpoofing(nicID, true)
n.ns.SetPromiscuousMode(nicID, true)
n.ns.SetSpoofing(nicID, true)
prefix := tcpip.AddrFrom4Slice(gwIP.AsSlice()).WithPrefix()
prefix.PrefixLen = 24
if tcpProb := s.ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{
prefix := tcpip.AddrFrom4Slice(n.lanIP.Addr().AsSlice()).WithPrefix()
prefix.PrefixLen = n.lanIP.Bits()
if tcpProb := n.ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: prefix,
}, stack.AddressProperties{}); tcpProb != nil {
@ -140,7 +146,7 @@ func (s *Server) initStack() error {
if err != nil {
return fmt.Errorf("could not create IPv4 subnet: %v", err)
}
s.ns.SetRouteTable([]tcpip.Route{
n.ns.SetRouteTable([]tcpip.Route{
{
Destination: ipv4Subnet,
NIC: nicID,
@ -149,17 +155,17 @@ func (s *Server) initStack() error {
const tcpReceiveBufferSize = 0 // default
const maxInFlightConnectionAttempts = 8192
tcpFwd := tcp.NewForwarder(s.ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.acceptTCP)
s.ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) {
tcpFwd := tcp.NewForwarder(n.ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, n.acceptTCP)
n.ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) {
log.Printf("TCP packet: %+v", tei)
return tcpFwd.HandlePacket(tei, pb)
})
go func() {
for {
pkt := s.linkEP.ReadContext(s.shutdownCtx)
pkt := n.linkEP.ReadContext(n.s.shutdownCtx)
if pkt.IsNil() {
if s.shutdownCtx.Err() != nil {
if n.s.shutdownCtx.Err() != nil {
// Return without logging.
return
}
@ -168,14 +174,20 @@ func (s *Server) initStack() error {
}
ipRaw := pkt.ToView().AsSlice()
log.Printf("Read packet from linkEP: % 02x", ipRaw)
goPkt := gopacket.NewPacket(
ipRaw,
layers.LayerTypeIPv4, gopacket.Lazy)
layerV4 := goPkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
dstIP, _ := netip.AddrFromSlice(layerV4.DstIP)
node, ok := n.nodesByIP[dstIP]
if !ok {
log.Printf("no MAC for dest IP %v", dstIP)
continue
}
eth := &layers.Ethernet{
SrcMAC: gwMACTOREMOVE,
DstMAC: client1mac.HWAddr(),
SrcMAC: n.mac.HWAddr(),
DstMAC: node.mac.HWAddr(),
EthernetType: layers.EthernetTypeIPv4,
}
buffer := gopacket.NewSerializeBuffer()
@ -201,11 +213,11 @@ func (s *Server) initStack() error {
log.Printf("Serialize error: %v", err)
continue
}
if writeFunc, ok := s.writeFunc.Load(client1mac); ok {
if writeFunc, ok := n.writeFunc.Load(node.mac); ok {
writeFunc(buffer.Bytes())
log.Printf("wrote packet to client: % 02x", buffer.Bytes())
} else {
log.Printf("No writeFunc for %v", client1mac)
log.Printf("No writeFunc for %v", node.mac)
}
}
}()
@ -228,7 +240,7 @@ func stringifyTEI(tei stack.TransportEndpointID) string {
return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort)
}
func (s *Server) acceptTCP(r *tcp.ForwarderRequest) {
func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
reqDetails := r.ID()
log.Printf("AcceptTCP: %v", stringifyTEI(reqDetails))
@ -275,17 +287,10 @@ func (s *Server) acceptTCP(r *tcp.ForwarderRequest) {
}
var (
// TODO: remove this and run a netstack per *network instead.
gwMACTOREMOVE = net.HardwareAddr{0x52, 0x54, 0x00, 0x01, 0x01, 0x01}
fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11})
fakeControlplaneIP = netip.AddrFrom4([4]byte{52, 52, 0, 1})
)
var gwIP = netip.AddrFrom4([4]byte{192, 168, 1, 1})
var client1mac = MAC{0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee}
type MAC [6]byte
func macOf(hwa net.HardwareAddr) (_ MAC, ok bool) {
@ -304,12 +309,27 @@ func (m MAC) String() string {
}
type network struct {
s *Server
mac MAC
doesNAT bool
wanIP netip.Addr
lanIP netip.Prefix // with host bits set (e.g. 192.168.2.1/24)
nodesByIP map[netip.Addr]*node
ns *stack.Stack
linkEP *channel.Endpoint
// writeFunc is a map of MAC -> func to write to that MAC.
// It contains entries for connected nodes only.
writeFunc syncs.Map[MAC, func([]byte)] // MAC -> func to write to that MAC
}
func (n *network) registerWriter(mac MAC, f func([]byte)) {
if f != nil {
n.writeFunc.Store(mac, f)
} else {
n.writeFunc.Delete(mac)
}
}
func (n *network) MACOfIP(ip netip.Addr) (_ MAC, ok bool) {
@ -333,11 +353,6 @@ type Server struct {
shutdownCancel context.CancelFunc
nodes map[MAC]*node
writeFunc syncs.Map[MAC, func([]byte)] // MAC -> func to write to that MAC
ns *stack.Stack
linkEP *channel.Endpoint
}
func newServer() (*Server, error) {
@ -347,9 +362,6 @@ func newServer() (*Server, error) {
shutdownCancel: cancel,
nodes: map[MAC]*node{},
}
if err := s.initStack(); err != nil {
return nil, fmt.Errorf("newServer: initStack: %v", err)
}
return s, nil
}
@ -371,7 +383,7 @@ func (s *Server) IPv4ForDNS(qname string) (netip.Addr, bool) {
}
func (s *Server) serveConn(uc net.Conn) {
log.Printf("Got conn")
log.Printf("Got conn %p", uc)
defer uc.Close()
bw := bufio.NewWriterSize(uc, 2<<10)
@ -395,9 +407,10 @@ func (s *Server) serveConn(uc net.Conn) {
log.Printf("Flush: %v", err)
}
}
s.writeFunc.Store(client1mac, writePkt)
buf := make([]byte, 16<<10)
var srcNode *node
var netw *network // non-nil after first packet
for {
if _, err := io.ReadFull(uc, buf[:4]); err != nil {
log.Printf("ReadFull header: %v", err)
@ -416,9 +429,26 @@ func (s *Server) serveConn(uc net.Conn) {
if !ok {
continue
}
srcMAC := MAC(ll.SrcMAC)
if srcNode == nil {
srcNode, ok = s.nodes[srcMAC]
if !ok {
log.Printf("[conn %p] ignoring frame from unknown MAC %v", uc, srcMAC)
continue
}
log.Printf("[conn %p] MAC %v is node %v", uc, srcMAC, srcNode.lanIP)
netw = srcNode.net
netw.registerWriter(srcMAC, writePkt)
defer netw.registerWriter(srcMAC, nil)
} else {
if srcMAC != srcNode.mac {
log.Printf("[conn %p] ignoring frame from MAC %v, expected %v", uc, srcMAC, srcNode.mac)
continue
}
}
if ll.EthernetType == layers.EthernetTypeARP {
res, err := s.createARPResponse(packet)
res, err := netw.createARPResponse(packet)
if err != nil {
log.Printf("createARPResponse: %v", err)
} else {
@ -478,7 +508,7 @@ func (s *Server) serveConn(uc net.Conn) {
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(pktCopy),
})
s.linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
netw.linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
// var list stack.PacketBufferList
// list.PushBack(packetBuf)
@ -679,7 +709,7 @@ func (s *Server) createSTUNResponse(pkt gopacket.Packet) ([]byte, error) {
log.Printf("invalid STUN request: %v", err)
return nil, nil
}
stunRes := stun.Response(txid, netip.AddrPortFrom(gwIP, 31234))
stunRes := stun.Response(txid, netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 12345))
eth2 := &layers.Ethernet{
SrcMAC: ethLayer.DstMAC,
@ -786,30 +816,24 @@ func (s *Server) createDNSResponse(pkt gopacket.Packet) ([]byte, error) {
return nil, err
}
const debugDNS = false
if debugDNS {
if len(response.Answers) > 0 {
back := gopacket.NewPacket(buffer.Bytes(), layers.LayerTypeEthernet, gopacket.Lazy)
log.Printf("Generated: %v", back)
} else {
log.Printf("made empty response for %q", names)
}
}
return buffer.Bytes(), nil
}
func (s *Server) createARPResponse(pkt gopacket.Packet) ([]byte, error) {
func (n *network) createARPResponse(pkt gopacket.Packet) ([]byte, error) {
ethLayer, ok := pkt.Layer(layers.LayerTypeEthernet).(*layers.Ethernet)
if !ok {
return nil, nil
}
srcMAC, ok := macOf(ethLayer.SrcMAC)
if !ok {
return nil, nil
}
node, ok := s.nodes[srcMAC]
if !ok {
return nil, nil
}
arpLayer, ok := pkt.Layer(layers.LayerTypeARP).(*layers.ARP)
if !ok ||
arpLayer.Operation != layers.ARPRequest ||
@ -822,7 +846,7 @@ func (s *Server) createARPResponse(pkt gopacket.Packet) ([]byte, error) {
}
wantIP := netip.AddrFrom4([4]byte(arpLayer.DstProtAddress))
foundMAC, ok := node.net.MACOfIP(wantIP)
foundMAC, ok := n.MACOfIP(wantIP)
if !ok {
return nil, nil
}