diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index ddae29668..aea90a8b6 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -35,7 +35,6 @@ "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tsweb" - "tailscale.com/types/nettype" "tailscale.com/util/dnsname" "tailscale.com/util/mak" ) @@ -169,24 +168,27 @@ func (c *connector) run(ctx context.Context) { log.Fatalf("failed to advertise routes: %v", err) } c.ts.RegisterFallbackTCPHandler(c.handleTCPFlow) + c.serveDNS() +} - ln, err := c.ts.Listen("udp", net.JoinHostPort(c.dnsAddr.String(), "53")) +func (c *connector) serveDNS() { + pc, err := c.ts.ListenPacket("udp", net.JoinHostPort(c.dnsAddr.String(), "53")) if err != nil { log.Fatalf("failed listening on port 53: %v", err) } - defer ln.Close() - log.Printf("Listening for DNS on %s", ln.Addr()) - c.serveDNS(ln) -} - -func (c *connector) serveDNS(ln net.Listener) { + defer pc.Close() + log.Printf("Listening for DNS on %s", pc.LocalAddr().String()) for { - conn, err := ln.Accept() + buf := make([]byte, 1500) + n, addr, err := pc.ReadFrom(buf) if err != nil { - log.Printf("serveDNS accept: %v", err) - return + if errors.Is(err, net.ErrClosed) { + return + } + log.Printf("serveDNS.ReadFrom failed: %v", err) + continue } - go c.handleDNS(conn.(nettype.ConnPacketConn)) + go c.handleDNS(pc, buf[:n], addr.(*net.UDPAddr)) } } @@ -201,27 +203,17 @@ func (c *connector) serveDNS(ln net.Listener) { // // This assignment later allows the connector to determine where to forward // traffic based on the destination IP address. -func (c *connector) handleDNS(conn nettype.ConnPacketConn) { - defer conn.Close() +func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDPAddr) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - remoteAddr := conn.RemoteAddr().(*net.UDPAddr).AddrPort() who, err := c.lc.WhoIs(ctx, remoteAddr.String()) if err != nil { log.Printf("HandleDNS: WhoIs failed: %v\n", err) return } - conn.SetReadDeadline(time.Now().Add(5 * time.Second)) - - buf := make([]byte, 1500) - n, err := conn.Read(buf) - if err != nil { - log.Printf("HandleDNS: read failed: %v\n ", err) - return - } var msg dnsmessage.Message - err = msg.Unpack(buf[:n]) + err = msg.Unpack(buf) if err != nil { log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err) return @@ -236,7 +228,7 @@ func (c *connector) handleDNS(conn nettype.ConnPacketConn) { return } // This connector handled the DNS request - _, err = conn.Write(resp) + _, err = pc.WriteTo(resp, remoteAddr) if err != nil { log.Printf("HandleDNS: write failed: %v\n", err) }