diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 2cbea6c0f..3cbb5ccd8 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -202,6 +202,20 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 netip.Addr r.dlogf("returning %d static results", len(allIPs)) return } + + // Hard-code this to avoid extra work, DNS fallbacks, etc. + if host == "localhost" { + r.dlogf("host is localhost") + + // TODO: @raggi mentioned that some distributions don't use + // 127.0.0.1 as the localhost IP; should we check the interface + // address to determine this? + ip = netip.AddrFrom4([4]byte{127, 0, 0, 1}) + v6 = netip.IPv6Loopback() + allIPs = []netip.Addr{ip, v6} + err = nil + return + } if ip, err := netip.ParseAddr(host); err == nil { ip = ip.Unmap() r.dlogf("%q is an IP", host) diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index ef4249b74..ceada2edc 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -11,6 +11,7 @@ "net" "net/netip" "reflect" + "slices" "testing" "time" @@ -240,3 +241,51 @@ type step struct { }) } } + +func TestLocalhost(t *testing.T) { + tstest.Replace(t, &debug, func() bool { return true }) + + r := &Resolver{ + Logf: t.Logf, + Forward: &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + // always return an error to force fallback + return nil, errors.New("some error") + }, + }, + LookupIPFallback: func(ctx context.Context, host string) ([]netip.Addr, error) { + t.Errorf("unexpected call to LookupIPFallback(%q)", host) + return nil, errors.New("unimplemented") + }, + } + + // Just overriding the 'Dial' function in the *net.Resolver isn't + // enough, because the Go resolver will read /etc/hosts and return + // localhost from that. + // + // Abuse the IP cache to insert a fake localhost entry pointing to some + // invalid IP; if we get this back, we know that we didn't hit our + // hard-coded "localhost" logic. + invalid4 := netip.MustParseAddr("169.254.169.254") + invalid6 := netip.MustParseAddr("fe80::1") + r.addIPCache("localhost", invalid4, invalid6, []netip.Addr{invalid4, invalid6}, 24*time.Hour) + + ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "localhost") + if err != nil { + t.Fatal(err) + } + + localhost4 := netip.MustParseAddr("127.0.0.1") + localhost6 := netip.MustParseAddr("::1") + + if ip4 != localhost4 { + t.Errorf("ip4 got %q; want %q", ip4, localhost4) + } + if ip6 != localhost6 { + t.Errorf("ip6 got %q; want %q", ip6, localhost6) + } + if !slices.Equal(allIPs, []netip.Addr{localhost4, localhost6}) { + t.Errorf("allIPs got %q; want %q", allIPs, []netip.Addr{localhost4, localhost6}) + } +}