diff --git a/wgengine/magicsock/legacy.go b/wgengine/magicsock/legacy.go index 7620cc1ce..eb4ce9da6 100644 --- a/wgengine/magicsock/legacy.go +++ b/wgengine/magicsock/legacy.go @@ -53,7 +53,6 @@ func (c *Conn) createLegacyEndpointLocked(pk key.Public, addrs string) (conn.End return nil, fmt.Errorf("bogus address %q", ep) } a.ipPorts = append(a.ipPorts, ipp) - a.addrs = append(a.addrs, *ipp.UDPAddr()) } } @@ -84,14 +83,14 @@ func (c *Conn) createLegacyEndpointLocked(pk key.Public, addrs string) (conn.End return a, nil } -func (c *Conn) findLegacyEndpointLocked(ipp netaddr.IPPort, addr *net.UDPAddr, packet []byte) conn.Endpoint { +func (c *Conn) findLegacyEndpointLocked(ipp netaddr.IPPort, packet []byte) conn.Endpoint { if c.disableLegacy { return nil } // Pre-disco: look up their addrSet. if as, ok := c.addrsByUDP[ipp]; ok { - as.updateDst(addr) + as.updateDst(ipp) return as } @@ -100,7 +99,7 @@ func (c *Conn) findLegacyEndpointLocked(ipp netaddr.IPPort, addr *net.UDPAddr, p // know. If this is a handshake packet, we can try to identify the // peer in question. if as := c.peerFromPacketLocked(packet); as != nil { - as.updateDst(addr) + as.updateDst(ipp) return as } @@ -268,14 +267,6 @@ func (as *addrSet) appendDests(dsts []netaddr.IPPort, b []byte) (_ []netaddr.IPP as.lastSend = now - // Some internal invariant checks. - if len(as.addrs) != len(as.ipPorts) { - panic(fmt.Sprintf("lena %d != leni %d", len(as.addrs), len(as.ipPorts))) - } - if n1, n2 := as.roamAddr != nil, as.roamAddrStd != nil; n1 != n2 { - panic(fmt.Sprintf("roamnil %v != roamstdnil %v", n1, n2)) - } - // Spray logic. // // After exchanging a handshake with a peer, we send some outbound @@ -320,8 +311,8 @@ func (as *addrSet) appendDests(dsts []netaddr.IPPort, b []byte) (_ []netaddr.IPP // roamAddr should be special like this. dsts = append(dsts, *as.roamAddr) case as.curAddr != -1: - if as.curAddr >= len(as.addrs) { - as.Logf("[unexpected] magicsock bug: as.curAddr >= len(as.addrs): %d >= %d", as.curAddr, len(as.addrs)) + if as.curAddr >= len(as.ipPorts) { + as.Logf("[unexpected] magicsock bug: as.curAddr >= len(as.ipPorts): %d >= %d", as.curAddr, len(as.ipPorts)) break } // No roaming addr, but we've seen packets from a known peer @@ -352,15 +343,14 @@ func (as *addrSet) appendDests(dsts []netaddr.IPPort, b []byte) (_ []netaddr.IPP type addrSet struct { publicKey key.Public // peer public key used for DERP communication - // addrs is an ordered priority list provided by wgengine, + // ipPorts is an ordered priority list provided by wgengine, // sorted from expensive+slow+reliable at the begnining to // fast+cheap at the end. More concretely, it's typically: // // [DERP fakeip:node, Global IP:port, LAN ip:port] // // But there could be multiple or none of each. - addrs []net.UDPAddr - ipPorts []netaddr.IPPort // same as addrs, in different form + ipPorts []netaddr.IPPort // clock, if non-nil, is used in tests instead of time.Now. clock func() time.Time @@ -376,8 +366,7 @@ type addrSet struct { // this should hopefully never be used (or at least used // rarely) in the case that all the components of Tailscale // are correctly learning/sharing the network map details. - roamAddr *netaddr.IPPort - roamAddrStd *net.UDPAddr + roamAddr *netaddr.IPPort // curAddr is an index into addrs of the highest-priority // address a valid packet has been received from so far. @@ -400,9 +389,9 @@ type addrSet struct { // derpID returns this addrSet's home DERP node, or 0 if none is found. func (as *addrSet) derpID() int { - for _, ua := range as.addrs { - if ua.IP.Equal(derpMagicIP) { - return ua.Port + for _, ua := range as.ipPorts { + if ua.IP == derpMagicIPAddr { + return int(ua.Port) } } return 0 @@ -424,7 +413,7 @@ func (a *addrSet) dst() netaddr.IPPort { if a.roamAddr != nil { return *a.roamAddr } - if len(a.addrs) == 0 { + if len(a.ipPorts) == 0 { return noAddr } i := a.curAddr @@ -439,7 +428,7 @@ func (a *addrSet) DstToBytes() []byte { } func (a *addrSet) DstToString() string { var addrs []string - for _, addr := range a.addrs { + for _, addr := range a.ipPorts { addrs = append(addrs, addr.String()) } @@ -459,8 +448,8 @@ func (a *addrSet) ClearSrc() {} // updateDst records receipt of a packet from new. This is used to // potentially update the transmit address used for this addrSet. -func (a *addrSet) updateDst(new *net.UDPAddr) error { - if new.IP.Equal(derpMagicIP) { +func (a *addrSet) updateDst(new netaddr.IPPort) error { + if new.IP == derpMagicIPAddr { // Never consider DERP addresses as a viable candidate for // either curAddr or roamAddr. It's only ever a last resort // choice, never a preferred choice. @@ -471,25 +460,20 @@ func (a *addrSet) updateDst(new *net.UDPAddr) error { a.mu.Lock() defer a.mu.Unlock() - if a.roamAddrStd != nil && equalUDPAddr(new, a.roamAddrStd) { + if a.roamAddr != nil && new == *a.roamAddr { // Packet from the current roaming address, no logging. // This is a hot path for established connections. return nil } - if a.roamAddr == nil && a.curAddr >= 0 && equalUDPAddr(new, &a.addrs[a.curAddr]) { + if a.roamAddr == nil && a.curAddr >= 0 && new == a.ipPorts[a.curAddr] { // Packet from current-priority address, no logging. // This is a hot path for established connections. return nil } - newa, ok := netaddr.FromStdAddr(new.IP, new.Port, new.Zone) - if !ok { - return nil - } - index := -1 - for i := range a.addrs { - if equalUDPAddr(new, &a.addrs[i]) { + for i := range a.ipPorts { + if new == a.ipPorts[i] { index = i break } @@ -499,7 +483,7 @@ func (a *addrSet) updateDst(new *net.UDPAddr) error { pk := publicKey.ShortString() old := "" if a.curAddr >= 0 { - old = a.addrs[a.curAddr].String() + old = a.ipPorts[a.curAddr].String() } switch { @@ -509,18 +493,16 @@ func (a *addrSet) updateDst(new *net.UDPAddr) error { } else { a.Logf("magicsock: rx %s from roaming address %s, replaces roaming address %s", pk, new, a.roamAddr) } - a.roamAddr = &newa - a.roamAddrStd = new + a.roamAddr = &new case a.roamAddr != nil: a.Logf("magicsock: rx %s from known %s (%d), replaces roaming address %s", pk, new, index, a.roamAddr) a.roamAddr = nil - a.roamAddrStd = nil a.curAddr = index a.loggedLogPriMask = 0 case a.curAddr == -1: - a.Logf("magicsock: rx %s from %s (%d/%d), set as new priority", pk, new, index, len(a.addrs)) + a.Logf("magicsock: rx %s from %s (%d/%d), set as new priority", pk, new, index, len(a.ipPorts)) a.curAddr = index a.loggedLogPriMask = 0 @@ -531,7 +513,7 @@ func (a *addrSet) updateDst(new *net.UDPAddr) error { } default: // index > a.curAddr - a.Logf("magicsock: rx %s from %s (%d/%d), replaces old priority %s", pk, new, index, len(a.addrs), old) + a.Logf("magicsock: rx %s from %s (%d/%d), replaces old priority %s", pk, new, index, len(a.ipPorts), old) a.curAddr = index a.loggedLogPriMask = 0 } @@ -539,10 +521,6 @@ func (a *addrSet) updateDst(new *net.UDPAddr) error { return nil } -func equalUDPAddr(x, y *net.UDPAddr) bool { - return x.Port == y.Port && x.IP.Equal(y.IP) -} - func (a *addrSet) String() string { a.mu.Lock() defer a.mu.Unlock() @@ -551,9 +529,9 @@ func (a *addrSet) String() string { buf.WriteByte('[') if a.roamAddr != nil { buf.WriteString("roam:") - sbPrintAddr(buf, *a.roamAddrStd) + sbPrintAddr(buf, *a.roamAddr) } - for i, addr := range a.addrs { + for i, addr := range a.ipPorts { if i > 0 || a.roamAddr != nil { buf.WriteString(", ") } @@ -572,8 +550,8 @@ func (as *addrSet) populatePeerStatus(ps *ipnstate.PeerStatus) { defer as.mu.Unlock() ps.LastWrite = as.lastSend - for i, ua := range as.addrs { - if ua.IP.Equal(derpMagicIP) { + for i, ua := range as.ipPorts { + if ua.IP == derpMagicIPAddr { continue } uaStr := ua.String() @@ -583,7 +561,7 @@ func (as *addrSet) populatePeerStatus(ps *ipnstate.PeerStatus) { } } if as.roamAddr != nil { - ps.CurAddr = udpAddrDebugString(*as.roamAddrStd) + ps.CurAddr = ippDebugString(*as.roamAddr) } } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index c3c2c9027..afeeb8323 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -348,8 +348,7 @@ func (c *Conn) addDerpPeerRoute(peer key.Public, derpID int, dc *derphttp.Client // Mnemonic: 3.3.40 are numbers above the keys D, E, R, P. const DerpMagicIP = "127.3.3.40" -var derpMagicIP = net.ParseIP(DerpMagicIP).To4() -var derpMagicIPAddr = netaddr.IPv4(127, 3, 3, 40) +var derpMagicIPAddr = netaddr.MustParseIP(DerpMagicIP) // activeDerp contains fields for an active DERP connection. type activeDerp struct { @@ -1539,7 +1538,6 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan // findEndpoint maps from a UDP address to a WireGuard endpoint, for // ReceiveIPv4/ReceiveIPv6. -// The provided addr and ipp must match. // // TODO(bradfitz): add a fast path that returns nil here for normal // wireguard-go transport packets; wireguard-go only uses this @@ -1547,7 +1545,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan // Endpoint to find the UDPAddr to return to wireguard anyway, so no // benefit unless we can, say, always return the same fake UDPAddr for // all packets. -func (c *Conn) findEndpoint(ipp netaddr.IPPort, addr *net.UDPAddr, packet []byte) conn.Endpoint { +func (c *Conn) findEndpoint(ipp netaddr.IPPort, packet []byte) conn.Endpoint { c.mu.Lock() defer c.mu.Unlock() @@ -1559,10 +1557,7 @@ func (c *Conn) findEndpoint(ipp netaddr.IPPort, addr *net.UDPAddr, packet []byte } } - if addr == nil { - addr = ipp.UDPAddr() - } - return c.findLegacyEndpointLocked(ipp, addr, packet) + return c.findLegacyEndpointLocked(ipp, packet) } // aLongTimeAgo is a non-zero time, far in the past, used for @@ -1590,7 +1585,12 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, error) { if err != nil { return 0, nil, err } - if ep, ok := c.receiveIP(b[:n], pAddr.(*net.UDPAddr), &c.ippEndpoint6); ok { + udpAddr := pAddr.(*net.UDPAddr) + ipp, ok := netaddr.FromStdAddr(udpAddr.IP, udpAddr.Port, udpAddr.Zone) + if !ok { + continue + } + if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint6); ok { return n, ep, nil } } @@ -1604,13 +1604,16 @@ func (c *Conn) derpPacketArrived() bool { // In Tailscale's case, that packet might also arrive via DERP. A DERP packet arrival // aborts the pconn4 read deadline to make it fail. func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { - var pAddr net.Addr + var addr net.Addr + var pAddr *net.UDPAddr + var ipp netaddr.IPPort + var ippOK bool for { // Drain DERP queues before reading new UDP packets. if c.derpPacketArrived() { goto ReadDERP } - n, pAddr, err = c.pconn4.ReadFrom(b) + n, addr, err = c.pconn4.ReadFrom(b) if err != nil { // If the pconn4 read failed, the likely reason is a DERP reader received // a packet and interrupted us. @@ -1622,7 +1625,12 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { } return 0, nil, err } - if ep, ok := c.receiveIP(b[:n], pAddr.(*net.UDPAddr), &c.ippEndpoint4); ok { + pAddr, _ = addr.(*net.UDPAddr) + ipp, ippOK = netaddr.FromStdAddr(pAddr.IP, pAddr.Port, pAddr.Zone) + if !ippOK { + continue + } + if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint4); ok { return n, ep, nil } else { continue @@ -1640,11 +1648,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { // // ok is whether this read should be reported up to wireguard-go (our // caller). -func (c *Conn) receiveIP(b []byte, ua *net.UDPAddr, cache *ippEndpointCache) (ep conn.Endpoint, ok bool) { - ipp, ok := netaddr.FromStdAddr(ua.IP, ua.Port, ua.Zone) - if !ok { - return nil, false - } +func (c *Conn) receiveIP(b []byte, ipp netaddr.IPPort, cache *ippEndpointCache) (ep conn.Endpoint, ok bool) { if stun.Is(b) { c.stunReceiveFunc.Load().(func([]byte, netaddr.IPPort))(b, ipp) return nil, false @@ -1662,7 +1666,7 @@ func (c *Conn) receiveIP(b []byte, ua *net.UDPAddr, cache *ippEndpointCache) (ep if cache.ipp == ipp && cache.de != nil && cache.gen == cache.de.numStopAndReset() { ep = cache.de } else { - ep = c.findEndpoint(ipp, ua, b) + ep = c.findEndpoint(ipp, b) if ep == nil { return nil, false } @@ -1759,7 +1763,7 @@ func (c *Conn) receiveIPv4DERP(b []byte) (n int, ep conn.Endpoint, err error) { } else { key := wgkey.Key(dm.src) c.logf("magicsock: DERP packet from unknown key: %s", key.ShortString()) - ep = c.findEndpoint(ipp, nil, b[:n]) + ep = c.findEndpoint(ipp, b[:n]) if ep == nil { return 0, nil, errLoopAgain } @@ -2833,8 +2837,8 @@ func peerShort(k key.Public) string { return k2.ShortString() } -func sbPrintAddr(sb *strings.Builder, a net.UDPAddr) { - is6 := a.IP.To4() == nil +func sbPrintAddr(sb *strings.Builder, a netaddr.IPPort) { + is6 := a.IP.Is6() if is6 { sb.WriteByte('[') } @@ -2931,8 +2935,8 @@ func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) { }) } -func udpAddrDebugString(ua net.UDPAddr) string { - if ua.IP.Equal(derpMagicIP) { +func ippDebugString(ua netaddr.IPPort) string { + if ua.IP == derpMagicIPAddr { return fmt.Sprintf("derp-%d", ua.Port) } return ua.String() diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 274735cd0..ea00fe22e 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -398,18 +398,6 @@ func pickPort(t testing.TB) uint16 { return uint16(conn.LocalAddr().(*net.UDPAddr).Port) } -func TestDerpIPConstant(t *testing.T) { - tstest.PanicOnLog() - tstest.ResourceCheck(t) - - if DerpMagicIP != derpMagicIP.String() { - t.Errorf("str %q != IP %v", DerpMagicIP, derpMagicIP) - } - if len(derpMagicIP) != 4 { - t.Errorf("derpMagicIP is len %d; want 4", len(derpMagicIP)) - } -} - func TestPickDERPFallback(t *testing.T) { tstest.PanicOnLog() tstest.ResourceCheck(t) @@ -452,7 +440,7 @@ func TestPickDERPFallback(t *testing.T) { // But move if peers are elsewhere. const otherNode = 789 c.addrsByKey = map[key.Public]*addrSet{ - key.Public{1}: &addrSet{addrs: []net.UDPAddr{{IP: derpMagicIP, Port: otherNode}}}, + key.Public{1}: &addrSet{ipPorts: []netaddr.IPPort{{IP: derpMagicIPAddr, Port: otherNode}}}, } if got := c.pickDERPFallback(); got != otherNode { t.Errorf("didn't join peers: got %v; want %v", got, someNode) @@ -1156,20 +1144,13 @@ func TestAddrSet(t *testing.T) { tstest.ResourceCheck(t) mustIPPortPtr := func(s string) *netaddr.IPPort { - t.Helper() - ipp, err := netaddr.ParseIPPort(s) - if err != nil { - t.Fatal(err) - } + ipp := netaddr.MustParseIPPort(s) return &ipp } - mustUDPAddr := func(s string) *net.UDPAddr { - return mustIPPortPtr(s).UDPAddr() - } - udpAddrs := func(ss ...string) (ret []net.UDPAddr) { + ipps := func(ss ...string) (ret []netaddr.IPPort) { t.Helper() for _, s := range ss { - ret = append(ret, *mustUDPAddr(s)) + ret = append(ret, netaddr.MustParseIPPort(s)) } return ret } @@ -1201,7 +1182,7 @@ type step struct { // updateDst, if set, does an UpdateDst call and // b+want are ignored. - updateDst *net.UDPAddr + updateDst *netaddr.IPPort b []byte want string // comma-separated @@ -1215,7 +1196,7 @@ type step struct { { name: "reg_packet_no_curaddr", as: &addrSet{ - addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), + ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), curAddr: -1, // unknown roamAddr: nil, }, @@ -1226,7 +1207,7 @@ type step struct { { name: "reg_packet_have_curaddr", as: &addrSet{ - addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), + ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), curAddr: 1, // global IP roamAddr: nil, }, @@ -1237,36 +1218,36 @@ type step struct { { name: "reg_packet_have_roamaddr", as: &addrSet{ - addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), + ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), curAddr: 2, // should be ignored roamAddr: mustIPPortPtr("5.6.7.8:123"), }, steps: []step{ {b: regPacket, want: "5.6.7.8:123"}, - {updateDst: mustUDPAddr("10.0.0.1:123")}, // no more roaming + {updateDst: mustIPPortPtr("10.0.0.1:123")}, // no more roaming {b: regPacket, want: "10.0.0.1:123"}, }, }, { name: "start_roaming", as: &addrSet{ - addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), + ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), curAddr: 2, }, steps: []step{ {b: regPacket, want: "10.0.0.1:123"}, - {updateDst: mustUDPAddr("4.5.6.7:123")}, + {updateDst: mustIPPortPtr("4.5.6.7:123")}, {b: regPacket, want: "4.5.6.7:123"}, - {updateDst: mustUDPAddr("5.6.7.8:123")}, + {updateDst: mustIPPortPtr("5.6.7.8:123")}, {b: regPacket, want: "5.6.7.8:123"}, - {updateDst: mustUDPAddr("123.45.67.89:123")}, // end roaming + {updateDst: mustIPPortPtr("123.45.67.89:123")}, // end roaming {b: regPacket, want: "123.45.67.89:123"}, }, }, { name: "spray_packet", as: &addrSet{ - addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), + ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), curAddr: 2, // should be ignored roamAddr: mustIPPortPtr("5.6.7.8:123"), }, @@ -1275,19 +1256,19 @@ type step struct { {advance: 300 * time.Millisecond, b: regPacket, want: "127.3.3.40:1,123.45.67.89:123,10.0.0.1:123,5.6.7.8:123"}, {advance: 300 * time.Millisecond, b: regPacket, want: "127.3.3.40:1,123.45.67.89:123,10.0.0.1:123,5.6.7.8:123"}, {advance: 3, b: regPacket, want: "5.6.7.8:123"}, - {advance: 2 * time.Millisecond, updateDst: mustUDPAddr("10.0.0.1:123")}, + {advance: 2 * time.Millisecond, updateDst: mustIPPortPtr("10.0.0.1:123")}, {advance: 3, b: regPacket, want: "10.0.0.1:123"}, }, }, { name: "low_pri", as: &addrSet{ - addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), + ipPorts: ipps("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), curAddr: 2, }, steps: []step{ - {updateDst: mustUDPAddr("123.45.67.89:123")}, - {updateDst: mustUDPAddr("123.45.67.89:123")}, + {updateDst: mustIPPortPtr("123.45.67.89:123")}, + {updateDst: mustIPPortPtr("123.45.67.89:123")}, }, logCheck: func(t *testing.T, logged []byte) { if n := bytes.Count(logged, []byte(", keeping current ")); n != 1 { @@ -1306,12 +1287,11 @@ type step struct { t.Logf(format, args...) } tt.as.clock = func() time.Time { return faket } - initAddrSet(tt.as) for i, st := range tt.steps { faket = faket.Add(st.advance) if st.updateDst != nil { - if err := tt.as.updateDst(st.updateDst); err != nil { + if err := tt.as.updateDst(*st.updateDst); err != nil { t.Fatal(err) } continue @@ -1328,23 +1308,6 @@ type step struct { } } -// initAddrSet initializes fields in the provided incomplete addrSet -// to satisfying invariants within magicsock. -func initAddrSet(as *addrSet) { - if as.roamAddr != nil && as.roamAddrStd == nil { - as.roamAddrStd = as.roamAddr.UDPAddr() - } - if len(as.ipPorts) == 0 { - for _, ua := range as.addrs { - ipp, ok := netaddr.FromStdAddr(ua.IP, ua.Port, ua.Zone) - if !ok { - panic(fmt.Sprintf("bogus UDPAddr %+v", ua)) - } - as.ipPorts = append(as.ipPorts, ipp) - } - } -} - func TestDiscoMessage(t *testing.T) { c := newConn() c.logf = t.Logf