diff --git a/cmd/lopower/lopower.go b/cmd/lopower/lopower.go index 17515f3fa..2140d22f9 100644 --- a/cmd/lopower/lopower.go +++ b/cmd/lopower/lopower.go @@ -17,6 +17,7 @@ import ( "path/filepath" "slices" "sync" + "time" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/device" @@ -24,6 +25,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -31,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/waiter" "tailscale.com/net/packet" "tailscale.com/tsnet" "tailscale.com/types/key" @@ -271,8 +274,50 @@ func (lp *lpServer) initNetstack(ctx context.Context) error { return nil } -func (lp *lpServer) acceptTCP(*tcp.ForwarderRequest) { - // TODO +func netaddrIPFromNetstackIP(s tcpip.Address) netip.Addr { + switch s.Len() { + case 4: + return netip.AddrFrom4(s.As4()) + case 16: + return netip.AddrFrom16(s.As16()).Unmap() + } + return netip.Addr{} +} + +func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) { + var wq waiter.Queue + ep, tcpErr := r.CreateEndpoint(&wq) + if tcpErr != nil { + r.Complete(true) + return + } + defer ep.Close() + reqDetails := r.ID() + + clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress) + destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) + destPort := reqDetails.LocalPort + if !clientRemoteIP.IsValid() { + r.Complete(true) // sends a RST + return + } + + dialCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + c, err := lp.tsnet.Dial(dialCtx, "tcp", fmt.Sprintf("%s:%d", destIP, destPort)) + cancel() + if err != nil { + r.Complete(true) // sends a RST + return + } + defer c.Close() + + tc := gonet.NewTCPConn(&wq, ep) + defer tc.Close() + r.Complete(false) + errc := make(chan error, 2) + go func() { _, err := io.Copy(tc, c); errc <- err }() + go func() { _, err := io.Copy(c, tc); errc <- err }() + <-errc } type nsTUN struct {