cmd/natc: use ListenPacket

Now that tsnet supports it, use it.

Updates tailscale/corp#20503

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2024-06-03 14:53:34 -07:00 committed by Maisem Ali
parent e84751217a
commit 2f2f588c80

View File

@ -35,7 +35,6 @@
"tailscale.com/tailcfg"
"tailscale.com/tsnet"
"tailscale.com/tsweb"
"tailscale.com/types/nettype"
"tailscale.com/util/dnsname"
"tailscale.com/util/mak"
)
@ -169,24 +168,27 @@ func (c *connector) run(ctx context.Context) {
log.Fatalf("failed to advertise routes: %v", err)
}
c.ts.RegisterFallbackTCPHandler(c.handleTCPFlow)
c.serveDNS()
}
ln, err := c.ts.Listen("udp", net.JoinHostPort(c.dnsAddr.String(), "53"))
func (c *connector) serveDNS() {
pc, err := c.ts.ListenPacket("udp", net.JoinHostPort(c.dnsAddr.String(), "53"))
if err != nil {
log.Fatalf("failed listening on port 53: %v", err)
}
defer ln.Close()
log.Printf("Listening for DNS on %s", ln.Addr())
c.serveDNS(ln)
}
func (c *connector) serveDNS(ln net.Listener) {
defer pc.Close()
log.Printf("Listening for DNS on %s", pc.LocalAddr().String())
for {
conn, err := ln.Accept()
buf := make([]byte, 1500)
n, addr, err := pc.ReadFrom(buf)
if err != nil {
log.Printf("serveDNS accept: %v", err)
return
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("serveDNS.ReadFrom failed: %v", err)
continue
}
go c.handleDNS(conn.(nettype.ConnPacketConn))
go c.handleDNS(pc, buf[:n], addr.(*net.UDPAddr))
}
}
@ -201,27 +203,17 @@ func (c *connector) serveDNS(ln net.Listener) {
//
// This assignment later allows the connector to determine where to forward
// traffic based on the destination IP address.
func (c *connector) handleDNS(conn nettype.ConnPacketConn) {
defer conn.Close()
func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDPAddr) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
remoteAddr := conn.RemoteAddr().(*net.UDPAddr).AddrPort()
who, err := c.lc.WhoIs(ctx, remoteAddr.String())
if err != nil {
log.Printf("HandleDNS: WhoIs failed: %v\n", err)
return
}
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
buf := make([]byte, 1500)
n, err := conn.Read(buf)
if err != nil {
log.Printf("HandleDNS: read failed: %v\n ", err)
return
}
var msg dnsmessage.Message
err = msg.Unpack(buf[:n])
err = msg.Unpack(buf)
if err != nil {
log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err)
return
@ -236,7 +228,7 @@ func (c *connector) handleDNS(conn nettype.ConnPacketConn) {
return
}
// This connector handled the DNS request
_, err = conn.Write(resp)
_, err = pc.WriteTo(resp, remoteAddr)
if err != nil {
log.Printf("HandleDNS: write failed: %v\n", err)
}