tailscale/net/dns/manager_windows_test.go

161 lines
3.9 KiB
Go
Raw Normal View History

// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dns
import (
"math/rand"
"strings"
"testing"
"time"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"inet.af/netaddr"
"tailscale.com/util/dnsname"
"tailscale.com/util/winutil"
)
func TestManagerWindows(t *testing.T) {
if !winutil.IsCurrentProcessElevated() {
t.Skipf("test requires running as elevated user")
}
logf := func(format string, args ...any) {
t.Logf(format, args...)
}
fakeInterface, err := windows.GenerateGUID()
if err != nil {
t.Fatalf("windows.GenerateGUID: %v\n", err)
}
cfg, err := NewOSConfigurator(logf, fakeInterface.String())
if err != nil {
t.Fatalf("NewOSConfigurator: %v\n", err)
}
mgr := cfg.(windowsManager)
// Upon initialization of cfg, we should not have any NRPT rules
ensureNoRules(t)
resolvers := []netaddr.IP{netaddr.MustParseIP("1.1.1.1")}
domains := make([]dnsname.FQDN, 0, 2*nrptMaxDomainsPerRule+1)
r := rand.New(rand.NewSource(time.Now().UnixNano()))
const charset = "abcdefghijklmnopqrstuvwxyz"
// Just generate a bunch of random subdomains
for len(domains) < cap(domains) {
l := r.Intn(19) + 1
b := make([]byte, l)
for i, _ := range b {
b[i] = charset[r.Intn(len(charset))]
}
d := string(b) + ".example.com"
fqdn, err := dnsname.ToFQDN(d)
if err != nil {
t.Fatalf("dnsname.ToFQDN: %v\n", err)
}
domains = append(domains, fqdn)
}
cases := []int{
1,
50,
51,
100,
101,
100,
50,
1,
51,
}
for _, n := range cases {
t.Logf("Test case: %d domains\n", n)
caseDomains := domains[:n]
err := mgr.setSplitDNS(resolvers, caseDomains)
if err != nil {
t.Fatalf("setSplitDNS: %v\n", err)
}
validateRegistry(t, caseDomains)
}
t.Logf("Test case: nil resolver\n")
err = mgr.setSplitDNS(nil, domains)
if err != nil {
t.Fatalf("setSplitDNS: %v\n", err)
}
ensureNoRules(t)
}
func ensureNoRules(t *testing.T) {
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
if ruleIDs != nil {
t.Errorf("%s: %v, want nil\n", nrptRuleIDValueName, ruleIDs)
}
legacyKeyPath := nrptBase + nrptSingleRuleID
key, err := registry.OpenKey(registry.LOCAL_MACHINE, legacyKeyPath, registry.READ)
if err == nil {
key.Close()
}
if err != registry.ErrNotExist {
t.Errorf("%s: %q, want %q\n", legacyKeyPath, err, registry.ErrNotExist)
}
}
func validateRegistry(t *testing.T, domains []dnsname.FQDN) {
q := len(domains) / nrptMaxDomainsPerRule
r := len(domains) % nrptMaxDomainsPerRule
numRules := q
if r > 0 {
numRules++
}
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
if ruleIDs == nil {
ruleIDs = []string{nrptSingleRuleID}
} else if len(ruleIDs) != numRules {
t.Errorf("%s for %d domains: %d, want %d\n", nrptRuleIDValueName, len(domains), len(ruleIDs), numRules)
}
for i, ruleID := range ruleIDs {
savedDomains, err := getSavedDomainsForRule(ruleID)
if err != nil {
t.Fatalf("getSavedDomainsForRule(%q): %v\n", ruleID, err)
}
start := i * nrptMaxDomainsPerRule
end := start + nrptMaxDomainsPerRule
if i == len(ruleIDs)-1 && r > 0 {
end = start + r
}
checkDomains := domains[start:end]
if len(checkDomains) != len(savedDomains) {
t.Errorf("len(checkDomains) != len(savedDomains): %d, want %d\n", len(savedDomains), len(checkDomains))
}
for j, cd := range checkDomains {
sd := strings.TrimPrefix(savedDomains[j], ".")
if string(cd.WithoutTrailingDot()) != sd {
t.Errorf("checkDomain differs savedDomain: %s, want %s\n", sd, cd.WithoutTrailingDot())
}
}
}
}
func getSavedDomainsForRule(ruleID string) ([]string, error) {
keyPath := nrptBase + ruleID
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.READ)
if err != nil {
return nil, err
}
defer key.Close()
result, _, err := key.GetStringsValue("Name")
return result, err
}