wgengine/magicsock: unify initial bind and rebind

We had two separate code paths for the initial UDP listener bind
and any subsequent rebinds.

IPv6 got left out of the rebind code.
Rather than duplicate it there, unify the two code paths.
Then improve the resulting code:

* Rebind had nested listen attempts to try the user-specified port first,
  and then fall back to :0 if that failed. Convert that into a loop.
* Initial bind tried only the user-specified port.
  Rebind tried the user-specified port and 0.
  But there are actually three ports of interest:
  The one the user specified, the most recent port in use, and 0.
  We now try all three in order, as appropriate.
* In the extremely rare case in which binding to port 0 fails,
  use a dummy net.PacketConn whose reads block until close.
  This will keep the wireguard-go receive func goroutine alive.

As a pleasant side-effect of this, if we decide that
we need to resuscitate #1796, it will now be much easier.

Fixes #1799

Co-authored-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
This commit is contained in:
Josh Bleecher Snyder 2021-04-27 14:40:29 -07:00
parent 3c7898728d
commit 6f23087175
2 changed files with 123 additions and 59 deletions

View File

@ -139,6 +139,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/util/pidowner from tailscale.com/ipn/ipnserver tailscale.com/util/pidowner from tailscale.com/ipn/ipnserver
tailscale.com/util/racebuild from tailscale.com/logpolicy tailscale.com/util/racebuild from tailscale.com/logpolicy
tailscale.com/util/systemd from tailscale.com/control/controlclient+ tailscale.com/util/systemd from tailscale.com/control/controlclient+
tailscale.com/util/uniq from tailscale.com/wgengine/magicsock
tailscale.com/util/winutil from tailscale.com/logpolicy+ tailscale.com/util/winutil from tailscale.com/logpolicy+
tailscale.com/version from tailscale.com/cmd/tailscaled+ tailscale.com/version from tailscale.com/cmd/tailscaled+
tailscale.com/version/distro from tailscale.com/control/controlclient+ tailscale.com/version/distro from tailscale.com/control/controlclient+

View File

@ -52,6 +52,7 @@
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
"tailscale.com/types/nettype" "tailscale.com/types/nettype"
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
"tailscale.com/util/uniq"
"tailscale.com/version" "tailscale.com/version"
"tailscale.com/wgengine/monitor" "tailscale.com/wgengine/monitor"
"tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg"
@ -2585,11 +2586,11 @@ func (c *Conn) ReSTUN(why string) {
} }
func (c *Conn) initialBind() error { func (c *Conn) initialBind() error {
if err := c.bind1(&c.pconn4, "udp4"); err != nil { if err := c.bindSocket(&c.pconn4, "udp4"); err != nil {
return err return fmt.Errorf("magicsock: initialBind IPv4 failed: %w", err)
} }
c.portMapper.SetLocalPort(c.LocalPort()) c.portMapper.SetLocalPort(c.LocalPort())
if err := c.bind1(&c.pconn6, "udp6"); err != nil { if err := c.bindSocket(&c.pconn6, "udp6"); err != nil {
c.logf("magicsock: ignoring IPv6 bind failure: %v", err) c.logf("magicsock: ignoring IPv6 bind failure: %v", err)
} }
return nil return nil
@ -2605,66 +2606,82 @@ func (c *Conn) listenPacket(network, host string, port uint16) (net.PacketConn,
return netns.Listener().ListenPacket(ctx, network, addr) return netns.Listener().ListenPacket(ctx, network, addr)
} }
func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error { // bindSocket initializes rucPtr if necessary and binds a UDP socket to it.
// Network indicates the UDP socket type; it must be "udp4" or "udp6".
// If rucPtr had an existing UDP socket bound, it closes that socket.
// The caller is responsible for informing the portMapper of any changes.
func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string) error {
host := "" host := ""
if inTest() && !c.simulatedNetwork { if inTest() && !c.simulatedNetwork {
host = "127.0.0.1" host = "127.0.0.1"
if which == "udp6" { if network == "udp6" {
host = "::1" host = "::1"
} }
} }
pc, err := c.listenPacket(which, host, c.port)
if err != nil { if *rucPtr == nil {
c.logf("magicsock: bind(%s/%v): %v", which, c.port, err) *rucPtr = new(RebindingUDPConn)
return fmt.Errorf("magicsock: bind: %s/%d: %v", which, c.port, err)
} }
if *ruc == nil { ruc := *rucPtr
*ruc = new(RebindingUDPConn)
// Hold the ruc lock the entire time, so that the close+bind is atomic
// from the perspective of ruc receive functions.
ruc.mu.Lock()
defer ruc.mu.Unlock()
// Build a list of preferred ports.
// Best is the port that the user requested.
// Second best is the port that is currently in use.
// If those fail, fall back to 0.
var ports []uint16
if c.port != 0 {
ports = append(ports, c.port)
} }
(*ruc).Reset(pc) if ruc.pconn != nil {
return nil curPort := uint16(ruc.localAddrLocked().Port)
ports = append(ports, curPort)
}
ports = append(ports, 0)
// Remove duplicates. (All duplicates are consecutive.)
uniq.ModifySlice(&ports, func(i, j int) bool { return ports[i] == ports[j] })
var pconn net.PacketConn
for _, port := range ports {
// Close the existing conn, in case it is sitting on the port we want.
err := ruc.closeLocked()
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, errNilPConn) {
c.logf("magicsock: bindSocket %v close failed: %v", network, err)
}
// Open a new one with the desired port.
pconn, err = c.listenPacket(network, host, port)
if err != nil {
c.logf("magicsock: unable to bind %v port %d: %v", network, port, err)
continue
}
// Success.
ruc.pconn = pconn
return nil
}
// Failed to bind, including on port 0 (!).
// Set pconn to a dummy conn whose reads block until closed.
// This keeps the receive funcs alive for a future in which
// we get a link change and we can try binding again.
ruc.pconn = newBlockForeverConn()
return fmt.Errorf("failed to bind any ports (tried %v)", ports)
} }
// Rebind closes and re-binds the UDP sockets. // Rebind closes and re-binds the UDP sockets.
// It should be followed by a call to ReSTUN. // It should be followed by a call to ReSTUN.
func (c *Conn) Rebind() { func (c *Conn) Rebind() {
host := "" if err := c.bindSocket(&c.pconn4, "udp4"); err != nil {
if inTest() && !c.simulatedNetwork { c.logf("magicsock: Rebind IPv4 failed: %w", err)
host = "127.0.0.1" return
}
if c.port != 0 {
c.pconn4.mu.Lock()
oldPort := c.pconn4.localAddrLocked().Port
if err := c.pconn4.pconn.Close(); err != nil {
c.logf("magicsock: link change close failed: %v", err)
}
packetConn, err := c.listenPacket("udp4", host, c.port)
if err != nil {
c.logf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.port, err)
packetConn, err = c.listenPacket("udp4", host, 0)
if err != nil {
c.logf("magicsock: link change failed to bind random port: %v", err)
c.pconn4.mu.Unlock()
return
}
newPort := packetConn.LocalAddr().(*net.UDPAddr).Port
c.logf("magicsock: link change rebound port: from %v to %v (failed to get %v)", oldPort, newPort, c.port)
} else {
c.logf("magicsock: link change rebound port from %d to %d", oldPort, c.port)
}
c.pconn4.pconn = packetConn
c.pconn4.mu.Unlock()
} else {
c.logf("magicsock: link change, binding new port")
packetConn, err := c.listenPacket("udp4", host, 0)
if err != nil {
c.logf("magicsock: link change failed to bind new port: %v", err)
return
}
c.pconn4.Reset(packetConn)
} }
c.portMapper.SetLocalPort(c.LocalPort()) c.portMapper.SetLocalPort(c.LocalPort())
if err := c.bindSocket(&c.pconn6, "udp6"); err != nil {
c.logf("magicsock: Rebind ignoring IPv6 bind failure: %v", err)
}
c.mu.Lock() c.mu.Lock()
c.closeAllDerpLocked("rebind") c.closeAllDerpLocked("rebind")
@ -2764,17 +2781,6 @@ func (c *RebindingUDPConn) currentConn() net.PacketConn {
return c.pconn return c.pconn
} }
func (c *RebindingUDPConn) Reset(pconn net.PacketConn) {
c.mu.Lock()
old := c.pconn
c.pconn = pconn
c.mu.Unlock()
if old != nil {
old.Close()
}
}
// ReadFrom reads a packet from c into b. // ReadFrom reads a packet from c into b.
// It returns the number of bytes copied and the source address. // It returns the number of bytes copied and the source address.
func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
@ -2844,9 +2850,20 @@ func (c *RebindingUDPConn) localAddrLocked() *net.UDPAddr {
return c.pconn.LocalAddr().(*net.UDPAddr) return c.pconn.LocalAddr().(*net.UDPAddr)
} }
// errNilPConn is returned by RebindingUDPConn.Close when there is no current pconn.
// It is for internal use only and should not be returned to users.
var errNilPConn = errors.New("nil pconn")
func (c *RebindingUDPConn) Close() error { func (c *RebindingUDPConn) Close() error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
return c.closeLocked()
}
func (c *RebindingUDPConn) closeLocked() error {
if c.pconn == nil {
return errNilPConn
}
return c.pconn.Close() return c.pconn.Close()
} }
@ -2890,6 +2907,52 @@ func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
} }
} }
func newBlockForeverConn() *blockForeverConn {
c := new(blockForeverConn)
c.cond = sync.NewCond(&c.mu)
return c
}
// blockForeverConn is a net.PacketConn whose reads block until it is closed.
type blockForeverConn struct {
mu sync.Mutex
cond *sync.Cond
closed bool
}
func (c *blockForeverConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
c.mu.Lock()
for !c.closed {
c.cond.Wait()
}
c.mu.Unlock()
return 0, nil, net.ErrClosed
}
func (c *blockForeverConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
// Silently drop writes.
return len(p), nil
}
func (c *blockForeverConn) LocalAddr() net.Addr {
// Return a *net.UDPAddr because lots of code assumes that it will.
return new(net.UDPAddr)
}
func (c *blockForeverConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return net.ErrClosed
}
c.closed = true
return nil
}
func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") }
func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") }
func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") }
// simpleDur rounds d such that it stringifies to something short. // simpleDur rounds d such that it stringifies to something short.
func simpleDur(d time.Duration) time.Duration { func simpleDur(d time.Duration) time.Duration {
if d < time.Second { if d < time.Second {