diff --git a/cmd/lopower/lopower.go b/cmd/lopower/lopower.go index a8599a5f6..8fe1f4826 100644 --- a/cmd/lopower/lopower.go +++ b/cmd/lopower/lopower.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "encoding/base64" + "encoding/binary" "encoding/hex" "encoding/json" "errors" @@ -465,6 +466,14 @@ func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) { defer ep.Close() ep.SocketOptions().SetKeepAlive(true) + if destPort == 53 && lp.c.IsLocalIP(destIP) { + tc := gonet.NewTCPConn(&wq, ep) + defer tc.Close() + r.Complete(false) // accept TCP connection + lp.handleTCPDNSQuery(tc, netip.AddrPortFrom(clientRemoteIP, reqDetails.RemotePort)) + return + } + dialCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) c, err := lp.tsnet.Dial(dialCtx, "tcp", fmt.Sprintf("%s:%d", destIP, destPort)) cancel() @@ -477,12 +486,7 @@ func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) { tc := gonet.NewTCPConn(&wq, ep) defer tc.Close() - r.Complete(false) - - if destPort == 53 && lp.c.IsLocalIP(destIP) { - // TODO(bradfitz,maisem): do TCP DNS server here. - // ... - } + r.Complete(false) // accept TCP connection errc := make(chan error, 2) go func() { _, err := io.Copy(tc, c); errc <- err }() @@ -705,6 +709,41 @@ func (lp *lpServer) filteredDNSQuery(ctx context.Context, q []byte, family strin return msg.Pack() } +func (lp *lpServer) handleTCPDNSQuery(c net.Conn, src netip.AddrPort) { + defer c.Close() + var lenBuf [2]byte + for { + c.SetReadDeadline(time.Now().Add(30 * time.Second)) + _, err := io.ReadFull(c, lenBuf[:]) + if err != nil { + return + } + n := binary.BigEndian.Uint16(lenBuf[:]) + buf := make([]byte, n) + c.SetReadDeadline(time.Now().Add(30 * time.Second)) + _, err = io.ReadFull(c, buf[:]) + if err != nil { + return + } + res, err := lp.filteredDNSQuery(context.Background(), buf, "tcp", src) + if err != nil { + log.Printf("TCP DNS query error: %v", err) + return + } + binary.BigEndian.PutUint16(lenBuf[:], uint16(len(res))) + c.SetWriteDeadline(time.Now().Add(30 * time.Second)) + _, err = c.Write(lenBuf[:]) + if err != nil { + return + } + c.SetWriteDeadline(time.Now().Add(30 * time.Second)) + _, err = c.Write(res) + if err != nil { + return + } + } +} + // caller owns the raw memory. func (lp *lpServer) handleDNSUDPQuery(raw []byte) { var pkt packet.Parsed