mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-08 09:07:44 +00:00
ecea6cb994
Simplify the ability to reason about the DoH dialing code by reusing the dnscache's dialer we already have. Also, reduce the scope of the "ip" variable we don't want to close over. This necessarily adds a new field to dnscache.Resolver: SingleHostStaticResult, for when the caller already knows the IPs to be returned. Change-Id: I9f2aef7926f649137a5a3e63eebad6a3fffa48c0 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
144 lines
3.3 KiB
Go
144 lines
3.3 KiB
Go
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package dnscache
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"net"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"inet.af/netaddr"
|
|
)
|
|
|
|
var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
|
|
|
|
func TestDialer(t *testing.T) {
|
|
if *dialTest == "" {
|
|
t.Skip("skipping; --dial-test is blank")
|
|
}
|
|
r := new(Resolver)
|
|
var std net.Dialer
|
|
dialer := Dialer(std.DialContext, r)
|
|
t0 := time.Now()
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
c, err := dialer(ctx, "tcp", *dialTest)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Logf("dialed in %v", time.Since(t0))
|
|
c.Close()
|
|
}
|
|
|
|
func TestDialCall_DNSWasTrustworthy(t *testing.T) {
|
|
type step struct {
|
|
ip netaddr.IP // IP we pretended to dial
|
|
err error // the dial error or nil for success
|
|
}
|
|
mustIP := netaddr.MustParseIP
|
|
errFail := errors.New("some connect failure")
|
|
tests := []struct {
|
|
name string
|
|
steps []step
|
|
want bool
|
|
}{
|
|
{
|
|
name: "no-info",
|
|
want: false,
|
|
},
|
|
{
|
|
name: "previous-dial",
|
|
steps: []step{
|
|
{mustIP("2003::1"), nil},
|
|
{mustIP("2003::1"), errFail},
|
|
},
|
|
want: true,
|
|
},
|
|
{
|
|
name: "no-previous-dial",
|
|
steps: []step{
|
|
{mustIP("2003::1"), errFail},
|
|
},
|
|
want: false,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
d := &dialer{
|
|
pastConnect: map[netaddr.IP]time.Time{},
|
|
}
|
|
dc := &dialCall{
|
|
d: d,
|
|
}
|
|
for _, st := range tt.steps {
|
|
dc.noteDialResult(st.ip, st.err)
|
|
}
|
|
got := dc.dnsWasTrustworthy()
|
|
if got != tt.want {
|
|
t.Errorf("got %v; want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDialCall_uniqueIPs(t *testing.T) {
|
|
dc := &dialCall{}
|
|
mustIP := netaddr.MustParseIP
|
|
errFail := errors.New("some connect failure")
|
|
dc.noteDialResult(mustIP("2003::1"), errFail)
|
|
dc.noteDialResult(mustIP("2003::2"), errFail)
|
|
got := dc.uniqueIPs([]netaddr.IP{
|
|
mustIP("2003::1"),
|
|
mustIP("2003::2"),
|
|
mustIP("2003::2"),
|
|
mustIP("2003::3"),
|
|
mustIP("2003::3"),
|
|
mustIP("2003::4"),
|
|
mustIP("2003::4"),
|
|
})
|
|
want := []netaddr.IP{
|
|
mustIP("2003::3"),
|
|
mustIP("2003::4"),
|
|
}
|
|
if !reflect.DeepEqual(got, want) {
|
|
t.Errorf("got %v; want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestResolverAllHostStaticResult(t *testing.T) {
|
|
r := &Resolver{
|
|
SingleHost: "foo.bar",
|
|
SingleHostStaticResult: []netaddr.IP{
|
|
netaddr.MustParseIP("2001:4860:4860::8888"),
|
|
netaddr.MustParseIP("2001:4860:4860::8844"),
|
|
netaddr.MustParseIP("8.8.8.8"),
|
|
netaddr.MustParseIP("8.8.4.4"),
|
|
},
|
|
}
|
|
ip4, ip6, allIPs, err := r.LookupIP(context.Background(), "foo.bar")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got, want := ip4.String(), "8.8.8.8"; got != want {
|
|
t.Errorf("ip4 got %q; want %q", got, want)
|
|
}
|
|
if got, want := ip6.String(), "2001:4860:4860::8888"; got != want {
|
|
t.Errorf("ip4 got %q; want %q", got, want)
|
|
}
|
|
if got, want := fmt.Sprintf("%q", allIPs), `[{"2001:4860:4860::8888" ""} {"2001:4860:4860::8844" ""} {"8.8.8.8" ""} {"8.8.4.4" ""}]`; got != want {
|
|
t.Errorf("allIPs got %q; want %q", got, want)
|
|
}
|
|
|
|
_, _, _, err = r.LookupIP(context.Background(), "bad")
|
|
if got, want := fmt.Sprint(err), `dnscache: unexpected hostname "bad" doesn't match expected "foo.bar"`; got != want {
|
|
t.Errorf("bad dial error got %q; want %q", got, want)
|
|
}
|
|
}
|