mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-25 02:02:51 +00:00 
			
		
		
		
	wgengine/magicsock: drop donec channel, rename epUpdateCtx to serve its purpose
This commit is contained in:
		| @@ -41,11 +41,10 @@ type Conn struct { | |||||||
| 	startEpUpdate chan struct{} // send to trigger endpoint update | 	startEpUpdate chan struct{} // send to trigger endpoint update | ||||||
| 	epFunc        func(endpoints []string) | 	epFunc        func(endpoints []string) | ||||||
| 	logf          func(format string, args ...interface{}) | 	logf          func(format string, args ...interface{}) | ||||||
| 	donec         chan struct{} // closed on Conn.Close |  | ||||||
| 	sendLogLimit  *rate.Limiter | 	sendLogLimit  *rate.Limiter | ||||||
|  |  | ||||||
| 	epUpdateCtx    context.Context // endpoint updater context | 	connCtx       context.Context // closed on Conn.Close | ||||||
| 	epUpdateCancel func()          // the func to cancel epUpdateCtx | 	connCtxCancel func()          // closes connCtx | ||||||
|  |  | ||||||
| 	// addrsByUDP is a map of every remote ip:port to a priority | 	// addrsByUDP is a map of every remote ip:port to a priority | ||||||
| 	// list of endpoint addresses for a peer. | 	// list of endpoint addresses for a peer. | ||||||
| @@ -135,29 +134,30 @@ func Listen(opts Options) (*Conn, error) { | |||||||
| 		return nil, fmt.Errorf("magicsock.Listen: %v", err) | 		return nil, fmt.Errorf("magicsock.Listen: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	epUpdateCtx, epUpdateCancel := context.WithCancel(context.Background()) | 	connCtx, connCtxCancel := context.WithCancel(context.Background()) | ||||||
| 	c := &Conn{ | 	c := &Conn{ | ||||||
| 		pconn:          new(RebindingUDPConn), | 		pconn:         new(RebindingUDPConn), | ||||||
| 		pconnPort:      opts.Port, | 		pconnPort:     opts.Port, | ||||||
| 		donec:          make(chan struct{}), | 		sendLogLimit:  rate.NewLimiter(rate.Every(1*time.Minute), 1), | ||||||
| 		sendLogLimit:   rate.NewLimiter(rate.Every(1*time.Minute), 1), | 		stunServers:   append([]string{}, opts.STUN...), | ||||||
| 		stunServers:    append([]string{}, opts.STUN...), | 		startEpUpdate: make(chan struct{}, 1), | ||||||
| 		startEpUpdate:  make(chan struct{}, 1), | 		connCtx:       connCtx, | ||||||
| 		epUpdateCtx:    epUpdateCtx, | 		connCtxCancel: connCtxCancel, | ||||||
| 		epUpdateCancel: epUpdateCancel, | 		epFunc:        opts.endpointsFunc(), | ||||||
| 		epFunc:         opts.endpointsFunc(), | 		logf:          log.Printf, | ||||||
| 		logf:           log.Printf, | 		addrsByUDP:    make(map[udpAddr]*AddrSet), | ||||||
| 		addrsByUDP:     make(map[udpAddr]*AddrSet), | 		derpRecvCh:    make(chan derpReadResult), | ||||||
| 		derpRecvCh:     make(chan derpReadResult), | 		udpRecvCh:     make(chan udpReadResult), | ||||||
| 		udpRecvCh:      make(chan udpReadResult), |  | ||||||
| 	} | 	} | ||||||
| 	c.ignoreSTUNPackets() | 	c.ignoreSTUNPackets() | ||||||
| 	c.pconn.Reset(packetConn.(*net.UDPConn)) | 	c.pconn.Reset(packetConn.(*net.UDPConn)) | ||||||
| 	c.reSTUN() | 	c.reSTUN() | ||||||
| 	go c.epUpdate(epUpdateCtx) | 	go c.epUpdate(connCtx) | ||||||
| 	return c, nil | 	return c, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (c *Conn) donec() <-chan struct{} { return c.connCtx.Done() } | ||||||
|  |  | ||||||
| // ignoreSTUNPackets sets a STUN packet processing func that does nothing. | // ignoreSTUNPackets sets a STUN packet processing func that does nothing. | ||||||
| func (c *Conn) ignoreSTUNPackets() { | func (c *Conn) ignoreSTUNPackets() { | ||||||
| 	c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {}) | 	c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {}) | ||||||
| @@ -497,11 +497,11 @@ func (c *Conn) sendAddr(addr *net.UDPAddr, pubKey key.Public, b []byte) error { | |||||||
| 	if ch := c.derpWriteChanOfAddr(addr); ch != nil { | 	if ch := c.derpWriteChanOfAddr(addr); ch != nil { | ||||||
| 		errc := make(chan error, 1) | 		errc := make(chan error, 1) | ||||||
| 		select { | 		select { | ||||||
| 		case <-c.donec: | 		case <-c.donec(): | ||||||
| 			return errConnClosed | 			return errConnClosed | ||||||
| 		case ch <- derpWriteRequest{addr, pubKey, b, errc}: | 		case ch <- derpWriteRequest{addr, pubKey, b, errc}: | ||||||
| 			select { | 			select { | ||||||
| 			case <-c.donec: | 			case <-c.donec(): | ||||||
| 				return errConnClosed | 				return errConnClosed | ||||||
| 			case err := <-errc: | 			case err := <-errc: | ||||||
| 				return err // usually nil | 				return err // usually nil | ||||||
| @@ -595,7 +595,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc | |||||||
| 		} | 		} | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			select { | 			select { | ||||||
| 			case <-c.donec: | 			case <-c.donec(): | ||||||
| 				return | 				return | ||||||
| 			case <-ctx.Done(): | 			case <-ctx.Done(): | ||||||
| 				return | 				return | ||||||
| @@ -617,7 +617,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc | |||||||
| 			log.Printf("got derp %v packet: %q", derpFakeAddr, buf[:bufValid]) | 			log.Printf("got derp %v packet: %q", derpFakeAddr, buf[:bufValid]) | ||||||
| 		} | 		} | ||||||
| 		select { | 		select { | ||||||
| 		case <-c.donec: | 		case <-c.donec(): | ||||||
| 			return | 			return | ||||||
| 		case c.derpRecvCh <- derpReadResult{derpFakeAddr, bufValid, copyFn}: | 		case c.derpRecvCh <- derpReadResult{derpFakeAddr, bufValid, copyFn}: | ||||||
| 			<-didCopy | 			<-didCopy | ||||||
| @@ -639,7 +639,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc | |||||||
| 		select { | 		select { | ||||||
| 		case <-ctx.Done(): | 		case <-ctx.Done(): | ||||||
| 			return | 			return | ||||||
| 		case <-c.donec: | 		case <-c.donec(): | ||||||
| 			return | 			return | ||||||
| 		case wr := <-ch: | 		case wr := <-ch: | ||||||
| 			err := dc.Send(wr.pubKey, wr.b) | 			err := dc.Send(wr.pubKey, wr.b) | ||||||
| @@ -648,7 +648,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc | |||||||
| 			} | 			} | ||||||
| 			select { | 			select { | ||||||
| 			case wr.errc <- err: | 			case wr.errc <- err: | ||||||
| 			case <-c.donec: | 			case <-c.donec(): | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @@ -685,7 +685,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr | |||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				select { | 				select { | ||||||
| 				case c.udpRecvCh <- udpReadResult{err: err}: | 				case c.udpRecvCh <- udpReadResult{err: err}: | ||||||
| 				case <-c.donec: | 				case <-c.donec(): | ||||||
| 				} | 				} | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| @@ -698,7 +698,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr | |||||||
| 			addr.IP = addr.IP.To4() | 			addr.IP = addr.IP.To4() | ||||||
| 			select { | 			select { | ||||||
| 			case c.udpRecvCh <- udpReadResult{n: n, addr: addr}: | 			case c.udpRecvCh <- udpReadResult{n: n, addr: addr}: | ||||||
| 			case <-c.donec: | 			case <-c.donec(): | ||||||
| 			} | 			} | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| @@ -719,7 +719,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr | |||||||
| 			// The main point of this receive, though, is to make sure that the goroutine | 			// The main point of this receive, though, is to make sure that the goroutine | ||||||
| 			// is done with our b []byte buf. | 			// is done with our b []byte buf. | ||||||
| 			c.pconn.SetReadDeadline(time.Time{}) | 			c.pconn.SetReadDeadline(time.Time{}) | ||||||
| 		case <-c.donec: | 		case <-c.donec(): | ||||||
| 			return 0, nil, nil, errors.New("Conn closed") | 			return 0, nil, nil, errors.New("Conn closed") | ||||||
| 		} | 		} | ||||||
| 		n, addr = dm.n, dm.derpAddr | 		n, addr = dm.n, dm.derpAddr | ||||||
| @@ -753,6 +753,13 @@ func (c *Conn) ReceiveIPv6(buff []byte) (int, conn.Endpoint, *net.UDPAddr, error | |||||||
| 	return 0, nil, nil, syscall.EAFNOSUPPORT | 	return 0, nil, nil, syscall.EAFNOSUPPORT | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // SetPrivateKey sets the connection's private key. | ||||||
|  | // | ||||||
|  | // This is only used to be able prove our identity when connecting to | ||||||
|  | // DERP servers. | ||||||
|  | // | ||||||
|  | // If the private key changes, any DERP connections are torn down & | ||||||
|  | // recreated when needed. | ||||||
| func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { | func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { | ||||||
| 	c.derpMu.Lock() | 	c.derpMu.Lock() | ||||||
| 	defer c.derpMu.Unlock() | 	defer c.derpMu.Unlock() | ||||||
| @@ -768,6 +775,13 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Key changed. Close any DERP connections. | 	// Key changed. Close any DERP connections. | ||||||
|  | 	c.closeAllDerpLocked() | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // c.derpMu must be held. | ||||||
|  | func (c *Conn) closeAllDerpLocked() { | ||||||
| 	for _, c := range c.derpConn { | 	for _, c := range c.derpConn { | ||||||
| 		go c.Close() | 		go c.Close() | ||||||
| 	} | 	} | ||||||
| @@ -777,30 +791,31 @@ func (c *Conn) SetPrivateKey(privateKey wgcfg.PrivateKey) error { | |||||||
| 	c.derpConn = nil | 	c.derpConn = nil | ||||||
| 	c.derpCancel = nil | 	c.derpCancel = nil | ||||||
| 	c.derpWriteCh = nil | 	c.derpWriteCh = nil | ||||||
| 	return nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Conn) SetMark(value uint32) error { return nil } | func (c *Conn) SetMark(value uint32) error { return nil } | ||||||
| func (c *Conn) LastMark() uint32           { return 0 } | func (c *Conn) LastMark() uint32           { return 0 } | ||||||
|  |  | ||||||
| func (c *Conn) Close() error { | func (c *Conn) Close() error { | ||||||
|  | 	// TODO: make this safe for concurrent Close? it's safe now only if Close calls are serialized. | ||||||
| 	select { | 	select { | ||||||
| 	case <-c.donec: | 	case <-c.donec(): | ||||||
| 		return nil | 		return nil | ||||||
| 	default: | 	default: | ||||||
| 	} | 	} | ||||||
| 	close(c.donec) | 	c.connCtxCancel() | ||||||
| 	c.epUpdateCancel() |  | ||||||
| 	for _, dc := range c.derpConn { | 	c.derpMu.Lock() | ||||||
| 		dc.Close() | 	c.closeAllDerpLocked() | ||||||
| 	} | 	c.derpMu.Unlock() | ||||||
|  |  | ||||||
| 	return c.pconn.Close() | 	return c.pconn.Close() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *Conn) reSTUN() { | func (c *Conn) reSTUN() { | ||||||
| 	select { | 	select { | ||||||
| 	case c.startEpUpdate <- struct{}{}: | 	case c.startEpUpdate <- struct{}{}: | ||||||
| 	case <-c.epUpdateCtx.Done(): | 	case <-c.donec(): | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Brad Fitzpatrick
					Brad Fitzpatrick