diff --git a/natlab/natlabd/natlabd.go b/natlab/natlabd/natlabd.go index ea8750e52..29529051e 100644 --- a/natlab/natlabd/natlabd.go +++ b/natlab/natlabd/natlabd.go @@ -2,22 +2,41 @@ package main import ( "bufio" + "context" "encoding/binary" + "errors" "flag" + "fmt" "io" "log" "net" "net/netip" + "strconv" + "sync" "github.com/google/gopacket" "github.com/google/gopacket/layers" "go4.org/mem" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/waiter" + "tailscale.com/syncs" ) var ( listen = flag.String("listen", "/tmp/qemu.sock", "path to listen on") ) +const nicID = 1 + func main() { log.Printf("natlabd.") flag.Parse() @@ -26,7 +45,11 @@ func main() { if err != nil { log.Fatal(err) } - var s Server + s, err := newServer() + if err != nil { + log.Fatalf("newServer: %v", err) + } + for { c, err := srv.Accept() if err != nil { @@ -37,17 +60,198 @@ func main() { } } +func (s *Server) initStack() error { + s.ns = stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + arp.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{ + tcp.NewProtocol, + icmp.NewProtocol4, + }, + }) + sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default + tcpipErr := s.ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) + if tcpipErr != nil { + return fmt.Errorf("SetTransportProtocolOption SACK: %v", tcpipErr) + } + s.linkEP = channel.New(512, 1500, tcpip.LinkAddress(gwMAC)) + if tcpipProblem := s.ns.CreateNIC(nicID, s.linkEP); tcpipProblem != nil { + return fmt.Errorf("CreateNIC: %v", tcpipProblem) + } + s.ns.SetPromiscuousMode(nicID, true) + s.ns.SetSpoofing(nicID, true) + + prefix := tcpip.AddrFrom4Slice(gwIP.AsSlice()).WithPrefix() + prefix.PrefixLen = 24 + if tcpProb := s.ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: prefix, + }, stack.AddressProperties{}); tcpProb != nil { + return errors.New(tcpProb.String()) + } + + ipv4Subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 4)), tcpip.MaskFromBytes(make([]byte, 4))) + if err != nil { + return fmt.Errorf("could not create IPv4 subnet: %v", err) + } + s.ns.SetRouteTable([]tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID, + }, + }) + + const tcpReceiveBufferSize = 0 // default + const maxInFlightConnectionAttempts = 8192 + tcpFwd := tcp.NewForwarder(s.ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.acceptTCP) + s.ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) { + log.Printf("TCP packet: %+v", tei) + return tcpFwd.HandlePacket(tei, pb) + }) + + go func() { + for { + pkt := s.linkEP.ReadContext(s.shutdownCtx) + if pkt.IsNil() { + if s.shutdownCtx.Err() != nil { + // Return without logging. + return + } + log.Printf("ReadContext got nil packet") + continue + } + + ipRaw := pkt.ToView().AsSlice() + log.Printf("Read packet from linkEP: % 02x", ipRaw) + goPkt := gopacket.NewPacket( + ipRaw, + layers.LayerTypeIPv4, gopacket.Lazy) + layerV4 := goPkt.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + eth := &layers.Ethernet{ + SrcMAC: gwMAC, + DstMAC: client1mac.HWAddr(), + EthernetType: layers.EthernetTypeIPv4, + } + buffer := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + sls := []gopacket.SerializableLayer{ + eth, + } + for _, layer := range goPkt.Layers() { + sl, ok := layer.(gopacket.SerializableLayer) + if !ok { + log.Fatalf("layer %s is not serializable", layer.LayerType().String()) + } + log.Printf("Appending %v ... ", layer.LayerType()) + switch gl := layer.(type) { + case *layers.TCP: + gl.SetNetworkLayerForChecksum(layerV4) + case *layers.UDP: + gl.SetNetworkLayerForChecksum(layerV4) + } + sls = append(sls, sl) + } + + if err := gopacket.SerializeLayers(buffer, options, sls...); err != nil { + log.Printf("Serialize error: %v", err) + continue + } + if writeFunc, ok := s.writeFunc.Load(client1mac); ok { + writeFunc(buffer.Bytes()) + log.Printf("wrote packet to client: % 02x", buffer.Bytes()) + } else { + log.Printf("No writeFunc for %v", client1mac) + } + } + }() + return nil +} + +func netaddrIPFromNetstackIP(s tcpip.Address) netip.Addr { + switch s.Len() { + case 4: + return netip.AddrFrom4(s.As4()) + case 16: + return netip.AddrFrom16(s.As16()).Unmap() + } + return netip.Addr{} +} + +func stringifyTEI(tei stack.TransportEndpointID) string { + localHostPort := net.JoinHostPort(tei.LocalAddress.String(), strconv.Itoa(int(tei.LocalPort))) + remoteHostPort := net.JoinHostPort(tei.RemoteAddress.String(), strconv.Itoa(int(tei.RemotePort))) + return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort) +} + +func (s *Server) acceptTCP(r *tcp.ForwarderRequest) { + reqDetails := r.ID() + + log.Printf("AcceptTCP: %v", stringifyTEI(reqDetails)) + clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress) + if !clientRemoteIP.IsValid() { + r.Complete(true) // sends a RST + return + } + + var wq waiter.Queue + ep, err := r.CreateEndpoint(&wq) + if err != nil { + log.Printf("CreateEndpoint error for %s: %v", stringifyTEI(reqDetails), err) + r.Complete(true) // sends a RST + return + } + r.Complete(false) + ep.SocketOptions().SetKeepAlive(true) + + tc := gonet.NewTCPConn(&wq, ep) + io.WriteString(tc, "Hello from Go\nGoodbye.\n") + tc.Close() +} + var gwMAC = net.HardwareAddr{0x52, 0x54, 0x00, 0x01, 0x01, 0x01} var fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11}) +var gwIP = netip.AddrFrom4([4]byte{192, 168, 1, 1}) + +var client1mac = MAC{0x5a, 0x94, 0xef, 0xe4, 0x0c, 0xee} + type MAC [6]byte +func (m MAC) HWAddr() net.HardwareAddr { + return net.HardwareAddr(m[:]) +} + +func (m MAC) String() string { + return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", m[0], m[1], m[2], m[3], m[4], m[5]) +} + type Server struct { + shutdownCtx context.Context + shutdownCancel context.CancelFunc + + writeFunc syncs.Map[MAC, func([]byte)] // MAC -> func to write to that MAC + + ns *stack.Stack + linkEP *channel.Endpoint +} + +func newServer() (*Server, error) { + ctx, cancel := context.WithCancel(context.Background()) + s := &Server{ + shutdownCtx: ctx, + shutdownCancel: cancel, + } + if err := s.initStack(); err != nil { + return nil, fmt.Errorf("newServer: initStack: %v", err) + } + return s, nil } func (s *Server) MacOfIP(ip netip.Addr) (MAC, bool) { - if ip == netip.AddrFrom4([4]byte{192, 168, 1, 1}) { + if ip == gwIP { return MAC(gwMAC), true } return MAC{}, false @@ -72,10 +276,13 @@ func (s *Server) serveConn(uc net.Conn) { defer uc.Close() bw := bufio.NewWriterSize(uc, 2<<10) + var writeMu sync.Mutex writePkt := func(pkt []byte) { if pkt == nil { return } + writeMu.Lock() + defer writeMu.Unlock() hdr := binary.BigEndian.AppendUint32(bw.AvailableBuffer()[:0], uint32(len(pkt))) if _, err := bw.Write(hdr); err != nil { log.Printf("Write hdr: %v", err) @@ -89,6 +296,7 @@ func (s *Server) serveConn(uc net.Conn) { log.Printf("Flush: %v", err) } } + s.writeFunc.Store(client1mac, writePkt) buf := make([]byte, 16<<10) for { @@ -103,7 +311,8 @@ func (s *Server) serveConn(uc net.Conn) { return } - packet := gopacket.NewPacket(buf[4:4+n], layers.LayerTypeEthernet, gopacket.Lazy) + packetRaw := buf[4 : 4+n] + packet := gopacket.NewPacket(packetRaw, layers.LayerTypeEthernet, gopacket.Lazy) ll, ok := packet.LinkLayer().(*layers.Ethernet) if !ok { continue @@ -151,6 +360,26 @@ func (s *Server) serveConn(uc net.Conn) { continue } + if isTCPTo123(packet) { + log.Printf("Injecting TCP to 123") + ipp := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + pktCopy := make([]byte, 0, len(ipp.Contents)+len(ipp.Payload)) + pktCopy = append(pktCopy, ipp.Contents...) + pktCopy = append(pktCopy, ipp.Payload...) + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(pktCopy), + }) + s.linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) + + // var list stack.PacketBufferList + // list.PushBack(packetBuf) + // n, err := s.linkEP.WritePackets(list) + // log.Printf("Injected: %v, %v", n, err) + + packetBuf.DecRef() + continue + } + log.Printf("Got packet: %v", packet) } } @@ -275,6 +504,11 @@ func isMDNSQuery(pkt gopacket.Packet) bool { return ok && udp.SrcPort == 5353 && udp.DstPort == 5353 } +func isTCPTo123(pkt gopacket.Packet) bool { + tcp, ok := pkt.Layer(layers.LayerTypeTCP).(*layers.TCP) + return ok && tcp.DstPort == 123 +} + // isDNSRequest reports whether pkt is a DNS request to the fake DNS server. func isDNSRequest(pkt gopacket.Packet) bool { udp, ok := pkt.Layer(layers.LayerTypeUDP).(*layers.UDP)