diff --git a/net/dnsfallback/dnsfallback.go b/net/dnsfallback/dnsfallback.go index 86a99eda2..6fedbd855 100644 --- a/net/dnsfallback/dnsfallback.go +++ b/net/dnsfallback/dnsfallback.go @@ -29,7 +29,9 @@ import ( "tailscale.com/net/netns" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" + "tailscale.com/syncs" "tailscale.com/tailcfg" + "tailscale.com/types/logger" ) func Lookup(ctx context.Context, host string) ([]netip.Addr, error) { @@ -77,16 +79,16 @@ func Lookup(ctx context.Context, host string) ([]netip.Addr, error) { if err := ctx.Err(); err != nil { return nil, err } - log.Printf("trying bootstrapDNS(%q, %q) for %q ...", cand.dnsName, cand.ip, host) + logf("trying bootstrapDNS(%q, %q) for %q ...", cand.dnsName, cand.ip, host) ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() dm, err := bootstrapDNSMap(ctx, cand.dnsName, cand.ip, host) if err != nil { - log.Printf("bootstrapDNS(%q, %q) for %q error: %v", cand.dnsName, cand.ip, host, err) + logf("bootstrapDNS(%q, %q) for %q error: %v", cand.dnsName, cand.ip, host, err) continue } if ips := dm[host]; len(ips) > 0 { - log.Printf("bootstrapDNS(%q, %q) for %q = %v", cand.dnsName, cand.ip, host, ips) + logf("bootstrapDNS(%q, %q) for %q = %v", cand.dnsName, cand.ip, host, ips) return ips, nil } } @@ -99,7 +101,7 @@ func Lookup(ctx context.Context, host string) ([]netip.Addr, error) { // serverName and serverIP of are, say, "derpN.tailscale.com". // queryName is the name being sought (e.g. "controlplane.tailscale.com"), passed as hint. func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netip.Addr, queryName string) (dnsMap, error) { - dialer := netns.NewDialer(log.Printf) + dialer := netns.NewDialer(logf) tr := http.DefaultTransport.(*http.Transport).Clone() tr.Proxy = tshttpproxy.ProxyFromEnvironment tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { @@ -201,7 +203,7 @@ func UpdateCache(c *tailcfg.DERPMap) { d, err := json.Marshal(c) if err != nil { - log.Printf("[v1] dnsfallback: UpdateCache error marshaling: %v", err) + logf("[v1] dnsfallback: UpdateCache error marshaling: %v", err) return } @@ -213,11 +215,10 @@ func UpdateCache(c *tailcfg.DERPMap) { if cachePath != "" { err = atomicfile.WriteFile(cachePath, d, 0600) if err != nil { - log.Printf("[v1] dnsfallback: UpdateCache error writing: %v", err) + logf("[v1] dnsfallback: UpdateCache error writing: %v", err) return } } - log.Printf("[v2] dnsfallback: UpdateCache succeeded") } // SetCachePath sets the path to the on-disk DERP map cache that we store and @@ -231,17 +232,34 @@ func SetCachePath(path string) { f, err := os.Open(path) if err != nil { - log.Printf("[v1] dnsfallback: SetCachePath error reading %q: %v", path, err) + logf("[v1] dnsfallback: SetCachePath error reading %q: %v", path, err) return } defer f.Close() dm := new(tailcfg.DERPMap) if err := json.NewDecoder(f).Decode(dm); err != nil { - log.Printf("[v1] dnsfallback: SetCachePath error decoding %q: %v", path, err) + logf("[v1] dnsfallback: SetCachePath error decoding %q: %v", path, err) return } cachedDERPMap.Store(dm) - log.Printf("[v2] dnsfallback: SetCachePath loaded cached DERP map") + logf("[v2] dnsfallback: SetCachePath loaded cached DERP map") +} + +// logfunc stores the logging function to use for this package. +var logfunc syncs.AtomicValue[logger.Logf] + +// SetLogger sets the logging function that this package will use. The default +// logger if this function is not called is 'log.Printf'. +func SetLogger(log logger.Logf) { + logfunc.Store(log) +} + +func logf(format string, args ...any) { + if lf := logfunc.Load(); lf != nil { + lf(format, args...) + } else { + log.Printf(format, args...) + } } diff --git a/net/dnsfallback/dnsfallback_test.go b/net/dnsfallback/dnsfallback_test.go index 656532546..42c12b0b0 100644 --- a/net/dnsfallback/dnsfallback_test.go +++ b/net/dnsfallback/dnsfallback_test.go @@ -25,6 +25,11 @@ func TestGetDERPMap(t *testing.T) { } func TestCache(t *testing.T) { + oldlog := logfunc.Load() + SetLogger(t.Logf) + t.Cleanup(func() { + SetLogger(oldlog) + }) cacheFile := filepath.Join(t.TempDir(), "cache.json") // Write initial cache value @@ -101,6 +106,11 @@ func TestCache(t *testing.T) { } func TestCacheUnchanged(t *testing.T) { + oldlog := logfunc.Load() + SetLogger(t.Logf) + t.Cleanup(func() { + SetLogger(oldlog) + }) cacheFile := filepath.Join(t.TempDir(), "cache.json") // Write initial cache value