wgengine/magicsock: add new discoInfo type for DiscoKey state, move some fields

As more prep for removing the false assumption that you're able to
map from DiscoKey to a single peer, move the lastPingFrom and lastPingTime
fields from the endpoint type to a new discoInfo type, effectively upgrading
the old sharedDiscoKey map (which only held a *[32]byte nacl precomputed key
as its value) to discoInfo which then includes that naclbox key.

Then start plumbing it into handlePing in prep for removing the need
for handlePing to take an endpoint parameter.

Updates #3088

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
(cherry picked from commit d86081f3535168921cd9b4ddf1dd127f3639bd00)
This commit is contained in:
Brad Fitzpatrick 2021-10-15 20:45:33 -07:00
parent fd85b3274e
commit 07b569fe26
2 changed files with 77 additions and 38 deletions

View File

@ -343,9 +343,9 @@ type Conn struct {
// nodeOfDisco tracks the networkmap Node entity for each peer // nodeOfDisco tracks the networkmap Node entity for each peer
// discovery key. // discovery key.
peerMap peerMap peerMap peerMap
// sharedDiscoKey is the precomputed nacl/box key for
// communication with the peer that has the given DiscoKey. // discoInfo is the state for an active DiscoKey.
sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte discoInfo map[tailcfg.DiscoKey]*discoInfo
// netInfoFunc is a callback that provides a tailcfg.NetInfo when // netInfoFunc is a callback that provides a tailcfg.NetInfo when
// discovered network conditions change. // discovered network conditions change.
@ -506,11 +506,11 @@ func (o *Options) derpActiveFunc() func() {
// of NewConn. Mostly for tests. // of NewConn. Mostly for tests.
func newConn() *Conn { func newConn() *Conn {
c := &Conn{ c := &Conn{
derpRecvCh: make(chan derpReadResult), derpRecvCh: make(chan derpReadResult),
derpStarted: make(chan struct{}), derpStarted: make(chan struct{}),
peerLastDerp: make(map[key.Public]int), peerLastDerp: make(map[key.Public]int),
peerMap: newPeerMap(), peerMap: newPeerMap(),
sharedDiscoKey: make(map[tailcfg.DiscoKey]*[32]byte), discoInfo: make(map[tailcfg.DiscoKey]*discoInfo),
} }
c.bind = &connBind{Conn: c, closed: true} c.bind = &connBind{Conn: c, closed: true}
c.muCond = sync.NewCond(&c.mu) c.muCond = sync.NewCond(&c.mu)
@ -1596,7 +1596,7 @@ func (c *Conn) receiveIP(b []byte, ipp netaddr.IPPort, cache *ippEndpointCache)
c.stunReceiveFunc.Load().(func([]byte, netaddr.IPPort))(b, ipp) c.stunReceiveFunc.Load().(func([]byte, netaddr.IPPort))(b, ipp)
return nil, false return nil, false
} }
if c.handleDiscoMessage(b, ipp, key.Public{}) { if c.handleDiscoMessage(b, ipp, tailcfg.NodeKey{}) {
return nil, false return nil, false
} }
if !c.havePrivateKey.Get() { if !c.havePrivateKey.Get() {
@ -1659,7 +1659,7 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
} }
ipp := netaddr.IPPortFrom(derpMagicIPAddr, uint16(regionID)) ipp := netaddr.IPPortFrom(derpMagicIPAddr, uint16(regionID))
if c.handleDiscoMessage(b[:n], ipp, dm.src) { if c.handleDiscoMessage(b[:n], ipp, tailcfg.NodeKey(dm.src)) {
return 0, nil return 0, nil
} }
@ -1703,10 +1703,10 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstD
pkt = append(pkt, disco.Magic...) pkt = append(pkt, disco.Magic...)
pkt = append(pkt, c.discoPublic[:]...) pkt = append(pkt, c.discoPublic[:]...)
pkt = append(pkt, nonce[:]...) pkt = append(pkt, nonce[:]...)
sharedKey := c.sharedDiscoKeyLocked(dstDisco) di := c.discoInfoLocked(dstDisco)
c.mu.Unlock() c.mu.Unlock()
pkt = box.SealAfterPrecomputation(pkt, m.AppendMarshal(nil), &nonce, sharedKey) pkt = box.SealAfterPrecomputation(pkt, m.AppendMarshal(nil), &nonce, di.sharedKey)
sent, err = c.sendAddr(dst, key.Public(dstKey), pkt) sent, err = c.sendAddr(dst, key.Public(dstKey), pkt)
if sent { if sent {
if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco) { if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco) {
@ -1736,7 +1736,7 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstD
// src.Port() being the region ID) and the derpNodeSrc will be the node key // src.Port() being the region ID) and the derpNodeSrc will be the node key
// it was received from at the DERP layer. derpNodeSrc is zero when received // it was received from at the DERP layer. derpNodeSrc is zero when received
// over UDP. // over UDP.
func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc key.Public) (isDiscoMsg bool) { func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc tailcfg.NodeKey) (isDiscoMsg bool) {
const headerLen = len(disco.Magic) + len(tailcfg.DiscoKey{}) + disco.NonceLen const headerLen = len(disco.Magic) + len(tailcfg.DiscoKey{}) + disco.NonceLen
if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic {
return false return false
@ -1784,10 +1784,12 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ke
// //
// From here on, peerNode and de are non-nil. // From here on, peerNode and de are non-nil.
di := c.discoInfoLocked(sender)
var nonce [disco.NonceLen]byte var nonce [disco.NonceLen]byte
copy(nonce[:], msg[len(disco.Magic)+len(key.Public{}):]) copy(nonce[:], msg[len(disco.Magic)+len(key.Public{}):])
sealedBox := msg[headerLen:] sealedBox := msg[headerLen:]
payload, ok := box.OpenAfterPrecomputation(nil, sealedBox, &nonce, c.sharedDiscoKeyLocked(sender)) payload, ok := box.OpenAfterPrecomputation(nil, sealedBox, &nonce, di.sharedKey)
if !ok { if !ok {
// This might be have been intended for a previous // This might be have been intended for a previous
// disco key. When we restart we get a new disco key // disco key. When we restart we get a new disco key
@ -1834,7 +1836,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ke
switch dm := dm.(type) { switch dm := dm.(type) {
case *disco.Ping: case *disco.Ping:
c.handlePingLocked(dm, ep, src, sender) c.handlePingLocked(dm, ep, src, di, derpNodeSrc)
case *disco.Pong: case *disco.Pong:
ep.handlePongConnLocked(dm, src) ep.handlePongConnLocked(dm, src)
case *disco.CallMeMaybe: case *disco.CallMeMaybe:
@ -1860,20 +1862,22 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ke
return return
} }
func (c *Conn) handlePingLocked(dm *disco.Ping, de *endpoint, src netaddr.IPPort, sender tailcfg.DiscoKey) { // di is the discoInfo of the source of the ping.
likelyHeartBeat := src == de.lastPingFrom && time.Since(de.lastPingTime) < 5*time.Second // derpNodeSrc is non-zero if the ping arrived via DERP.
de.lastPingFrom = src func (c *Conn) handlePingLocked(dm *disco.Ping, de *endpoint, src netaddr.IPPort, di *discoInfo, derpNodeSrc tailcfg.NodeKey) {
de.lastPingTime = time.Now() likelyHeartBeat := src == di.lastPingFrom && time.Since(di.lastPingTime) < 5*time.Second
di.lastPingFrom = src
di.lastPingTime = time.Now()
if !likelyHeartBeat || debugDisco { if !likelyHeartBeat || debugDisco {
c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, de.discoShort, de.publicKey.ShortString(), src, dm.TxID[:6]) c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, di.discoShort, de.publicKey.ShortString(), src, dm.TxID[:6])
} }
// Remember this route if not present. // Remember this route if not present.
c.setAddrToDiscoLocked(src, sender) c.setAddrToDiscoLocked(src, di.discoKey)
de.addCandidateEndpoint(src) de.addCandidateEndpoint(src)
ipDst := src ipDst := src
discoDest := sender discoDest := di.discoKey
go c.sendDiscoMessage(ipDst, de.publicKey, discoDest, &disco.Pong{ go c.sendDiscoMessage(ipDst, de.publicKey, discoDest, &disco.Pong{
TxID: dm.TxID, TxID: dm.TxID,
Src: src, Src: src,
@ -1935,14 +1939,21 @@ func (c *Conn) setAddrToDiscoLocked(src netaddr.IPPort, newk tailcfg.DiscoKey) {
c.peerMap.setDiscoKeyForIPPort(src, newk) c.peerMap.setDiscoKeyForIPPort(src, newk)
} }
func (c *Conn) sharedDiscoKeyLocked(k tailcfg.DiscoKey) *[32]byte { // discoInfoLocked returns the previous or new discoInfo for k.
if v, ok := c.sharedDiscoKey[k]; ok { //
return v // c.mu must be held.
func (c *Conn) discoInfoLocked(k tailcfg.DiscoKey) *discoInfo {
di, ok := c.discoInfo[k]
if !ok {
di = &discoInfo{
discoKey: k,
discoShort: k.ShortString(),
sharedKey: new([32]byte),
}
box.Precompute(di.sharedKey, key.Public(k).B32(), c.discoPrivate.B32())
c.discoInfo[k] = di
} }
shared := new([32]byte) return di
box.Precompute(shared, key.Public(k).B32(), c.discoPrivate.B32())
c.sharedDiscoKey[k] = shared
return shared
} }
func (c *Conn) SetNetworkUp(up bool) { func (c *Conn) SetNetworkUp(up bool) {
@ -2191,10 +2202,10 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
}) })
} }
// discokeys might have changed in the above. Discard unused cached keys. // discokeys might have changed in the above. Discard unused info.
for discoKey := range c.sharedDiscoKey { for dk := range c.discoInfo {
if !c.peerMap.anyEndpointForDiscoKey(discoKey) { if !c.peerMap.anyEndpointForDiscoKey(dk) {
delete(c.sharedDiscoKey, discoKey) delete(c.discoInfo, dk)
} }
} }
} }
@ -2999,10 +3010,6 @@ type endpoint struct {
fakeWGAddr netaddr.IPPort // the UDP address we tell wireguard-go we're using fakeWGAddr netaddr.IPPort // the UDP address we tell wireguard-go we're using
wgEndpoint string // string from ParseEndpoint, holds a JSON-serialized wgcfg.Endpoints wgEndpoint string // string from ParseEndpoint, holds a JSON-serialized wgcfg.Endpoints
// Owned by Conn.mu:
lastPingFrom netaddr.IPPort
lastPingTime time.Time
// mu protects all following fields. // mu protects all following fields.
mu sync.Mutex // Lock ordering: Conn.mu, then endpoint.mu mu sync.Mutex // Lock ordering: Conn.mu, then endpoint.mu
@ -3788,3 +3795,35 @@ type ippEndpointCache struct {
gen int64 gen int64
de *endpoint de *endpoint
} }
// discoInfo is the info and state for the DiscoKey
// in the Conn.discoInfo map key.
//
// Note that a DiscoKey does not necessarily map to exactly one
// node. In the case of shared nodes and users switching accounts, two
// nodes in the NetMap may legitimately have the same DiscoKey. As
// such, no fields in here should be considered node-specific.
type discoInfo struct {
// discoKey is the same as the Conn.discoInfo map key,
// just so you can pass around a *discoInfo alone.
// Not modifed once initiazed.
discoKey tailcfg.DiscoKey
// discoShort is discoKey.ShortString().
// Not modifed once initiazed;
discoShort string
// sharedKey is the precomputed nacl/box key for
// communication with the peer that has the DiscoKey
// used to look up this *discoInfo in Conn.discoInfo.
// Not modifed once initialized.
sharedKey *[32]byte
// Mutable fields follow, owned by Conn.mu:
// lastPingFrom is the src of a ping for discoKey.
lastPingFrom netaddr.IPPort
// lastPingTime is the last time of a ping for discoKey.
lastPingTime time.Time
}

View File

@ -1158,7 +1158,7 @@ func TestDiscoMessage(t *testing.T) {
pkt = append(pkt, nonce[:]...) pkt = append(pkt, nonce[:]...)
pkt = box.Seal(pkt, []byte(payload), &nonce, c.discoPrivate.Public().B32(), peer1Priv.B32()) pkt = box.Seal(pkt, []byte(payload), &nonce, c.discoPrivate.Public().B32(), peer1Priv.B32())
got := c.handleDiscoMessage(pkt, netaddr.IPPort{}, key.Public{}) got := c.handleDiscoMessage(pkt, netaddr.IPPort{}, tailcfg.NodeKey{})
if !got { if !got {
t.Error("failed to open it") t.Error("failed to open it")
} }