diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 81a7f14f4..7d4a8f724 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -315,4 +315,4 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa unicode from bytes+ unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ - unique from net/netip + unique from net/netip+ diff --git a/cmd/derper/mesh.go b/cmd/derper/mesh.go index ee1807f00..4bf1e0fed 100644 --- a/cmd/derper/mesh.go +++ b/cmd/derper/mesh.go @@ -69,8 +69,8 @@ func startMeshWithHost(s *derp.Server, host string) error { return d.DialContext(ctx, network, addr) }) - add := func(m derp.PeerPresentMessage) { s.AddPacketForwarder(m.Key, c) } - remove := func(m derp.PeerGoneMessage) { s.RemovePacketForwarder(m.Peer, c) } + add := func(m derp.PeerPresentMessage) { s.AddPacketForwarder(m.Key.Handle(), c) } + remove := func(m derp.PeerGoneMessage) { s.RemovePacketForwarder(m.Peer.Handle(), c) } go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove) return nil } diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 900d10efe..7f9c36548 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -1010,4 +1010,4 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ unicode from bytes+ unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ - unique from net/netip + unique from net/netip+ diff --git a/cmd/stund/depaware.txt b/cmd/stund/depaware.txt index 7031b18e2..e0a5ece1f 100644 --- a/cmd/stund/depaware.txt +++ b/cmd/stund/depaware.txt @@ -199,4 +199,4 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar unicode from bytes+ unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ - unique from net/netip + unique from net/netip+ diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index d18d88873..6c939cb1e 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -336,4 +336,4 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep unicode from bytes+ unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ - unique from net/netip + unique from net/netip+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 81cd53271..d112a046c 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -586,4 +586,4 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de unicode from bytes+ unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ - unique from net/netip + unique from net/netip+ diff --git a/derp/derp_client.go b/derp/derp_client.go index 7a646fa51..c70cc19d3 100644 --- a/derp/derp_client.go +++ b/derp/derp_client.go @@ -13,6 +13,7 @@ "net/netip" "sync" "time" + "unique" "go4.org/mem" "golang.org/x/time/rate" @@ -236,7 +237,7 @@ func (c *Client) send(dstKey key.NodePublic, pkt []byte) (ret error) { return c.bw.Flush() } -func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err error) { +func (c *Client) ForwardPacket(srcHandle, dstHandle unique.Handle[key.NodePublic], pkt []byte) (err error) { defer func() { if err != nil { err = fmt.Errorf("derp.ForwardPacket: %w", err) @@ -256,10 +257,10 @@ func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err e if err := writeFrameHeader(c.bw, frameForwardPacket, uint32(keyLen*2+len(pkt))); err != nil { return err } - if _, err := c.bw.Write(srcKey.AppendTo(nil)); err != nil { + if _, err := c.bw.Write(srcHandle.Value().AppendTo(nil)); err != nil { return err } - if _, err := c.bw.Write(dstKey.AppendTo(nil)); err != nil { + if _, err := c.bw.Write(dstHandle.Value().AppendTo(nil)); err != nil { return err } if _, err := c.bw.Write(pkt); err != nil { diff --git a/derp/derp_server.go b/derp/derp_server.go index ab0ab0a90..7e1e12ea8 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -34,6 +34,7 @@ "sync" "sync/atomic" "time" + "unique" "go4.org/mem" "golang.org/x/sync/errgroup" @@ -48,15 +49,18 @@ "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/util/ctxkey" - "tailscale.com/util/mak" "tailscale.com/util/set" "tailscale.com/util/slicesx" "tailscale.com/version" ) +// NodeHandle is an interned and cheap to pass around and compare representation +// of a NodePublic key. +type NodeHandle = unique.Handle[key.NodePublic] + // verboseDropKeys is the set of destination public keys that should // verbosely log whenever DERP drops a packet. -var verboseDropKeys = map[key.NodePublic]bool{} +var verboseDropKeys = map[NodeHandle]bool{} // IdealNodeHeader is the HTTP request header sent on DERP HTTP client requests // to indicate that they're connecting to their ideal (Region.Nodes[0]) node. @@ -78,7 +82,7 @@ func init() { if err != nil { log.Printf("ignoring invalid debug key %q: %v", keyStr, err) } else { - verboseDropKeys[k] = true + verboseDropKeys[k.Handle()] = true } } } @@ -173,22 +177,22 @@ type Server struct { mu sync.Mutex closed bool netConns map[Conn]chan struct{} // chan is closed when conn closes - clients map[key.NodePublic]*clientSet + clients map[NodeHandle]*clientSet watchers set.Set[*sclient] // mesh peers // clientsMesh tracks all clients in the cluster, both locally // and to mesh peers. If the value is nil, that means the // peer is only local (and thus in the clients Map, but not // remote). If the value is non-nil, it's remote (+ maybe also // local). - clientsMesh map[key.NodePublic]PacketForwarder + clientsMesh map[NodeHandle]PacketForwarder // peerGoneWatchers is the set of watchers that subscribed to a // peer disconnecting from the region overall. When a peer // is gone from the region, we notify all of these watchers, // calling their funcs in a new goroutine. - peerGoneWatchers map[key.NodePublic]set.HandleSet[func(key.NodePublic)] + peerGoneWatchers map[NodeHandle]map[NodeHandle]func(NodeHandle) // maps from netip.AddrPort to a client's public key - keyOfAddr map[netip.AddrPort]key.NodePublic + keyOfAddr map[netip.AddrPort]NodeHandle clock tstime.Clock } @@ -325,7 +329,7 @@ func (s *dupClientSet) removeClient(c *sclient) bool { // is a multiForwarder, which this package creates as needed if a // public key gets more than one PacketForwarder registered for it. type PacketForwarder interface { - ForwardPacket(src, dst key.NodePublic, payload []byte) error + ForwardPacket(src, dst NodeHandle, payload []byte) error String() string } @@ -355,18 +359,18 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server { packetsRecvByKind: metrics.LabelMap{Label: "kind"}, packetsDroppedReason: metrics.LabelMap{Label: "reason"}, packetsDroppedType: metrics.LabelMap{Label: "type"}, - clients: map[key.NodePublic]*clientSet{}, - clientsMesh: map[key.NodePublic]PacketForwarder{}, + clients: map[NodeHandle]*clientSet{}, + clientsMesh: map[NodeHandle]PacketForwarder{}, netConns: map[Conn]chan struct{}{}, memSys0: ms.Sys, watchers: set.Set[*sclient]{}, - peerGoneWatchers: map[key.NodePublic]set.HandleSet[func(key.NodePublic)]{}, + peerGoneWatchers: map[NodeHandle]map[NodeHandle]func(NodeHandle){}, avgQueueDuration: new(uint64), tcpRtt: metrics.LabelMap{Label: "le"}, meshUpdateBatchSize: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000}), meshUpdateLoopCount: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100}), bufferedWriteFrames: metrics.NewHistogram([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 50, 100}), - keyOfAddr: map[netip.AddrPort]key.NodePublic{}, + keyOfAddr: map[netip.AddrPort]NodeHandle{}, clock: tstime.StdClock{}, } s.initMetacert() @@ -479,7 +483,7 @@ func (s *Server) isClosed() bool { func (s *Server) IsClientConnectedForTest(k key.NodePublic) bool { s.mu.Lock() defer s.mu.Unlock() - x, ok := s.clients[k] + x, ok := s.clients[k.Handle()] if !ok { return false } @@ -573,11 +577,11 @@ func (s *Server) registerClient(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - cs, ok := s.clients[c.key] + cs, ok := s.clients[c.handle] if !ok { c.debugLogf("register single client") cs = &clientSet{} - s.clients[c.key] = cs + s.clients[c.handle] = cs } was := cs.activeClient.Load() if was == nil { @@ -610,15 +614,15 @@ func (s *Server) registerClient(c *sclient) { cs.activeClient.Store(c) - if _, ok := s.clientsMesh[c.key]; !ok { - s.clientsMesh[c.key] = nil // just for varz of total users in cluster + if _, ok := s.clientsMesh[c.handle]; !ok { + s.clientsMesh[c.handle] = nil // just for varz of total users in cluster } - s.keyOfAddr[c.remoteIPPort] = c.key + s.keyOfAddr[c.remoteIPPort] = c.handle s.curClients.Add(1) if c.isNotIdealConn { s.curClientsNotIdeal.Add(1) } - s.broadcastPeerStateChangeLocked(c.key, c.remoteIPPort, c.presentFlags(), true) + s.broadcastPeerStateChangeLocked(c.handle, c.remoteIPPort, c.presentFlags(), true) } // broadcastPeerStateChangeLocked enqueues a message to all watchers @@ -626,10 +630,10 @@ func (s *Server) registerClient(c *sclient) { // presence changed. // // s.mu must be held. -func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, ipPort netip.AddrPort, flags PeerPresentFlags, present bool) { +func (s *Server) broadcastPeerStateChangeLocked(peer NodeHandle, ipPort netip.AddrPort, flags PeerPresentFlags, present bool) { for w := range s.watchers { w.peerStateChange = append(w.peerStateChange, peerConnState{ - peer: peer, + peer: peer.Value(), present: present, ipPort: ipPort, flags: flags, @@ -643,7 +647,7 @@ func (s *Server) unregisterClient(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - set, ok := s.clients[c.key] + set, ok := s.clients[c.handle] if !ok { c.logf("[unexpected]; clients map is empty") return @@ -663,12 +667,12 @@ func (s *Server) unregisterClient(c *sclient) { } c.debugLogf("removed connection") set.activeClient.Store(nil) - delete(s.clients, c.key) - if v, ok := s.clientsMesh[c.key]; ok && v == nil { - delete(s.clientsMesh, c.key) - s.notePeerGoneFromRegionLocked(c.key) + delete(s.clients, c.handle) + if v, ok := s.clientsMesh[c.handle]; ok && v == nil { + delete(s.clientsMesh, c.handle) + s.notePeerGoneFromRegionLocked(c.handle) } - s.broadcastPeerStateChangeLocked(c.key, netip.AddrPort{}, 0, false) + s.broadcastPeerStateChangeLocked(c.handle, netip.AddrPort{}, 0, false) } else { c.debugLogf("removed duplicate client") if dup.removeClient(c) { @@ -720,30 +724,30 @@ func (s *Server) unregisterClient(c *sclient) { // The provided f func is usually [sclient.onPeerGoneFromRegion], added by // [sclient.noteSendFromSrc]; this func doesn't take a whole *sclient to make it // clear what has access to what. -func (s *Server) addPeerGoneFromRegionWatcher(peer key.NodePublic, f func(key.NodePublic)) set.Handle { +func (s *Server) addPeerGoneFromRegionWatcher(peerA, peerB NodeHandle, f func(NodeHandle)) { s.mu.Lock() defer s.mu.Unlock() - hset, ok := s.peerGoneWatchers[peer] + m, ok := s.peerGoneWatchers[peerA] if !ok { - hset = set.HandleSet[func(key.NodePublic)]{} - s.peerGoneWatchers[peer] = hset + m = map[NodeHandle]func(NodeHandle){} + s.peerGoneWatchers[peerA] = m } - return hset.Add(f) + m[peerB] = f } // removePeerGoneFromRegionWatcher removes a peer watcher previously added by // addPeerGoneFromRegionWatcher, using the handle returned by // addPeerGoneFromRegionWatcher. -func (s *Server) removePeerGoneFromRegionWatcher(peer key.NodePublic, h set.Handle) { +func (s *Server) removePeerGoneFromRegionWatcher(peerA, peerB NodeHandle) { s.mu.Lock() defer s.mu.Unlock() - hset, ok := s.peerGoneWatchers[peer] + hset, ok := s.peerGoneWatchers[peerA] if !ok { return } - delete(hset, h) + delete(hset, peerB) if len(hset) == 0 { - delete(s.peerGoneWatchers, peer) + delete(s.peerGoneWatchers, peerA) } } @@ -751,8 +755,8 @@ func (s *Server) removePeerGoneFromRegionWatcher(peer key.NodePublic, h set.Hand // key has sent to previously (whether those sends were from a local // client or forwarded). It must only be called after the key has // been removed from clientsMesh. -func (s *Server) notePeerGoneFromRegionLocked(key key.NodePublic) { - if _, ok := s.clientsMesh[key]; ok { +func (s *Server) notePeerGoneFromRegionLocked(handle NodeHandle) { + if _, ok := s.clientsMesh[handle]; ok { panic("usage") } @@ -760,17 +764,17 @@ func (s *Server) notePeerGoneFromRegionLocked(key key.NodePublic) { // so they can drop their route entries to us (issue 150) // or move them over to the active client (in case a replaced client // connection is being unregistered). - set := s.peerGoneWatchers[key] + set := s.peerGoneWatchers[handle] for _, f := range set { - go f(key) + go f(handle) } - delete(s.peerGoneWatchers, key) + delete(s.peerGoneWatchers, handle) } // requestPeerGoneWriteLimited sends a request to write a "peer gone" // frame, but only in reply to a disco packet, and only if we haven't // sent one recently. -func (c *sclient) requestPeerGoneWriteLimited(peer key.NodePublic, contents []byte, reason PeerGoneReasonType) { +func (c *sclient) requestPeerGoneWriteLimited(peer NodeHandle, contents []byte, reason PeerGoneReasonType) { if disco.LooksLikeDiscoWrapper(contents) != true { return } @@ -800,7 +804,7 @@ func (s *Server) addWatcher(c *sclient) { continue } c.peerStateChange = append(c.peerStateChange, peerConnState{ - peer: peer, + peer: peer.Value(), present: true, ipPort: ac.remoteIPPort, flags: ac.presentFlags(), @@ -826,6 +830,8 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem if err != nil { return fmt.Errorf("receive client key: %v", err) } + handle := clientKey.Handle() + clientKey = handle.Value() // interned remoteIPPort, _ := netip.ParseAddrPort(remoteAddr) if err := s.verifyClient(ctx, clientKey, clientInfo, remoteIPPort.Addr()); err != nil { @@ -842,6 +848,7 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem connNum: connNum, s: s, key: clientKey, + handle: handle, nc: nc, br: br, bw: bw, @@ -1014,22 +1021,23 @@ func (c *sclient) handleFrameClosePeer(ft frameType, fl uint32) error { if err := targetKey.ReadRawWithoutAllocating(c.br); err != nil { return err } + handle := targetKey.Handle() s := c.s s.mu.Lock() defer s.mu.Unlock() - if set, ok := s.clients[targetKey]; ok { + if set, ok := s.clients[handle]; ok { if set.Len() == 1 { - c.logf("frameClosePeer closing peer %x", targetKey) + c.logf("frameClosePeer closing peer %x", handle.Value()) } else { - c.logf("frameClosePeer closing peer %x (%d connections)", targetKey, set.Len()) + c.logf("frameClosePeer closing peer %x (%d connections)", handle.Value(), set.Len()) } set.ForeachClient(func(target *sclient) { go target.nc.Close() }) } else { - c.logf("frameClosePeer failed to find peer %x", targetKey) + c.logf("frameClosePeer failed to find peer %x", handle.Value()) } return nil @@ -1043,7 +1051,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { } s := c.s - srcKey, dstKey, contents, err := s.recvForwardPacket(c.br, fl) + srcHandle, dstHandle, contents, err := s.recvForwardPacket(c.br, fl) if err != nil { return fmt.Errorf("client %v: recvForwardPacket: %v", c.key, err) } @@ -1053,7 +1061,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { var dst *sclient s.mu.Lock() - if set, ok := s.clients[dstKey]; ok { + if set, ok := s.clients[dstHandle]; ok { dstLen = set.Len() dst = set.activeClient.Load() } @@ -1064,18 +1072,18 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { if dstLen > 1 { reason = dropReasonDupClient } else { - c.requestPeerGoneWriteLimited(dstKey, contents, PeerGoneReasonNotHere) + c.requestPeerGoneWriteLimited(dstHandle, contents, PeerGoneReasonNotHere) } - s.recordDrop(contents, srcKey, dstKey, reason) + s.recordDrop(contents, srcHandle, dstHandle, reason) return nil } - dst.debugLogf("received forwarded packet from %s via %s", srcKey.ShortString(), c.key.ShortString()) + dst.debugLogf("received forwarded packet from %s via %s", srcHandle.Value().ShortString(), c.key.ShortString()) return c.sendPkt(dst, pkt{ bs: contents, enqueuedAt: c.s.clock.Now(), - src: srcKey, + src: srcHandle, }) } @@ -1083,7 +1091,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { s := c.s - dstKey, contents, err := s.recvPacket(c.br, fl) + dstHandle, contents, err := s.recvPacket(c.br, fl) if err != nil { return fmt.Errorf("client %v: recvPacket: %v", c.key, err) } @@ -1093,20 +1101,20 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { var dst *sclient s.mu.Lock() - if set, ok := s.clients[dstKey]; ok { + if set, ok := s.clients[dstHandle]; ok { dstLen = set.Len() dst = set.activeClient.Load() } if dst == nil && dstLen < 1 { - fwd = s.clientsMesh[dstKey] + fwd = s.clientsMesh[dstHandle] } s.mu.Unlock() if dst == nil { if fwd != nil { s.packetsForwardedOut.Add(1) - err := fwd.ForwardPacket(c.key, dstKey, contents) - c.debugLogf("SendPacket for %s, forwarding via %s: %v", dstKey.ShortString(), fwd, err) + err := fwd.ForwardPacket(c.handle, dstHandle, contents) + c.debugLogf("SendPacket for %s, forwarding via %s: %v", dstHandle.Value().ShortString(), fwd, err) if err != nil { // TODO: return nil @@ -1117,18 +1125,18 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { if dstLen > 1 { reason = dropReasonDupClient } else { - c.requestPeerGoneWriteLimited(dstKey, contents, PeerGoneReasonNotHere) + c.requestPeerGoneWriteLimited(dstHandle, contents, PeerGoneReasonNotHere) } - s.recordDrop(contents, c.key, dstKey, reason) - c.debugLogf("SendPacket for %s, dropping with reason=%s", dstKey.ShortString(), reason) + s.recordDrop(contents, c.handle, dstHandle, reason) + c.debugLogf("SendPacket for %s, dropping with reason=%s", dstHandle.Value().ShortString(), reason) return nil } - c.debugLogf("SendPacket for %s, sending directly", dstKey.ShortString()) + c.debugLogf("SendPacket for %s, sending directly", dstHandle.Value().ShortString()) p := pkt{ bs: contents, enqueuedAt: c.s.clock.Now(), - src: c.key, + src: c.handle, } return c.sendPkt(dst, p) } @@ -1155,7 +1163,7 @@ func (c *sclient) debugLogf(format string, v ...any) { numDropReasons // unused; keep last ) -func (s *Server) recordDrop(packetBytes []byte, srcKey, dstKey key.NodePublic, reason dropReason) { +func (s *Server) recordDrop(packetBytes []byte, srcHandle, dstHandle NodeHandle, reason dropReason) { s.packetsDropped.Add(1) s.packetsDroppedReasonCounters[reason].Add(1) looksDisco := disco.LooksLikeDiscoWrapper(packetBytes) @@ -1164,20 +1172,19 @@ func (s *Server) recordDrop(packetBytes []byte, srcKey, dstKey key.NodePublic, r } else { s.packetsDroppedTypeOther.Add(1) } - if verboseDropKeys[dstKey] { + if verboseDropKeys[dstHandle] { // Preformat the log string prior to calling limitedLogf. The // limiter acts based on the format string, and we want to // rate-limit per src/dst keys, not on the generic "dropped // stuff" message. - msg := fmt.Sprintf("drop (%s) %s -> %s", srcKey.ShortString(), reason, dstKey.ShortString()) + msg := fmt.Sprintf("drop (%s) %s -> %s", srcHandle.Value().ShortString(), reason, dstHandle.Value().ShortString()) s.limitedLogf(msg) } - s.debugLogf("dropping packet reason=%s dst=%s disco=%v", reason, dstKey, looksDisco) + s.debugLogf("dropping packet reason=%s dst=%s disco=%v", reason, dstHandle.Value(), looksDisco) } func (c *sclient) sendPkt(dst *sclient, p pkt) error { s := c.s - dstKey := dst.key // Attempt to queue for sending up to 3 times. On each attempt, if // the queue is full, try to drop from queue head to prioritize @@ -1189,7 +1196,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error { for attempt := 0; attempt < 3; attempt++ { select { case <-dst.done: - s.recordDrop(p.bs, c.key, dstKey, dropReasonGoneDisconnected) + s.recordDrop(p.bs, c.handle, dst.handle, dropReasonGoneDisconnected) dst.debugLogf("sendPkt attempt %d dropped, dst gone", attempt) return nil default: @@ -1203,7 +1210,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error { select { case pkt := <-sendQueue: - s.recordDrop(pkt.bs, c.key, dstKey, dropReasonQueueHead) + s.recordDrop(pkt.bs, c.handle, dst.handle, dropReasonQueueHead) c.recordQueueTime(pkt.enqueuedAt) default: } @@ -1211,7 +1218,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error { // Failed to make room for packet. This can happen in a heavily // contended queue with racing writers. Give up and tail-drop in // this case to keep reader unblocked. - s.recordDrop(p.bs, c.key, dstKey, dropReasonQueueTail) + s.recordDrop(p.bs, c.handle, dst.handle, dropReasonQueueTail) dst.debugLogf("sendPkt attempt %d dropped, queue full") return nil @@ -1220,17 +1227,17 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error { // onPeerGoneFromRegion is the callback registered with the Server to be // notified (in a new goroutine) whenever a peer has disconnected from all DERP // nodes in the current region. -func (c *sclient) onPeerGoneFromRegion(peer key.NodePublic) { +func (c *sclient) onPeerGoneFromRegion(peer NodeHandle) { c.requestPeerGoneWrite(peer, PeerGoneReasonDisconnected) } // requestPeerGoneWrite sends a request to write a "peer gone" frame // with an explanation of why it is gone. It blocks until either the // write request is scheduled, or the client has closed. -func (c *sclient) requestPeerGoneWrite(peer key.NodePublic, reason PeerGoneReasonType) { +func (c *sclient) requestPeerGoneWrite(peer NodeHandle, reason PeerGoneReasonType) { select { case c.peerGone <- peerGoneMsg{ - peer: peer, + peer: peer.Value(), reason: reason, }: case <-c.done: @@ -1346,7 +1353,7 @@ func (s *Server) noteClientActivity(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - cs, ok := s.clients[c.key] + cs, ok := s.clients[c.handle] if !ok { return } @@ -1453,20 +1460,22 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info return clientKey, info, nil } -func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodePublic, contents []byte, err error) { +func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstHandle NodeHandle, contents []byte, err error) { if frameLen < keyLen { - return zpub, nil, errors.New("short send packet frame") + return zpubHandle, nil, errors.New("short send packet frame") } + var dstKey key.NodePublic if err := dstKey.ReadRawWithoutAllocating(br); err != nil { - return zpub, nil, err + return zpubHandle, nil, err } + dstHandle = dstKey.Handle() packetLen := frameLen - keyLen if packetLen > MaxPacketSize { - return zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize) + return zpubHandle, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize) } contents = make([]byte, packetLen) if _, err := io.ReadFull(br, contents); err != nil { - return zpub, nil, err + return zpubHandle, nil, err } s.packetsRecv.Add(1) s.bytesRecv.Add(int64(len(contents))) @@ -1475,33 +1484,35 @@ func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodeP } else { s.packetsRecvOther.Add(1) } - return dstKey, contents, nil + return dstHandle, contents, nil } // zpub is the key.NodePublic zero value. var zpub key.NodePublic +var zpubHandle NodeHandle = zpub.Handle() -func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, dstKey key.NodePublic, contents []byte, err error) { +func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcHandle, dstHandle NodeHandle, contents []byte, err error) { if frameLen < keyLen*2 { - return zpub, zpub, nil, errors.New("short send packet frame") + return zpubHandle, zpubHandle, nil, errors.New("short send packet frame") } + var srcKey, dstKey key.NodePublic if err := srcKey.ReadRawWithoutAllocating(br); err != nil { - return zpub, zpub, nil, err + return zpubHandle, zpubHandle, nil, err } if err := dstKey.ReadRawWithoutAllocating(br); err != nil { - return zpub, zpub, nil, err + return zpubHandle, zpubHandle, nil, err } packetLen := frameLen - keyLen*2 if packetLen > MaxPacketSize { - return zpub, zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize) + return zpubHandle, zpubHandle, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize) } contents = make([]byte, packetLen) if _, err := io.ReadFull(br, contents); err != nil { - return zpub, zpub, nil, err + return zpubHandle, zpubHandle, nil, err } // TODO: was s.packetsRecv.Add(1) // TODO: was s.bytesRecv.Add(int64(len(contents))) - return srcKey, dstKey, contents, nil + return srcKey.Handle(), dstKey.Handle(), contents, nil } // sclient is a client connection to the server. @@ -1518,6 +1529,7 @@ type sclient struct { s *Server nc Conn key key.NodePublic + handle NodeHandle // handle is a cached handle for key info clientInfo logf logger.Logf done <-chan struct{} // closed when connection closes @@ -1539,7 +1551,7 @@ type sclient struct { preferred bool // Owned by sendLoop, not thread-safe. - sawSrc map[key.NodePublic]set.Handle + sawSrc set.Set[NodeHandle] bw *lazyBufioWriter // Guarded by s.mu @@ -1593,7 +1605,7 @@ type pkt struct { bs []byte // src is the who's the sender of the packet. - src key.NodePublic + src NodeHandle } // peerGoneMsg is a request to write a peerGone frame to an sclient @@ -1656,17 +1668,17 @@ func (c *sclient) onSendLoopDone() { c.nc.Close() // Clean up watches. - for peer, h := range c.sawSrc { - c.s.removePeerGoneFromRegionWatcher(peer, h) + for peer := range c.sawSrc { + c.s.removePeerGoneFromRegionWatcher(peer, c.handle) } // Drain the send queue to count dropped packets for { select { case pkt := <-c.sendQueue: - c.s.recordDrop(pkt.bs, pkt.src, c.key, dropReasonGoneDisconnected) + c.s.recordDrop(pkt.bs, pkt.src, c.handle, dropReasonGoneDisconnected) case pkt := <-c.discoSendQueue: - c.s.recordDrop(pkt.bs, pkt.src, c.key, dropReasonGoneDisconnected) + c.s.recordDrop(pkt.bs, pkt.src, c.handle, dropReasonGoneDisconnected) default: return } @@ -1869,31 +1881,31 @@ func (c *sclient) sendMeshUpdates() error { // DERPv2. The bytes of contents are only valid until this function // returns, do not retain slices. // It does not flush its bufio.Writer. -func (c *sclient) sendPacket(srcKey key.NodePublic, contents []byte) (err error) { +func (c *sclient) sendPacket(srcHandle NodeHandle, contents []byte) (err error) { defer func() { // Stats update. if err != nil { - c.s.recordDrop(contents, srcKey, c.key, dropReasonWriteError) + c.s.recordDrop(contents, srcHandle, c.handle, dropReasonWriteError) } else { c.s.packetsSent.Add(1) c.s.bytesSent.Add(int64(len(contents))) } - c.debugLogf("sendPacket from %s: %v", srcKey.ShortString(), err) + c.debugLogf("sendPacket from %s: %v", srcHandle.Value().ShortString(), err) }() c.setWriteDeadline() - withKey := !srcKey.IsZero() + withKey := !srcHandle.Value().IsZero() pktLen := len(contents) if withKey { pktLen += key.NodePublicRawLen - c.noteSendFromSrc(srcKey) + c.noteSendFromSrc(srcHandle) } if err = writeFrameHeader(c.bw.bw(), frameRecvPacket, uint32(pktLen)); err != nil { return err } if withKey { - if err := srcKey.WriteRawWithoutAllocating(c.bw.bw()); err != nil { + if err := srcHandle.Value().WriteRawWithoutAllocating(c.bw.bw()); err != nil { return err } } @@ -1905,17 +1917,18 @@ func (c *sclient) sendPacket(srcKey key.NodePublic, contents []byte) (err error) // from src to sclient. // // It must only be called from the sendLoop goroutine. -func (c *sclient) noteSendFromSrc(src key.NodePublic) { +func (c *sclient) noteSendFromSrc(src NodeHandle) { if _, ok := c.sawSrc[src]; ok { return } - h := c.s.addPeerGoneFromRegionWatcher(src, c.onPeerGoneFromRegion) - mak.Set(&c.sawSrc, src, h) + c.s.addPeerGoneFromRegionWatcher(src, c.handle, c.onPeerGoneFromRegion) + c.sawSrc.Make() // ensure sawSrc is non-nil + c.sawSrc.Add(src) } // AddPacketForwarder registers fwd as a packet forwarder for dst. // fwd must be comparable. -func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) { +func (s *Server) AddPacketForwarder(dst NodeHandle, fwd PacketForwarder) { s.mu.Lock() defer s.mu.Unlock() if prev, ok := s.clientsMesh[dst]; ok { @@ -1945,7 +1958,7 @@ func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) { // RemovePacketForwarder removes fwd as a packet forwarder for dst. // fwd must be comparable. -func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder) { +func (s *Server) RemovePacketForwarder(dst NodeHandle, fwd PacketForwarder) { s.mu.Lock() defer s.mu.Unlock() v, ok := s.clientsMesh[dst] @@ -2048,7 +2061,7 @@ func (f *multiForwarder) deleteLocked(fwd PacketForwarder) (_ PacketForwarder, i return nil, false } -func (f *multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error { +func (f *multiForwarder) ForwardPacket(src, dst NodeHandle, payload []byte) error { return f.fwd.Load().ForwardPacket(src, dst, payload) } @@ -2238,8 +2251,8 @@ func (s *Server) ServeDebugTraffic(w http.ResponseWriter, r *http.Request) { for k, next := range newState { prev := prevState[k] if prev.Sent < next.Sent || prev.Recv < next.Recv { - if pkey, ok := s.keyOfAddr[k]; ok { - next.Key = pkey + if pHandle, ok := s.keyOfAddr[k]; ok { + next.Key = pHandle.Value() if err := enc.Encode(next); err != nil { s.mu.Unlock() return diff --git a/derp/derp_test.go b/derp/derp_test.go index 9185194dd..3af53a64b 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -714,27 +714,27 @@ func TestWatch(t *testing.T) { type testFwd int -func (testFwd) ForwardPacket(key.NodePublic, key.NodePublic, []byte) error { +func (testFwd) ForwardPacket(NodeHandle, NodeHandle, []byte) error { panic("not called in tests") } func (testFwd) String() string { panic("not called in tests") } -func pubAll(b byte) (ret key.NodePublic) { +func pubAll(b byte) (ret NodeHandle) { var bs [32]byte for i := range bs { bs[i] = b } - return key.NodePublicFromRaw32(mem.B(bs[:])) + return key.NodePublicFromRaw32(mem.B(bs[:])).Handle() } func TestForwarderRegistration(t *testing.T) { s := &Server{ - clients: make(map[key.NodePublic]*clientSet), - clientsMesh: map[key.NodePublic]PacketForwarder{}, + clients: make(map[NodeHandle]*clientSet), + clientsMesh: map[NodeHandle]PacketForwarder{}, } - want := func(want map[key.NodePublic]PacketForwarder) { + want := func(want map[NodeHandle]PacketForwarder) { t.Helper() if got := s.clientsMesh; !reflect.DeepEqual(got, want) { t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want) @@ -758,28 +758,28 @@ func TestForwarderRegistration(t *testing.T) { s.AddPacketForwarder(u1, testFwd(1)) s.AddPacketForwarder(u2, testFwd(2)) - want(map[key.NodePublic]PacketForwarder{ + want(map[NodeHandle]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Verify a remove of non-registered forwarder is no-op. s.RemovePacketForwarder(u2, testFwd(999)) - want(map[key.NodePublic]PacketForwarder{ + want(map[NodeHandle]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Verify a remove of non-registered user is no-op. s.RemovePacketForwarder(u3, testFwd(1)) - want(map[key.NodePublic]PacketForwarder{ + want(map[NodeHandle]PacketForwarder{ u1: testFwd(1), u2: testFwd(2), }) // Actual removal. s.RemovePacketForwarder(u2, testFwd(2)) - want(map[key.NodePublic]PacketForwarder{ + want(map[NodeHandle]PacketForwarder{ u1: testFwd(1), }) @@ -787,14 +787,14 @@ func TestForwarderRegistration(t *testing.T) { wantCounter(&s.multiForwarderCreated, 0) s.AddPacketForwarder(u1, testFwd(100)) s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path - want(map[key.NodePublic]PacketForwarder{ + want(map[NodeHandle]PacketForwarder{ u1: newMultiForwarder(testFwd(1), testFwd(100)), }) wantCounter(&s.multiForwarderCreated, 1) // Removing a forwarder in a multi set that doesn't exist; does nothing. s.RemovePacketForwarder(u1, testFwd(55)) - want(map[key.NodePublic]PacketForwarder{ + want(map[NodeHandle]PacketForwarder{ u1: newMultiForwarder(testFwd(1), testFwd(100)), }) @@ -802,7 +802,7 @@ func TestForwarderRegistration(t *testing.T) { // from being a multiForwarder. wantCounter(&s.multiForwarderDeleted, 0) s.RemovePacketForwarder(u1, testFwd(1)) - want(map[key.NodePublic]PacketForwarder{ + want(map[NodeHandle]PacketForwarder{ u1: testFwd(100), }) wantCounter(&s.multiForwarderDeleted, 1) @@ -810,23 +810,24 @@ func TestForwarderRegistration(t *testing.T) { // Removing an entry for a client that's still connected locally should result // in a nil forwarder. u1c := &sclient{ - key: u1, - logf: logger.Discard, + key: u1.Value(), + handle: u1, + logf: logger.Discard, } s.clients[u1] = singleClient(u1c) s.RemovePacketForwarder(u1, testFwd(100)) - want(map[key.NodePublic]PacketForwarder{ + want(map[NodeHandle]PacketForwarder{ u1: nil, }) // But once that client disconnects, it should go away. s.unregisterClient(u1c) - want(map[key.NodePublic]PacketForwarder{}) + want(map[NodeHandle]PacketForwarder{}) // But if it already has a forwarder, it's not removed. s.AddPacketForwarder(u1, testFwd(2)) s.unregisterClient(u1c) - want(map[key.NodePublic]PacketForwarder{ + want(map[NodeHandle]PacketForwarder{ u1: testFwd(2), }) @@ -835,11 +836,11 @@ func TestForwarderRegistration(t *testing.T) { // from nil to the new one, not a multiForwarder. s.clients[u1] = singleClient(u1c) s.clientsMesh[u1] = nil - want(map[key.NodePublic]PacketForwarder{ + want(map[NodeHandle]PacketForwarder{ u1: nil, }) s.AddPacketForwarder(u1, testFwd(3)) - want(map[key.NodePublic]PacketForwarder{ + want(map[NodeHandle]PacketForwarder{ u1: testFwd(3), }) } @@ -853,7 +854,7 @@ type channelFwd struct { } func (f channelFwd) String() string { return "" } -func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error { +func (f channelFwd) ForwardPacket(_ NodeHandle, _ NodeHandle, packet []byte) error { f.c <- packet return nil } @@ -865,8 +866,8 @@ func TestMultiForwarder(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) s := &Server{ - clients: make(map[key.NodePublic]*clientSet), - clientsMesh: map[key.NodePublic]PacketForwarder{}, + clients: make(map[NodeHandle]*clientSet), + clientsMesh: map[NodeHandle]PacketForwarder{}, } u := pubAll(1) s.AddPacketForwarder(u, channelFwd{1, ch}) @@ -1067,9 +1068,9 @@ func TestServerDupClients(t *testing.T) { run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) { s = NewServer(serverPriv, t.Logf) s.dupPolicy = dupPolicy - c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")} - c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")} - c3 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c3: ")} + c1 = &sclient{key: clientPub, handle: clientPub.Handle(), logf: logger.WithPrefix(t.Logf, "c1: ")} + c2 = &sclient{key: clientPub, handle: clientPub.Handle(), logf: logger.WithPrefix(t.Logf, "c2: ")} + c3 = &sclient{key: clientPub, handle: clientPub.Handle(), logf: logger.WithPrefix(t.Logf, "c3: ")} clientName = map[*sclient]string{ c1: "c1", c2: "c2", @@ -1083,7 +1084,7 @@ func TestServerDupClients(t *testing.T) { } wantSingleClient := func(t *testing.T, want *sclient) { t.Helper() - got, ok := s.clients[want.key] + got, ok := s.clients[want.key.Handle()] if !ok { t.Error("no clients for key") return @@ -1106,7 +1107,7 @@ func TestServerDupClients(t *testing.T) { } wantNoClient := func(t *testing.T) { t.Helper() - _, ok := s.clients[clientPub] + _, ok := s.clients[clientPub.Handle()] if !ok { // Good return @@ -1115,7 +1116,7 @@ func TestServerDupClients(t *testing.T) { } wantDupSet := func(t *testing.T) *dupClientSet { t.Helper() - cs, ok := s.clients[clientPub] + cs, ok := s.clients[clientPub.Handle()] if !ok { t.Fatal("no set for key; want dup set") return nil @@ -1128,7 +1129,7 @@ func TestServerDupClients(t *testing.T) { } wantActive := func(t *testing.T, want *sclient) { t.Helper() - set, ok := s.clients[clientPub] + set, ok := s.clients[clientPub.Handle()] if !ok { t.Error("no set for key") return diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index c95d072b1..2a8fc5a8c 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -27,6 +27,7 @@ "strings" "sync" "time" + "unique" "go4.org/mem" "tailscale.com/derp" @@ -971,7 +972,7 @@ func (c *Client) LocalAddr() (netip.AddrPort, error) { return la, nil } -func (c *Client) ForwardPacket(from, to key.NodePublic, b []byte) error { +func (c *Client) ForwardPacket(from, to unique.Handle[key.NodePublic], b []byte) error { client, _, err := c.connect(c.newContext(), "derphttp.Client.ForwardPacket") if err != nil { return err diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index cf6032a5e..97294be2a 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -327,7 +327,7 @@ func TestBreakWatcherConnRecv(t *testing.T) { } watcher1.breakConnection(watcher1.client) // re-establish connection by sending a packet - watcher1.ForwardPacket(key.NodePublic{}, key.NodePublic{}, []byte("bogus")) + watcher1.ForwardPacket(key.NodePublic{}.Handle(), key.NodePublic{}.Handle(), []byte("bogus")) timer.Reset(5 * time.Second) } @@ -400,7 +400,7 @@ func TestBreakWatcherConn(t *testing.T) { } watcher1.breakConnection(watcher1.client) // re-establish connection by sending a packet - watcher1.ForwardPacket(key.NodePublic{}, key.NodePublic{}, []byte("bogus")) + watcher1.ForwardPacket(key.NodePublic{}.Handle(), key.NodePublic{}.Handle(), []byte("bogus")) // signal that the breaker is done breakerChan <- true diff --git a/types/key/node.go b/types/key/node.go index 11ee1fa3c..ad7e204c3 100644 --- a/types/key/node.go +++ b/types/key/node.go @@ -10,6 +10,7 @@ "encoding/hex" "errors" "fmt" + "unique" "go4.org/mem" "golang.org/x/crypto/curve25519" @@ -174,6 +175,16 @@ func (p NodePublic) Compare(p2 NodePublic) int { return bytes.Compare(p.k[:], p2.k[:]) } +// Handle returns a unique.Handle for this NodePublic. The Handle is more +// efficient for storage and comparison than the NodePublic itself, but is also +// more expensive to create. It is best to keep a copy of the Handle on a longer +// term object representing a NodePublic, rather than creating it on the fly, +// but in doing so if the Handle is used in multiple other data structures the +// cost of Handle storage and comparisons on lookups will quickly amortize. +func (p NodePublic) Handle() unique.Handle[NodePublic] { + return unique.Make(p) +} + // ParseNodePublicUntyped parses an untyped 64-character hex value // as a NodePublic. //