wgengine/magicsock: fix panic when rebinding fails

We would replace the existing real implementation of nettype.PacketConn
with a blockForeverConn, but that violates the contract of atomic.Value
(where the type cannot change). Fix by switching to a pointer value
(atomic.Pointer[nettype.PacketConn]).

A longstanding issue, but became more prevalent when we started binding
connections to interfaces on macOS and iOS (), which could lead to
the bind call failing if the interface was no longer available.

Fixes 

Signed-off-by: Mihai Parparita <mihai@tailscale.com>
This commit is contained in:
Mihai Parparita 2022-12-06 17:42:40 -08:00 committed by Mihai Parparita
parent e27f4f022e
commit bdc45b9066
2 changed files with 29 additions and 21 deletions

@ -3008,13 +3008,14 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) {
// RebindingUDPConn is a UDP socket that can be re-bound.
// Unix has no notion of re-binding a socket, so we swap it out for a new one.
type RebindingUDPConn struct {
// pconnAtomic is the same as pconn, but doesn't require acquiring mu. It's
// used for reads/writes and only upon failure do the reads/writes then
// check pconn (after acquiring mu) to see if there's been a rebind
// meanwhile.
// pconnAtomic is a pointer to the value stored in pconn, but doesn't
// require acquiring mu. It's used for reads/writes and only upon failure
// do the reads/writes then check pconn (after acquiring mu) to see if
// there's been a rebind meanwhile.
// pconn isn't really needed, but makes some of the code simpler
// to keep it in a type safe form.
pconnAtomic syncs.AtomicValue[nettype.PacketConn]
// to keep it distinct.
// Neither is expected to be nil, sockets are bound on creation.
pconnAtomic atomic.Pointer[nettype.PacketConn]
mu sync.Mutex // held while changing pconn (and pconnAtomic)
pconn nettype.PacketConn
@ -3023,7 +3024,7 @@ type RebindingUDPConn struct {
func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn) {
c.pconn = p
c.pconnAtomic.Store(p)
c.pconnAtomic.Store(&p)
c.port = uint16(c.localAddrLocked().Port)
}
@ -3038,7 +3039,7 @@ func (c *RebindingUDPConn) currentConn() nettype.PacketConn {
// It returns the number of bytes copied and the source address.
func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
for {
pconn := c.pconnAtomic.Load()
pconn := *c.pconnAtomic.Load()
n, addr, err := pconn.ReadFrom(b)
if err != nil && pconn != c.currentConn() {
continue
@ -3056,7 +3057,7 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
// when c's underlying connection is a net.UDPConn.
func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netip.AddrPort, err error) {
for {
pconn := c.pconnAtomic.Load()
pconn := *c.pconnAtomic.Load()
// Optimization: Treat *net.UDPConn specially.
// This lets us avoid allocations by calling ReadFromUDPAddrPort.
@ -3122,13 +3123,10 @@ func (c *RebindingUDPConn) closeLocked() error {
func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
for {
pconn := c.pconnAtomic.Load()
pconn := *c.pconnAtomic.Load()
n, err := pconn.WriteTo(b, addr)
if err != nil {
if pconn != c.currentConn() {
continue
}
if err != nil && pconn != c.currentConn() {
continue
}
return n, err
}
@ -3136,13 +3134,10 @@ func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
for {
pconn := c.pconnAtomic.Load()
pconn := *c.pconnAtomic.Load()
n, err := pconn.WriteToUDPAddrPort(b, addr)
if err != nil {
if pconn != c.currentConn() {
continue
}
if err != nil && pconn != c.currentConn() {
continue
}
return n, err
}

@ -1803,3 +1803,16 @@ func TestDiscoMagicMatches(t *testing.T) {
t.Errorf("last 2 bytes of disco magic don't match, got %v want %v", discoMagic2, m2)
}
}
func TestRebindingUDPConn(t *testing.T) {
// Test that RebindingUDPConn can be re-bound to different connection
// types.
c := RebindingUDPConn{}
realConn, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer realConn.Close()
c.setConnLocked(realConn.(nettype.PacketConn))
c.setConnLocked(newBlockForeverConn())
}