wgengine/magicsock: drop donec channel, rename epUpdateCtx to serve its purpose

This commit is contained in:
Brad Fitzpatrick 2020-03-02 09:31:25 -08:00
parent a399ef3dc7
commit af7a01d6f0

View File

@ -41,11 +41,10 @@ type Conn struct {
startEpUpdate chan struct{} // send to trigger endpoint update
epFunc func(endpoints []string)
logf func(format string, args ...interface{})
donec chan struct{} // closed on Conn.Close
sendLogLimit *rate.Limiter
epUpdateCtx context.Context // endpoint updater context
epUpdateCancel func() // the func to cancel epUpdateCtx
connCtx context.Context // closed on Conn.Close
connCtxCancel func() // closes connCtx
// addrsByUDP is a map of every remote ip:port to a priority
// list of endpoint addresses for a peer.
@ -135,29 +134,30 @@ func Listen(opts Options) (*Conn, error) {
return nil, fmt.Errorf("magicsock.Listen: %v", err)
}
epUpdateCtx, epUpdateCancel := context.WithCancel(context.Background())
connCtx, connCtxCancel := context.WithCancel(context.Background())
c := &Conn{
pconn: new(RebindingUDPConn),
pconnPort: opts.Port,
donec: make(chan struct{}),
sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1),
stunServers: append([]string{}, opts.STUN...),
startEpUpdate: make(chan struct{}, 1),
epUpdateCtx: epUpdateCtx,
epUpdateCancel: epUpdateCancel,
epFunc: opts.endpointsFunc(),
logf: log.Printf,
addrsByUDP: make(map[udpAddr]*AddrSet),
derpRecvCh: make(chan derpReadResult),
udpRecvCh: make(chan udpReadResult),
pconn: new(RebindingUDPConn),
pconnPort: opts.Port,
sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1),
stunServers: append([]string{}, opts.STUN...),
startEpUpdate: make(chan struct{}, 1),
connCtx: connCtx,
connCtxCancel: connCtxCancel,
epFunc: opts.endpointsFunc(),
logf: log.Printf,
addrsByUDP: make(map[udpAddr]*AddrSet),
derpRecvCh: make(chan derpReadResult),
udpRecvCh: make(chan udpReadResult),
}
c.ignoreSTUNPackets()
c.pconn.Reset(packetConn.(*net.UDPConn))
c.reSTUN()
go c.epUpdate(epUpdateCtx)
go c.epUpdate(connCtx)
return c, nil
}
func (c *Conn) donec() <-chan struct{} { return c.connCtx.Done() }
// ignoreSTUNPackets sets a STUN packet processing func that does nothing.
func (c *Conn) ignoreSTUNPackets() {
c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {})
@ -497,11 +497,11 @@ func (c *Conn) sendAddr(addr *net.UDPAddr, pubKey key.Public, b []byte) error {
if ch := c.derpWriteChanOfAddr(addr); ch != nil {
errc := make(chan error, 1)
select {
case <-c.donec:
case <-c.donec():
return errConnClosed
case ch <- derpWriteRequest{addr, pubKey, b, errc}:
select {
case <-c.donec:
case <-c.donec():
return errConnClosed
case err := <-errc:
return err // usually nil
@ -595,7 +595,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
}
if err != nil {
select {
case <-c.donec:
case <-c.donec():
return
case <-ctx.Done():
return
@ -617,7 +617,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
log.Printf("got derp %v packet: %q", derpFakeAddr, buf[:bufValid])
}
select {
case <-c.donec:
case <-c.donec():
return
case c.derpRecvCh <- derpReadResult{derpFakeAddr, bufValid, copyFn}:
<-didCopy
@ -639,7 +639,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
select {
case <-ctx.Done():
return
case <-c.donec:
case <-c.donec():
return
case wr := <-ch:
err := dc.Send(wr.pubKey, wr.b)
@ -648,7 +648,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
}
select {
case wr.errc <- err:
case <-c.donec:
case <-c.donec():
return
}
}
@ -685,7 +685,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
if err != nil {
select {
case c.udpRecvCh <- udpReadResult{err: err}:
case <-c.donec:
case <-c.donec():
}
return
}
@ -698,7 +698,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
addr.IP = addr.IP.To4()
select {
case c.udpRecvCh <- udpReadResult{n: n, addr: addr}:
case <-c.donec:
case <-c.donec():
}
return
}
@ -719,7 +719,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
// The main point of this receive, though, is to make sure that the goroutine
// is done with our b []byte buf.
c.pconn.SetReadDeadline(time.Time{})
case <-c.donec:
case <-c.donec():
return 0, nil, nil, errors.New("Conn closed")
}
n, addr = dm.n, dm.derpAddr
@ -753,6 +753,13 @@ func (c *Conn) ReceiveIPv6(buff []byte) (int, conn.Endpoint, *net.UDPAddr, error
return 0, nil, nil, syscall.EAFNOSUPPORT
}
// SetPrivateKey sets the connection's private key.
//
// This is only used to be able prove our identity when connecting to
// DERP servers.
//
// If the private key changes, any DERP connections are torn down &
// recreated when needed.
func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error {
c.derpMu.Lock()
defer c.derpMu.Unlock()
@ -768,6 +775,13 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error {
}
// Key changed. Close any DERP connections.
c.closeAllDerpLocked()
return nil
}
// c.derpMu must be held.
func (c *Conn) closeAllDerpLocked() {
for _, c := range c.derpConn {
go c.Close()
}
@ -777,30 +791,31 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error {
c.derpConn = nil
c.derpCancel = nil
c.derpWriteCh = nil
return nil
}
func (c *Conn) SetMark(value uint32) error { return nil }
func (c *Conn) LastMark() uint32 { return 0 }
func (c *Conn) Close() error {
// TODO: make this safe for concurrent Close? it's safe now only if Close calls are serialized.
select {
case <-c.donec:
case <-c.donec():
return nil
default:
}
close(c.donec)
c.epUpdateCancel()
for _, dc := range c.derpConn {
dc.Close()
}
c.connCtxCancel()
c.derpMu.Lock()
c.closeAllDerpLocked()
c.derpMu.Unlock()
return c.pconn.Close()
}
func (c *Conn) reSTUN() {
select {
case c.startEpUpdate <- struct{}{}:
case <-c.epUpdateCtx.Done():
case <-c.donec():
}
}