tailscale/net/dnscache/dnscache_test.go
Brad Fitzpatrick ecea6cb994 net/dns/resolver: make DoH dialer use existing dnscache happy eyeball dialer
Simplify the ability to reason about the DoH dialing code by reusing the
dnscache's dialer we already have.

Also, reduce the scope of the "ip" variable we don't want to close over.

This necessarily adds a new field to dnscache.Resolver:
SingleHostStaticResult, for when the caller already knows the IPs to be
returned.

Change-Id: I9f2aef7926f649137a5a3e63eebad6a3fffa48c0
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2022-04-18 13:18:39 -07:00

144 lines
3.3 KiB
Go

// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dnscache
import (
"context"
"errors"
"flag"
"fmt"
"net"
"reflect"
"testing"
"time"
"inet.af/netaddr"
)
var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
func TestDialer(t *testing.T) {
if *dialTest == "" {
t.Skip("skipping; --dial-test is blank")
}
r := new(Resolver)
var std net.Dialer
dialer := Dialer(std.DialContext, r)
t0 := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
c, err := dialer(ctx, "tcp", *dialTest)
if err != nil {
t.Fatal(err)
}
t.Logf("dialed in %v", time.Since(t0))
c.Close()
}
func TestDialCall_DNSWasTrustworthy(t *testing.T) {
type step struct {
ip netaddr.IP // IP we pretended to dial
err error // the dial error or nil for success
}
mustIP := netaddr.MustParseIP
errFail := errors.New("some connect failure")
tests := []struct {
name string
steps []step
want bool
}{
{
name: "no-info",
want: false,
},
{
name: "previous-dial",
steps: []step{
{mustIP("2003::1"), nil},
{mustIP("2003::1"), errFail},
},
want: true,
},
{
name: "no-previous-dial",
steps: []step{
{mustIP("2003::1"), errFail},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := &dialer{
pastConnect: map[netaddr.IP]time.Time{},
}
dc := &dialCall{
d: d,
}
for _, st := range tt.steps {
dc.noteDialResult(st.ip, st.err)
}
got := dc.dnsWasTrustworthy()
if got != tt.want {
t.Errorf("got %v; want %v", got, tt.want)
}
})
}
}
func TestDialCall_uniqueIPs(t *testing.T) {
dc := &dialCall{}
mustIP := netaddr.MustParseIP
errFail := errors.New("some connect failure")
dc.noteDialResult(mustIP("2003::1"), errFail)
dc.noteDialResult(mustIP("2003::2"), errFail)
got := dc.uniqueIPs([]netaddr.IP{
mustIP("2003::1"),
mustIP("2003::2"),
mustIP("2003::2"),
mustIP("2003::3"),
mustIP("2003::3"),
mustIP("2003::4"),
mustIP("2003::4"),
})
want := []netaddr.IP{
mustIP("2003::3"),
mustIP("2003::4"),
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v; want %v", got, want)
}
}
func TestResolverAllHostStaticResult(t *testing.T) {
r := &Resolver{
SingleHost: "foo.bar",
SingleHostStaticResult: []netaddr.IP{
netaddr.MustParseIP("2001:4860:4860::8888"),
netaddr.MustParseIP("2001:4860:4860::8844"),
netaddr.MustParseIP("8.8.8.8"),
netaddr.MustParseIP("8.8.4.4"),
},
}
ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar")
if err != nil {
t.Fatal(err)
}
if got, want := ip4.String(), "8.8.8.8"; got != want {
t.Errorf("ip4 got %q; want %q", got, want)
}
if got, want := ip6.String(), "2001:4860:4860::8888"; got != want {
t.Errorf("ip4 got %q; want %q", got, want)
}
if got, want := fmt.Sprintf("%q", allIPs), `[{"2001:4860:4860::8888" ""} {"2001:4860:4860::8844" ""} {"8.8.8.8" ""} {"8.8.4.4" ""}]`; got != want {
t.Errorf("allIPs got %q; want %q", got, want)
}
_, _, _, err = r.LookupIP(context.Background(), "bad")
if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want {
t.Errorf("bad dial error got %q; want %q", got, want)
}
}