wgengine/magicsock: filter disco packets and packets when stopped from wireguard

Fixes #1167
Fixes tailscale/corp#219

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2021-02-03 18:15:01 -08:00 committed by Brad Fitzpatrick
parent 81466eef81
commit f7eed25bb9

View File

@ -304,6 +304,9 @@ type Conn struct {
// with IPv4 or IPv6). It's used to suppress log spam and prevent // with IPv4 or IPv6). It's used to suppress log spam and prevent
// new connection that'll fail. // new connection that'll fail.
networkUp syncs.AtomicBool networkUp syncs.AtomicBool
// havePrivateKey is whether privateKey is non-zero.
havePrivateKey syncs.AtomicBool
} }
// derpRoute is a route entry for a public key, saying that a certain // derpRoute is a route entry for a public key, saying that a certain
@ -1588,6 +1591,9 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
} }
// receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6. // receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6.
//
// ok is whether this read should be reported up to wireguard-go (our
// caller).
func (c *Conn) receiveIP(b []byte, ua *net.UDPAddr, cache *ippEndpointCache) (ep conn.Endpoint, ok bool) { func (c *Conn) receiveIP(b []byte, ua *net.UDPAddr, cache *ippEndpointCache) (ep conn.Endpoint, ok bool) {
ipp, ok := netaddr.FromStdAddr(ua.IP, ua.Port, ua.Zone) ipp, ok := netaddr.FromStdAddr(ua.IP, ua.Port, ua.Zone)
if !ok { if !ok {
@ -1600,6 +1606,13 @@ func (c *Conn) receiveIP(b []byte, ua *net.UDPAddr, cache *ippEndpointCache) (ep
if c.handleDiscoMessage(b, ipp) { if c.handleDiscoMessage(b, ipp) {
return nil, false return nil, false
} }
if !c.havePrivateKey.Get() {
// If we have no private key, we're logged out or
// stopped. Don't try to pass these wireguard packets
// up to wireguard-go; it'll just complain (Issue
// 1167).
return nil, false
}
if cache.ipp == ipp && cache.de != nil && cache.gen == cache.de.numStopAndReset() { if cache.ipp == ipp && cache.de != nil && cache.gen == cache.de.numStopAndReset() {
ep = cache.de ep = cache.de
} else { } else {
@ -1750,8 +1763,8 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstD
return sent, err return sent, err
} }
// handleDiscoMessage reports whether msg was a Tailscale inter-node discovery message // handleDiscoMessage handles a discovery message and reports whether
// that was handled. // msg was a Tailscale inter-node discovery message.
// //
// A discovery message has the form: // A discovery message has the form:
// //
@ -1762,11 +1775,18 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstD
// //
// For messages received over DERP, the addr will be derpMagicIP (with // For messages received over DERP, the addr will be derpMagicIP (with
// port being the region) // port being the region)
func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool { func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) (isDiscoMsg bool) {
const headerLen = len(disco.Magic) + len(tailcfg.DiscoKey{}) + disco.NonceLen const headerLen = len(disco.Magic) + len(tailcfg.DiscoKey{}) + disco.NonceLen
if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic {
return false return false
} }
// If the first four parts are the prefix of disco.Magic
// (0x5453f09f) then it's definitely not a valid Wireguard
// packet (which starts with little-endian uint32 1, 2, 3, 4).
// Use naked returns for all following paths.
isDiscoMsg = true
var sender tailcfg.DiscoKey var sender tailcfg.DiscoKey
copy(sender[:], msg[len(disco.Magic):]) copy(sender[:], msg[len(disco.Magic):])
@ -1774,20 +1794,21 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
defer c.mu.Unlock() defer c.mu.Unlock()
if c.closed { if c.closed {
return true return
} }
if debugDisco { if debugDisco {
c.logf("magicsock: disco: got disco-looking frame from %v", sender.ShortString()) c.logf("magicsock: disco: got disco-looking frame from %v", sender.ShortString())
} }
if c.privateKey.IsZero() { if c.privateKey.IsZero() {
// Ignore disco messages when we're stopped. // Ignore disco messages when we're stopped.
return false // Still return true, to not pass it down to wireguard.
return
} }
if c.discoPrivate.IsZero() { if c.discoPrivate.IsZero() {
if debugDisco { if debugDisco {
c.logf("magicsock: disco: ignoring disco-looking frame, no local key") c.logf("magicsock: disco: ignoring disco-looking frame, no local key")
} }
return false return
} }
peerNode, ok := c.nodeOfDisco[sender] peerNode, ok := c.nodeOfDisco[sender]
@ -1795,9 +1816,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
if debugDisco { if debugDisco {
c.logf("magicsock: disco: ignoring disco-looking frame, don't know node for %v", sender.ShortString()) c.logf("magicsock: disco: ignoring disco-looking frame, don't know node for %v", sender.ShortString())
} }
// Returning false keeps passing it down, to WireGuard. return
// WireGuard will almost surely reject it, but give it a chance.
return false
} }
needsRecvActivityCall := false needsRecvActivityCall := false
@ -1810,7 +1829,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
c.logf("magicsock: got disco message from idle peer, starting lazy conf for %v, %v", peerNode.Key.ShortString(), sender.ShortString()) c.logf("magicsock: got disco message from idle peer, starting lazy conf for %v, %v", peerNode.Key.ShortString(), sender.ShortString())
if c.noteRecvActivity == nil { if c.noteRecvActivity == nil {
c.logf("magicsock: [unexpected] have node without endpoint, without c.noteRecvActivity hook") c.logf("magicsock: [unexpected] have node without endpoint, without c.noteRecvActivity hook")
return false return
} }
needsRecvActivityCall = true needsRecvActivityCall = true
} else { } else {
@ -1829,7 +1848,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
// Now, recheck invariants that might've changed while we'd // Now, recheck invariants that might've changed while we'd
// released the lock, which isn't much: // released the lock, which isn't much:
if c.closed || c.privateKey.IsZero() { if c.closed || c.privateKey.IsZero() {
return true return
} }
de, ok = c.endpointOfDisco[sender] de, ok = c.endpointOfDisco[sender]
if !ok { if !ok {
@ -1838,7 +1857,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
return false return false
} }
c.logf("magicsock: [unexpected] lazy endpoint not created for %v, %v", peerNode.Key.ShortString(), sender.ShortString()) c.logf("magicsock: [unexpected] lazy endpoint not created for %v, %v", peerNode.Key.ShortString(), sender.ShortString())
return false return
} }
if !endpointFound0 { if !endpointFound0 {
c.logf("magicsock: lazy endpoint created via disco message for %v, %v", peerNode.Key.ShortString(), sender.ShortString()) c.logf("magicsock: lazy endpoint created via disco message for %v, %v", peerNode.Key.ShortString(), sender.ShortString())
@ -1865,7 +1884,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
c.logf("magicsock: disco: failed to open naclbox from %v (wrong rcpt?)", sender) c.logf("magicsock: disco: failed to open naclbox from %v (wrong rcpt?)", sender)
} }
// TODO(bradfitz): add some counter for this that logs rarely // TODO(bradfitz): add some counter for this that logs rarely
return false return
} }
dm, err := disco.Parse(payload) dm, err := disco.Parse(payload)
@ -1879,7 +1898,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
// understand. Not even worth logging about, lest it // understand. Not even worth logging about, lest it
// be too spammy for old clients. // be too spammy for old clients.
// TODO(bradfitz): add some counter for this that logs rarely // TODO(bradfitz): add some counter for this that logs rarely
return true return
} }
switch dm := dm.(type) { switch dm := dm.(type) {
@ -1887,14 +1906,14 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
c.handlePingLocked(dm, de, src, sender, peerNode) c.handlePingLocked(dm, de, src, sender, peerNode)
case *disco.Pong: case *disco.Pong:
if de == nil { if de == nil {
return true return
} }
de.handlePongConnLocked(dm, src) de.handlePongConnLocked(dm, src)
case *disco.CallMeMaybe: case *disco.CallMeMaybe:
if src.IP != derpMagicIPAddr { if src.IP != derpMagicIPAddr {
// CallMeMaybe messages should only come via DERP. // CallMeMaybe messages should only come via DERP.
c.logf("[unexpected] CallMeMaybe packets should only come via DERP") c.logf("[unexpected] CallMeMaybe packets should only come via DERP")
return true return
} }
if de != nil { if de != nil {
c.logf("magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", c.logf("magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints",
@ -1904,8 +1923,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
go de.handleCallMeMaybe(dm) go de.handleCallMeMaybe(dm)
} }
} }
return
return true
} }
func (c *Conn) handlePingLocked(dm *disco.Ping, de *discoEndpoint, src netaddr.IPPort, sender tailcfg.DiscoKey, peerNode *tailcfg.Node) { func (c *Conn) handlePingLocked(dm *disco.Ping, de *discoEndpoint, src netaddr.IPPort, sender tailcfg.DiscoKey, peerNode *tailcfg.Node) {
@ -2082,6 +2100,7 @@ func (c *Conn) SetPrivateKey(privateKey wgkey.Private) error {
return nil return nil
} }
c.privateKey = newKey c.privateKey = newKey
c.havePrivateKey.Set(!newKey.IsZero())
if oldKey.IsZero() { if oldKey.IsZero() {
c.everHadKey = true c.everHadKey = true