diff --git a/derp/derp_server.go b/derp/derp_server.go index f38ae6621..2e17cbfe5 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -46,6 +46,7 @@ "tailscale.com/tstime/rate" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/mak" "tailscale.com/util/set" "tailscale.com/util/slicesx" "tailscale.com/version" @@ -164,11 +165,11 @@ type Server struct { // remote). If the value is non-nil, it's remote (+ maybe also // local). clientsMesh map[key.NodePublic]PacketForwarder - // sentTo tracks which peers have sent to which other peers, - // and at which connection number. This isn't on sclient - // because it includes intra-region forwarded packets as the - // src. - sentTo map[key.NodePublic]map[key.NodePublic]int64 // src => dst => dst's latest sclient.connNum + // peerGoneWatchers is the set of watchers that subscribed to a + // peer disconnecting from the region overall. When a peer + // is gone from the region, we notify all of these watchers, + // calling their funcs in a new goroutine. + peerGoneWatchers map[key.NodePublic]set.HandleSet[func(key.NodePublic)] // maps from netip.AddrPort to a client's public key keyOfAddr map[netip.AddrPort]key.NodePublic @@ -343,7 +344,7 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server { netConns: map[Conn]chan struct{}{}, memSys0: ms.Sys, watchers: set.Set[*sclient]{}, - sentTo: map[key.NodePublic]map[key.NodePublic]int64{}, + peerGoneWatchers: map[key.NodePublic]set.HandleSet[func(key.NodePublic)]{}, avgQueueDuration: new(uint64), tcpRtt: metrics.LabelMap{Label: "le"}, meshUpdateBatchSize: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000}), @@ -689,6 +690,40 @@ func (s *Server) unregisterClient(c *sclient) { } } +// addPeerGoneFromRegionWatcher adds a function to be called when peer is gone +// from the region overall. It returns a handle that can be used to remove the +// watcher later. +// +// The provided f func is usually [sclient.onPeerGoneFromRegion], added by +// [sclient.noteSendFromSrc]; this func doesn't take a whole *sclient to make it +// clear what has access to what. +func (s *Server) addPeerGoneFromRegionWatcher(peer key.NodePublic, f func(key.NodePublic)) set.Handle { + s.mu.Lock() + defer s.mu.Unlock() + hset, ok := s.peerGoneWatchers[peer] + if !ok { + hset = set.HandleSet[func(key.NodePublic)]{} + s.peerGoneWatchers[peer] = hset + } + return hset.Add(f) +} + +// removePeerGoneFromRegionWatcher removes a peer watcher previously added by +// addPeerGoneFromRegionWatcher, using the handle returned by +// addPeerGoneFromRegionWatcher. +func (s *Server) removePeerGoneFromRegionWatcher(peer key.NodePublic, h set.Handle) { + s.mu.Lock() + defer s.mu.Unlock() + hset, ok := s.peerGoneWatchers[peer] + if !ok { + return + } + delete(hset, h) + if len(hset) == 0 { + delete(s.peerGoneWatchers, peer) + } +} + // notePeerGoneFromRegionLocked sends peerGone frames to parties that // key has sent to previously (whether those sends were from a local // client or forwarded). It must only be called after the key has @@ -702,18 +737,11 @@ func (s *Server) notePeerGoneFromRegionLocked(key key.NodePublic) { // so they can drop their route entries to us (issue 150) // or move them over to the active client (in case a replaced client // connection is being unregistered). - for pubKey, connNum := range s.sentTo[key] { - set, ok := s.clients[pubKey] - if !ok { - continue - } - set.ForeachClient(func(peer *sclient) { - if peer.connNum == connNum { - go peer.requestPeerGoneWrite(key, PeerGoneReasonDisconnected) - } - }) + set := s.peerGoneWatchers[key] + for _, f := range set { + go f(key) } - delete(s.sentTo, key) + delete(s.peerGoneWatchers, key) } // requestPeerGoneWriteLimited sends a request to write a "peer gone" @@ -1004,9 +1032,6 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { dstLen = set.Len() dst = set.activeClient.Load() } - if dst != nil { - s.notePeerSendLocked(srcKey, dst) - } s.mu.Unlock() if dst == nil { @@ -1029,18 +1054,6 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { }) } -// notePeerSendLocked records that src sent to dst. We keep track of -// that so when src disconnects, we can tell dst (if it's still -// around) that src is gone (a peerGone frame). -func (s *Server) notePeerSendLocked(src key.NodePublic, dst *sclient) { - m, ok := s.sentTo[src] - if !ok { - m = map[key.NodePublic]int64{} - s.sentTo[src] = m - } - m[dst.key] = dst.connNum -} - // handleFrameSendPacket reads a "send packet" frame from the client. func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { s := c.s @@ -1059,9 +1072,7 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { dstLen = set.Len() dst = set.activeClient.Load() } - if dst != nil { - s.notePeerSendLocked(c.key, dst) - } else if dstLen < 1 { + if dst == nil && dstLen < 1 { fwd = s.clientsMesh[dstKey] } s.mu.Unlock() @@ -1181,6 +1192,13 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error { return nil } +// onPeerGoneFromRegion is the callback registered with the Server to be +// notified (in a new goroutine) whenever a peer has disconnected from all DERP +// nodes in the current region. +func (c *sclient) onPeerGoneFromRegion(peer key.NodePublic) { + c.requestPeerGoneWrite(peer, PeerGoneReasonDisconnected) +} + // requestPeerGoneWrite sends a request to write a "peer gone" frame // with an explanation of why it is gone. It blocks until either the // write request is scheduled, or the client has closed. @@ -1494,8 +1512,9 @@ type sclient struct { connectedAt time.Time preferred bool - // Owned by sender, not thread-safe. - bw *lazyBufioWriter + // Owned by sendLoop, not thread-safe. + sawSrc map[key.NodePublic]set.Handle + bw *lazyBufioWriter // Guarded by s.mu // @@ -1598,24 +1617,36 @@ func (c *sclient) recordQueueTime(enqueuedAt time.Time) { } } -func (c *sclient) sendLoop(ctx context.Context) error { - defer func() { - // If the sender shuts down unilaterally due to an error, close so - // that the receive loop unblocks and cleans up the rest. - c.nc.Close() +// onSendLoopDone is called when the send loop is done +// to clean up. +// +// It must only be called from the sendLoop goroutine. +func (c *sclient) onSendLoopDone() { + // If the sender shuts down unilaterally due to an error, close so + // that the receive loop unblocks and cleans up the rest. + c.nc.Close() - // Drain the send queue to count dropped packets - for { - select { - case pkt := <-c.sendQueue: - c.s.recordDrop(pkt.bs, pkt.src, c.key, dropReasonGoneDisconnected) - case pkt := <-c.discoSendQueue: - c.s.recordDrop(pkt.bs, pkt.src, c.key, dropReasonGoneDisconnected) - default: - return - } + // Clean up watches. + for peer, h := range c.sawSrc { + c.s.removePeerGoneFromRegionWatcher(peer, h) + } + + // Drain the send queue to count dropped packets + for { + select { + case pkt := <-c.sendQueue: + c.s.recordDrop(pkt.bs, pkt.src, c.key, dropReasonGoneDisconnected) + case pkt := <-c.discoSendQueue: + c.s.recordDrop(pkt.bs, pkt.src, c.key, dropReasonGoneDisconnected) + default: + return } - }() + } + +} + +func (c *sclient) sendLoop(ctx context.Context) error { + defer c.onSendLoopDone() jitter := rand.N(5 * time.Second) keepAliveTick, keepAliveTickChannel := c.s.clock.NewTicker(keepAlive + jitter) @@ -1811,6 +1842,7 @@ func (c *sclient) sendPacket(srcKey key.NodePublic, contents []byte) (err error) pktLen := len(contents) if withKey { pktLen += key.NodePublicRawLen + c.noteSendFromSrc(srcKey) } if err = writeFrameHeader(c.bw.bw(), frameRecvPacket, uint32(pktLen)); err != nil { return err @@ -1824,6 +1856,18 @@ func (c *sclient) sendPacket(srcKey key.NodePublic, contents []byte) (err error) return err } +// noteSendFromSrc notes that we are about to write a packet +// from src to sclient. +// +// It must only be called from the sendLoop goroutine. +func (c *sclient) noteSendFromSrc(src key.NodePublic) { + if _, ok := c.sawSrc[src]; ok { + return + } + h := c.s.addPeerGoneFromRegionWatcher(src, c.onPeerGoneFromRegion) + mak.Set(&c.sawSrc, src, h) +} + // AddPacketForwarder registers fwd as a packet forwarder for dst. // fwd must be comparable. func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) {