derp: intern key.NodePublic across the server

Consistently interning the NodePublic's throughout DERP, particularly
inside the maps reduces memory usage and reduces lookup costs in the
associated data structures.

It is not clear exactly how efficient the weak pointers will be in
practice, but estimating this using derpstress with 10k conns pushing
40kpps in each direction, this is patch grows heap at approximately half
the rate vs.  the old code and has fewer instances of long stalls that
trigger i/o timeouts for the clients.

Updates tailscale/corp#24485

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker 2024-11-07 12:26:43 -08:00
parent 3b93fd9c44
commit f7ad04bea4
No known key found for this signature in database
12 changed files with 178 additions and 151 deletions

View File

@ -315,4 +315,4 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
unicode from bytes+ unicode from bytes+
unicode/utf16 from crypto/x509+ unicode/utf16 from crypto/x509+
unicode/utf8 from bufio+ unicode/utf8 from bufio+
unique from net/netip unique from net/netip+

View File

@ -69,8 +69,8 @@ func startMeshWithHost(s *derp.Server, host string) error {
return d.DialContext(ctx, network, addr) return d.DialContext(ctx, network, addr)
}) })
add := func(m derp.PeerPresentMessage) { s.AddPacketForwarder(m.Key, c) } add := func(m derp.PeerPresentMessage) { s.AddPacketForwarder(m.Key.Handle(), c) }
remove := func(m derp.PeerGoneMessage) { s.RemovePacketForwarder(m.Peer, c) } remove := func(m derp.PeerGoneMessage) { s.RemovePacketForwarder(m.Peer.Handle(), c) }
go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove) go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove)
return nil return nil
} }

View File

@ -1010,4 +1010,4 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
unicode from bytes+ unicode from bytes+
unicode/utf16 from crypto/x509+ unicode/utf16 from crypto/x509+
unicode/utf8 from bufio+ unicode/utf8 from bufio+
unique from net/netip unique from net/netip+

View File

@ -199,4 +199,4 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar
unicode from bytes+ unicode from bytes+
unicode/utf16 from crypto/x509+ unicode/utf16 from crypto/x509+
unicode/utf8 from bufio+ unicode/utf8 from bufio+
unique from net/netip unique from net/netip+

View File

@ -336,4 +336,4 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
unicode from bytes+ unicode from bytes+
unicode/utf16 from crypto/x509+ unicode/utf16 from crypto/x509+
unicode/utf8 from bufio+ unicode/utf8 from bufio+
unique from net/netip unique from net/netip+

View File

@ -586,4 +586,4 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
unicode from bytes+ unicode from bytes+
unicode/utf16 from crypto/x509+ unicode/utf16 from crypto/x509+
unicode/utf8 from bufio+ unicode/utf8 from bufio+
unique from net/netip unique from net/netip+

View File

@ -13,6 +13,7 @@
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
"unique"
"go4.org/mem" "go4.org/mem"
"golang.org/x/time/rate" "golang.org/x/time/rate"
@ -236,7 +237,7 @@ func (c *Client) send(dstKey key.NodePublic, pkt []byte) (ret error) {
return c.bw.Flush() return c.bw.Flush()
} }
func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err error) { func (c *Client) ForwardPacket(srcHandle, dstHandle unique.Handle[key.NodePublic], pkt []byte) (err error) {
defer func() { defer func() {
if err != nil { if err != nil {
err = fmt.Errorf("derp.ForwardPacket: %w", err) err = fmt.Errorf("derp.ForwardPacket: %w", err)
@ -256,10 +257,10 @@ func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err e
if err := writeFrameHeader(c.bw, frameForwardPacket, uint32(keyLen*2+len(pkt))); err != nil { if err := writeFrameHeader(c.bw, frameForwardPacket, uint32(keyLen*2+len(pkt))); err != nil {
return err return err
} }
if _, err := c.bw.Write(srcKey.AppendTo(nil)); err != nil { if _, err := c.bw.Write(srcHandle.Value().AppendTo(nil)); err != nil {
return err return err
} }
if _, err := c.bw.Write(dstKey.AppendTo(nil)); err != nil { if _, err := c.bw.Write(dstHandle.Value().AppendTo(nil)); err != nil {
return err return err
} }
if _, err := c.bw.Write(pkt); err != nil { if _, err := c.bw.Write(pkt); err != nil {

View File

@ -34,6 +34,7 @@
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unique"
"go4.org/mem" "go4.org/mem"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -48,15 +49,18 @@
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/ctxkey" "tailscale.com/util/ctxkey"
"tailscale.com/util/mak"
"tailscale.com/util/set" "tailscale.com/util/set"
"tailscale.com/util/slicesx" "tailscale.com/util/slicesx"
"tailscale.com/version" "tailscale.com/version"
) )
// NodeHandle is an interned and cheap to pass around and compare representation
// of a NodePublic key.
type NodeHandle = unique.Handle[key.NodePublic]
// verboseDropKeys is the set of destination public keys that should // verboseDropKeys is the set of destination public keys that should
// verbosely log whenever DERP drops a packet. // verbosely log whenever DERP drops a packet.
var verboseDropKeys = map[key.NodePublic]bool{} var verboseDropKeys = map[NodeHandle]bool{}
// IdealNodeHeader is the HTTP request header sent on DERP HTTP client requests // IdealNodeHeader is the HTTP request header sent on DERP HTTP client requests
// to indicate that they're connecting to their ideal (Region.Nodes[0]) node. // to indicate that they're connecting to their ideal (Region.Nodes[0]) node.
@ -78,7 +82,7 @@ func init() {
if err != nil { if err != nil {
log.Printf("ignoring invalid debug key %q: %v", keyStr, err) log.Printf("ignoring invalid debug key %q: %v", keyStr, err)
} else { } else {
verboseDropKeys[k] = true verboseDropKeys[k.Handle()] = true
} }
} }
} }
@ -173,22 +177,22 @@ type Server struct {
mu sync.Mutex mu sync.Mutex
closed bool closed bool
netConns map[Conn]chan struct{} // chan is closed when conn closes netConns map[Conn]chan struct{} // chan is closed when conn closes
clients map[key.NodePublic]*clientSet clients map[NodeHandle]*clientSet
watchers set.Set[*sclient] // mesh peers watchers set.Set[*sclient] // mesh peers
// clientsMesh tracks all clients in the cluster, both locally // clientsMesh tracks all clients in the cluster, both locally
// and to mesh peers. If the value is nil, that means the // and to mesh peers. If the value is nil, that means the
// peer is only local (and thus in the clients Map, but not // peer is only local (and thus in the clients Map, but not
// remote). If the value is non-nil, it's remote (+ maybe also // remote). If the value is non-nil, it's remote (+ maybe also
// local). // local).
clientsMesh map[key.NodePublic]PacketForwarder clientsMesh map[NodeHandle]PacketForwarder
// peerGoneWatchers is the set of watchers that subscribed to a // peerGoneWatchers is the set of watchers that subscribed to a
// peer disconnecting from the region overall. When a peer // peer disconnecting from the region overall. When a peer
// is gone from the region, we notify all of these watchers, // is gone from the region, we notify all of these watchers,
// calling their funcs in a new goroutine. // calling their funcs in a new goroutine.
peerGoneWatchers map[key.NodePublic]set.HandleSet[func(key.NodePublic)] peerGoneWatchers map[NodeHandle]map[NodeHandle]func(NodeHandle)
// maps from netip.AddrPort to a client's public key // maps from netip.AddrPort to a client's public key
keyOfAddr map[netip.AddrPort]key.NodePublic keyOfAddr map[netip.AddrPort]NodeHandle
clock tstime.Clock clock tstime.Clock
} }
@ -325,7 +329,7 @@ func (s *dupClientSet) removeClient(c *sclient) bool {
// is a multiForwarder, which this package creates as needed if a // is a multiForwarder, which this package creates as needed if a
// public key gets more than one PacketForwarder registered for it. // public key gets more than one PacketForwarder registered for it.
type PacketForwarder interface { type PacketForwarder interface {
ForwardPacket(src, dst key.NodePublic, payload []byte) error ForwardPacket(src, dst NodeHandle, payload []byte) error
String() string String() string
} }
@ -355,18 +359,18 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server {
packetsRecvByKind: metrics.LabelMap{Label: "kind"}, packetsRecvByKind: metrics.LabelMap{Label: "kind"},
packetsDroppedReason: metrics.LabelMap{Label: "reason"}, packetsDroppedReason: metrics.LabelMap{Label: "reason"},
packetsDroppedType: metrics.LabelMap{Label: "type"}, packetsDroppedType: metrics.LabelMap{Label: "type"},
clients: map[key.NodePublic]*clientSet{}, clients: map[NodeHandle]*clientSet{},
clientsMesh: map[key.NodePublic]PacketForwarder{}, clientsMesh: map[NodeHandle]PacketForwarder{},
netConns: map[Conn]chan struct{}{}, netConns: map[Conn]chan struct{}{},
memSys0: ms.Sys, memSys0: ms.Sys,
watchers: set.Set[*sclient]{}, watchers: set.Set[*sclient]{},
peerGoneWatchers: map[key.NodePublic]set.HandleSet[func(key.NodePublic)]{}, peerGoneWatchers: map[NodeHandle]map[NodeHandle]func(NodeHandle){},
avgQueueDuration: new(uint64), avgQueueDuration: new(uint64),
tcpRtt: metrics.LabelMap{Label: "le"}, tcpRtt: metrics.LabelMap{Label: "le"},
meshUpdateBatchSize: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000}), meshUpdateBatchSize: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000}),
meshUpdateLoopCount: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100}), meshUpdateLoopCount: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100}),
bufferedWriteFrames: metrics.NewHistogram([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 50, 100}), bufferedWriteFrames: metrics.NewHistogram([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 50, 100}),
keyOfAddr: map[netip.AddrPort]key.NodePublic{}, keyOfAddr: map[netip.AddrPort]NodeHandle{},
clock: tstime.StdClock{}, clock: tstime.StdClock{},
} }
s.initMetacert() s.initMetacert()
@ -479,7 +483,7 @@ func (s *Server) isClosed() bool {
func (s *Server) IsClientConnectedForTest(k key.NodePublic) bool { func (s *Server) IsClientConnectedForTest(k key.NodePublic) bool {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
x, ok := s.clients[k] x, ok := s.clients[k.Handle()]
if !ok { if !ok {
return false return false
} }
@ -573,11 +577,11 @@ func (s *Server) registerClient(c *sclient) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
cs, ok := s.clients[c.key] cs, ok := s.clients[c.handle]
if !ok { if !ok {
c.debugLogf("register single client") c.debugLogf("register single client")
cs = &clientSet{} cs = &clientSet{}
s.clients[c.key] = cs s.clients[c.handle] = cs
} }
was := cs.activeClient.Load() was := cs.activeClient.Load()
if was == nil { if was == nil {
@ -610,15 +614,15 @@ func (s *Server) registerClient(c *sclient) {
cs.activeClient.Store(c) cs.activeClient.Store(c)
if _, ok := s.clientsMesh[c.key]; !ok { if _, ok := s.clientsMesh[c.handle]; !ok {
s.clientsMesh[c.key] = nil // just for varz of total users in cluster s.clientsMesh[c.handle] = nil // just for varz of total users in cluster
} }
s.keyOfAddr[c.remoteIPPort] = c.key s.keyOfAddr[c.remoteIPPort] = c.handle
s.curClients.Add(1) s.curClients.Add(1)
if c.isNotIdealConn { if c.isNotIdealConn {
s.curClientsNotIdeal.Add(1) s.curClientsNotIdeal.Add(1)
} }
s.broadcastPeerStateChangeLocked(c.key, c.remoteIPPort, c.presentFlags(), true) s.broadcastPeerStateChangeLocked(c.handle, c.remoteIPPort, c.presentFlags(), true)
} }
// broadcastPeerStateChangeLocked enqueues a message to all watchers // broadcastPeerStateChangeLocked enqueues a message to all watchers
@ -626,10 +630,10 @@ func (s *Server) registerClient(c *sclient) {
// presence changed. // presence changed.
// //
// s.mu must be held. // s.mu must be held.
func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, ipPort netip.AddrPort, flags PeerPresentFlags, present bool) { func (s *Server) broadcastPeerStateChangeLocked(peer NodeHandle, ipPort netip.AddrPort, flags PeerPresentFlags, present bool) {
for w := range s.watchers { for w := range s.watchers {
w.peerStateChange = append(w.peerStateChange, peerConnState{ w.peerStateChange = append(w.peerStateChange, peerConnState{
peer: peer, peer: peer.Value(),
present: present, present: present,
ipPort: ipPort, ipPort: ipPort,
flags: flags, flags: flags,
@ -643,7 +647,7 @@ func (s *Server) unregisterClient(c *sclient) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
set, ok := s.clients[c.key] set, ok := s.clients[c.handle]
if !ok { if !ok {
c.logf("[unexpected]; clients map is empty") c.logf("[unexpected]; clients map is empty")
return return
@ -663,12 +667,12 @@ func (s *Server) unregisterClient(c *sclient) {
} }
c.debugLogf("removed connection") c.debugLogf("removed connection")
set.activeClient.Store(nil) set.activeClient.Store(nil)
delete(s.clients, c.key) delete(s.clients, c.handle)
if v, ok := s.clientsMesh[c.key]; ok && v == nil { if v, ok := s.clientsMesh[c.handle]; ok && v == nil {
delete(s.clientsMesh, c.key) delete(s.clientsMesh, c.handle)
s.notePeerGoneFromRegionLocked(c.key) s.notePeerGoneFromRegionLocked(c.handle)
} }
s.broadcastPeerStateChangeLocked(c.key, netip.AddrPort{}, 0, false) s.broadcastPeerStateChangeLocked(c.handle, netip.AddrPort{}, 0, false)
} else { } else {
c.debugLogf("removed duplicate client") c.debugLogf("removed duplicate client")
if dup.removeClient(c) { if dup.removeClient(c) {
@ -720,30 +724,30 @@ func (s *Server) unregisterClient(c *sclient) {
// The provided f func is usually [sclient.onPeerGoneFromRegion], added by // The provided f func is usually [sclient.onPeerGoneFromRegion], added by
// [sclient.noteSendFromSrc]; this func doesn't take a whole *sclient to make it // [sclient.noteSendFromSrc]; this func doesn't take a whole *sclient to make it
// clear what has access to what. // clear what has access to what.
func (s *Server) addPeerGoneFromRegionWatcher(peer key.NodePublic, f func(key.NodePublic)) set.Handle { func (s *Server) addPeerGoneFromRegionWatcher(peerA, peerB NodeHandle, f func(NodeHandle)) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
hset, ok := s.peerGoneWatchers[peer] m, ok := s.peerGoneWatchers[peerA]
if !ok { if !ok {
hset = set.HandleSet[func(key.NodePublic)]{} m = map[NodeHandle]func(NodeHandle){}
s.peerGoneWatchers[peer] = hset s.peerGoneWatchers[peerA] = m
} }
return hset.Add(f) m[peerB] = f
} }
// removePeerGoneFromRegionWatcher removes a peer watcher previously added by // removePeerGoneFromRegionWatcher removes a peer watcher previously added by
// addPeerGoneFromRegionWatcher, using the handle returned by // addPeerGoneFromRegionWatcher, using the handle returned by
// addPeerGoneFromRegionWatcher. // addPeerGoneFromRegionWatcher.
func (s *Server) removePeerGoneFromRegionWatcher(peer key.NodePublic, h set.Handle) { func (s *Server) removePeerGoneFromRegionWatcher(peerA, peerB NodeHandle) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
hset, ok := s.peerGoneWatchers[peer] hset, ok := s.peerGoneWatchers[peerA]
if !ok { if !ok {
return return
} }
delete(hset, h) delete(hset, peerB)
if len(hset) == 0 { if len(hset) == 0 {
delete(s.peerGoneWatchers, peer) delete(s.peerGoneWatchers, peerA)
} }
} }
@ -751,8 +755,8 @@ func (s *Server) removePeerGoneFromRegionWatcher(peer key.NodePublic, h set.Hand
// key has sent to previously (whether those sends were from a local // key has sent to previously (whether those sends were from a local
// client or forwarded). It must only be called after the key has // client or forwarded). It must only be called after the key has
// been removed from clientsMesh. // been removed from clientsMesh.
func (s *Server) notePeerGoneFromRegionLocked(key key.NodePublic) { func (s *Server) notePeerGoneFromRegionLocked(handle NodeHandle) {
if _, ok := s.clientsMesh[key]; ok { if _, ok := s.clientsMesh[handle]; ok {
panic("usage") panic("usage")
} }
@ -760,17 +764,17 @@ func (s *Server) notePeerGoneFromRegionLocked(key key.NodePublic) {
// so they can drop their route entries to us (issue 150) // so they can drop their route entries to us (issue 150)
// or move them over to the active client (in case a replaced client // or move them over to the active client (in case a replaced client
// connection is being unregistered). // connection is being unregistered).
set := s.peerGoneWatchers[key] set := s.peerGoneWatchers[handle]
for _, f := range set { for _, f := range set {
go f(key) go f(handle)
} }
delete(s.peerGoneWatchers, key) delete(s.peerGoneWatchers, handle)
} }
// requestPeerGoneWriteLimited sends a request to write a "peer gone" // requestPeerGoneWriteLimited sends a request to write a "peer gone"
// frame, but only in reply to a disco packet, and only if we haven't // frame, but only in reply to a disco packet, and only if we haven't
// sent one recently. // sent one recently.
func (c *sclient) requestPeerGoneWriteLimited(peer key.NodePublic, contents []byte, reason PeerGoneReasonType) { func (c *sclient) requestPeerGoneWriteLimited(peer NodeHandle, contents []byte, reason PeerGoneReasonType) {
if disco.LooksLikeDiscoWrapper(contents) != true { if disco.LooksLikeDiscoWrapper(contents) != true {
return return
} }
@ -800,7 +804,7 @@ func (s *Server) addWatcher(c *sclient) {
continue continue
} }
c.peerStateChange = append(c.peerStateChange, peerConnState{ c.peerStateChange = append(c.peerStateChange, peerConnState{
peer: peer, peer: peer.Value(),
present: true, present: true,
ipPort: ac.remoteIPPort, ipPort: ac.remoteIPPort,
flags: ac.presentFlags(), flags: ac.presentFlags(),
@ -826,6 +830,8 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem
if err != nil { if err != nil {
return fmt.Errorf("receive client key: %v", err) return fmt.Errorf("receive client key: %v", err)
} }
handle := clientKey.Handle()
clientKey = handle.Value() // interned
remoteIPPort, _ := netip.ParseAddrPort(remoteAddr) remoteIPPort, _ := netip.ParseAddrPort(remoteAddr)
if err := s.verifyClient(ctx, clientKey, clientInfo, remoteIPPort.Addr()); err != nil { if err := s.verifyClient(ctx, clientKey, clientInfo, remoteIPPort.Addr()); err != nil {
@ -842,6 +848,7 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem
connNum: connNum, connNum: connNum,
s: s, s: s,
key: clientKey, key: clientKey,
handle: handle,
nc: nc, nc: nc,
br: br, br: br,
bw: bw, bw: bw,
@ -1014,22 +1021,23 @@ func (c *sclient) handleFrameClosePeer(ft frameType, fl uint32) error {
if err := targetKey.ReadRawWithoutAllocating(c.br); err != nil { if err := targetKey.ReadRawWithoutAllocating(c.br); err != nil {
return err return err
} }
handle := targetKey.Handle()
s := c.s s := c.s
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if set, ok := s.clients[targetKey]; ok { if set, ok := s.clients[handle]; ok {
if set.Len() == 1 { if set.Len() == 1 {
c.logf("frameClosePeer closing peer %x", targetKey) c.logf("frameClosePeer closing peer %x", handle.Value())
} else { } else {
c.logf("frameClosePeer closing peer %x (%d connections)", targetKey, set.Len()) c.logf("frameClosePeer closing peer %x (%d connections)", handle.Value(), set.Len())
} }
set.ForeachClient(func(target *sclient) { set.ForeachClient(func(target *sclient) {
go target.nc.Close() go target.nc.Close()
}) })
} else { } else {
c.logf("frameClosePeer failed to find peer %x", targetKey) c.logf("frameClosePeer failed to find peer %x", handle.Value())
} }
return nil return nil
@ -1043,7 +1051,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error {
} }
s := c.s s := c.s
srcKey, dstKey, contents, err := s.recvForwardPacket(c.br, fl) srcHandle, dstHandle, contents, err := s.recvForwardPacket(c.br, fl)
if err != nil { if err != nil {
return fmt.Errorf("client %v: recvForwardPacket: %v", c.key, err) return fmt.Errorf("client %v: recvForwardPacket: %v", c.key, err)
} }
@ -1053,7 +1061,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error {
var dst *sclient var dst *sclient
s.mu.Lock() s.mu.Lock()
if set, ok := s.clients[dstKey]; ok { if set, ok := s.clients[dstHandle]; ok {
dstLen = set.Len() dstLen = set.Len()
dst = set.activeClient.Load() dst = set.activeClient.Load()
} }
@ -1064,18 +1072,18 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error {
if dstLen > 1 { if dstLen > 1 {
reason = dropReasonDupClient reason = dropReasonDupClient
} else { } else {
c.requestPeerGoneWriteLimited(dstKey, contents, PeerGoneReasonNotHere) c.requestPeerGoneWriteLimited(dstHandle, contents, PeerGoneReasonNotHere)
} }
s.recordDrop(contents, srcKey, dstKey, reason) s.recordDrop(contents, srcHandle, dstHandle, reason)
return nil return nil
} }
dst.debugLogf("received forwarded packet from %s via %s", srcKey.ShortString(), c.key.ShortString()) dst.debugLogf("received forwarded packet from %s via %s", srcHandle.Value().ShortString(), c.key.ShortString())
return c.sendPkt(dst, pkt{ return c.sendPkt(dst, pkt{
bs: contents, bs: contents,
enqueuedAt: c.s.clock.Now(), enqueuedAt: c.s.clock.Now(),
src: srcKey, src: srcHandle,
}) })
} }
@ -1083,7 +1091,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error {
func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
s := c.s s := c.s
dstKey, contents, err := s.recvPacket(c.br, fl) dstHandle, contents, err := s.recvPacket(c.br, fl)
if err != nil { if err != nil {
return fmt.Errorf("client %v: recvPacket: %v", c.key, err) return fmt.Errorf("client %v: recvPacket: %v", c.key, err)
} }
@ -1093,20 +1101,20 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
var dst *sclient var dst *sclient
s.mu.Lock() s.mu.Lock()
if set, ok := s.clients[dstKey]; ok { if set, ok := s.clients[dstHandle]; ok {
dstLen = set.Len() dstLen = set.Len()
dst = set.activeClient.Load() dst = set.activeClient.Load()
} }
if dst == nil && dstLen < 1 { if dst == nil && dstLen < 1 {
fwd = s.clientsMesh[dstKey] fwd = s.clientsMesh[dstHandle]
} }
s.mu.Unlock() s.mu.Unlock()
if dst == nil { if dst == nil {
if fwd != nil { if fwd != nil {
s.packetsForwardedOut.Add(1) s.packetsForwardedOut.Add(1)
err := fwd.ForwardPacket(c.key, dstKey, contents) err := fwd.ForwardPacket(c.handle, dstHandle, contents)
c.debugLogf("SendPacket for %s, forwarding via %s: %v", dstKey.ShortString(), fwd, err) c.debugLogf("SendPacket for %s, forwarding via %s: %v", dstHandle.Value().ShortString(), fwd, err)
if err != nil { if err != nil {
// TODO: // TODO:
return nil return nil
@ -1117,18 +1125,18 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
if dstLen > 1 { if dstLen > 1 {
reason = dropReasonDupClient reason = dropReasonDupClient
} else { } else {
c.requestPeerGoneWriteLimited(dstKey, contents, PeerGoneReasonNotHere) c.requestPeerGoneWriteLimited(dstHandle, contents, PeerGoneReasonNotHere)
} }
s.recordDrop(contents, c.key, dstKey, reason) s.recordDrop(contents, c.handle, dstHandle, reason)
c.debugLogf("SendPacket for %s, dropping with reason=%s", dstKey.ShortString(), reason) c.debugLogf("SendPacket for %s, dropping with reason=%s", dstHandle.Value().ShortString(), reason)
return nil return nil
} }
c.debugLogf("SendPacket for %s, sending directly", dstKey.ShortString()) c.debugLogf("SendPacket for %s, sending directly", dstHandle.Value().ShortString())
p := pkt{ p := pkt{
bs: contents, bs: contents,
enqueuedAt: c.s.clock.Now(), enqueuedAt: c.s.clock.Now(),
src: c.key, src: c.handle,
} }
return c.sendPkt(dst, p) return c.sendPkt(dst, p)
} }
@ -1155,7 +1163,7 @@ func (c *sclient) debugLogf(format string, v ...any) {
numDropReasons // unused; keep last numDropReasons // unused; keep last
) )
func (s *Server) recordDrop(packetBytes []byte, srcKey, dstKey key.NodePublic, reason dropReason) { func (s *Server) recordDrop(packetBytes []byte, srcHandle, dstHandle NodeHandle, reason dropReason) {
s.packetsDropped.Add(1) s.packetsDropped.Add(1)
s.packetsDroppedReasonCounters[reason].Add(1) s.packetsDroppedReasonCounters[reason].Add(1)
looksDisco := disco.LooksLikeDiscoWrapper(packetBytes) looksDisco := disco.LooksLikeDiscoWrapper(packetBytes)
@ -1164,20 +1172,19 @@ func (s *Server) recordDrop(packetBytes []byte, srcKey, dstKey key.NodePublic, r
} else { } else {
s.packetsDroppedTypeOther.Add(1) s.packetsDroppedTypeOther.Add(1)
} }
if verboseDropKeys[dstKey] { if verboseDropKeys[dstHandle] {
// Preformat the log string prior to calling limitedLogf. The // Preformat the log string prior to calling limitedLogf. The
// limiter acts based on the format string, and we want to // limiter acts based on the format string, and we want to
// rate-limit per src/dst keys, not on the generic "dropped // rate-limit per src/dst keys, not on the generic "dropped
// stuff" message. // stuff" message.
msg := fmt.Sprintf("drop (%s) %s -> %s", srcKey.ShortString(), reason, dstKey.ShortString()) msg := fmt.Sprintf("drop (%s) %s -> %s", srcHandle.Value().ShortString(), reason, dstHandle.Value().ShortString())
s.limitedLogf(msg) s.limitedLogf(msg)
} }
s.debugLogf("dropping packet reason=%s dst=%s disco=%v", reason, dstKey, looksDisco) s.debugLogf("dropping packet reason=%s dst=%s disco=%v", reason, dstHandle.Value(), looksDisco)
} }
func (c *sclient) sendPkt(dst *sclient, p pkt) error { func (c *sclient) sendPkt(dst *sclient, p pkt) error {
s := c.s s := c.s
dstKey := dst.key
// Attempt to queue for sending up to 3 times. On each attempt, if // Attempt to queue for sending up to 3 times. On each attempt, if
// the queue is full, try to drop from queue head to prioritize // the queue is full, try to drop from queue head to prioritize
@ -1189,7 +1196,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error {
for attempt := 0; attempt < 3; attempt++ { for attempt := 0; attempt < 3; attempt++ {
select { select {
case <-dst.done: case <-dst.done:
s.recordDrop(p.bs, c.key, dstKey, dropReasonGoneDisconnected) s.recordDrop(p.bs, c.handle, dst.handle, dropReasonGoneDisconnected)
dst.debugLogf("sendPkt attempt %d dropped, dst gone", attempt) dst.debugLogf("sendPkt attempt %d dropped, dst gone", attempt)
return nil return nil
default: default:
@ -1203,7 +1210,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error {
select { select {
case pkt := <-sendQueue: case pkt := <-sendQueue:
s.recordDrop(pkt.bs, c.key, dstKey, dropReasonQueueHead) s.recordDrop(pkt.bs, c.handle, dst.handle, dropReasonQueueHead)
c.recordQueueTime(pkt.enqueuedAt) c.recordQueueTime(pkt.enqueuedAt)
default: default:
} }
@ -1211,7 +1218,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error {
// Failed to make room for packet. This can happen in a heavily // Failed to make room for packet. This can happen in a heavily
// contended queue with racing writers. Give up and tail-drop in // contended queue with racing writers. Give up and tail-drop in
// this case to keep reader unblocked. // this case to keep reader unblocked.
s.recordDrop(p.bs, c.key, dstKey, dropReasonQueueTail) s.recordDrop(p.bs, c.handle, dst.handle, dropReasonQueueTail)
dst.debugLogf("sendPkt attempt %d dropped, queue full") dst.debugLogf("sendPkt attempt %d dropped, queue full")
return nil return nil
@ -1220,17 +1227,17 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error {
// onPeerGoneFromRegion is the callback registered with the Server to be // onPeerGoneFromRegion is the callback registered with the Server to be
// notified (in a new goroutine) whenever a peer has disconnected from all DERP // notified (in a new goroutine) whenever a peer has disconnected from all DERP
// nodes in the current region. // nodes in the current region.
func (c *sclient) onPeerGoneFromRegion(peer key.NodePublic) { func (c *sclient) onPeerGoneFromRegion(peer NodeHandle) {
c.requestPeerGoneWrite(peer, PeerGoneReasonDisconnected) c.requestPeerGoneWrite(peer, PeerGoneReasonDisconnected)
} }
// requestPeerGoneWrite sends a request to write a "peer gone" frame // requestPeerGoneWrite sends a request to write a "peer gone" frame
// with an explanation of why it is gone. It blocks until either the // with an explanation of why it is gone. It blocks until either the
// write request is scheduled, or the client has closed. // write request is scheduled, or the client has closed.
func (c *sclient) requestPeerGoneWrite(peer key.NodePublic, reason PeerGoneReasonType) { func (c *sclient) requestPeerGoneWrite(peer NodeHandle, reason PeerGoneReasonType) {
select { select {
case c.peerGone <- peerGoneMsg{ case c.peerGone <- peerGoneMsg{
peer: peer, peer: peer.Value(),
reason: reason, reason: reason,
}: }:
case <-c.done: case <-c.done:
@ -1346,7 +1353,7 @@ func (s *Server) noteClientActivity(c *sclient) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
cs, ok := s.clients[c.key] cs, ok := s.clients[c.handle]
if !ok { if !ok {
return return
} }
@ -1453,20 +1460,22 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info
return clientKey, info, nil return clientKey, info, nil
} }
func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodePublic, contents []byte, err error) { func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstHandle NodeHandle, contents []byte, err error) {
if frameLen < keyLen { if frameLen < keyLen {
return zpub, nil, errors.New("short send packet frame") return zpubHandle, nil, errors.New("short send packet frame")
} }
var dstKey key.NodePublic
if err := dstKey.ReadRawWithoutAllocating(br); err != nil { if err := dstKey.ReadRawWithoutAllocating(br); err != nil {
return zpub, nil, err return zpubHandle, nil, err
} }
dstHandle = dstKey.Handle()
packetLen := frameLen - keyLen packetLen := frameLen - keyLen
if packetLen > MaxPacketSize { if packetLen > MaxPacketSize {
return zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize) return zpubHandle, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize)
} }
contents = make([]byte, packetLen) contents = make([]byte, packetLen)
if _, err := io.ReadFull(br, contents); err != nil { if _, err := io.ReadFull(br, contents); err != nil {
return zpub, nil, err return zpubHandle, nil, err
} }
s.packetsRecv.Add(1) s.packetsRecv.Add(1)
s.bytesRecv.Add(int64(len(contents))) s.bytesRecv.Add(int64(len(contents)))
@ -1475,33 +1484,35 @@ func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodeP
} else { } else {
s.packetsRecvOther.Add(1) s.packetsRecvOther.Add(1)
} }
return dstKey, contents, nil return dstHandle, contents, nil
} }
// zpub is the key.NodePublic zero value. // zpub is the key.NodePublic zero value.
var zpub key.NodePublic var zpub key.NodePublic
var zpubHandle NodeHandle = zpub.Handle()
func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, dstKey key.NodePublic, contents []byte, err error) { func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcHandle, dstHandle NodeHandle, contents []byte, err error) {
if frameLen < keyLen*2 { if frameLen < keyLen*2 {
return zpub, zpub, nil, errors.New("short send packet frame") return zpubHandle, zpubHandle, nil, errors.New("short send packet frame")
} }
var srcKey, dstKey key.NodePublic
if err := srcKey.ReadRawWithoutAllocating(br); err != nil { if err := srcKey.ReadRawWithoutAllocating(br); err != nil {
return zpub, zpub, nil, err return zpubHandle, zpubHandle, nil, err
} }
if err := dstKey.ReadRawWithoutAllocating(br); err != nil { if err := dstKey.ReadRawWithoutAllocating(br); err != nil {
return zpub, zpub, nil, err return zpubHandle, zpubHandle, nil, err
} }
packetLen := frameLen - keyLen*2 packetLen := frameLen - keyLen*2
if packetLen > MaxPacketSize { if packetLen > MaxPacketSize {
return zpub, zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize) return zpubHandle, zpubHandle, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize)
} }
contents = make([]byte, packetLen) contents = make([]byte, packetLen)
if _, err := io.ReadFull(br, contents); err != nil { if _, err := io.ReadFull(br, contents); err != nil {
return zpub, zpub, nil, err return zpubHandle, zpubHandle, nil, err
} }
// TODO: was s.packetsRecv.Add(1) // TODO: was s.packetsRecv.Add(1)
// TODO: was s.bytesRecv.Add(int64(len(contents))) // TODO: was s.bytesRecv.Add(int64(len(contents)))
return srcKey, dstKey, contents, nil return srcKey.Handle(), dstKey.Handle(), contents, nil
} }
// sclient is a client connection to the server. // sclient is a client connection to the server.
@ -1518,6 +1529,7 @@ type sclient struct {
s *Server s *Server
nc Conn nc Conn
key key.NodePublic key key.NodePublic
handle NodeHandle // handle is a cached handle for key
info clientInfo info clientInfo
logf logger.Logf logf logger.Logf
done <-chan struct{} // closed when connection closes done <-chan struct{} // closed when connection closes
@ -1539,7 +1551,7 @@ type sclient struct {
preferred bool preferred bool
// Owned by sendLoop, not thread-safe. // Owned by sendLoop, not thread-safe.
sawSrc map[key.NodePublic]set.Handle sawSrc set.Set[NodeHandle]
bw *lazyBufioWriter bw *lazyBufioWriter
// Guarded by s.mu // Guarded by s.mu
@ -1593,7 +1605,7 @@ type pkt struct {
bs []byte bs []byte
// src is the who's the sender of the packet. // src is the who's the sender of the packet.
src key.NodePublic src NodeHandle
} }
// peerGoneMsg is a request to write a peerGone frame to an sclient // peerGoneMsg is a request to write a peerGone frame to an sclient
@ -1656,17 +1668,17 @@ func (c *sclient) onSendLoopDone() {
c.nc.Close() c.nc.Close()
// Clean up watches. // Clean up watches.
for peer, h := range c.sawSrc { for peer := range c.sawSrc {
c.s.removePeerGoneFromRegionWatcher(peer, h) c.s.removePeerGoneFromRegionWatcher(peer, c.handle)
} }
// Drain the send queue to count dropped packets // Drain the send queue to count dropped packets
for { for {
select { select {
case pkt := <-c.sendQueue: case pkt := <-c.sendQueue:
c.s.recordDrop(pkt.bs, pkt.src, c.key, dropReasonGoneDisconnected) c.s.recordDrop(pkt.bs, pkt.src, c.handle, dropReasonGoneDisconnected)
case pkt := <-c.discoSendQueue: case pkt := <-c.discoSendQueue:
c.s.recordDrop(pkt.bs, pkt.src, c.key, dropReasonGoneDisconnected) c.s.recordDrop(pkt.bs, pkt.src, c.handle, dropReasonGoneDisconnected)
default: default:
return return
} }
@ -1869,31 +1881,31 @@ func (c *sclient) sendMeshUpdates() error {
// DERPv2. The bytes of contents are only valid until this function // DERPv2. The bytes of contents are only valid until this function
// returns, do not retain slices. // returns, do not retain slices.
// It does not flush its bufio.Writer. // It does not flush its bufio.Writer.
func (c *sclient) sendPacket(srcKey key.NodePublic, contents []byte) (err error) { func (c *sclient) sendPacket(srcHandle NodeHandle, contents []byte) (err error) {
defer func() { defer func() {
// Stats update. // Stats update.
if err != nil { if err != nil {
c.s.recordDrop(contents, srcKey, c.key, dropReasonWriteError) c.s.recordDrop(contents, srcHandle, c.handle, dropReasonWriteError)
} else { } else {
c.s.packetsSent.Add(1) c.s.packetsSent.Add(1)
c.s.bytesSent.Add(int64(len(contents))) c.s.bytesSent.Add(int64(len(contents)))
} }
c.debugLogf("sendPacket from %s: %v", srcKey.ShortString(), err) c.debugLogf("sendPacket from %s: %v", srcHandle.Value().ShortString(), err)
}() }()
c.setWriteDeadline() c.setWriteDeadline()
withKey := !srcKey.IsZero() withKey := !srcHandle.Value().IsZero()
pktLen := len(contents) pktLen := len(contents)
if withKey { if withKey {
pktLen += key.NodePublicRawLen pktLen += key.NodePublicRawLen
c.noteSendFromSrc(srcKey) c.noteSendFromSrc(srcHandle)
} }
if err = writeFrameHeader(c.bw.bw(), frameRecvPacket, uint32(pktLen)); err != nil { if err = writeFrameHeader(c.bw.bw(), frameRecvPacket, uint32(pktLen)); err != nil {
return err return err
} }
if withKey { if withKey {
if err := srcKey.WriteRawWithoutAllocating(c.bw.bw()); err != nil { if err := srcHandle.Value().WriteRawWithoutAllocating(c.bw.bw()); err != nil {
return err return err
} }
} }
@ -1905,17 +1917,18 @@ func (c *sclient) sendPacket(srcKey key.NodePublic, contents []byte) (err error)
// from src to sclient. // from src to sclient.
// //
// It must only be called from the sendLoop goroutine. // It must only be called from the sendLoop goroutine.
func (c *sclient) noteSendFromSrc(src key.NodePublic) { func (c *sclient) noteSendFromSrc(src NodeHandle) {
if _, ok := c.sawSrc[src]; ok { if _, ok := c.sawSrc[src]; ok {
return return
} }
h := c.s.addPeerGoneFromRegionWatcher(src, c.onPeerGoneFromRegion) c.s.addPeerGoneFromRegionWatcher(src, c.handle, c.onPeerGoneFromRegion)
mak.Set(&c.sawSrc, src, h) c.sawSrc.Make() // ensure sawSrc is non-nil
c.sawSrc.Add(src)
} }
// AddPacketForwarder registers fwd as a packet forwarder for dst. // AddPacketForwarder registers fwd as a packet forwarder for dst.
// fwd must be comparable. // fwd must be comparable.
func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) { func (s *Server) AddPacketForwarder(dst NodeHandle, fwd PacketForwarder) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if prev, ok := s.clientsMesh[dst]; ok { if prev, ok := s.clientsMesh[dst]; ok {
@ -1945,7 +1958,7 @@ func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) {
// RemovePacketForwarder removes fwd as a packet forwarder for dst. // RemovePacketForwarder removes fwd as a packet forwarder for dst.
// fwd must be comparable. // fwd must be comparable.
func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder) { func (s *Server) RemovePacketForwarder(dst NodeHandle, fwd PacketForwarder) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
v, ok := s.clientsMesh[dst] v, ok := s.clientsMesh[dst]
@ -2048,7 +2061,7 @@ func (f *multiForwarder) deleteLocked(fwd PacketForwarder) (_ PacketForwarder, i
return nil, false return nil, false
} }
func (f *multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error { func (f *multiForwarder) ForwardPacket(src, dst NodeHandle, payload []byte) error {
return f.fwd.Load().ForwardPacket(src, dst, payload) return f.fwd.Load().ForwardPacket(src, dst, payload)
} }
@ -2238,8 +2251,8 @@ func (s *Server) ServeDebugTraffic(w http.ResponseWriter, r *http.Request) {
for k, next := range newState { for k, next := range newState {
prev := prevState[k] prev := prevState[k]
if prev.Sent < next.Sent || prev.Recv < next.Recv { if prev.Sent < next.Sent || prev.Recv < next.Recv {
if pkey, ok := s.keyOfAddr[k]; ok { if pHandle, ok := s.keyOfAddr[k]; ok {
next.Key = pkey next.Key = pHandle.Value()
if err := enc.Encode(next); err != nil { if err := enc.Encode(next); err != nil {
s.mu.Unlock() s.mu.Unlock()
return return

View File

@ -714,27 +714,27 @@ func TestWatch(t *testing.T) {
type testFwd int type testFwd int
func (testFwd) ForwardPacket(key.NodePublic, key.NodePublic, []byte) error { func (testFwd) ForwardPacket(NodeHandle, NodeHandle, []byte) error {
panic("not called in tests") panic("not called in tests")
} }
func (testFwd) String() string { func (testFwd) String() string {
panic("not called in tests") panic("not called in tests")
} }
func pubAll(b byte) (ret key.NodePublic) { func pubAll(b byte) (ret NodeHandle) {
var bs [32]byte var bs [32]byte
for i := range bs { for i := range bs {
bs[i] = b bs[i] = b
} }
return key.NodePublicFromRaw32(mem.B(bs[:])) return key.NodePublicFromRaw32(mem.B(bs[:])).Handle()
} }
func TestForwarderRegistration(t *testing.T) { func TestForwarderRegistration(t *testing.T) {
s := &Server{ s := &Server{
clients: make(map[key.NodePublic]*clientSet), clients: make(map[NodeHandle]*clientSet),
clientsMesh: map[key.NodePublic]PacketForwarder{}, clientsMesh: map[NodeHandle]PacketForwarder{},
} }
want := func(want map[key.NodePublic]PacketForwarder) { want := func(want map[NodeHandle]PacketForwarder) {
t.Helper() t.Helper()
if got := s.clientsMesh; !reflect.DeepEqual(got, want) { if got := s.clientsMesh; !reflect.DeepEqual(got, want) {
t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want) t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want)
@ -758,28 +758,28 @@ func TestForwarderRegistration(t *testing.T) {
s.AddPacketForwarder(u1, testFwd(1)) s.AddPacketForwarder(u1, testFwd(1))
s.AddPacketForwarder(u2, testFwd(2)) s.AddPacketForwarder(u2, testFwd(2))
want(map[key.NodePublic]PacketForwarder{ want(map[NodeHandle]PacketForwarder{
u1: testFwd(1), u1: testFwd(1),
u2: testFwd(2), u2: testFwd(2),
}) })
// Verify a remove of non-registered forwarder is no-op. // Verify a remove of non-registered forwarder is no-op.
s.RemovePacketForwarder(u2, testFwd(999)) s.RemovePacketForwarder(u2, testFwd(999))
want(map[key.NodePublic]PacketForwarder{ want(map[NodeHandle]PacketForwarder{
u1: testFwd(1), u1: testFwd(1),
u2: testFwd(2), u2: testFwd(2),
}) })
// Verify a remove of non-registered user is no-op. // Verify a remove of non-registered user is no-op.
s.RemovePacketForwarder(u3, testFwd(1)) s.RemovePacketForwarder(u3, testFwd(1))
want(map[key.NodePublic]PacketForwarder{ want(map[NodeHandle]PacketForwarder{
u1: testFwd(1), u1: testFwd(1),
u2: testFwd(2), u2: testFwd(2),
}) })
// Actual removal. // Actual removal.
s.RemovePacketForwarder(u2, testFwd(2)) s.RemovePacketForwarder(u2, testFwd(2))
want(map[key.NodePublic]PacketForwarder{ want(map[NodeHandle]PacketForwarder{
u1: testFwd(1), u1: testFwd(1),
}) })
@ -787,14 +787,14 @@ func TestForwarderRegistration(t *testing.T) {
wantCounter(&s.multiForwarderCreated, 0) wantCounter(&s.multiForwarderCreated, 0)
s.AddPacketForwarder(u1, testFwd(100)) s.AddPacketForwarder(u1, testFwd(100))
s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path
want(map[key.NodePublic]PacketForwarder{ want(map[NodeHandle]PacketForwarder{
u1: newMultiForwarder(testFwd(1), testFwd(100)), u1: newMultiForwarder(testFwd(1), testFwd(100)),
}) })
wantCounter(&s.multiForwarderCreated, 1) wantCounter(&s.multiForwarderCreated, 1)
// Removing a forwarder in a multi set that doesn't exist; does nothing. // Removing a forwarder in a multi set that doesn't exist; does nothing.
s.RemovePacketForwarder(u1, testFwd(55)) s.RemovePacketForwarder(u1, testFwd(55))
want(map[key.NodePublic]PacketForwarder{ want(map[NodeHandle]PacketForwarder{
u1: newMultiForwarder(testFwd(1), testFwd(100)), u1: newMultiForwarder(testFwd(1), testFwd(100)),
}) })
@ -802,7 +802,7 @@ func TestForwarderRegistration(t *testing.T) {
// from being a multiForwarder. // from being a multiForwarder.
wantCounter(&s.multiForwarderDeleted, 0) wantCounter(&s.multiForwarderDeleted, 0)
s.RemovePacketForwarder(u1, testFwd(1)) s.RemovePacketForwarder(u1, testFwd(1))
want(map[key.NodePublic]PacketForwarder{ want(map[NodeHandle]PacketForwarder{
u1: testFwd(100), u1: testFwd(100),
}) })
wantCounter(&s.multiForwarderDeleted, 1) wantCounter(&s.multiForwarderDeleted, 1)
@ -810,23 +810,24 @@ func TestForwarderRegistration(t *testing.T) {
// Removing an entry for a client that's still connected locally should result // Removing an entry for a client that's still connected locally should result
// in a nil forwarder. // in a nil forwarder.
u1c := &sclient{ u1c := &sclient{
key: u1, key: u1.Value(),
logf: logger.Discard, handle: u1,
logf: logger.Discard,
} }
s.clients[u1] = singleClient(u1c) s.clients[u1] = singleClient(u1c)
s.RemovePacketForwarder(u1, testFwd(100)) s.RemovePacketForwarder(u1, testFwd(100))
want(map[key.NodePublic]PacketForwarder{ want(map[NodeHandle]PacketForwarder{
u1: nil, u1: nil,
}) })
// But once that client disconnects, it should go away. // But once that client disconnects, it should go away.
s.unregisterClient(u1c) s.unregisterClient(u1c)
want(map[key.NodePublic]PacketForwarder{}) want(map[NodeHandle]PacketForwarder{})
// But if it already has a forwarder, it's not removed. // But if it already has a forwarder, it's not removed.
s.AddPacketForwarder(u1, testFwd(2)) s.AddPacketForwarder(u1, testFwd(2))
s.unregisterClient(u1c) s.unregisterClient(u1c)
want(map[key.NodePublic]PacketForwarder{ want(map[NodeHandle]PacketForwarder{
u1: testFwd(2), u1: testFwd(2),
}) })
@ -835,11 +836,11 @@ func TestForwarderRegistration(t *testing.T) {
// from nil to the new one, not a multiForwarder. // from nil to the new one, not a multiForwarder.
s.clients[u1] = singleClient(u1c) s.clients[u1] = singleClient(u1c)
s.clientsMesh[u1] = nil s.clientsMesh[u1] = nil
want(map[key.NodePublic]PacketForwarder{ want(map[NodeHandle]PacketForwarder{
u1: nil, u1: nil,
}) })
s.AddPacketForwarder(u1, testFwd(3)) s.AddPacketForwarder(u1, testFwd(3))
want(map[key.NodePublic]PacketForwarder{ want(map[NodeHandle]PacketForwarder{
u1: testFwd(3), u1: testFwd(3),
}) })
} }
@ -853,7 +854,7 @@ type channelFwd struct {
} }
func (f channelFwd) String() string { return "" } func (f channelFwd) String() string { return "" }
func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error { func (f channelFwd) ForwardPacket(_ NodeHandle, _ NodeHandle, packet []byte) error {
f.c <- packet f.c <- packet
return nil return nil
} }
@ -865,8 +866,8 @@ func TestMultiForwarder(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
s := &Server{ s := &Server{
clients: make(map[key.NodePublic]*clientSet), clients: make(map[NodeHandle]*clientSet),
clientsMesh: map[key.NodePublic]PacketForwarder{}, clientsMesh: map[NodeHandle]PacketForwarder{},
} }
u := pubAll(1) u := pubAll(1)
s.AddPacketForwarder(u, channelFwd{1, ch}) s.AddPacketForwarder(u, channelFwd{1, ch})
@ -1067,9 +1068,9 @@ func TestServerDupClients(t *testing.T) {
run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) { run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) {
s = NewServer(serverPriv, t.Logf) s = NewServer(serverPriv, t.Logf)
s.dupPolicy = dupPolicy s.dupPolicy = dupPolicy
c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")} c1 = &sclient{key: clientPub, handle: clientPub.Handle(), logf: logger.WithPrefix(t.Logf, "c1: ")}
c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")} c2 = &sclient{key: clientPub, handle: clientPub.Handle(), logf: logger.WithPrefix(t.Logf, "c2: ")}
c3 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c3: ")} c3 = &sclient{key: clientPub, handle: clientPub.Handle(), logf: logger.WithPrefix(t.Logf, "c3: ")}
clientName = map[*sclient]string{ clientName = map[*sclient]string{
c1: "c1", c1: "c1",
c2: "c2", c2: "c2",
@ -1083,7 +1084,7 @@ func TestServerDupClients(t *testing.T) {
} }
wantSingleClient := func(t *testing.T, want *sclient) { wantSingleClient := func(t *testing.T, want *sclient) {
t.Helper() t.Helper()
got, ok := s.clients[want.key] got, ok := s.clients[want.key.Handle()]
if !ok { if !ok {
t.Error("no clients for key") t.Error("no clients for key")
return return
@ -1106,7 +1107,7 @@ func TestServerDupClients(t *testing.T) {
} }
wantNoClient := func(t *testing.T) { wantNoClient := func(t *testing.T) {
t.Helper() t.Helper()
_, ok := s.clients[clientPub] _, ok := s.clients[clientPub.Handle()]
if !ok { if !ok {
// Good // Good
return return
@ -1115,7 +1116,7 @@ func TestServerDupClients(t *testing.T) {
} }
wantDupSet := func(t *testing.T) *dupClientSet { wantDupSet := func(t *testing.T) *dupClientSet {
t.Helper() t.Helper()
cs, ok := s.clients[clientPub] cs, ok := s.clients[clientPub.Handle()]
if !ok { if !ok {
t.Fatal("no set for key; want dup set") t.Fatal("no set for key; want dup set")
return nil return nil
@ -1128,7 +1129,7 @@ func TestServerDupClients(t *testing.T) {
} }
wantActive := func(t *testing.T, want *sclient) { wantActive := func(t *testing.T, want *sclient) {
t.Helper() t.Helper()
set, ok := s.clients[clientPub] set, ok := s.clients[clientPub.Handle()]
if !ok { if !ok {
t.Error("no set for key") t.Error("no set for key")
return return

View File

@ -27,6 +27,7 @@
"strings" "strings"
"sync" "sync"
"time" "time"
"unique"
"go4.org/mem" "go4.org/mem"
"tailscale.com/derp" "tailscale.com/derp"
@ -971,7 +972,7 @@ func (c *Client) LocalAddr() (netip.AddrPort, error) {
return la, nil return la, nil
} }
func (c *Client) ForwardPacket(from, to key.NodePublic, b []byte) error { func (c *Client) ForwardPacket(from, to unique.Handle[key.NodePublic], b []byte) error {
client, _, err := c.connect(c.newContext(), "derphttp.Client.ForwardPacket") client, _, err := c.connect(c.newContext(), "derphttp.Client.ForwardPacket")
if err != nil { if err != nil {
return err return err

View File

@ -327,7 +327,7 @@ func TestBreakWatcherConnRecv(t *testing.T) {
} }
watcher1.breakConnection(watcher1.client) watcher1.breakConnection(watcher1.client)
// re-establish connection by sending a packet // re-establish connection by sending a packet
watcher1.ForwardPacket(key.NodePublic{}, key.NodePublic{}, []byte("bogus")) watcher1.ForwardPacket(key.NodePublic{}.Handle(), key.NodePublic{}.Handle(), []byte("bogus"))
timer.Reset(5 * time.Second) timer.Reset(5 * time.Second)
} }
@ -400,7 +400,7 @@ func TestBreakWatcherConn(t *testing.T) {
} }
watcher1.breakConnection(watcher1.client) watcher1.breakConnection(watcher1.client)
// re-establish connection by sending a packet // re-establish connection by sending a packet
watcher1.ForwardPacket(key.NodePublic{}, key.NodePublic{}, []byte("bogus")) watcher1.ForwardPacket(key.NodePublic{}.Handle(), key.NodePublic{}.Handle(), []byte("bogus"))
// signal that the breaker is done // signal that the breaker is done
breakerChan <- true breakerChan <- true

View File

@ -10,6 +10,7 @@
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"unique"
"go4.org/mem" "go4.org/mem"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
@ -174,6 +175,16 @@ func (p NodePublic) Compare(p2 NodePublic) int {
return bytes.Compare(p.k[:], p2.k[:]) return bytes.Compare(p.k[:], p2.k[:])
} }
// Handle returns a unique.Handle for this NodePublic. The Handle is more
// efficient for storage and comparison than the NodePublic itself, but is also
// more expensive to create. It is best to keep a copy of the Handle on a longer
// term object representing a NodePublic, rather than creating it on the fly,
// but in doing so if the Handle is used in multiple other data structures the
// cost of Handle storage and comparisons on lookups will quickly amortize.
func (p NodePublic) Handle() unique.Handle[NodePublic] {
return unique.Make(p)
}
// ParseNodePublicUntyped parses an untyped 64-character hex value // ParseNodePublicUntyped parses an untyped 64-character hex value
// as a NodePublic. // as a NodePublic.
// //