derp: intern key.NodePublic across the server

Consistently interning the NodePublic's throughout DERP, particularly
inside the maps reduces memory usage and reduces lookup costs in the
associated data structures.

It is not clear exactly how efficient the weak pointers will be in
practice, but estimating this using derpstress with 10k conns pushing
40kpps in each direction, this is patch grows heap at approximately half
the rate vs.  the old code and has fewer instances of long stalls that
trigger i/o timeouts for the clients.

Updates tailscale/corp#24485

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker 2024-11-07 12:26:43 -08:00
parent 3b93fd9c44
commit f7ad04bea4
No known key found for this signature in database
12 changed files with 178 additions and 151 deletions

View File

@ -315,4 +315,4 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
unicode from bytes+
unicode/utf16 from crypto/x509+
unicode/utf8 from bufio+
unique from net/netip
unique from net/netip+

View File

@ -69,8 +69,8 @@ func startMeshWithHost(s *derp.Server, host string) error {
return d.DialContext(ctx, network, addr)
})
add := func(m derp.PeerPresentMessage) { s.AddPacketForwarder(m.Key, c) }
remove := func(m derp.PeerGoneMessage) { s.RemovePacketForwarder(m.Peer, c) }
add := func(m derp.PeerPresentMessage) { s.AddPacketForwarder(m.Key.Handle(), c) }
remove := func(m derp.PeerGoneMessage) { s.RemovePacketForwarder(m.Peer.Handle(), c) }
go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove)
return nil
}

View File

@ -1010,4 +1010,4 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
unicode from bytes+
unicode/utf16 from crypto/x509+
unicode/utf8 from bufio+
unique from net/netip
unique from net/netip+

View File

@ -199,4 +199,4 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar
unicode from bytes+
unicode/utf16 from crypto/x509+
unicode/utf8 from bufio+
unique from net/netip
unique from net/netip+

View File

@ -336,4 +336,4 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
unicode from bytes+
unicode/utf16 from crypto/x509+
unicode/utf8 from bufio+
unique from net/netip
unique from net/netip+

View File

@ -586,4 +586,4 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
unicode from bytes+
unicode/utf16 from crypto/x509+
unicode/utf8 from bufio+
unique from net/netip
unique from net/netip+

View File

@ -13,6 +13,7 @@
"net/netip"
"sync"
"time"
"unique"
"go4.org/mem"
"golang.org/x/time/rate"
@ -236,7 +237,7 @@ func (c *Client) send(dstKey key.NodePublic, pkt []byte) (ret error) {
return c.bw.Flush()
}
func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err error) {
func (c *Client) ForwardPacket(srcHandle, dstHandle unique.Handle[key.NodePublic], pkt []byte) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("derp.ForwardPacket: %w", err)
@ -256,10 +257,10 @@ func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err e
if err := writeFrameHeader(c.bw, frameForwardPacket, uint32(keyLen*2+len(pkt))); err != nil {
return err
}
if _, err := c.bw.Write(srcKey.AppendTo(nil)); err != nil {
if _, err := c.bw.Write(srcHandle.Value().AppendTo(nil)); err != nil {
return err
}
if _, err := c.bw.Write(dstKey.AppendTo(nil)); err != nil {
if _, err := c.bw.Write(dstHandle.Value().AppendTo(nil)); err != nil {
return err
}
if _, err := c.bw.Write(pkt); err != nil {

View File

@ -34,6 +34,7 @@
"sync"
"sync/atomic"
"time"
"unique"
"go4.org/mem"
"golang.org/x/sync/errgroup"
@ -48,15 +49,18 @@
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/util/ctxkey"
"tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/util/slicesx"
"tailscale.com/version"
)
// NodeHandle is an interned and cheap to pass around and compare representation
// of a NodePublic key.
type NodeHandle = unique.Handle[key.NodePublic]
// verboseDropKeys is the set of destination public keys that should
// verbosely log whenever DERP drops a packet.
var verboseDropKeys = map[key.NodePublic]bool{}
var verboseDropKeys = map[NodeHandle]bool{}
// IdealNodeHeader is the HTTP request header sent on DERP HTTP client requests
// to indicate that they're connecting to their ideal (Region.Nodes[0]) node.
@ -78,7 +82,7 @@ func init() {
if err != nil {
log.Printf("ignoring invalid debug key %q: %v", keyStr, err)
} else {
verboseDropKeys[k] = true
verboseDropKeys[k.Handle()] = true
}
}
}
@ -173,22 +177,22 @@ type Server struct {
mu sync.Mutex
closed bool
netConns map[Conn]chan struct{} // chan is closed when conn closes
clients map[key.NodePublic]*clientSet
clients map[NodeHandle]*clientSet
watchers set.Set[*sclient] // mesh peers
// clientsMesh tracks all clients in the cluster, both locally
// and to mesh peers. If the value is nil, that means the
// peer is only local (and thus in the clients Map, but not
// remote). If the value is non-nil, it's remote (+ maybe also
// local).
clientsMesh map[key.NodePublic]PacketForwarder
clientsMesh map[NodeHandle]PacketForwarder
// peerGoneWatchers is the set of watchers that subscribed to a
// peer disconnecting from the region overall. When a peer
// is gone from the region, we notify all of these watchers,
// calling their funcs in a new goroutine.
peerGoneWatchers map[key.NodePublic]set.HandleSet[func(key.NodePublic)]
peerGoneWatchers map[NodeHandle]map[NodeHandle]func(NodeHandle)
// maps from netip.AddrPort to a client's public key
keyOfAddr map[netip.AddrPort]key.NodePublic
keyOfAddr map[netip.AddrPort]NodeHandle
clock tstime.Clock
}
@ -325,7 +329,7 @@ func (s *dupClientSet) removeClient(c *sclient) bool {
// is a multiForwarder, which this package creates as needed if a
// public key gets more than one PacketForwarder registered for it.
type PacketForwarder interface {
ForwardPacket(src, dst key.NodePublic, payload []byte) error
ForwardPacket(src, dst NodeHandle, payload []byte) error
String() string
}
@ -355,18 +359,18 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server {
packetsRecvByKind: metrics.LabelMap{Label: "kind"},
packetsDroppedReason: metrics.LabelMap{Label: "reason"},
packetsDroppedType: metrics.LabelMap{Label: "type"},
clients: map[key.NodePublic]*clientSet{},
clientsMesh: map[key.NodePublic]PacketForwarder{},
clients: map[NodeHandle]*clientSet{},
clientsMesh: map[NodeHandle]PacketForwarder{},
netConns: map[Conn]chan struct{}{},
memSys0: ms.Sys,
watchers: set.Set[*sclient]{},
peerGoneWatchers: map[key.NodePublic]set.HandleSet[func(key.NodePublic)]{},
peerGoneWatchers: map[NodeHandle]map[NodeHandle]func(NodeHandle){},
avgQueueDuration: new(uint64),
tcpRtt: metrics.LabelMap{Label: "le"},
meshUpdateBatchSize: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000}),
meshUpdateLoopCount: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100}),
bufferedWriteFrames: metrics.NewHistogram([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 50, 100}),
keyOfAddr: map[netip.AddrPort]key.NodePublic{},
keyOfAddr: map[netip.AddrPort]NodeHandle{},
clock: tstime.StdClock{},
}
s.initMetacert()
@ -479,7 +483,7 @@ func (s *Server) isClosed() bool {
func (s *Server) IsClientConnectedForTest(k key.NodePublic) bool {
s.mu.Lock()
defer s.mu.Unlock()
x, ok := s.clients[k]
x, ok := s.clients[k.Handle()]
if !ok {
return false
}
@ -573,11 +577,11 @@ func (s *Server) registerClient(c *sclient) {
s.mu.Lock()
defer s.mu.Unlock()
cs, ok := s.clients[c.key]
cs, ok := s.clients[c.handle]
if !ok {
c.debugLogf("register single client")
cs = &clientSet{}
s.clients[c.key] = cs
s.clients[c.handle] = cs
}
was := cs.activeClient.Load()
if was == nil {
@ -610,15 +614,15 @@ func (s *Server) registerClient(c *sclient) {
cs.activeClient.Store(c)
if _, ok := s.clientsMesh[c.key]; !ok {
s.clientsMesh[c.key] = nil // just for varz of total users in cluster
if _, ok := s.clientsMesh[c.handle]; !ok {
s.clientsMesh[c.handle] = nil // just for varz of total users in cluster
}
s.keyOfAddr[c.remoteIPPort] = c.key
s.keyOfAddr[c.remoteIPPort] = c.handle
s.curClients.Add(1)
if c.isNotIdealConn {
s.curClientsNotIdeal.Add(1)
}
s.broadcastPeerStateChangeLocked(c.key, c.remoteIPPort, c.presentFlags(), true)
s.broadcastPeerStateChangeLocked(c.handle, c.remoteIPPort, c.presentFlags(), true)
}
// broadcastPeerStateChangeLocked enqueues a message to all watchers
@ -626,10 +630,10 @@ func (s *Server) registerClient(c *sclient) {
// presence changed.
//
// s.mu must be held.
func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, ipPort netip.AddrPort, flags PeerPresentFlags, present bool) {
func (s *Server) broadcastPeerStateChangeLocked(peer NodeHandle, ipPort netip.AddrPort, flags PeerPresentFlags, present bool) {
for w := range s.watchers {
w.peerStateChange = append(w.peerStateChange, peerConnState{
peer: peer,
peer: peer.Value(),
present: present,
ipPort: ipPort,
flags: flags,
@ -643,7 +647,7 @@ func (s *Server) unregisterClient(c *sclient) {
s.mu.Lock()
defer s.mu.Unlock()
set, ok := s.clients[c.key]
set, ok := s.clients[c.handle]
if !ok {
c.logf("[unexpected]; clients map is empty")
return
@ -663,12 +667,12 @@ func (s *Server) unregisterClient(c *sclient) {
}
c.debugLogf("removed connection")
set.activeClient.Store(nil)
delete(s.clients, c.key)
if v, ok := s.clientsMesh[c.key]; ok && v == nil {
delete(s.clientsMesh, c.key)
s.notePeerGoneFromRegionLocked(c.key)
delete(s.clients, c.handle)
if v, ok := s.clientsMesh[c.handle]; ok && v == nil {
delete(s.clientsMesh, c.handle)
s.notePeerGoneFromRegionLocked(c.handle)
}
s.broadcastPeerStateChangeLocked(c.key, netip.AddrPort{}, 0, false)
s.broadcastPeerStateChangeLocked(c.handle, netip.AddrPort{}, 0, false)
} else {
c.debugLogf("removed duplicate client")
if dup.removeClient(c) {
@ -720,30 +724,30 @@ func (s *Server) unregisterClient(c *sclient) {
// The provided f func is usually [sclient.onPeerGoneFromRegion], added by
// [sclient.noteSendFromSrc]; this func doesn't take a whole *sclient to make it
// clear what has access to what.
func (s *Server) addPeerGoneFromRegionWatcher(peer key.NodePublic, f func(key.NodePublic)) set.Handle {
func (s *Server) addPeerGoneFromRegionWatcher(peerA, peerB NodeHandle, f func(NodeHandle)) {
s.mu.Lock()
defer s.mu.Unlock()
hset, ok := s.peerGoneWatchers[peer]
m, ok := s.peerGoneWatchers[peerA]
if !ok {
hset = set.HandleSet[func(key.NodePublic)]{}
s.peerGoneWatchers[peer] = hset
m = map[NodeHandle]func(NodeHandle){}
s.peerGoneWatchers[peerA] = m
}
return hset.Add(f)
m[peerB] = f
}
// removePeerGoneFromRegionWatcher removes a peer watcher previously added by
// addPeerGoneFromRegionWatcher, using the handle returned by
// addPeerGoneFromRegionWatcher.
func (s *Server) removePeerGoneFromRegionWatcher(peer key.NodePublic, h set.Handle) {
func (s *Server) removePeerGoneFromRegionWatcher(peerA, peerB NodeHandle) {
s.mu.Lock()
defer s.mu.Unlock()
hset, ok := s.peerGoneWatchers[peer]
hset, ok := s.peerGoneWatchers[peerA]
if !ok {
return
}
delete(hset, h)
delete(hset, peerB)
if len(hset) == 0 {
delete(s.peerGoneWatchers, peer)
delete(s.peerGoneWatchers, peerA)
}
}
@ -751,8 +755,8 @@ func (s *Server) removePeerGoneFromRegionWatcher(peer key.NodePublic, h set.Hand
// key has sent to previously (whether those sends were from a local
// client or forwarded). It must only be called after the key has
// been removed from clientsMesh.
func (s *Server) notePeerGoneFromRegionLocked(key key.NodePublic) {
if _, ok := s.clientsMesh[key]; ok {
func (s *Server) notePeerGoneFromRegionLocked(handle NodeHandle) {
if _, ok := s.clientsMesh[handle]; ok {
panic("usage")
}
@ -760,17 +764,17 @@ func (s *Server) notePeerGoneFromRegionLocked(key key.NodePublic) {
// so they can drop their route entries to us (issue 150)
// or move them over to the active client (in case a replaced client
// connection is being unregistered).
set := s.peerGoneWatchers[key]
set := s.peerGoneWatchers[handle]
for _, f := range set {
go f(key)
go f(handle)
}
delete(s.peerGoneWatchers, key)
delete(s.peerGoneWatchers, handle)
}
// requestPeerGoneWriteLimited sends a request to write a "peer gone"
// frame, but only in reply to a disco packet, and only if we haven't
// sent one recently.
func (c *sclient) requestPeerGoneWriteLimited(peer key.NodePublic, contents []byte, reason PeerGoneReasonType) {
func (c *sclient) requestPeerGoneWriteLimited(peer NodeHandle, contents []byte, reason PeerGoneReasonType) {
if disco.LooksLikeDiscoWrapper(contents) != true {
return
}
@ -800,7 +804,7 @@ func (s *Server) addWatcher(c *sclient) {
continue
}
c.peerStateChange = append(c.peerStateChange, peerConnState{
peer: peer,
peer: peer.Value(),
present: true,
ipPort: ac.remoteIPPort,
flags: ac.presentFlags(),
@ -826,6 +830,8 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem
if err != nil {
return fmt.Errorf("receive client key: %v", err)
}
handle := clientKey.Handle()
clientKey = handle.Value() // interned
remoteIPPort, _ := netip.ParseAddrPort(remoteAddr)
if err := s.verifyClient(ctx, clientKey, clientInfo, remoteIPPort.Addr()); err != nil {
@ -842,6 +848,7 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem
connNum: connNum,
s: s,
key: clientKey,
handle: handle,
nc: nc,
br: br,
bw: bw,
@ -1014,22 +1021,23 @@ func (c *sclient) handleFrameClosePeer(ft frameType, fl uint32) error {
if err := targetKey.ReadRawWithoutAllocating(c.br); err != nil {
return err
}
handle := targetKey.Handle()
s := c.s
s.mu.Lock()
defer s.mu.Unlock()
if set, ok := s.clients[targetKey]; ok {
if set, ok := s.clients[handle]; ok {
if set.Len() == 1 {
c.logf("frameClosePeer closing peer %x", targetKey)
c.logf("frameClosePeer closing peer %x", handle.Value())
} else {
c.logf("frameClosePeer closing peer %x (%d connections)", targetKey, set.Len())
c.logf("frameClosePeer closing peer %x (%d connections)", handle.Value(), set.Len())
}
set.ForeachClient(func(target *sclient) {
go target.nc.Close()
})
} else {
c.logf("frameClosePeer failed to find peer %x", targetKey)
c.logf("frameClosePeer failed to find peer %x", handle.Value())
}
return nil
@ -1043,7 +1051,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error {
}
s := c.s
srcKey, dstKey, contents, err := s.recvForwardPacket(c.br, fl)
srcHandle, dstHandle, contents, err := s.recvForwardPacket(c.br, fl)
if err != nil {
return fmt.Errorf("client %v: recvForwardPacket: %v", c.key, err)
}
@ -1053,7 +1061,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error {
var dst *sclient
s.mu.Lock()
if set, ok := s.clients[dstKey]; ok {
if set, ok := s.clients[dstHandle]; ok {
dstLen = set.Len()
dst = set.activeClient.Load()
}
@ -1064,18 +1072,18 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error {
if dstLen > 1 {
reason = dropReasonDupClient
} else {
c.requestPeerGoneWriteLimited(dstKey, contents, PeerGoneReasonNotHere)
c.requestPeerGoneWriteLimited(dstHandle, contents, PeerGoneReasonNotHere)
}
s.recordDrop(contents, srcKey, dstKey, reason)
s.recordDrop(contents, srcHandle, dstHandle, reason)
return nil
}
dst.debugLogf("received forwarded packet from %s via %s", srcKey.ShortString(), c.key.ShortString())
dst.debugLogf("received forwarded packet from %s via %s", srcHandle.Value().ShortString(), c.key.ShortString())
return c.sendPkt(dst, pkt{
bs: contents,
enqueuedAt: c.s.clock.Now(),
src: srcKey,
src: srcHandle,
})
}
@ -1083,7 +1091,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error {
func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
s := c.s
dstKey, contents, err := s.recvPacket(c.br, fl)
dstHandle, contents, err := s.recvPacket(c.br, fl)
if err != nil {
return fmt.Errorf("client %v: recvPacket: %v", c.key, err)
}
@ -1093,20 +1101,20 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
var dst *sclient
s.mu.Lock()
if set, ok := s.clients[dstKey]; ok {
if set, ok := s.clients[dstHandle]; ok {
dstLen = set.Len()
dst = set.activeClient.Load()
}
if dst == nil && dstLen < 1 {
fwd = s.clientsMesh[dstKey]
fwd = s.clientsMesh[dstHandle]
}
s.mu.Unlock()
if dst == nil {
if fwd != nil {
s.packetsForwardedOut.Add(1)
err := fwd.ForwardPacket(c.key, dstKey, contents)
c.debugLogf("SendPacket for %s, forwarding via %s: %v", dstKey.ShortString(), fwd, err)
err := fwd.ForwardPacket(c.handle, dstHandle, contents)
c.debugLogf("SendPacket for %s, forwarding via %s: %v", dstHandle.Value().ShortString(), fwd, err)
if err != nil {
// TODO:
return nil
@ -1117,18 +1125,18 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error {
if dstLen > 1 {
reason = dropReasonDupClient
} else {
c.requestPeerGoneWriteLimited(dstKey, contents, PeerGoneReasonNotHere)
c.requestPeerGoneWriteLimited(dstHandle, contents, PeerGoneReasonNotHere)
}
s.recordDrop(contents, c.key, dstKey, reason)
c.debugLogf("SendPacket for %s, dropping with reason=%s", dstKey.ShortString(), reason)
s.recordDrop(contents, c.handle, dstHandle, reason)
c.debugLogf("SendPacket for %s, dropping with reason=%s", dstHandle.Value().ShortString(), reason)
return nil
}
c.debugLogf("SendPacket for %s, sending directly", dstKey.ShortString())
c.debugLogf("SendPacket for %s, sending directly", dstHandle.Value().ShortString())
p := pkt{
bs: contents,
enqueuedAt: c.s.clock.Now(),
src: c.key,
src: c.handle,
}
return c.sendPkt(dst, p)
}
@ -1155,7 +1163,7 @@ func (c *sclient) debugLogf(format string, v ...any) {
numDropReasons // unused; keep last
)
func (s *Server) recordDrop(packetBytes []byte, srcKey, dstKey key.NodePublic, reason dropReason) {
func (s *Server) recordDrop(packetBytes []byte, srcHandle, dstHandle NodeHandle, reason dropReason) {
s.packetsDropped.Add(1)
s.packetsDroppedReasonCounters[reason].Add(1)
looksDisco := disco.LooksLikeDiscoWrapper(packetBytes)
@ -1164,20 +1172,19 @@ func (s *Server) recordDrop(packetBytes []byte, srcKey, dstKey key.NodePublic, r
} else {
s.packetsDroppedTypeOther.Add(1)
}
if verboseDropKeys[dstKey] {
if verboseDropKeys[dstHandle] {
// Preformat the log string prior to calling limitedLogf. The
// limiter acts based on the format string, and we want to
// rate-limit per src/dst keys, not on the generic "dropped
// stuff" message.
msg := fmt.Sprintf("drop (%s) %s -> %s", srcKey.ShortString(), reason, dstKey.ShortString())
msg := fmt.Sprintf("drop (%s) %s -> %s", srcHandle.Value().ShortString(), reason, dstHandle.Value().ShortString())
s.limitedLogf(msg)
}
s.debugLogf("dropping packet reason=%s dst=%s disco=%v", reason, dstKey, looksDisco)
s.debugLogf("dropping packet reason=%s dst=%s disco=%v", reason, dstHandle.Value(), looksDisco)
}
func (c *sclient) sendPkt(dst *sclient, p pkt) error {
s := c.s
dstKey := dst.key
// Attempt to queue for sending up to 3 times. On each attempt, if
// the queue is full, try to drop from queue head to prioritize
@ -1189,7 +1196,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error {
for attempt := 0; attempt < 3; attempt++ {
select {
case <-dst.done:
s.recordDrop(p.bs, c.key, dstKey, dropReasonGoneDisconnected)
s.recordDrop(p.bs, c.handle, dst.handle, dropReasonGoneDisconnected)
dst.debugLogf("sendPkt attempt %d dropped, dst gone", attempt)
return nil
default:
@ -1203,7 +1210,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error {
select {
case pkt := <-sendQueue:
s.recordDrop(pkt.bs, c.key, dstKey, dropReasonQueueHead)
s.recordDrop(pkt.bs, c.handle, dst.handle, dropReasonQueueHead)
c.recordQueueTime(pkt.enqueuedAt)
default:
}
@ -1211,7 +1218,7 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error {
// Failed to make room for packet. This can happen in a heavily
// contended queue with racing writers. Give up and tail-drop in
// this case to keep reader unblocked.
s.recordDrop(p.bs, c.key, dstKey, dropReasonQueueTail)
s.recordDrop(p.bs, c.handle, dst.handle, dropReasonQueueTail)
dst.debugLogf("sendPkt attempt %d dropped, queue full")
return nil
@ -1220,17 +1227,17 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error {
// onPeerGoneFromRegion is the callback registered with the Server to be
// notified (in a new goroutine) whenever a peer has disconnected from all DERP
// nodes in the current region.
func (c *sclient) onPeerGoneFromRegion(peer key.NodePublic) {
func (c *sclient) onPeerGoneFromRegion(peer NodeHandle) {
c.requestPeerGoneWrite(peer, PeerGoneReasonDisconnected)
}
// requestPeerGoneWrite sends a request to write a "peer gone" frame
// with an explanation of why it is gone. It blocks until either the
// write request is scheduled, or the client has closed.
func (c *sclient) requestPeerGoneWrite(peer key.NodePublic, reason PeerGoneReasonType) {
func (c *sclient) requestPeerGoneWrite(peer NodeHandle, reason PeerGoneReasonType) {
select {
case c.peerGone <- peerGoneMsg{
peer: peer,
peer: peer.Value(),
reason: reason,
}:
case <-c.done:
@ -1346,7 +1353,7 @@ func (s *Server) noteClientActivity(c *sclient) {
s.mu.Lock()
defer s.mu.Unlock()
cs, ok := s.clients[c.key]
cs, ok := s.clients[c.handle]
if !ok {
return
}
@ -1453,20 +1460,22 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info
return clientKey, info, nil
}
func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodePublic, contents []byte, err error) {
func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstHandle NodeHandle, contents []byte, err error) {
if frameLen < keyLen {
return zpub, nil, errors.New("short send packet frame")
return zpubHandle, nil, errors.New("short send packet frame")
}
var dstKey key.NodePublic
if err := dstKey.ReadRawWithoutAllocating(br); err != nil {
return zpub, nil, err
return zpubHandle, nil, err
}
dstHandle = dstKey.Handle()
packetLen := frameLen - keyLen
if packetLen > MaxPacketSize {
return zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize)
return zpubHandle, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize)
}
contents = make([]byte, packetLen)
if _, err := io.ReadFull(br, contents); err != nil {
return zpub, nil, err
return zpubHandle, nil, err
}
s.packetsRecv.Add(1)
s.bytesRecv.Add(int64(len(contents)))
@ -1475,33 +1484,35 @@ func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodeP
} else {
s.packetsRecvOther.Add(1)
}
return dstKey, contents, nil
return dstHandle, contents, nil
}
// zpub is the key.NodePublic zero value.
var zpub key.NodePublic
var zpubHandle NodeHandle = zpub.Handle()
func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, dstKey key.NodePublic, contents []byte, err error) {
func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcHandle, dstHandle NodeHandle, contents []byte, err error) {
if frameLen < keyLen*2 {
return zpub, zpub, nil, errors.New("short send packet frame")
return zpubHandle, zpubHandle, nil, errors.New("short send packet frame")
}
var srcKey, dstKey key.NodePublic
if err := srcKey.ReadRawWithoutAllocating(br); err != nil {
return zpub, zpub, nil, err
return zpubHandle, zpubHandle, nil, err
}
if err := dstKey.ReadRawWithoutAllocating(br); err != nil {
return zpub, zpub, nil, err
return zpubHandle, zpubHandle, nil, err
}
packetLen := frameLen - keyLen*2
if packetLen > MaxPacketSize {
return zpub, zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize)
return zpubHandle, zpubHandle, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize)
}
contents = make([]byte, packetLen)
if _, err := io.ReadFull(br, contents); err != nil {
return zpub, zpub, nil, err
return zpubHandle, zpubHandle, nil, err
}
// TODO: was s.packetsRecv.Add(1)
// TODO: was s.bytesRecv.Add(int64(len(contents)))
return srcKey, dstKey, contents, nil
return srcKey.Handle(), dstKey.Handle(), contents, nil
}
// sclient is a client connection to the server.
@ -1518,6 +1529,7 @@ type sclient struct {
s *Server
nc Conn
key key.NodePublic
handle NodeHandle // handle is a cached handle for key
info clientInfo
logf logger.Logf
done <-chan struct{} // closed when connection closes
@ -1539,7 +1551,7 @@ type sclient struct {
preferred bool
// Owned by sendLoop, not thread-safe.
sawSrc map[key.NodePublic]set.Handle
sawSrc set.Set[NodeHandle]
bw *lazyBufioWriter
// Guarded by s.mu
@ -1593,7 +1605,7 @@ type pkt struct {
bs []byte
// src is the who's the sender of the packet.
src key.NodePublic
src NodeHandle
}
// peerGoneMsg is a request to write a peerGone frame to an sclient
@ -1656,17 +1668,17 @@ func (c *sclient) onSendLoopDone() {
c.nc.Close()
// Clean up watches.
for peer, h := range c.sawSrc {
c.s.removePeerGoneFromRegionWatcher(peer, h)
for peer := range c.sawSrc {
c.s.removePeerGoneFromRegionWatcher(peer, c.handle)
}
// Drain the send queue to count dropped packets
for {
select {
case pkt := <-c.sendQueue:
c.s.recordDrop(pkt.bs, pkt.src, c.key, dropReasonGoneDisconnected)
c.s.recordDrop(pkt.bs, pkt.src, c.handle, dropReasonGoneDisconnected)
case pkt := <-c.discoSendQueue:
c.s.recordDrop(pkt.bs, pkt.src, c.key, dropReasonGoneDisconnected)
c.s.recordDrop(pkt.bs, pkt.src, c.handle, dropReasonGoneDisconnected)
default:
return
}
@ -1869,31 +1881,31 @@ func (c *sclient) sendMeshUpdates() error {
// DERPv2. The bytes of contents are only valid until this function
// returns, do not retain slices.
// It does not flush its bufio.Writer.
func (c *sclient) sendPacket(srcKey key.NodePublic, contents []byte) (err error) {
func (c *sclient) sendPacket(srcHandle NodeHandle, contents []byte) (err error) {
defer func() {
// Stats update.
if err != nil {
c.s.recordDrop(contents, srcKey, c.key, dropReasonWriteError)
c.s.recordDrop(contents, srcHandle, c.handle, dropReasonWriteError)
} else {
c.s.packetsSent.Add(1)
c.s.bytesSent.Add(int64(len(contents)))
}
c.debugLogf("sendPacket from %s: %v", srcKey.ShortString(), err)
c.debugLogf("sendPacket from %s: %v", srcHandle.Value().ShortString(), err)
}()
c.setWriteDeadline()
withKey := !srcKey.IsZero()
withKey := !srcHandle.Value().IsZero()
pktLen := len(contents)
if withKey {
pktLen += key.NodePublicRawLen
c.noteSendFromSrc(srcKey)
c.noteSendFromSrc(srcHandle)
}
if err = writeFrameHeader(c.bw.bw(), frameRecvPacket, uint32(pktLen)); err != nil {
return err
}
if withKey {
if err := srcKey.WriteRawWithoutAllocating(c.bw.bw()); err != nil {
if err := srcHandle.Value().WriteRawWithoutAllocating(c.bw.bw()); err != nil {
return err
}
}
@ -1905,17 +1917,18 @@ func (c *sclient) sendPacket(srcKey key.NodePublic, contents []byte) (err error)
// from src to sclient.
//
// It must only be called from the sendLoop goroutine.
func (c *sclient) noteSendFromSrc(src key.NodePublic) {
func (c *sclient) noteSendFromSrc(src NodeHandle) {
if _, ok := c.sawSrc[src]; ok {
return
}
h := c.s.addPeerGoneFromRegionWatcher(src, c.onPeerGoneFromRegion)
mak.Set(&c.sawSrc, src, h)
c.s.addPeerGoneFromRegionWatcher(src, c.handle, c.onPeerGoneFromRegion)
c.sawSrc.Make() // ensure sawSrc is non-nil
c.sawSrc.Add(src)
}
// AddPacketForwarder registers fwd as a packet forwarder for dst.
// fwd must be comparable.
func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) {
func (s *Server) AddPacketForwarder(dst NodeHandle, fwd PacketForwarder) {
s.mu.Lock()
defer s.mu.Unlock()
if prev, ok := s.clientsMesh[dst]; ok {
@ -1945,7 +1958,7 @@ func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) {
// RemovePacketForwarder removes fwd as a packet forwarder for dst.
// fwd must be comparable.
func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder) {
func (s *Server) RemovePacketForwarder(dst NodeHandle, fwd PacketForwarder) {
s.mu.Lock()
defer s.mu.Unlock()
v, ok := s.clientsMesh[dst]
@ -2048,7 +2061,7 @@ func (f *multiForwarder) deleteLocked(fwd PacketForwarder) (_ PacketForwarder, i
return nil, false
}
func (f *multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error {
func (f *multiForwarder) ForwardPacket(src, dst NodeHandle, payload []byte) error {
return f.fwd.Load().ForwardPacket(src, dst, payload)
}
@ -2238,8 +2251,8 @@ func (s *Server) ServeDebugTraffic(w http.ResponseWriter, r *http.Request) {
for k, next := range newState {
prev := prevState[k]
if prev.Sent < next.Sent || prev.Recv < next.Recv {
if pkey, ok := s.keyOfAddr[k]; ok {
next.Key = pkey
if pHandle, ok := s.keyOfAddr[k]; ok {
next.Key = pHandle.Value()
if err := enc.Encode(next); err != nil {
s.mu.Unlock()
return

View File

@ -714,27 +714,27 @@ func TestWatch(t *testing.T) {
type testFwd int
func (testFwd) ForwardPacket(key.NodePublic, key.NodePublic, []byte) error {
func (testFwd) ForwardPacket(NodeHandle, NodeHandle, []byte) error {
panic("not called in tests")
}
func (testFwd) String() string {
panic("not called in tests")
}
func pubAll(b byte) (ret key.NodePublic) {
func pubAll(b byte) (ret NodeHandle) {
var bs [32]byte
for i := range bs {
bs[i] = b
}
return key.NodePublicFromRaw32(mem.B(bs[:]))
return key.NodePublicFromRaw32(mem.B(bs[:])).Handle()
}
func TestForwarderRegistration(t *testing.T) {
s := &Server{
clients: make(map[key.NodePublic]*clientSet),
clientsMesh: map[key.NodePublic]PacketForwarder{},
clients: make(map[NodeHandle]*clientSet),
clientsMesh: map[NodeHandle]PacketForwarder{},
}
want := func(want map[key.NodePublic]PacketForwarder) {
want := func(want map[NodeHandle]PacketForwarder) {
t.Helper()
if got := s.clientsMesh; !reflect.DeepEqual(got, want) {
t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want)
@ -758,28 +758,28 @@ func TestForwarderRegistration(t *testing.T) {
s.AddPacketForwarder(u1, testFwd(1))
s.AddPacketForwarder(u2, testFwd(2))
want(map[key.NodePublic]PacketForwarder{
want(map[NodeHandle]PacketForwarder{
u1: testFwd(1),
u2: testFwd(2),
})
// Verify a remove of non-registered forwarder is no-op.
s.RemovePacketForwarder(u2, testFwd(999))
want(map[key.NodePublic]PacketForwarder{
want(map[NodeHandle]PacketForwarder{
u1: testFwd(1),
u2: testFwd(2),
})
// Verify a remove of non-registered user is no-op.
s.RemovePacketForwarder(u3, testFwd(1))
want(map[key.NodePublic]PacketForwarder{
want(map[NodeHandle]PacketForwarder{
u1: testFwd(1),
u2: testFwd(2),
})
// Actual removal.
s.RemovePacketForwarder(u2, testFwd(2))
want(map[key.NodePublic]PacketForwarder{
want(map[NodeHandle]PacketForwarder{
u1: testFwd(1),
})
@ -787,14 +787,14 @@ func TestForwarderRegistration(t *testing.T) {
wantCounter(&s.multiForwarderCreated, 0)
s.AddPacketForwarder(u1, testFwd(100))
s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path
want(map[key.NodePublic]PacketForwarder{
want(map[NodeHandle]PacketForwarder{
u1: newMultiForwarder(testFwd(1), testFwd(100)),
})
wantCounter(&s.multiForwarderCreated, 1)
// Removing a forwarder in a multi set that doesn't exist; does nothing.
s.RemovePacketForwarder(u1, testFwd(55))
want(map[key.NodePublic]PacketForwarder{
want(map[NodeHandle]PacketForwarder{
u1: newMultiForwarder(testFwd(1), testFwd(100)),
})
@ -802,7 +802,7 @@ func TestForwarderRegistration(t *testing.T) {
// from being a multiForwarder.
wantCounter(&s.multiForwarderDeleted, 0)
s.RemovePacketForwarder(u1, testFwd(1))
want(map[key.NodePublic]PacketForwarder{
want(map[NodeHandle]PacketForwarder{
u1: testFwd(100),
})
wantCounter(&s.multiForwarderDeleted, 1)
@ -810,23 +810,24 @@ func TestForwarderRegistration(t *testing.T) {
// Removing an entry for a client that's still connected locally should result
// in a nil forwarder.
u1c := &sclient{
key: u1,
key: u1.Value(),
handle: u1,
logf: logger.Discard,
}
s.clients[u1] = singleClient(u1c)
s.RemovePacketForwarder(u1, testFwd(100))
want(map[key.NodePublic]PacketForwarder{
want(map[NodeHandle]PacketForwarder{
u1: nil,
})
// But once that client disconnects, it should go away.
s.unregisterClient(u1c)
want(map[key.NodePublic]PacketForwarder{})
want(map[NodeHandle]PacketForwarder{})
// But if it already has a forwarder, it's not removed.
s.AddPacketForwarder(u1, testFwd(2))
s.unregisterClient(u1c)
want(map[key.NodePublic]PacketForwarder{
want(map[NodeHandle]PacketForwarder{
u1: testFwd(2),
})
@ -835,11 +836,11 @@ func TestForwarderRegistration(t *testing.T) {
// from nil to the new one, not a multiForwarder.
s.clients[u1] = singleClient(u1c)
s.clientsMesh[u1] = nil
want(map[key.NodePublic]PacketForwarder{
want(map[NodeHandle]PacketForwarder{
u1: nil,
})
s.AddPacketForwarder(u1, testFwd(3))
want(map[key.NodePublic]PacketForwarder{
want(map[NodeHandle]PacketForwarder{
u1: testFwd(3),
})
}
@ -853,7 +854,7 @@ type channelFwd struct {
}
func (f channelFwd) String() string { return "" }
func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error {
func (f channelFwd) ForwardPacket(_ NodeHandle, _ NodeHandle, packet []byte) error {
f.c <- packet
return nil
}
@ -865,8 +866,8 @@ func TestMultiForwarder(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
s := &Server{
clients: make(map[key.NodePublic]*clientSet),
clientsMesh: map[key.NodePublic]PacketForwarder{},
clients: make(map[NodeHandle]*clientSet),
clientsMesh: map[NodeHandle]PacketForwarder{},
}
u := pubAll(1)
s.AddPacketForwarder(u, channelFwd{1, ch})
@ -1067,9 +1068,9 @@ func TestServerDupClients(t *testing.T) {
run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) {
s = NewServer(serverPriv, t.Logf)
s.dupPolicy = dupPolicy
c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")}
c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")}
c3 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c3: ")}
c1 = &sclient{key: clientPub, handle: clientPub.Handle(), logf: logger.WithPrefix(t.Logf, "c1: ")}
c2 = &sclient{key: clientPub, handle: clientPub.Handle(), logf: logger.WithPrefix(t.Logf, "c2: ")}
c3 = &sclient{key: clientPub, handle: clientPub.Handle(), logf: logger.WithPrefix(t.Logf, "c3: ")}
clientName = map[*sclient]string{
c1: "c1",
c2: "c2",
@ -1083,7 +1084,7 @@ func TestServerDupClients(t *testing.T) {
}
wantSingleClient := func(t *testing.T, want *sclient) {
t.Helper()
got, ok := s.clients[want.key]
got, ok := s.clients[want.key.Handle()]
if !ok {
t.Error("no clients for key")
return
@ -1106,7 +1107,7 @@ func TestServerDupClients(t *testing.T) {
}
wantNoClient := func(t *testing.T) {
t.Helper()
_, ok := s.clients[clientPub]
_, ok := s.clients[clientPub.Handle()]
if !ok {
// Good
return
@ -1115,7 +1116,7 @@ func TestServerDupClients(t *testing.T) {
}
wantDupSet := func(t *testing.T) *dupClientSet {
t.Helper()
cs, ok := s.clients[clientPub]
cs, ok := s.clients[clientPub.Handle()]
if !ok {
t.Fatal("no set for key; want dup set")
return nil
@ -1128,7 +1129,7 @@ func TestServerDupClients(t *testing.T) {
}
wantActive := func(t *testing.T, want *sclient) {
t.Helper()
set, ok := s.clients[clientPub]
set, ok := s.clients[clientPub.Handle()]
if !ok {
t.Error("no set for key")
return

View File

@ -27,6 +27,7 @@
"strings"
"sync"
"time"
"unique"
"go4.org/mem"
"tailscale.com/derp"
@ -971,7 +972,7 @@ func (c *Client) LocalAddr() (netip.AddrPort, error) {
return la, nil
}
func (c *Client) ForwardPacket(from, to key.NodePublic, b []byte) error {
func (c *Client) ForwardPacket(from, to unique.Handle[key.NodePublic], b []byte) error {
client, _, err := c.connect(c.newContext(), "derphttp.Client.ForwardPacket")
if err != nil {
return err

View File

@ -327,7 +327,7 @@ func TestBreakWatcherConnRecv(t *testing.T) {
}
watcher1.breakConnection(watcher1.client)
// re-establish connection by sending a packet
watcher1.ForwardPacket(key.NodePublic{}, key.NodePublic{}, []byte("bogus"))
watcher1.ForwardPacket(key.NodePublic{}.Handle(), key.NodePublic{}.Handle(), []byte("bogus"))
timer.Reset(5 * time.Second)
}
@ -400,7 +400,7 @@ func TestBreakWatcherConn(t *testing.T) {
}
watcher1.breakConnection(watcher1.client)
// re-establish connection by sending a packet
watcher1.ForwardPacket(key.NodePublic{}, key.NodePublic{}, []byte("bogus"))
watcher1.ForwardPacket(key.NodePublic{}.Handle(), key.NodePublic{}.Handle(), []byte("bogus"))
// signal that the breaker is done
breakerChan <- true

View File

@ -10,6 +10,7 @@
"encoding/hex"
"errors"
"fmt"
"unique"
"go4.org/mem"
"golang.org/x/crypto/curve25519"
@ -174,6 +175,16 @@ func (p NodePublic) Compare(p2 NodePublic) int {
return bytes.Compare(p.k[:], p2.k[:])
}
// Handle returns a unique.Handle for this NodePublic. The Handle is more
// efficient for storage and comparison than the NodePublic itself, but is also
// more expensive to create. It is best to keep a copy of the Handle on a longer
// term object representing a NodePublic, rather than creating it on the fly,
// but in doing so if the Handle is used in multiple other data structures the
// cost of Handle storage and comparisons on lookups will quickly amortize.
func (p NodePublic) Handle() unique.Handle[NodePublic] {
return unique.Make(p)
}
// ParseNodePublicUntyped parses an untyped 64-character hex value
// as a NodePublic.
//