mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 04:55:31 +00:00
wgengine/magicsock: unify initial bind and rebind
We had two separate code paths for the initial UDP listener bind and any subsequent rebinds. IPv6 got left out of the rebind code. Rather than duplicate it there, unify the two code paths. Then improve the resulting code: * Rebind had nested listen attempts to try the user-specified port first, and then fall back to :0 if that failed. Convert that into a loop. * Initial bind tried only the user-specified port. Rebind tried the user-specified port and 0. But there are actually three ports of interest: The one the user specified, the most recent port in use, and 0. We now try all three in order, as appropriate. * In the extremely rare case in which binding to port 0 fails, use a dummy net.PacketConn whose reads block until close. This will keep the wireguard-go receive func goroutine alive. As a pleasant side-effect of this, if we decide that we need to resuscitate #1796, it will now be much easier. Fixes #1799 Co-authored-by: David Anderson <danderson@tailscale.com> Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
This commit is contained in:
parent
3c7898728d
commit
6f23087175
@ -139,6 +139,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
|
||||
tailscale.com/util/pidowner from tailscale.com/ipn/ipnserver
|
||||
tailscale.com/util/racebuild from tailscale.com/logpolicy
|
||||
tailscale.com/util/systemd from tailscale.com/control/controlclient+
|
||||
tailscale.com/util/uniq from tailscale.com/wgengine/magicsock
|
||||
tailscale.com/util/winutil from tailscale.com/logpolicy+
|
||||
tailscale.com/version from tailscale.com/cmd/tailscaled+
|
||||
tailscale.com/version/distro from tailscale.com/control/controlclient+
|
||||
|
@ -52,6 +52,7 @@
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/types/nettype"
|
||||
"tailscale.com/types/wgkey"
|
||||
"tailscale.com/util/uniq"
|
||||
"tailscale.com/version"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
"tailscale.com/wgengine/wgcfg"
|
||||
@ -2585,11 +2586,11 @@ func (c *Conn) ReSTUN(why string) {
|
||||
}
|
||||
|
||||
func (c *Conn) initialBind() error {
|
||||
if err := c.bind1(&c.pconn4, "udp4"); err != nil {
|
||||
return err
|
||||
if err := c.bindSocket(&c.pconn4, "udp4"); err != nil {
|
||||
return fmt.Errorf("magicsock: initialBind IPv4 failed: %w", err)
|
||||
}
|
||||
c.portMapper.SetLocalPort(c.LocalPort())
|
||||
if err := c.bind1(&c.pconn6, "udp6"); err != nil {
|
||||
if err := c.bindSocket(&c.pconn6, "udp6"); err != nil {
|
||||
c.logf("magicsock: ignoring IPv6 bind failure: %v", err)
|
||||
}
|
||||
return nil
|
||||
@ -2605,66 +2606,82 @@ func (c *Conn) listenPacket(network, host string, port uint16) (net.PacketConn,
|
||||
return netns.Listener().ListenPacket(ctx, network, addr)
|
||||
}
|
||||
|
||||
func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error {
|
||||
// bindSocket initializes rucPtr if necessary and binds a UDP socket to it.
|
||||
// Network indicates the UDP socket type; it must be "udp4" or "udp6".
|
||||
// If rucPtr had an existing UDP socket bound, it closes that socket.
|
||||
// The caller is responsible for informing the portMapper of any changes.
|
||||
func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string) error {
|
||||
host := ""
|
||||
if inTest() && !c.simulatedNetwork {
|
||||
host = "127.0.0.1"
|
||||
if which == "udp6" {
|
||||
if network == "udp6" {
|
||||
host = "::1"
|
||||
}
|
||||
}
|
||||
pc, err := c.listenPacket(which, host, c.port)
|
||||
|
||||
if *rucPtr == nil {
|
||||
*rucPtr = new(RebindingUDPConn)
|
||||
}
|
||||
ruc := *rucPtr
|
||||
|
||||
// Hold the ruc lock the entire time, so that the close+bind is atomic
|
||||
// from the perspective of ruc receive functions.
|
||||
ruc.mu.Lock()
|
||||
defer ruc.mu.Unlock()
|
||||
|
||||
// Build a list of preferred ports.
|
||||
// Best is the port that the user requested.
|
||||
// Second best is the port that is currently in use.
|
||||
// If those fail, fall back to 0.
|
||||
var ports []uint16
|
||||
if c.port != 0 {
|
||||
ports = append(ports, c.port)
|
||||
}
|
||||
if ruc.pconn != nil {
|
||||
curPort := uint16(ruc.localAddrLocked().Port)
|
||||
ports = append(ports, curPort)
|
||||
}
|
||||
ports = append(ports, 0)
|
||||
// Remove duplicates. (All duplicates are consecutive.)
|
||||
uniq.ModifySlice(&ports, func(i, j int) bool { return ports[i] == ports[j] })
|
||||
|
||||
var pconn net.PacketConn
|
||||
for _, port := range ports {
|
||||
// Close the existing conn, in case it is sitting on the port we want.
|
||||
err := ruc.closeLocked()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, errNilPConn) {
|
||||
c.logf("magicsock: bindSocket %v close failed: %v", network, err)
|
||||
}
|
||||
// Open a new one with the desired port.
|
||||
pconn, err = c.listenPacket(network, host, port)
|
||||
if err != nil {
|
||||
c.logf("magicsock: bind(%s/%v): %v", which, c.port, err)
|
||||
return fmt.Errorf("magicsock: bind: %s/%d: %v", which, c.port, err)
|
||||
c.logf("magicsock: unable to bind %v port %d: %v", network, port, err)
|
||||
continue
|
||||
}
|
||||
if *ruc == nil {
|
||||
*ruc = new(RebindingUDPConn)
|
||||
}
|
||||
(*ruc).Reset(pc)
|
||||
// Success.
|
||||
ruc.pconn = pconn
|
||||
return nil
|
||||
}
|
||||
|
||||
// Failed to bind, including on port 0 (!).
|
||||
// Set pconn to a dummy conn whose reads block until closed.
|
||||
// This keeps the receive funcs alive for a future in which
|
||||
// we get a link change and we can try binding again.
|
||||
ruc.pconn = newBlockForeverConn()
|
||||
return fmt.Errorf("failed to bind any ports (tried %v)", ports)
|
||||
}
|
||||
|
||||
// Rebind closes and re-binds the UDP sockets.
|
||||
// It should be followed by a call to ReSTUN.
|
||||
func (c *Conn) Rebind() {
|
||||
host := ""
|
||||
if inTest() && !c.simulatedNetwork {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
|
||||
if c.port != 0 {
|
||||
c.pconn4.mu.Lock()
|
||||
oldPort := c.pconn4.localAddrLocked().Port
|
||||
if err := c.pconn4.pconn.Close(); err != nil {
|
||||
c.logf("magicsock: link change close failed: %v", err)
|
||||
}
|
||||
packetConn, err := c.listenPacket("udp4", host, c.port)
|
||||
if err != nil {
|
||||
c.logf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.port, err)
|
||||
packetConn, err = c.listenPacket("udp4", host, 0)
|
||||
if err != nil {
|
||||
c.logf("magicsock: link change failed to bind random port: %v", err)
|
||||
c.pconn4.mu.Unlock()
|
||||
if err := c.bindSocket(&c.pconn4, "udp4"); err != nil {
|
||||
c.logf("magicsock: Rebind IPv4 failed: %w", err)
|
||||
return
|
||||
}
|
||||
newPort := packetConn.LocalAddr().(*net.UDPAddr).Port
|
||||
c.logf("magicsock: link change rebound port: from %v to %v (failed to get %v)", oldPort, newPort, c.port)
|
||||
} else {
|
||||
c.logf("magicsock: link change rebound port from %d to %d", oldPort, c.port)
|
||||
}
|
||||
c.pconn4.pconn = packetConn
|
||||
c.pconn4.mu.Unlock()
|
||||
} else {
|
||||
c.logf("magicsock: link change, binding new port")
|
||||
packetConn, err := c.listenPacket("udp4", host, 0)
|
||||
if err != nil {
|
||||
c.logf("magicsock: link change failed to bind new port: %v", err)
|
||||
return
|
||||
}
|
||||
c.pconn4.Reset(packetConn)
|
||||
}
|
||||
c.portMapper.SetLocalPort(c.LocalPort())
|
||||
if err := c.bindSocket(&c.pconn6, "udp6"); err != nil {
|
||||
c.logf("magicsock: Rebind ignoring IPv6 bind failure: %v", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.closeAllDerpLocked("rebind")
|
||||
@ -2764,17 +2781,6 @@ func (c *RebindingUDPConn) currentConn() net.PacketConn {
|
||||
return c.pconn
|
||||
}
|
||||
|
||||
func (c *RebindingUDPConn) Reset(pconn net.PacketConn) {
|
||||
c.mu.Lock()
|
||||
old := c.pconn
|
||||
c.pconn = pconn
|
||||
c.mu.Unlock()
|
||||
|
||||
if old != nil {
|
||||
old.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrom reads a packet from c into b.
|
||||
// It returns the number of bytes copied and the source address.
|
||||
func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
|
||||
@ -2844,9 +2850,20 @@ func (c *RebindingUDPConn) localAddrLocked() *net.UDPAddr {
|
||||
return c.pconn.LocalAddr().(*net.UDPAddr)
|
||||
}
|
||||
|
||||
// errNilPConn is returned by RebindingUDPConn.Close when there is no current pconn.
|
||||
// It is for internal use only and should not be returned to users.
|
||||
var errNilPConn = errors.New("nil pconn")
|
||||
|
||||
func (c *RebindingUDPConn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.closeLocked()
|
||||
}
|
||||
|
||||
func (c *RebindingUDPConn) closeLocked() error {
|
||||
if c.pconn == nil {
|
||||
return errNilPConn
|
||||
}
|
||||
return c.pconn.Close()
|
||||
}
|
||||
|
||||
@ -2890,6 +2907,52 @@ func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func newBlockForeverConn() *blockForeverConn {
|
||||
c := new(blockForeverConn)
|
||||
c.cond = sync.NewCond(&c.mu)
|
||||
return c
|
||||
}
|
||||
|
||||
// blockForeverConn is a net.PacketConn whose reads block until it is closed.
|
||||
type blockForeverConn struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *blockForeverConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
c.mu.Lock()
|
||||
for !c.closed {
|
||||
c.cond.Wait()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
|
||||
func (c *blockForeverConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
// Silently drop writes.
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *blockForeverConn) LocalAddr() net.Addr {
|
||||
// Return a *net.UDPAddr because lots of code assumes that it will.
|
||||
return new(net.UDPAddr)
|
||||
}
|
||||
|
||||
func (c *blockForeverConn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return net.ErrClosed
|
||||
}
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *blockForeverConn) SetDeadline(t time.Time) error { return errors.New("unimplemented") }
|
||||
func (c *blockForeverConn) SetReadDeadline(t time.Time) error { return errors.New("unimplemented") }
|
||||
func (c *blockForeverConn) SetWriteDeadline(t time.Time) error { return errors.New("unimplemented") }
|
||||
|
||||
// simpleDur rounds d such that it stringifies to something short.
|
||||
func simpleDur(d time.Duration) time.Duration {
|
||||
if d < time.Second {
|
||||
|
Loading…
Reference in New Issue
Block a user