mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-31 03:49:52 +00:00 
			
		
		
		
	wgengine/magicsock: make portmapping async
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
		 Brad Fitzpatrick
					Brad Fitzpatrick
				
			
				
					committed by
					
						 Brad Fitzpatrick
						Brad Fitzpatrick
					
				
			
			
				
	
			
			
			 Brad Fitzpatrick
						Brad Fitzpatrick
					
				
			
						parent
						
							afbd35482d
						
					
				
				
					commit
					92077ae78c
				
			| @@ -50,7 +50,7 @@ var netcheckArgs struct { | ||||
| func runNetcheck(ctx context.Context, args []string) error { | ||||
| 	c := &netcheck.Client{ | ||||
| 		UDPBindAddr: os.Getenv("TS_DEBUG_NETCHECK_UDP_BIND"), | ||||
| 		PortMapper:  portmapper.NewClient(logger.WithPrefix(log.Printf, "portmap: ")), | ||||
| 		PortMapper:  portmapper.NewClient(logger.WithPrefix(log.Printf, "portmap: "), nil), | ||||
| 	} | ||||
| 	if netcheckArgs.verbose { | ||||
| 		c.Logf = logger.WithPrefix(log.Printf, "netcheck: ") | ||||
|   | ||||
| @@ -44,9 +44,15 @@ const trustServiceStillAvailableDuration = 10 * time.Minute | ||||
| type Client struct { | ||||
| 	logf         logger.Logf | ||||
| 	ipAndGateway func() (gw, ip netaddr.IP, ok bool) | ||||
| 	onChange     func() // or nil | ||||
|  | ||||
| 	mu sync.Mutex // guards following, and all fields thereof | ||||
|  | ||||
| 	// runningCreate is whether we're currently working on creating | ||||
| 	// a port mapping (whether GetCachedMappingOrStartCreatingOne kicked | ||||
| 	// off a createMapping goroutine). | ||||
| 	runningCreate bool | ||||
|  | ||||
| 	lastMyIP netaddr.IP | ||||
| 	lastGW   netaddr.IP | ||||
| 	closed   bool | ||||
| @@ -68,7 +74,7 @@ type Client struct { | ||||
| func (c *Client) HaveMapping() bool { | ||||
| 	c.mu.Lock() | ||||
| 	defer c.mu.Unlock() | ||||
| 	return c.pmpMapping != nil && c.pmpMapping.useUntil.After(time.Now()) | ||||
| 	return c.pmpMapping != nil && c.pmpMapping.goodUntil.After(time.Now()) | ||||
| } | ||||
|  | ||||
| // pmpMapping is an already-created PMP mapping. | ||||
| @@ -78,7 +84,8 @@ type pmpMapping struct { | ||||
| 	gw         netaddr.IP | ||||
| 	external   netaddr.IPPort | ||||
| 	internal   netaddr.IPPort | ||||
| 	useUntil time.Time // the mapping's lifetime minus renewal interval | ||||
| 	renewAfter time.Time // the time at which we want to renew the mapping | ||||
| 	goodUntil  time.Time // the mapping's total lifetime | ||||
| 	epoch      uint32 | ||||
| } | ||||
|  | ||||
| @@ -99,10 +106,15 @@ func (m *pmpMapping) release() { | ||||
| } | ||||
|  | ||||
| // NewClient returns a new portmapping client. | ||||
| func NewClient(logf logger.Logf) *Client { | ||||
| // | ||||
| // The optional onChange argument specifies a func to run in a new | ||||
| // goroutine whenever the port mapping status has changed. If nil, | ||||
| // it doesn't make a callback. | ||||
| func NewClient(logf logger.Logf, onChange func()) *Client { | ||||
| 	return &Client{ | ||||
| 		logf:         logf, | ||||
| 		ipAndGateway: interfaces.LikelyHomeRouterIP, | ||||
| 		onChange:     onChange, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -221,8 +233,7 @@ func closeCloserOnContextDone(ctx context.Context, c io.Closer) (stop func()) { | ||||
| 	return func() { close(stopWaitDone) } | ||||
| } | ||||
|  | ||||
| // NoMappingError is returned by CreateOrGetMapping when no NAT | ||||
| // mapping could be returned. | ||||
| // NoMappingError is returned when no NAT mapping could be done. | ||||
| type NoMappingError struct { | ||||
| 	err error | ||||
| } | ||||
| @@ -241,12 +252,62 @@ var ( | ||||
| 	ErrGatewayNotFound       = errors.New("failed to look up gateway address") | ||||
| ) | ||||
|  | ||||
| // CreateOrGetMapping either creates a new mapping or returns a cached | ||||
| // GetCachedMappingOrStartCreatingOne quickly returns with our current cached portmapping, if any. | ||||
| // If there's not one, it starts up a background goroutine to create one. | ||||
| // If the background goroutine ends up creating one, the onChange hook registered with the | ||||
| // NewClient constructor (if any) will fire. | ||||
| func (c *Client) GetCachedMappingOrStartCreatingOne() (external netaddr.IPPort, ok bool) { | ||||
| 	c.mu.Lock() | ||||
| 	defer c.mu.Unlock() | ||||
|  | ||||
| 	// Do we have an existing mapping that's valid? | ||||
| 	now := time.Now() | ||||
| 	if m := c.pmpMapping; m != nil { | ||||
| 		if now.Before(m.goodUntil) { | ||||
| 			if now.After(m.renewAfter) { | ||||
| 				c.maybeStartMappingLocked() | ||||
| 			} | ||||
| 			return m.external, true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	c.maybeStartMappingLocked() | ||||
| 	return netaddr.IPPort{}, false | ||||
| } | ||||
|  | ||||
| // maybeStartMappingLocked starts a createMapping goroutine up, if one isn't already running. | ||||
| // | ||||
| // c.mu must be held. | ||||
| func (c *Client) maybeStartMappingLocked() { | ||||
| 	if !c.runningCreate { | ||||
| 		c.runningCreate = true | ||||
| 		go c.createMapping() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (c *Client) createMapping() { | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	defer func() { | ||||
| 		c.mu.Lock() | ||||
| 		defer c.mu.Unlock() | ||||
| 		c.runningCreate = false | ||||
| 	}() | ||||
|  | ||||
| 	if _, err := c.createOrGetMapping(ctx); err == nil && c.onChange != nil { | ||||
| 		go c.onChange() | ||||
| 	} else if err != nil && !IsNoMappingError(err) { | ||||
| 		c.logf("createOrGetMapping: %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // createOrGetMapping either creates a new mapping or returns a cached | ||||
| // valid one. | ||||
| // | ||||
| // If no mapping is available, the error will be of type | ||||
| // NoMappingError; see IsNoMappingError. | ||||
| func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPort, err error) { | ||||
| func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPort, err error) { | ||||
| 	gw, myIP, ok := c.gatewayAndSelfIP() | ||||
| 	if !ok { | ||||
| 		return netaddr.IPPort{}, NoMappingError{ErrGatewayNotFound} | ||||
| @@ -266,7 +327,7 @@ func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPor | ||||
| 	// Do we have an existing mapping that's valid? | ||||
| 	now := time.Now() | ||||
| 	if m := c.pmpMapping; m != nil { | ||||
| 		if now.Before(m.useUntil) { | ||||
| 		if now.Before(m.renewAfter) { | ||||
| 			defer c.mu.Unlock() | ||||
| 			return m.external, nil | ||||
| 		} | ||||
| @@ -342,8 +403,9 @@ func (c *Client) CreateOrGetMapping(ctx context.Context) (external netaddr.IPPor | ||||
| 			if pres.OpCode == pmpOpReply|pmpOpMapUDP { | ||||
| 				m.external = m.external.WithPort(pres.ExternalPort) | ||||
| 				d := time.Duration(pres.MappingValidSeconds) * time.Second | ||||
| 				d /= 2 // renew in half the time | ||||
| 				m.useUntil = time.Now().Add(d) | ||||
| 				now := time.Now() | ||||
| 				m.goodUntil = now.Add(d) | ||||
| 				m.renewAfter = now.Add(d / 2) // renew in half the time | ||||
| 				m.epoch = pres.SecondsSinceEpoch | ||||
| 			} | ||||
| 		} | ||||
|   | ||||
| @@ -16,13 +16,13 @@ func TestCreateOrGetMapping(t *testing.T) { | ||||
| 	if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { | ||||
| 		t.Skip("skipping test without HIT_NETWORK=1") | ||||
| 	} | ||||
| 	c := NewClient(t.Logf) | ||||
| 	c := NewClient(t.Logf, nil) | ||||
| 	c.SetLocalPort(1234) | ||||
| 	for i := 0; i < 2; i++ { | ||||
| 		if i > 0 { | ||||
| 			time.Sleep(100 * time.Millisecond) | ||||
| 		} | ||||
| 		ext, err := c.CreateOrGetMapping(context.Background()) | ||||
| 		ext, err := c.createOrGetMapping(context.Background()) | ||||
| 		t.Logf("Got: %v, %v", ext, err) | ||||
| 	} | ||||
| } | ||||
| @@ -31,7 +31,7 @@ func TestClientProbe(t *testing.T) { | ||||
| 	if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { | ||||
| 		t.Skip("skipping test without HIT_NETWORK=1") | ||||
| 	} | ||||
| 	c := NewClient(t.Logf) | ||||
| 	c := NewClient(t.Logf, nil) | ||||
| 	for i := 0; i < 2; i++ { | ||||
| 		if i > 0 { | ||||
| 			time.Sleep(100 * time.Millisecond) | ||||
| @@ -45,10 +45,10 @@ func TestClientProbeThenMap(t *testing.T) { | ||||
| 	if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { | ||||
| 		t.Skip("skipping test without HIT_NETWORK=1") | ||||
| 	} | ||||
| 	c := NewClient(t.Logf) | ||||
| 	c := NewClient(t.Logf, nil) | ||||
| 	c.SetLocalPort(1234) | ||||
| 	res, err := c.Probe(context.Background()) | ||||
| 	t.Logf("Probe: %+v, %v", res, err) | ||||
| 	ext, err := c.CreateOrGetMapping(context.Background()) | ||||
| 	t.Logf("CreateOrGetMapping: %v, %v", ext, err) | ||||
| 	ext, err := c.createOrGetMapping(context.Background()) | ||||
| 	t.Logf("createOrGetMapping: %v, %v", ext, err) | ||||
| } | ||||
|   | ||||
| @@ -486,7 +486,7 @@ func NewConn(opts Options) (*Conn, error) { | ||||
| 	c.noteRecvActivity = opts.NoteRecvActivity | ||||
| 	c.simulatedNetwork = opts.SimulatedNetwork | ||||
| 	c.disableLegacy = opts.DisableLegacyNetworking | ||||
| 	c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: ")) | ||||
| 	c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: "), c.onPortMapChanged) | ||||
| 	if opts.LinkMonitor != nil { | ||||
| 		c.portMapper.SetGatewayLookupFunc(opts.LinkMonitor.GatewayAndSelfIP) | ||||
| 	} | ||||
| @@ -979,6 +979,8 @@ func (c *Conn) goDerpConnect(node int) { | ||||
| // | ||||
| // c.mu must NOT be held. | ||||
| func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, error) { | ||||
| 	portmapExt, havePortmap := c.portMapper.GetCachedMappingOrStartCreatingOne() | ||||
|  | ||||
| 	nr, err := c.updateNetInfo(ctx) | ||||
| 	if err != nil { | ||||
| 		c.logf("magicsock.Conn.determineEndpoints: updateNetInfo: %v", err) | ||||
| @@ -1002,11 +1004,13 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if ext, err := c.portMapper.CreateOrGetMapping(ctx); err == nil { | ||||
| 		addAddr(ext, tailcfg.EndpointPortmapped) | ||||
| 	// If we didn't have a portmap earlier, maybe it's done by now. | ||||
| 	if !havePortmap { | ||||
| 		portmapExt, havePortmap = c.portMapper.GetCachedMappingOrStartCreatingOne() | ||||
| 	} | ||||
| 	if havePortmap { | ||||
| 		addAddr(portmapExt, tailcfg.EndpointPortmapped) | ||||
| 		c.setNetInfoHavePortMap() | ||||
| 	} else if !portmapper.IsNoMappingError(err) { | ||||
| 		c.logf("portmapper: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if nr.GlobalV4 != "" { | ||||
| @@ -2563,6 +2567,8 @@ func (c *Conn) shouldDoPeriodicReSTUNLocked() bool { | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func (c *Conn) onPortMapChanged() { c.ReSTUN("portmap-changed") } | ||||
|  | ||||
| // ReSTUN triggers an address discovery. | ||||
| // The provided why string is for debug logging only. | ||||
| func (c *Conn) ReSTUN(why string) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user