diff --git a/derp/derp_server.go b/derp/derp_server.go index cf7f6fe7e..924192477 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -40,6 +40,7 @@ import ( "tailscale.com/disco" "tailscale.com/envknob" "tailscale.com/metrics" + "tailscale.com/syncs" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/version" @@ -1560,22 +1561,20 @@ func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) { // Duplicate registration of same forwarder. Ignore. return } - if m, ok := prev.(multiForwarder); ok { - if _, ok := m[fwd]; ok { + if m, ok := prev.(*multiForwarder); ok { + if _, ok := m.all[fwd]; ok { // Duplicate registration of same forwarder in set; ignore. return } - m[fwd] = m.maxVal() + 1 + m.add(fwd) return } if prev != nil { // Otherwise, the existing value is not a set, // not a dup, and not local-only (nil) so make - // it a set. - fwd = multiForwarder{ - prev: 1, // existed 1st, higher priority - fwd: 2, // the passed in fwd is in 2nd place - } + // it a set. `prev` existed first, so will have higher + // priority. + fwd = newMultiForwarder(prev, fwd) s.multiForwarderCreated.Add(1) } } @@ -1591,19 +1590,14 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder) if !ok { return } - if m, ok := v.(multiForwarder); ok { - if len(m) < 2 { + if m, ok := v.(*multiForwarder); ok { + if len(m.all) < 2 { panic("unexpected") } - delete(m, fwd) - // If fwd was in m and we no longer need to be a - // multiForwarder, replace the entry with the - // remaining PacketForwarder. - if len(m) == 1 { - var remain PacketForwarder - for k := range m { - remain = k - } + if remain, isLast := m.deleteLocked(fwd); isLast { + // If fwd was in m and we no longer need to be a + // multiForwarder, replace the entry with the + // remaining PacketForwarder. s.clientsMesh[dst] = remain s.multiForwarderDeleted.Add(1) } @@ -1635,27 +1629,65 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder) // client is. The map value is unique connection number; the lowest // one has been seen the longest. It's used to make sure we forward // packets consistently to the same node and don't pick randomly. -type multiForwarder map[PacketForwarder]uint8 +type multiForwarder struct { + fwd syncs.AtomicValue[PacketForwarder] // preferred forwarder. + all map[PacketForwarder]uint8 // all forwarders, protected by s.mu. +} -func (m multiForwarder) maxVal() (max uint8) { - for _, v := range m { +// newMultiForwarder creates a new multiForwarder. +// The first PacketForwarder passed to this function will be the preferred one. +func newMultiForwarder(fwds ...PacketForwarder) *multiForwarder { + f := &multiForwarder{all: make(map[PacketForwarder]uint8)} + f.fwd.Store(fwds[0]) + for idx, fwd := range fwds { + f.all[fwd] = uint8(idx) + } + return f +} + +// add adds a new forwarder to the map with a connection number that +// is higher than the existing ones. +func (f *multiForwarder) add(fwd PacketForwarder) { + var max uint8 + for _, v := range f.all { if v > max { max = v } } - return + f.all[fwd] = max + 1 } -func (m multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error { - var fwd PacketForwarder - var lowest uint8 - for k, v := range m { - if fwd == nil || v < lowest { - fwd = k - lowest = v +// deleteLocked removes a packet forwarder from the map. It expects Server.mu to be held. +// If only one forwarder remains after the removal, it will be returned alongside a `true` boolean value. +func (f *multiForwarder) deleteLocked(fwd PacketForwarder) (_ PacketForwarder, isLast bool) { + delete(f.all, fwd) + + if fwd == f.fwd.Load() { + // The preferred forwarder has been removed, choose a new one + // based on the lowest index. + var lowestfwd PacketForwarder + var lowest uint8 + for k, v := range f.all { + if lowestfwd == nil || v < lowest { + lowestfwd = k + lowest = v + } + } + if lowestfwd != nil { + f.fwd.Store(lowestfwd) } } - return fwd.ForwardPacket(src, dst, payload) + + if len(f.all) == 1 { + for k := range f.all { + return k, true + } + } + return nil, false +} + +func (f *multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error { + return f.fwd.Load().ForwardPacket(src, dst, payload) } func (s *Server) expVarFunc(f func() any) expvar.Func { diff --git a/derp/derp_test.go b/derp/derp_test.go index 6da11197a..2edcb057a 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -19,6 +19,7 @@ import ( "net" "os" "reflect" + "strconv" "sync" "testing" "time" @@ -723,20 +724,14 @@ func TestForwarderRegistration(t *testing.T) { s.AddPacketForwarder(u1, testFwd(100)) s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path want(map[key.NodePublic]PacketForwarder{ - u1: multiForwarder{ - testFwd(1): 1, - testFwd(100): 2, - }, + u1: newMultiForwarder(testFwd(1), testFwd(100)), }) wantCounter(&s.multiForwarderCreated, 1) // Removing a forwarder in a multi set that doesn't exist; does nothing. s.RemovePacketForwarder(u1, testFwd(55)) want(map[key.NodePublic]PacketForwarder{ - u1: multiForwarder{ - testFwd(1): 1, - testFwd(100): 2, - }, + u1: newMultiForwarder(testFwd(1), testFwd(100)), }) // Removing a forwarder in a multi set that does exist should collapse it away @@ -785,6 +780,76 @@ func TestForwarderRegistration(t *testing.T) { }) } +type channelFwd struct { + // id is to ensure that different instances that reference the + // same channel are not equal, as they are used as keys in the + // multiForwarder map. + id int + c chan []byte +} + +func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error { + f.c <- packet + return nil +} + +func TestMultiForwarder(t *testing.T) { + received := 0 + var wg sync.WaitGroup + ch := make(chan []byte) + ctx, cancel := context.WithCancel(context.Background()) + + s := &Server{ + clients: make(map[key.NodePublic]clientSet), + clientsMesh: map[key.NodePublic]PacketForwarder{}, + } + u := pubAll(1) + s.AddPacketForwarder(u, channelFwd{1, ch}) + + wg.Add(2) + go func() { + defer wg.Done() + for { + select { + case <-ch: + received += 1 + case <-ctx.Done(): + return + } + } + }() + go func() { + defer wg.Done() + for { + s.AddPacketForwarder(u, channelFwd{2, ch}) + s.AddPacketForwarder(u, channelFwd{3, ch}) + s.RemovePacketForwarder(u, channelFwd{2, ch}) + s.RemovePacketForwarder(u, channelFwd{1, ch}) + s.AddPacketForwarder(u, channelFwd{1, ch}) + s.RemovePacketForwarder(u, channelFwd{3, ch}) + if ctx.Err() != nil { + return + } + } + }() + + // Number of messages is chosen arbitrarily, just for this loop to + // run long enough concurrently with {Add,Remove}PacketForwarder loop above. + numMsgs := 5000 + var fwd PacketForwarder + for i := 0; i < numMsgs; i++ { + s.mu.Lock() + fwd = s.clientsMesh[u] + s.mu.Unlock() + fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i))) + } + + cancel() + wg.Wait() + if received != numMsgs { + t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received) + } +} func TestMetaCert(t *testing.T) { priv := key.NewNode() pub := priv.Public()