cmd/lopower: add TCP DNS support

Change-Id: I3288bfd538e2662d644c75e62e6c5cdb24464386
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2024-11-03 14:33:19 -08:00 committed by Anton Tolchanov
parent b8d9c3bc88
commit 455e926d09

View File

@ -7,6 +7,7 @@ import (
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
@ -465,6 +466,14 @@ func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) {
defer ep.Close()
ep.SocketOptions().SetKeepAlive(true)
if destPort == 53 && lp.c.IsLocalIP(destIP) {
tc := gonet.NewTCPConn(&wq, ep)
defer tc.Close()
r.Complete(false) // accept TCP connection
lp.handleTCPDNSQuery(tc, netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort))
return
}
dialCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
c, err := lp.tsnet.Dial(dialCtx, "tcp", fmt.Sprintf("%s:%d", destIP, destPort))
cancel()
@ -477,12 +486,7 @@ func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) {
tc := gonet.NewTCPConn(&wq, ep)
defer tc.Close()
r.Complete(false)
if destPort == 53 && lp.c.IsLocalIP(destIP) {
// TODO(bradfitz,maisem): do TCP DNS server here.
// ...
}
r.Complete(false) // accept TCP connection
errc := make(chan error, 2)
go func() { _, err := io.Copy(tc, c); errc <- err }()
@ -705,6 +709,41 @@ func (lp *lpServer) filteredDNSQuery(ctx context.Context, q []byte, family strin
return msg.Pack()
}
func (lp *lpServer) handleTCPDNSQuery(c net.Conn, src netip.AddrPort) {
defer c.Close()
var lenBuf [2]byte
for {
c.SetReadDeadline(time.Now().Add(30 * time.Second))
_, err := io.ReadFull(c, lenBuf[:])
if err != nil {
return
}
n := binary.BigEndian.Uint16(lenBuf[:])
buf := make([]byte, n)
c.SetReadDeadline(time.Now().Add(30 * time.Second))
_, err = io.ReadFull(c, buf[:])
if err != nil {
return
}
res, err := lp.filteredDNSQuery(context.Background(), buf, "tcp", src)
if err != nil {
log.Printf("TCP DNS query error: %v", err)
return
}
binary.BigEndian.PutUint16(lenBuf[:], uint16(len(res)))
c.SetWriteDeadline(time.Now().Add(30 * time.Second))
_, err = c.Write(lenBuf[:])
if err != nil {
return
}
c.SetWriteDeadline(time.Now().Add(30 * time.Second))
_, err = c.Write(res)
if err != nil {
return
}
}
}
// caller owns the raw memory.
func (lp *lpServer) handleDNSUDPQuery(raw []byte) {
var pkt packet.Parsed