diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 537ec2098..cd2846934 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -29,14 +29,16 @@ // A Conn routes UDP packets and actively manages a list of its endpoints. // It implements wireguard/device.Bind. type Conn struct { - pconn *RebindingUDPConn - pconnPort uint16 - stunServers []string - derpServer string - startEpUpdate chan struct{} // send to trigger endpoint update - epUpdateCancel func() - epFunc func(endpoints []string) - logf func(format string, args ...interface{}) + pconn *RebindingUDPConn + pconnPort uint16 + stunServers []string + derpServer string + startEpUpdate chan struct{} // send to trigger endpoint update + epFunc func(endpoints []string) + logf func(format string, args ...interface{}) + + epUpdateCtx context.Context // endpoint updater context + epUpdateCancel func() // the func to cancel epUpdateCtx // indexedAddrs is a map of every remote ip:port to a priority // list of endpoint addresses for a peer. @@ -137,6 +139,7 @@ func Listen(opts Options) (*Conn, error) { stunServers: append([]string{}, opts.STUN...), derpServer: opts.DERP, startEpUpdate: make(chan struct{}, 1), + epUpdateCtx: epUpdateCtx, epUpdateCancel: epUpdateCancel, epFunc: opts.endpointsFunc(), logf: log.Printf, @@ -144,7 +147,7 @@ func Listen(opts Options) (*Conn, error) { } c.ignoreSTUNPackets() c.pconn.Reset(packetConn.(*net.UDPConn)) - c.startEpUpdate <- struct{}{} // STUN immediately on start + c.reSTUN() go c.epUpdate(epUpdateCtx) return c, nil } @@ -472,8 +475,7 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { time.Sleep(250 * time.Millisecond) } - // Trigger re-STUN. - c.startEpUpdate <- struct{}{} + c.reSTUN() addr := c.pconn.LocalAddr() if _, err := c.pconn.WriteToUDP(b[:n], addr); err != nil { @@ -501,10 +503,15 @@ func (c *Conn) Close() error { return c.pconn.Close() } +func (c *Conn) reSTUN() { + select { + case c.startEpUpdate <- struct{}{}: + case <-c.epUpdateCtx.Done(): + } +} + func (c *Conn) LinkChange() { - defer func() { - c.startEpUpdate <- struct{}{} // re-STUN - }() + defer c.reSTUN() if c.pconnPort != 0 { c.pconn.mu.Lock()