wgengine/netstack: stop UDP forwarding when one side dies

Updates #504

Updates #707

Signed-off-by: Naman Sood <mail@nsood.in>
This commit is contained in:
Naman Sood 2021-03-08 13:43:01 -05:00
parent 7325b5a7ba
commit 4c80344e27

View File

@ -419,7 +419,7 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalA
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
timer := time.AfterFunc(2*time.Minute, func() { timer := time.AfterFunc(2*time.Minute, func() {
ns.logf("netstack: forwarder UDP connection on port %v closed", port) ns.logf("netstack: UDP session between %s and %s timed out", clientRemoteAddr, backendRemoteAddr)
cancel() cancel()
client.Close() client.Close()
backendConn.Close() backendConn.Close()
@ -427,16 +427,17 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalA
extend := func() { extend := func() {
timer.Reset(2 * time.Minute) timer.Reset(2 * time.Minute)
} }
startPacketCopy(ctx, client, &net.UDPAddr{ startPacketCopy(ctx, cancel, client, &net.UDPAddr{
IP: net.ParseIP(clientRemoteAddr.Addr.String()), IP: net.ParseIP(clientRemoteAddr.Addr.String()),
Port: int(clientRemoteAddr.Port), Port: int(clientRemoteAddr.Port),
}, backendConn, ns.logf, extend) }, backendConn, ns.logf, extend)
startPacketCopy(ctx, backendConn, backendRemoteAddr, client, ns.logf, extend) startPacketCopy(ctx, cancel, backendConn, backendRemoteAddr, client, ns.logf, extend)
} }
func startPacketCopy(ctx context.Context, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) { func startPacketCopy(ctx context.Context, cancel context.CancelFunc, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) {
go func() { go func() {
defer cancel() // tear down the other direction's copy
pkt := make([]byte, mtu) pkt := make([]byte, mtu)
for { for {
select { select {
@ -457,7 +458,9 @@ func startPacketCopy(ctx context.Context, dst net.PacketConn, dstAddr net.Addr,
} }
return return
} }
logf("[v2] wrote UDP packet %s -> %s", srcAddr, dstAddr) if debugNetstack {
logf("[v2] wrote UDP packet %s -> %s", srcAddr, dstAddr)
}
extend() extend()
} }
} }