derp: include src IPs in mesh watch messages

Updates tailscale/corp#13945

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2023-08-15 19:35:24 -07:00 committed by Brad Fitzpatrick
parent 7ed3681cbe
commit 6c791f7d60
6 changed files with 73 additions and 38 deletions

View File

@ -9,6 +9,7 @@
"fmt" "fmt"
"log" "log"
"net" "net"
"net/netip"
"strings" "strings"
"time" "time"
@ -67,7 +68,7 @@ func startMeshWithHost(s *derp.Server, host string) error {
return d.DialContext(ctx, network, addr) return d.DialContext(ctx, network, addr)
}) })
add := func(k key.NodePublic) { s.AddPacketForwarder(k, c) } add := func(k key.NodePublic, _ netip.AddrPort) { s.AddPacketForwarder(k, c) }
remove := func(k key.NodePublic) { s.RemovePacketForwarder(k, c) } remove := func(k key.NodePublic) { s.RemovePacketForwarder(k, c) }
go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove) go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove)
return nil return nil

View File

@ -85,7 +85,7 @@
// framePeerPresent is like framePeerGone, but for other // framePeerPresent is like framePeerGone, but for other
// members of the DERP region when they're meshed up together. // members of the DERP region when they're meshed up together.
framePeerPresent = frameType(0x09) // 32B pub key of peer that's connected framePeerPresent = frameType(0x09) // 32B pub key of peer that's connected + optional 18B ip:port (16 byte IP + 2 byte BE uint16 port)
// frameWatchConns is how one DERP node in a regional mesh // frameWatchConns is how one DERP node in a regional mesh
// subscribes to the others in the region. // subscribes to the others in the region.

View File

@ -363,7 +363,12 @@ func (PeerGoneMessage) msg() {}
// PeerPresentMessage is a ReceivedMessage that indicates that the client // PeerPresentMessage is a ReceivedMessage that indicates that the client
// is connected to the server. (Only used by trusted mesh clients) // is connected to the server. (Only used by trusted mesh clients)
type PeerPresentMessage key.NodePublic type PeerPresentMessage struct {
// Key is the public key of the client.
Key key.NodePublic
// IPPort is the remote IP and port of the client.
IPPort netip.AddrPort
}
func (PeerPresentMessage) msg() {} func (PeerPresentMessage) msg() {}
@ -546,8 +551,15 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro
c.logf("[unexpected] dropping short peerPresent frame from DERP server") c.logf("[unexpected] dropping short peerPresent frame from DERP server")
continue continue
} }
pg := PeerPresentMessage(key.NodePublicFromRaw32(mem.B(b[:keyLen]))) var msg PeerPresentMessage
return pg, nil msg.Key = key.NodePublicFromRaw32(mem.B(b[:keyLen]))
if n >= keyLen+16+2 {
msg.IPPort = netip.AddrPortFrom(
netip.AddrFrom16([16]byte(b[keyLen:keyLen+16])).Unmap(),
binary.BigEndian.Uint16(b[keyLen+16:keyLen+16+2]),
)
}
return msg, nil
case frameRecvPacket: case frameRecvPacket:
var rp ReceivedPacket var rp ReceivedPacket

View File

@ -12,6 +12,7 @@
crand "crypto/rand" crand "crypto/rand"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/binary"
"encoding/json" "encoding/json"
"errors" "errors"
"expvar" "expvar"
@ -43,6 +44,7 @@
"tailscale.com/tstime/rate" "tailscale.com/tstime/rate"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/set"
"tailscale.com/version" "tailscale.com/version"
) )
@ -150,7 +152,7 @@ type Server struct {
closed bool closed bool
netConns map[Conn]chan struct{} // chan is closed when conn closes netConns map[Conn]chan struct{} // chan is closed when conn closes
clients map[key.NodePublic]clientSet clients map[key.NodePublic]clientSet
watchers map[*sclient]bool // mesh peer -> true watchers set.Set[*sclient] // mesh peers
// clientsMesh tracks all clients in the cluster, both locally // clientsMesh tracks all clients in the cluster, both locally
// and to mesh peers. If the value is nil, that means the // and to mesh peers. If the value is nil, that means the
// peer is only local (and thus in the clients Map, but not // peer is only local (and thus in the clients Map, but not
@ -219,8 +221,7 @@ func (s singleClient) ForeachClient(f func(*sclient)) { f(s.c) }
// All fields are guarded by Server.mu. // All fields are guarded by Server.mu.
type dupClientSet struct { type dupClientSet struct {
// set is the set of connected clients for sclient.key. // set is the set of connected clients for sclient.key.
// The values are all true. set set.Set[*sclient]
set map[*sclient]bool
// last is the most recent addition to set, or nil if the most // last is the most recent addition to set, or nil if the most
// recent one has since disconnected and nobody else has send // recent one has since disconnected and nobody else has send
@ -261,7 +262,7 @@ func (s *dupClientSet) removeClient(c *sclient) bool {
trim := s.sendHistory[:0] trim := s.sendHistory[:0]
for _, v := range s.sendHistory { for _, v := range s.sendHistory {
if s.set[v] && (len(trim) == 0 || trim[len(trim)-1] != v) { if s.set.Contains(v) && (len(trim) == 0 || trim[len(trim)-1] != v) {
trim = append(trim, v) trim = append(trim, v)
} }
} }
@ -316,7 +317,7 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server {
clientsMesh: map[key.NodePublic]PacketForwarder{}, clientsMesh: map[key.NodePublic]PacketForwarder{},
netConns: map[Conn]chan struct{}{}, netConns: map[Conn]chan struct{}{},
memSys0: ms.Sys, memSys0: ms.Sys,
watchers: map[*sclient]bool{}, watchers: set.Set[*sclient]{},
sentTo: map[key.NodePublic]map[key.NodePublic]int64{}, sentTo: map[key.NodePublic]map[key.NodePublic]int64{},
avgQueueDuration: new(uint64), avgQueueDuration: new(uint64),
tcpRtt: metrics.LabelMap{Label: "le"}, tcpRtt: metrics.LabelMap{Label: "le"},
@ -498,8 +499,8 @@ func (s *Server) registerClient(c *sclient) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
set := s.clients[c.key] curSet := s.clients[c.key]
switch set := set.(type) { switch curSet := curSet.(type) {
case nil: case nil:
s.clients[c.key] = singleClient{c} s.clients[c.key] = singleClient{c}
c.debugLogf("register single client") c.debugLogf("register single client")
@ -507,14 +508,14 @@ func (s *Server) registerClient(c *sclient) {
s.dupClientKeys.Add(1) s.dupClientKeys.Add(1)
s.dupClientConns.Add(2) // both old and new count s.dupClientConns.Add(2) // both old and new count
s.dupClientConnTotal.Add(1) s.dupClientConnTotal.Add(1)
old := set.ActiveClient() old := curSet.ActiveClient()
old.isDup.Store(true) old.isDup.Store(true)
c.isDup.Store(true) c.isDup.Store(true)
s.clients[c.key] = &dupClientSet{ s.clients[c.key] = &dupClientSet{
last: c, last: c,
set: map[*sclient]bool{ set: set.Set[*sclient]{
old: true, old: struct{}{},
c: true, c: struct{}{},
}, },
sendHistory: []*sclient{old}, sendHistory: []*sclient{old},
} }
@ -523,9 +524,9 @@ func (s *Server) registerClient(c *sclient) {
s.dupClientConns.Add(1) // the gauge s.dupClientConns.Add(1) // the gauge
s.dupClientConnTotal.Add(1) // the counter s.dupClientConnTotal.Add(1) // the counter
c.isDup.Store(true) c.isDup.Store(true)
set.set[c] = true curSet.set.Add(c)
set.last = c curSet.last = c
set.sendHistory = append(set.sendHistory, c) curSet.sendHistory = append(curSet.sendHistory, c)
c.debugLogf("register another duplicate client") c.debugLogf("register another duplicate client")
} }
@ -534,7 +535,7 @@ func (s *Server) registerClient(c *sclient) {
} }
s.keyOfAddr[c.remoteIPPort] = c.key s.keyOfAddr[c.remoteIPPort] = c.key
s.curClients.Add(1) s.curClients.Add(1)
s.broadcastPeerStateChangeLocked(c.key, true) s.broadcastPeerStateChangeLocked(c.key, c.remoteIPPort, true)
} }
// broadcastPeerStateChangeLocked enqueues a message to all watchers // broadcastPeerStateChangeLocked enqueues a message to all watchers
@ -542,9 +543,13 @@ func (s *Server) registerClient(c *sclient) {
// presence changed. // presence changed.
// //
// s.mu must be held. // s.mu must be held.
func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, present bool) { func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, ipPort netip.AddrPort, present bool) {
for w := range s.watchers { for w := range s.watchers {
w.peerStateChange = append(w.peerStateChange, peerConnState{peer: peer, present: present}) w.peerStateChange = append(w.peerStateChange, peerConnState{
peer: peer,
present: present,
ipPort: ipPort,
})
go w.requestMeshUpdate() go w.requestMeshUpdate()
} }
} }
@ -565,7 +570,7 @@ func (s *Server) unregisterClient(c *sclient) {
delete(s.clientsMesh, c.key) delete(s.clientsMesh, c.key)
s.notePeerGoneFromRegionLocked(c.key) s.notePeerGoneFromRegionLocked(c.key)
} }
s.broadcastPeerStateChangeLocked(c.key, false) s.broadcastPeerStateChangeLocked(c.key, netip.AddrPort{}, false)
case *dupClientSet: case *dupClientSet:
c.debugLogf("removed duplicate client") c.debugLogf("removed duplicate client")
if set.removeClient(c) { if set.removeClient(c) {
@ -655,13 +660,21 @@ func (s *Server) addWatcher(c *sclient) {
defer s.mu.Unlock() defer s.mu.Unlock()
// Queue messages for each already-connected client. // Queue messages for each already-connected client.
for peer := range s.clients { for peer, clientSet := range s.clients {
c.peerStateChange = append(c.peerStateChange, peerConnState{peer: peer, present: true}) ac := clientSet.ActiveClient()
if ac == nil {
continue
}
c.peerStateChange = append(c.peerStateChange, peerConnState{
peer: peer,
present: true,
ipPort: ac.remoteIPPort,
})
} }
// And enroll the watcher in future updates (of both // And enroll the watcher in future updates (of both
// connections & disconnections). // connections & disconnections).
s.watchers[c] = true s.watchers.Add(c)
go c.requestMeshUpdate() go c.requestMeshUpdate()
} }
@ -1349,6 +1362,7 @@ type sclient struct {
type peerConnState struct { type peerConnState struct {
peer key.NodePublic peer key.NodePublic
present bool present bool
ipPort netip.AddrPort // if present, the peer's IP:port
} }
// pkt is a request to write a data frame to an sclient. // pkt is a request to write a data frame to an sclient.
@ -1542,12 +1556,18 @@ func (c *sclient) sendPeerGone(peer key.NodePublic, reason PeerGoneReasonType) e
} }
// sendPeerPresent sends a peerPresent frame, without flushing. // sendPeerPresent sends a peerPresent frame, without flushing.
func (c *sclient) sendPeerPresent(peer key.NodePublic) error { func (c *sclient) sendPeerPresent(peer key.NodePublic, ipPort netip.AddrPort) error {
c.setWriteDeadline() c.setWriteDeadline()
if err := writeFrameHeader(c.bw.bw(), framePeerPresent, keyLen); err != nil { const frameLen = keyLen + 16 + 2
if err := writeFrameHeader(c.bw.bw(), framePeerPresent, frameLen); err != nil {
return err return err
} }
_, err := c.bw.Write(peer.AppendTo(nil)) payload := make([]byte, frameLen)
_ = peer.AppendTo(payload[:0])
a16 := ipPort.Addr().As16()
copy(payload[keyLen:], a16[:])
binary.BigEndian.PutUint16(payload[keyLen+16:], ipPort.Port())
_, err := c.bw.Write(payload)
return err return err
} }
@ -1566,7 +1586,7 @@ func (c *sclient) sendMeshUpdates() error {
} }
var err error var err error
if pcs.present { if pcs.present {
err = c.sendPeerPresent(pcs.peer) err = c.sendPeerPresent(pcs.peer, pcs.ipPort)
} else { } else {
err = c.sendPeerGone(pcs.peer, PeerGoneReasonDisconnected) err = c.sendPeerGone(pcs.peer, PeerGoneReasonDisconnected)
} }

View File

@ -92,7 +92,7 @@ func TestSendRecv(t *testing.T) {
defer cancel() defer cancel()
brwServer := bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin)) brwServer := bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin))
go s.Accept(ctx, cin, brwServer, fmt.Sprintf("test-client-%d", i)) go s.Accept(ctx, cin, brwServer, fmt.Sprintf("[abc::def]:%v", i))
key := clientPrivateKeys[i] key := clientPrivateKeys[i]
brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout)) brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout))
@ -528,7 +528,7 @@ func newTestServer(t *testing.T, ctx context.Context) *testServer {
// TODO: register c in ts so Close also closes it? // TODO: register c in ts so Close also closes it?
go func(i int) { go func(i int) {
brwServer := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) brwServer := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
go s.Accept(ctx, c, brwServer, fmt.Sprintf("test-client-%d", i)) go s.Accept(ctx, c, brwServer, c.RemoteAddr().String())
}(i) }(i)
} }
}() }()
@ -615,7 +615,7 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.NodePublic) {
} }
switch m := m.(type) { switch m := m.(type) {
case PeerPresentMessage: case PeerPresentMessage:
got := key.NodePublic(m) got := m.Key
if !want[got] { if !want[got] {
t.Fatalf("got peer present for %v; want present for %v", tc.ts.keyName(got), logger.ArgWriter(func(bw *bufio.Writer) { t.Fatalf("got peer present for %v; want present for %v", tc.ts.keyName(got), logger.ArgWriter(func(bw *bufio.Writer) {
for _, pub := range peers { for _, pub := range peers {
@ -623,6 +623,7 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.NodePublic) {
} }
})) }))
} }
t.Logf("got present with IP %v", m.IPPort)
delete(want, got) delete(want, got)
if len(want) == 0 { if len(want) == 0 {
return return

View File

@ -5,6 +5,7 @@
import ( import (
"context" "context"
"net/netip"
"sync" "sync"
"time" "time"
@ -26,7 +27,7 @@
// //
// To force RunWatchConnectionLoop to return quickly, its ctx needs to // To force RunWatchConnectionLoop to return quickly, its ctx needs to
// be closed, and c itself needs to be closed. // be closed, and c itself needs to be closed.
func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, add, remove func(key.NodePublic)) { func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, add func(key.NodePublic, netip.AddrPort), remove func(key.NodePublic)) {
if infoLogf == nil { if infoLogf == nil {
infoLogf = logger.Discard infoLogf = logger.Discard
} }
@ -68,9 +69,9 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key
}) })
defer timer.Stop() defer timer.Stop()
updatePeer := func(k key.NodePublic, isPresent bool) { updatePeer := func(k key.NodePublic, ipPort netip.AddrPort, isPresent bool) {
if isPresent { if isPresent {
add(k) add(k, ipPort)
} else { } else {
remove(k) remove(k)
} }
@ -126,7 +127,7 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key
} }
switch m := m.(type) { switch m := m.(type) {
case derp.PeerPresentMessage: case derp.PeerPresentMessage:
updatePeer(key.NodePublic(m), true) updatePeer(m.Key, m.IPPort, true)
case derp.PeerGoneMessage: case derp.PeerGoneMessage:
switch m.Reason { switch m.Reason {
case derp.PeerGoneReasonDisconnected: case derp.PeerGoneReasonDisconnected:
@ -138,7 +139,7 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key
logf("Recv: peer %s not at server %s for unknown reason %v", logf("Recv: peer %s not at server %s for unknown reason %v",
key.NodePublic(m.Peer).ShortString(), c.ServerPublicKey().ShortString(), m.Reason) key.NodePublic(m.Peer).ShortString(), c.ServerPublicKey().ShortString(), m.Reason)
} }
updatePeer(key.NodePublic(m.Peer), false) updatePeer(key.NodePublic(m.Peer), netip.AddrPort{}, false)
default: default:
continue continue
} }