diff --git a/net/tshttpproxy/tshttpproxy.go b/net/tshttpproxy/tshttpproxy.go index 1f2298e2b..990789d3b 100644 --- a/net/tshttpproxy/tshttpproxy.go +++ b/net/tshttpproxy/tshttpproxy.go @@ -28,6 +28,7 @@ var ( noProxyUntil time.Time // if non-zero, time at which ProxyFromEnvironment should check again ) +// setNoProxyUntil stops calls to sysProxyEnv (if any) for the provided duration. func setNoProxyUntil(d time.Duration) { mu.Lock() defer mu.Unlock() @@ -41,7 +42,15 @@ var _ = setNoProxyUntil // quiet staticcheck; Windows uses the above, more might // For example, WPAD PAC files on Windows. var sysProxyFromEnv func(*http.Request) (*url.URL, error) +// ProxyFromEnvironment is like the standard library's http.ProxyFromEnvironment +// but additionally does OS-specific proxy lookups if the environment variables +// alone don't specify a proxy. func ProxyFromEnvironment(req *http.Request) (*url.URL, error) { + u, err := http.ProxyFromEnvironment(req) + if u != nil && err == nil { + return u, nil + } + mu.Lock() noProxyTime := noProxyUntil mu.Unlock() @@ -49,11 +58,6 @@ func ProxyFromEnvironment(req *http.Request) (*url.URL, error) { return nil, nil } - u, err := http.ProxyFromEnvironment(req) - if u != nil && err == nil { - return u, nil - } - if sysProxyFromEnv != nil { u, err := sysProxyFromEnv(req) if u != nil && err == nil { diff --git a/net/tshttpproxy/tshttpproxy_test.go b/net/tshttpproxy/tshttpproxy_test.go index 350068bd9..88e390847 100644 --- a/net/tshttpproxy/tshttpproxy_test.go +++ b/net/tshttpproxy/tshttpproxy_test.go @@ -5,10 +5,15 @@ package tshttpproxy import ( + "net/http" "net/url" + "os" "runtime" "strings" "testing" + "time" + + "tailscale.com/util/must" ) func TestGetAuthHeaderNoResult(t *testing.T) { @@ -51,3 +56,29 @@ func TestGetAuthHeaderBasicAuth(t *testing.T) { t.Fatalf("GetAuthHeader(%q) = %q; want %q", proxyURL, got, want) } } + +func TestProxyFromEnvironment_setNoProxyUntil(t *testing.T) { + const fakeProxyEnv = "10.1.2.3:456" + const fakeProxyFull = "http://" + fakeProxyEnv + + defer os.Setenv("HTTPS_PROXY", os.Getenv("HTTPS_PROXY")) + os.Setenv("HTTPS_PROXY", fakeProxyEnv) + + req := &http.Request{URL: must.Get(url.Parse("https://example.com/"))} + for i := 0; i < 3; i++ { + switch i { + case 1: + setNoProxyUntil(time.Minute) + case 2: + setNoProxyUntil(0) + } + got, err := ProxyFromEnvironment(req) + if err != nil { + t.Fatalf("[%d] ProxyFromEnvironment: %v", i, err) + } + if got == nil || got.String() != fakeProxyFull { + t.Errorf("[%d] Got proxy %v; want %v", i, got, fakeProxyFull) + } + } + +}