wgengine/magicsock: use key.NodePublic instead of tailcfg.NodeKey.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2021-11-01 17:53:40 -07:00
committed by Dave Anderson
parent d6e7cec6a7
commit 72ace0acba
5 changed files with 67 additions and 67 deletions

View File

@@ -92,19 +92,19 @@ func newPeerInfo(ep *endpoint) *peerInfo {
//
// Doesn't do any locking, all access must be done with Conn.mu held.
type peerMap struct {
byNodeKey map[tailcfg.NodeKey]*peerInfo
byNodeKey map[key.NodePublic]*peerInfo
byIPPort map[netaddr.IPPort]*peerInfo
// nodesOfDisco are contains the set of nodes that are using a
// DiscoKey. Usually those sets will be just one node.
nodesOfDisco map[key.DiscoPublic]map[tailcfg.NodeKey]bool
nodesOfDisco map[key.DiscoPublic]map[key.NodePublic]bool
}
func newPeerMap() peerMap {
return peerMap{
byNodeKey: map[tailcfg.NodeKey]*peerInfo{},
byNodeKey: map[key.NodePublic]*peerInfo{},
byIPPort: map[netaddr.IPPort]*peerInfo{},
nodesOfDisco: map[key.DiscoPublic]map[tailcfg.NodeKey]bool{},
nodesOfDisco: map[key.DiscoPublic]map[key.NodePublic]bool{},
}
}
@@ -121,7 +121,7 @@ func (m *peerMap) anyEndpointForDiscoKey(dk key.DiscoPublic) bool {
// endpointForNodeKey returns the endpoint for nk, or nil if
// nk is not known to us.
func (m *peerMap) endpointForNodeKey(nk tailcfg.NodeKey) (ep *endpoint, ok bool) {
func (m *peerMap) endpointForNodeKey(nk key.NodePublic) (ep *endpoint, ok bool) {
if nk.IsZero() {
return nil, false
}
@@ -182,7 +182,7 @@ func (m *peerMap) upsertEndpoint(ep *endpoint) {
if !ep.discoKey.IsZero() {
set := m.nodesOfDisco[ep.discoKey]
if set == nil {
set = map[tailcfg.NodeKey]bool{}
set = map[key.NodePublic]bool{}
m.nodesOfDisco[ep.discoKey] = set
}
set[ep.publicKey] = true
@@ -195,7 +195,7 @@ func (m *peerMap) upsertEndpoint(ep *endpoint) {
// This should only be called with a fully verified mapping of ipp to
// nk, because calling this function defines the endpoint we hand to
// WireGuard for packets received from ipp.
func (m *peerMap) setNodeKeyForIPPort(ipp netaddr.IPPort, nk tailcfg.NodeKey) {
func (m *peerMap) setNodeKeyForIPPort(ipp netaddr.IPPort, nk key.NodePublic) {
if pi := m.byIPPort[ipp]; pi != nil {
delete(pi.ipPorts, ipp)
delete(m.byIPPort, ipp)
@@ -237,7 +237,7 @@ type Conn struct {
derpActiveFunc func()
idleFunc func() time.Duration // nil means unknown
testOnlyPacketListener nettype.PacketListener
noteRecvActivity func(tailcfg.NodeKey) // or nil, see Options.NoteRecvActivity
noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity
// ================================================================
// No locking required to access these fields, either because
@@ -299,7 +299,7 @@ type Conn struct {
// havePrivateKey is whether privateKey is non-zero.
havePrivateKey syncs.AtomicBool
publicKeyAtomic atomic.Value // of tailcfg.NodeKey (or NodeKey zero value if !havePrivateKey)
publicKeyAtomic atomic.Value // of key.NodePublic (or NodeKey zero value if !havePrivateKey)
// port is the preferred port from opts.Port; 0 means auto.
port syncs.AtomicUint32
@@ -492,7 +492,7 @@ type Options struct {
// The provided func is likely to call back into
// Conn.ParseEndpoint, which acquires Conn.mu. As such, you should
// not hold Conn.mu while calling it.
NoteRecvActivity func(tailcfg.NodeKey)
NoteRecvActivity func(key.NodePublic)
// LinkMonitor is the link monitor to use.
// With one, the portmapper won't be used.
@@ -841,7 +841,7 @@ func (c *Conn) callNetInfoCallbackLocked(ni *tailcfg.NetInfo) {
// discoKey. It's used in tests to enable receiving of packets from
// addr without having to spin up the entire active discovery
// machinery.
func (c *Conn) addValidDiscoPathForTest(nodeKey tailcfg.NodeKey, addr netaddr.IPPort) {
func (c *Conn) addValidDiscoPathForTest(nodeKey key.NodePublic, addr netaddr.IPPort) {
c.mu.Lock()
defer c.mu.Unlock()
c.peerMap.setNodeKeyForIPPort(addr, nodeKey)
@@ -863,7 +863,7 @@ func (c *Conn) SetNetInfoCallback(fn func(*tailcfg.NetInfo)) {
// LastRecvActivityOfNodeKey describes the time we last got traffic from
// this endpoint (updated every ~10 seconds).
func (c *Conn) LastRecvActivityOfNodeKey(nk tailcfg.NodeKey) string {
func (c *Conn) LastRecvActivityOfNodeKey(nk key.NodePublic) string {
c.mu.Lock()
defer c.mu.Unlock()
de, ok := c.peerMap.endpointForNodeKey(nk)
@@ -898,7 +898,7 @@ func (c *Conn) Ping(peer *tailcfg.Node, res *ipnstate.PingResult, cb func(*ipnst
}
}
ep, ok := c.peerMap.endpointForNodeKey(peer.Key)
ep, ok := c.peerMap.endpointForNodeKey(peer.Key.AsNodePublic())
if !ok {
res.Err = "unknown peer"
cb(res)
@@ -944,7 +944,7 @@ func (c *Conn) DiscoPublicKey() key.DiscoPublic {
}
// PeerHasDiscoKey reports whether peer k supports discovery keys (client version 0.100.0+).
func (c *Conn) PeerHasDiscoKey(k tailcfg.NodeKey) bool {
func (c *Conn) PeerHasDiscoKey(k key.NodePublic) bool {
c.mu.Lock()
defer c.mu.Unlock()
if ep, ok := c.peerMap.endpointForNodeKey(k); ok {
@@ -1647,7 +1647,7 @@ func (c *Conn) receiveIP(b []byte, ipp netaddr.IPPort, cache *ippEndpointCache)
c.stunReceiveFunc.Load().(func([]byte, netaddr.IPPort))(b, ipp)
return nil, false
}
if c.handleDiscoMessage(b, ipp, tailcfg.NodeKey{}) {
if c.handleDiscoMessage(b, ipp, key.NodePublic{}) {
return nil, false
}
if !c.havePrivateKey.Get() {
@@ -1710,13 +1710,13 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
}
ipp := netaddr.IPPortFrom(derpMagicIPAddr, uint16(regionID))
if c.handleDiscoMessage(b[:n], ipp, dm.src.AsNodeKey()) {
if c.handleDiscoMessage(b[:n], ipp, dm.src) {
return 0, nil
}
var ok bool
c.mu.Lock()
ep, ok = c.peerMap.endpointForNodeKey(dm.src.AsNodeKey())
ep, ok = c.peerMap.endpointForNodeKey(dm.src)
c.mu.Unlock()
if !ok {
// We don't know anything about this node key, nothing to
@@ -1746,7 +1746,7 @@ const (
//
// The dstKey should only be non-zero if the dstDisco key
// unambiguously maps to exactly one peer.
func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstDisco key.DiscoPublic, m disco.Message, logLevel discoLogLevel) (sent bool, err error) {
func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey key.NodePublic, dstDisco key.DiscoPublic, m disco.Message, logLevel discoLogLevel) (sent bool, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
@@ -1764,7 +1764,7 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstD
box := di.sharedKey.Seal(m.AppendMarshal(nil))
pkt = append(pkt, box...)
sent, err = c.sendAddr(dst, key.NodePublicFromRaw32(mem.B(dstKey[:])), pkt)
sent, err = c.sendAddr(dst, dstKey, pkt)
if sent {
if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco) {
node := "?"
@@ -1797,7 +1797,7 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstD
// src.Port() being the region ID) and the derpNodeSrc will be the node key
// it was received from at the DERP layer. derpNodeSrc is zero when received
// over UDP.
func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc tailcfg.NodeKey) (isDiscoMsg bool) {
func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc key.NodePublic) (isDiscoMsg bool) {
const headerLen = len(disco.Magic) + key.DiscoPublicRawLen
if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic {
return false
@@ -1897,7 +1897,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ta
c.logf("[unexpected] CallMeMaybe packets should only come via DERP")
return
}
nodeKey := tailcfg.NodeKey(derpNodeSrc)
nodeKey := derpNodeSrc
ep, ok := c.peerMap.endpointForNodeKey(nodeKey)
if !ok {
c.logf("magicsock: disco: ignoring CallMeMaybe from %v; %v is unknown", sender.ShortString(), derpNodeSrc.ShortString())
@@ -1927,7 +1927,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort, derpNodeSrc ta
// derpNodeSrc is non-zero if the disco ping arrived via DERP.
//
// c.mu must be held.
func (c *Conn) unambiguousNodeKeyOfPingLocked(dm *disco.Ping, dk key.DiscoPublic, derpNodeSrc tailcfg.NodeKey) (nk tailcfg.NodeKey, ok bool) {
func (c *Conn) unambiguousNodeKeyOfPingLocked(dm *disco.Ping, dk key.DiscoPublic, derpNodeSrc key.NodePublic) (nk key.NodePublic, ok bool) {
if !derpNodeSrc.IsZero() {
if ep, ok := c.peerMap.endpointForNodeKey(derpNodeSrc); ok && ep.discoKey == dk {
return derpNodeSrc, true
@@ -1936,8 +1936,8 @@ func (c *Conn) unambiguousNodeKeyOfPingLocked(dm *disco.Ping, dk key.DiscoPublic
// Pings after 1.16.0 contains its node source. See if it maps back.
if !dm.NodeKey.IsZero() {
if ep, ok := c.peerMap.endpointForNodeKey(dm.NodeKey.AsNodeKey()); ok && ep.discoKey == dk {
return dm.NodeKey.AsNodeKey(), true
if ep, ok := c.peerMap.endpointForNodeKey(dm.NodeKey); ok && ep.discoKey == dk {
return dm.NodeKey, true
}
}
@@ -1954,7 +1954,7 @@ func (c *Conn) unambiguousNodeKeyOfPingLocked(dm *disco.Ping, dk key.DiscoPublic
// di is the discoInfo of the source of the ping.
// derpNodeSrc is non-zero if the ping arrived via DERP.
func (c *Conn) handlePingLocked(dm *disco.Ping, src netaddr.IPPort, di *discoInfo, derpNodeSrc tailcfg.NodeKey) {
func (c *Conn) handlePingLocked(dm *disco.Ping, src netaddr.IPPort, di *discoInfo, derpNodeSrc key.NodePublic) {
likelyHeartBeat := src == di.lastPingFrom && time.Since(di.lastPingTime) < 5*time.Second
di.lastPingFrom = src
di.lastPingTime = time.Now()
@@ -1999,7 +1999,7 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netaddr.IPPort, di *discoInf
if numNodes > 1 {
// Zero it out if it's ambiguous, so sendDiscoMessage logging
// isn't confusing.
dstKey = tailcfg.NodeKey{}
dstKey = key.NodePublic{}
}
}
@@ -2130,9 +2130,9 @@ func (c *Conn) SetPrivateKey(privateKey key.NodePrivate) error {
c.havePrivateKey.Set(!newKey.IsZero())
if newKey.IsZero() {
c.publicKeyAtomic.Store(tailcfg.NodeKey{})
c.publicKeyAtomic.Store(key.NodePublic{})
} else {
c.publicKeyAtomic.Store(newKey.Public().AsNodeKey())
c.publicKeyAtomic.Store(newKey.Public())
}
if oldKey.IsZero() {
@@ -2256,7 +2256,7 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
// we'll fall through to the next pass, which allocates but can
// handle full set updates.
for _, n := range nm.Peers {
if ep, ok := c.peerMap.endpointForNodeKey(n.Key); ok {
if ep, ok := c.peerMap.endpointForNodeKey(n.Key.AsNodePublic()); ok {
ep.updateFromNode(n)
c.peerMap.upsertEndpoint(ep) // maybe update discokey mappings in peerMap
continue
@@ -2264,7 +2264,7 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
ep := &endpoint{
c: c,
publicKey: n.Key,
publicKey: n.Key.AsNodePublic(),
sentPing: map[stun.TxID]sentPing{},
endpointState: map[netaddr.IPPort]*endpointState{},
}
@@ -2307,9 +2307,9 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
// current netmap. If that happens, go through the allocful
// deletion path to clean up moribund nodes.
if c.peerMap.nodeCount() != len(nm.Peers) {
keep := make(map[tailcfg.NodeKey]bool, len(nm.Peers))
keep := make(map[key.NodePublic]bool, len(nm.Peers))
for _, n := range nm.Peers {
keep[n.Key] = true
keep[n.Key.AsNodePublic()] = true
}
c.peerMap.forEachEndpoint(func(ep *endpoint) {
if !keep[ep.publicKey] {
@@ -2818,19 +2818,18 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) {
if err != nil {
return nil, fmt.Errorf("magicsock: ParseEndpoint: parse failed on %q: %w", nodeKeyStr, err)
}
pk := k.AsNodeKey()
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil, errConnClosed
}
ep, ok := c.peerMap.endpointForNodeKey(pk)
ep, ok := c.peerMap.endpointForNodeKey(k)
if !ok {
// We should never be telling WireGuard about a new peer
// before magicsock knows about it.
c.logf("[unexpected] magicsock: ParseEndpoint: unknown node key=%s", pk.ShortString())
return nil, fmt.Errorf("magicsock: ParseEndpoint: unknown peer %q", pk.ShortString())
c.logf("[unexpected] magicsock: ParseEndpoint: unknown node key=%s", k.ShortString())
return nil, fmt.Errorf("magicsock: ParseEndpoint: unknown peer %q", k.ShortString())
}
return ep, nil
@@ -3108,7 +3107,7 @@ func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) {
ps := &ipnstate.PeerStatus{InMagicSock: true}
//ps.Addrs = append(ps.Addrs, n.Endpoints...)
ep.populatePeerStatus(ps)
sb.AddPeer(key.NodePublicFromRaw32(mem.B(ep.publicKey[:])), ps)
sb.AddPeer(ep.publicKey, ps)
})
c.foreachActiveDerpSortedLocked(func(node int, ad activeDerp) {
@@ -3134,9 +3133,9 @@ type endpoint struct {
// These fields are initialized once and never modified.
c *Conn
publicKey tailcfg.NodeKey // peer public key (for WireGuard + DERP)
fakeWGAddr netaddr.IPPort // the UDP address we tell wireguard-go we're using
wgEndpoint string // string from ParseEndpoint, holds a JSON-serialized wgcfg.Endpoints
publicKey key.NodePublic // peer public key (for WireGuard + DERP)
fakeWGAddr netaddr.IPPort // the UDP address we tell wireguard-go we're using
wgEndpoint string // string from ParseEndpoint, holds a JSON-serialized wgcfg.Endpoints
// mu protects all following fields.
mu sync.Mutex // Lock ordering: Conn.mu, then endpoint.mu
@@ -3458,10 +3457,10 @@ func (de *endpoint) send(b []byte) error {
}
var err error
if !udpAddr.IsZero() {
_, err = de.c.sendAddr(udpAddr, key.NodePublicFromRaw32(mem.B(de.publicKey[:])), b)
_, err = de.c.sendAddr(udpAddr, de.publicKey, b)
}
if !derpAddr.IsZero() {
if ok, _ := de.c.sendAddr(derpAddr, key.NodePublicFromRaw32(mem.B(de.publicKey[:])), b); ok && err != nil {
if ok, _ := de.c.sendAddr(derpAddr, de.publicKey, b); ok && err != nil {
// UDP failed but DERP worked, so good enough:
return nil
}
@@ -3504,10 +3503,10 @@ func (de *endpoint) removeSentPingLocked(txid stun.TxID, sp sentPing) {
// The caller (startPingLocked) should've already been recorded the ping in
// sentPing and set up the timer.
func (de *endpoint) sendDiscoPing(ep netaddr.IPPort, txid stun.TxID, logLevel discoLogLevel) {
selfPubKey, _ := de.c.publicKeyAtomic.Load().(tailcfg.NodeKey)
selfPubKey, _ := de.c.publicKeyAtomic.Load().(key.NodePublic)
sent, _ := de.sendDiscoMessage(ep, &disco.Ping{
TxID: [12]byte(txid),
NodeKey: selfPubKey.AsNodePublic(),
NodeKey: selfPubKey,
}, logLevel)
if !sent {
de.forgetPing(txid)
@@ -3986,7 +3985,7 @@ type discoInfo struct {
// lastNodeKey is the last NodeKey seen using discoKey.
// It's only updated if the NodeKey is unambiguous.
lastNodeKey tailcfg.NodeKey
lastNodeKey key.NodePublic
// lastNodeKeyTime is the time a NodeKey was last seen using
// this discoKey. It's only updated if the NodeKey is
@@ -3996,7 +3995,7 @@ type discoInfo struct {
// setNodeKey sets the most recent mapping from di.discoKey to the
// NodeKey nk.
func (di *discoInfo) setNodeKey(nk tailcfg.NodeKey) {
func (di *discoInfo) setNodeKey(nk key.NodePublic) {
di.lastNodeKey = nk
di.lastNodeKeyTime = time.Now()
}