wgengine, magicsock: fix SetPrivateKey data race

Updates #112

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2020-02-28 11:13:28 -08:00
parent 4cd3e82bbd
commit 67ede8d6d2
2 changed files with 54 additions and 15 deletions

View File

@ -36,7 +36,6 @@
type Conn struct { type Conn struct {
pconn *RebindingUDPConn pconn *RebindingUDPConn
pconnPort uint16 pconnPort uint16
privateKey key.Private
stunServers []string stunServers []string
startEpUpdate chan struct{} // send to trigger endpoint update startEpUpdate chan struct{} // send to trigger endpoint update
epFunc func(endpoints []string) epFunc func(endpoints []string)
@ -68,7 +67,9 @@ type Conn struct {
derpRecvCh chan derpReadResult derpRecvCh chan derpReadResult
derpMu sync.Mutex derpMu sync.Mutex
derpConn map[int]*derphttp.Client // magic derp port (see derpmap.go) to its client privateKey key.Private
derpConn map[int]*derphttp.Client // magic derp port (see derpmap.go) to its client
derpCancel map[int]context.CancelFunc // to close derp goroutines
derpWriteCh map[int]chan<- derpWriteRequest derpWriteCh map[int]chan<- derpWriteRequest
} }
@ -525,25 +526,33 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr) chan<- derpWriteRequest {
} }
c.derpMu.Lock() c.derpMu.Lock()
defer c.derpMu.Unlock() defer c.derpMu.Unlock()
if c.privateKey.IsZero() {
c.logf("DERP lookup of %v with no private key; ignoring", addr.IP)
return nil
}
ch, ok := c.derpWriteCh[addr.Port] ch, ok := c.derpWriteCh[addr.Port]
if !ok { if !ok {
if c.derpWriteCh == nil { if c.derpWriteCh == nil {
c.derpWriteCh = make(map[int]chan<- derpWriteRequest) c.derpWriteCh = make(map[int]chan<- derpWriteRequest)
c.derpConn = make(map[int]*derphttp.Client) c.derpConn = make(map[int]*derphttp.Client)
c.derpCancel = make(map[int]context.CancelFunc)
} }
host := derpHost(addr.Port) host := derpHost(addr.Port)
dc, err := derphttp.NewClient(c.privateKey, "https://"+host+"/derp", log.Printf) dc, err := derphttp.NewClient(c.privateKey, "https://"+host+"/derp", log.Printf)
if err != nil { if err != nil {
log.Printf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err) c.logf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err)
return nil return nil
} }
ctx, cancel := context.WithCancel(context.Background())
bidiCh := make(chan derpWriteRequest, bufferedDerpWritesBeforeDrop) bidiCh := make(chan derpWriteRequest, bufferedDerpWritesBeforeDrop)
ch = bidiCh ch = bidiCh
c.derpConn[addr.Port] = dc c.derpConn[addr.Port] = dc
c.derpWriteCh[addr.Port] = ch c.derpWriteCh[addr.Port] = ch
go c.runDerpReader(addr, dc) c.derpCancel[addr.Port] = cancel
go c.runDerpWriter(addr, dc, bidiCh) go c.runDerpReader(ctx, addr, dc)
go c.runDerpWriter(ctx, addr, dc, bidiCh)
} }
return ch return ch
} }
@ -564,7 +573,7 @@ type derpReadResult struct {
// runDerpReader runs in a goroutine for the life of a DERP // runDerpReader runs in a goroutine for the life of a DERP
// connection, handling received packets. // connection, handling received packets.
func (c *Conn) runDerpReader(derpFakeAddr *net.UDPAddr, dc *derphttp.Client) { func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc *derphttp.Client) {
didCopy := make(chan struct{}, 1) didCopy := make(chan struct{}, 1)
var buf [derp.MaxPacketSize]byte var buf [derp.MaxPacketSize]byte
var bufValid int // bytes in buf that are valid var bufValid int // bytes in buf that are valid
@ -576,13 +585,15 @@ func (c *Conn) runDerpReader(derpFakeAddr *net.UDPAddr, dc *derphttp.Client) {
for { for {
msg, err := dc.Recv(buf[:]) msg, err := dc.Recv(buf[:])
if err == derphttp.ErrClientClosed {
return
}
if err != nil { if err != nil {
if err == derphttp.ErrClientClosed {
return
}
select { select {
case <-c.donec: case <-c.donec:
return return
case <-ctx.Done():
return
default: default:
} }
log.Printf("derp.Recv: %v", err) log.Printf("derp.Recv: %v", err)
@ -618,9 +629,11 @@ type derpWriteRequest struct {
// runDerpWriter runs in a goroutine for the life of a DERP // runDerpWriter runs in a goroutine for the life of a DERP
// connection, handling received packets. // connection, handling received packets.
func (c *Conn) runDerpWriter(derpFakeAddr *net.UDPAddr, dc *derphttp.Client, ch <-chan derpWriteRequest) { func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc *derphttp.Client, ch <-chan derpWriteRequest) {
for { for {
select { select {
case <-ctx.Done():
return
case <-c.donec: case <-c.donec:
return return
case wr := <-ch: case wr := <-ch:
@ -740,7 +753,29 @@ func (c *Conn) ReceiveIPv6(buff []byte) (int, conn.Endpoint, *net.UDPAddr, error
} }
func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error {
c.privateKey = key.Private(privateKey) c.derpMu.Lock()
defer c.derpMu.Unlock()
oldKey, newKey := c.privateKey, key.Private(privateKey)
if newKey == oldKey {
return nil
}
c.privateKey = newKey
if oldKey.IsZero() {
// Initial configuration on start.
return nil
}
// Key changed. Close any DERP connections.
for _, c := range c.derpConn {
go c.Close()
}
for _, cancel := range c.derpCancel {
cancel()
}
c.derpConn = nil
c.derpCancel = nil
c.derpWriteCh = nil
return nil return nil
} }

View File

@ -325,15 +325,19 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, dnsDomains []string) error
e.lastReconfig = rc e.lastReconfig = rc
e.lastCfg = cfg.Copy() e.lastCfg = cfg.Copy()
// Tell magicsock about the new (or initial) private key
// (which is needed by DERP) before wgdev gets it, as wgdev
// will start trying to handshake, which we want to be able to
// go over DERP.
if err := e.magicConn.SetPrivateKey(cfg.PrivateKey); err != nil {
e.logf("magicsock: %v\n", err)
}
if err := e.wgdev.Reconfig(cfg); err != nil { if err := e.wgdev.Reconfig(cfg); err != nil {
e.logf("wgdev.Reconfig: %v\n", err) e.logf("wgdev.Reconfig: %v\n", err)
return err return err
} }
if err := e.magicConn.SetPrivateKey(cfg.PrivateKey); err != nil {
e.logf("magicsock: %v\n", err)
}
// TODO(apenwarr): only handling the first local address. // TODO(apenwarr): only handling the first local address.
// Currently we never use more than one anyway. // Currently we never use more than one anyway.
var cidr wgcfg.CIDR var cidr wgcfg.CIDR