wgengine/magicsock: make peerMap also keyed by NodeID

In prep for incremental netmap update plumbing (#1909), make peerMap
also keyed by NodeID, as all the netmap node mutations passed around
later will be keyed by NodeID.

In the process, also:

* add envknob.InDevMode, as a signal that we can panic more aggressively
  in unexpected cases.
* pull two moderately large blocks of code in Conn.SetNetworkMap out
  into their own methods
* convert a few more sets from maps to set.Set

Updates #1909

Change-Id: I7acdd64452ba58e9d554140ee7a8760f9043f961
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2023-09-11 10:13:00 -07:00
committed by Brad Fitzpatrick
parent 683ba62f3e
commit d050700a3b
8 changed files with 188 additions and 84 deletions

View File

@@ -55,6 +55,7 @@ import (
"tailscale.com/util/mak"
"tailscale.com/util/ringbuffer"
"tailscale.com/util/set"
"tailscale.com/util/testenv"
"tailscale.com/util/uniq"
"tailscale.com/wgengine/capture"
)
@@ -232,8 +233,8 @@ type Conn struct {
// in other maps below that are keyed by peer public key.
peerSet set.Set[key.NodePublic]
// nodeOfDisco tracks the networkmap Node entity for each peer
// discovery key.
// peerMap tracks the networkmap Node entity for each peer
// by node key, node ID, and discovery key.
peerMap peerMap
// discoInfo is the state for an active DiscoKey.
@@ -1742,6 +1743,30 @@ func nodesEqual(x, y []tailcfg.NodeView) bool {
return true
}
// debugRingBufferSize returns a maximum size for our set of endpoint ring
// buffers by assuming that a single large update is ~500 bytes, and that we
// want to not use more than 1MiB of memory on phones / 4MiB on other devices.
// Calculate the per-endpoint ring buffer size by dividing that out, but always
// storing at least two entries.
func debugRingBufferSize(numPeers int) int {
const defaultVal = 2
if numPeers == 0 {
return defaultVal
}
var maxRingBufferSize int
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
maxRingBufferSize = 1 * 1024 * 1024
} else {
maxRingBufferSize = 4 * 1024 * 1024
}
if v := debugRingBufferMaxSizeBytes(); v > 0 {
maxRingBufferSize = v
}
const averageRingBufferElemSize = 512
return max(defaultVal, maxRingBufferSize/(averageRingBufferElemSize*numPeers))
}
// SetNetworkMap is called when the control client gets a new network
// map from the control server. It must always be non-nil.
//
@@ -1771,29 +1796,7 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
c.logf("[v1] magicsock: got updated network map; %d peers", len(nm.Peers))
heartbeatDisabled := debugEnableSilentDisco()
// Set a maximum size for our set of endpoint ring buffers by assuming
// that a single large update is ~500 bytes, and that we want to not
// use more than 1MiB of memory on phones / 4MiB on other devices.
// Calculate the per-endpoint ring buffer size by dividing that out,
// but always storing at least two entries.
var entriesPerBuffer int = 2
if len(nm.Peers) > 0 {
var maxRingBufferSize int
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
maxRingBufferSize = 1 * 1024 * 1024
} else {
maxRingBufferSize = 4 * 1024 * 1024
}
if v := debugRingBufferMaxSizeBytes(); v > 0 {
maxRingBufferSize = v
}
const averageRingBufferElemSize = 512
entriesPerBuffer = maxRingBufferSize / (averageRingBufferElemSize * len(nm.Peers))
if entriesPerBuffer < 2 {
entriesPerBuffer = 2
}
}
entriesPerBuffer := debugRingBufferSize(len(nm.Peers))
// Try a pass of just upserting nodes and creating missing
// endpoints. If the set of nodes is the same, this is an
@@ -1801,7 +1804,26 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
// we'll fall through to the next pass, which allocates but can
// handle full set updates.
for _, n := range nm.Peers {
if ep, ok := c.peerMap.endpointForNodeKey(n.Key()); ok {
if n.ID() == 0 {
devPanicf("node with zero ID")
continue
}
if n.Key().IsZero() {
devPanicf("node with zero key")
continue
}
ep, ok := c.peerMap.endpointForNodeID(n.ID())
if ok && ep.publicKey != n.Key() {
// The node rotated public keys. Delete the old endpoint and create
// it anew.
c.peerMap.deleteEndpoint(ep)
ok = false
}
if ok {
// At this point we're modifying an existing endpoint (ep) whose
// public key and nodeID match n. Its other fields (such as disco
// key or endpoints) might've changed.
if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() {
// Discokey transitioned from non-zero to zero? This should not
// happen in the wild, however it could mean:
@@ -1821,14 +1843,31 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
c.peerMap.upsertEndpoint(ep, oldDiscoKey) // maybe update discokey mappings in peerMap
continue
}
if ep, ok := c.peerMap.endpointForNodeKey(n.Key()); ok {
// At this point n.Key() should be for a key we've never seen before. If
// ok was true above, it was an update to an existing matching key and
// we don't get this far. If ok was false above, that means it's a key
// that differs from the one the NodeID had. But double check.
if ep.nodeID != n.ID() {
// Server error.
devPanicf("public key moved between nodeIDs")
} else {
// Internal data structures out of sync.
devPanicf("public key found in peerMap but not by nodeID")
}
continue
}
if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() {
// Ancient pre-0.100 node, which does not have a disco key, and will only be reachable via DERP.
// Ancient pre-0.100 node, which does not have a disco key.
// No longer supported.
continue
}
ep := &endpoint{
ep = &endpoint{
c: c,
debugUpdates: ringbuffer.New[EndpointChange](entriesPerBuffer),
nodeID: n.ID(),
publicKey: n.Key(),
publicKeyHex: n.Key().UntypedHexString(),
sentPing: map[stun.TxID]sentPing{},
@@ -1847,35 +1886,12 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
key: n.DiscoKey(),
short: n.DiscoKey().ShortString(),
})
if debugDisco() { // rather than making a new knob
c.logf("magicsock: created endpoint key=%s: disco=%s; %v", n.Key().ShortString(), n.DiscoKey().ShortString(), logger.ArgWriter(func(w *bufio.Writer) {
const derpPrefix = "127.3.3.40:"
if strings.HasPrefix(n.DERP(), derpPrefix) {
ipp, _ := netip.ParseAddrPort(n.DERP())
regionID := int(ipp.Port())
code := c.derpRegionCodeLocked(regionID)
if code != "" {
code = "(" + code + ")"
}
fmt.Fprintf(w, "derp=%v%s ", regionID, code)
}
for i := range n.AllowedIPs().LenIter() {
a := n.AllowedIPs().At(i)
if a.IsSingleIP() {
fmt.Fprintf(w, "aip=%v ", a.Addr())
} else {
fmt.Fprintf(w, "aip=%v ", a)
}
}
for i := range n.Endpoints().LenIter() {
ep := n.Endpoints().At(i)
fmt.Fprintf(w, "ep=%v ", ep)
}
}))
}
}
if debugPeerMap() {
c.logEndpointCreated(n)
}
ep.updateFromNode(n, heartbeatDisabled)
c.peerMap.upsertEndpoint(ep, key.DiscoPublic{})
}
@@ -1886,12 +1902,12 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
// current netmap. If that happens, go through the allocful
// deletion path to clean up moribund nodes.
if c.peerMap.nodeCount() != len(nm.Peers) {
keep := make(map[key.NodePublic]bool, len(nm.Peers))
keep := set.Set[key.NodePublic]{}
for _, n := range nm.Peers {
keep[n.Key()] = true
keep.Add(n.Key())
}
c.peerMap.forEachEndpoint(func(ep *endpoint) {
if !keep[ep.publicKey] {
if !keep.Contains(ep.publicKey) {
c.peerMap.deleteEndpoint(ep)
}
})
@@ -1905,6 +1921,40 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
}
}
func devPanicf(format string, a ...any) {
if testenv.InTest() || envknob.CrashOnUnexpected() {
panic(fmt.Sprintf(format, a...))
}
}
func (c *Conn) logEndpointCreated(n tailcfg.NodeView) {
c.logf("magicsock: created endpoint key=%s: disco=%s; %v", n.Key().ShortString(), n.DiscoKey().ShortString(), logger.ArgWriter(func(w *bufio.Writer) {
const derpPrefix = "127.3.3.40:"
if strings.HasPrefix(n.DERP(), derpPrefix) {
ipp, _ := netip.ParseAddrPort(n.DERP())
regionID := int(ipp.Port())
code := c.derpRegionCodeLocked(regionID)
if code != "" {
code = "(" + code + ")"
}
fmt.Fprintf(w, "derp=%v%s ", regionID, code)
}
for i := range n.AllowedIPs().LenIter() {
a := n.AllowedIPs().At(i)
if a.IsSingleIP() {
fmt.Fprintf(w, "aip=%v ", a.Addr())
} else {
fmt.Fprintf(w, "aip=%v ", a)
}
}
for i := range n.Endpoints().LenIter() {
ep := n.Endpoints().At(i)
fmt.Fprintf(w, "ep=%v ", ep)
}
}))
}
func (c *Conn) logEndpointChange(endpoints []tailcfg.Endpoint) {
c.logf("magicsock: endpoints changed: %s", logger.ArgWriter(func(buf *bufio.Writer) {
for i, ep := range endpoints {