control/controlclient: use ctx passed down to NoiseClient.getConn

Without this, the client would just get stuck dialing even if the
context was canceled.

Updates tailscale/corp#12590

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2023-07-07 10:44:42 -07:00 committed by Maisem Ali
parent 92fb80d55f
commit 9d1a3a995c

View File

@ -287,6 +287,25 @@ func (nc *NoiseClient) GetSingleUseRoundTripper(ctx context.Context) (http.Round
return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection")
}
// contextErr is an error that wraps another error and is used to indicate that
// the error was because a context expired.
type contextErr struct {
err error
}
func (e contextErr) Error() string {
return e.err.Error()
}
func (e contextErr) Unwrap() error {
return e.err
}
// getConn returns a noiseConn that can be used to make requests to the
// coordination server. It may return a cached connection or create a new one.
// Dials are singleflighted, so concurrent calls to getConn may only dial once.
// As such, context values may not be respected as there are no guarantees that
// the context passed to getConn is the same as the context passed to dial.
func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) {
nc.mu.Lock()
if last := nc.last; last != nil && last.canTakeNewRequest() {
@ -295,11 +314,35 @@ func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) {
}
nc.mu.Unlock()
conn, err, _ := nc.sfDial.Do(struct{}{}, nc.dial)
if err != nil {
return nil, err
for {
// We singeflight the dial to avoid making multiple connections, however
// that means that we can't simply cancel the dial if the context is
// canceled. Instead, we have to additionally check that the context
// which was canceled is our context and retry if our context is still
// valid.
conn, err, _ := nc.sfDial.Do(struct{}{}, func() (*noiseConn, error) {
c, err := nc.dial(ctx)
if err != nil {
if ctx.Err() != nil {
return nil, contextErr{ctx.Err()}
}
return nil, err
}
return c, nil
})
var ce contextErr
if err == nil || !errors.As(err, &ce) {
return conn, err
}
if ctx.Err() == nil {
// The dial failed because of a context error, but our context
// is still valid. Retry.
continue
}
// The dial failed because our context was canceled. Return the
// underlying error.
return nil, ce.Unwrap()
}
return conn, nil
}
func (nc *NoiseClient) RoundTrip(req *http.Request) (*http.Response, error) {
@ -344,7 +387,7 @@ func (nc *NoiseClient) Close() error {
// dial opens a new connection to tailcontrol, fetching the server noise key
// if not cached.
func (nc *NoiseClient) dial() (*noiseConn, error) {
func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) {
nc.mu.Lock()
connID := nc.nextID
nc.nextID++
@ -392,7 +435,7 @@ func (nc *NoiseClient) dial() (*noiseConn, error) {
}
timeout := time.Duration(timeoutSec * float64(time.Second))
ctx, cancel := context.WithTimeout(context.Background(), timeout)
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
clientConn, err := (&controlhttp.Dialer{