diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index cad81b82c..4a8335133 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -287,6 +287,25 @@ func (nc *NoiseClient) GetSingleUseRoundTripper(ctx context.Context) (http.Round return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection") } +// contextErr is an error that wraps another error and is used to indicate that +// the error was because a context expired. +type contextErr struct { + err error +} + +func (e contextErr) Error() string { + return e.err.Error() +} + +func (e contextErr) Unwrap() error { + return e.err +} + +// getConn returns a noiseConn that can be used to make requests to the +// coordination server. It may return a cached connection or create a new one. +// Dials are singleflighted, so concurrent calls to getConn may only dial once. +// As such, context values may not be respected as there are no guarantees that +// the context passed to getConn is the same as the context passed to dial. func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) { nc.mu.Lock() if last := nc.last; last != nil && last.canTakeNewRequest() { @@ -295,11 +314,35 @@ func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) { } nc.mu.Unlock() - conn, err, _ := nc.sfDial.Do(struct{}{}, nc.dial) - if err != nil { - return nil, err + for { + // We singeflight the dial to avoid making multiple connections, however + // that means that we can't simply cancel the dial if the context is + // canceled. Instead, we have to additionally check that the context + // which was canceled is our context and retry if our context is still + // valid. + conn, err, _ := nc.sfDial.Do(struct{}{}, func() (*noiseConn, error) { + c, err := nc.dial(ctx) + if err != nil { + if ctx.Err() != nil { + return nil, contextErr{ctx.Err()} + } + return nil, err + } + return c, nil + }) + var ce contextErr + if err == nil || !errors.As(err, &ce) { + return conn, err + } + if ctx.Err() == nil { + // The dial failed because of a context error, but our context + // is still valid. Retry. + continue + } + // The dial failed because our context was canceled. Return the + // underlying error. + return nil, ce.Unwrap() } - return conn, nil } func (nc *NoiseClient) RoundTrip(req *http.Request) (*http.Response, error) { @@ -344,7 +387,7 @@ func (nc *NoiseClient) Close() error { // dial opens a new connection to tailcontrol, fetching the server noise key // if not cached. -func (nc *NoiseClient) dial() (*noiseConn, error) { +func (nc *NoiseClient) dial(ctx context.Context) (*noiseConn, error) { nc.mu.Lock() connID := nc.nextID nc.nextID++ @@ -392,7 +435,7 @@ func (nc *NoiseClient) dial() (*noiseConn, error) { } timeout := time.Duration(timeoutSec * float64(time.Second)) - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() clientConn, err := (&controlhttp.Dialer{