net/dns, paths, util/winutil: change net/dns/windowsManager NRPT management to support more than 50 domains.

AFAICT this isn't documented on MSDN, but based on the issue referenced below,
NRPT rules are not working when a rule specifies > 50 domains.

This patch modifies our NRPT rule generator to split the list of domains
into chunks as necessary, and write a separate rule for each chunk.

For compatibility reasons, we continue to use the hard-coded rule ID, but
as additional rules are required, we generate new GUIDs. Those GUIDs are
stored under the Tailscale registry path so that we know which rules are ours.

I made some changes to winutils to add additional helper functions in support
of both the code and its test: I added additional registry accessors, and also
moved some token accessors from paths to util/winutil.

Fixes https://github.com/tailscale/coral/issues/63

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
Aaron Klotz
2022-05-25 15:51:54 -06:00
parent c16271fb46
commit b005b79236
6 changed files with 466 additions and 131 deletions

View File

@@ -20,18 +20,27 @@ import (
"tailscale.com/envknob"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
"tailscale.com/util/winutil"
)
const (
ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters`
ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters`
// the GUID is randomly generated. At present, Tailscale installs
// zero or one NRPT rules, so hardcoding a single GUID everywhere
// is fine.
nrptBase = `SYSTEM\CurrentControlSet\services\Dnscache\Parameters\DnsPolicyConfig\{5abe529b-675b-4486-8459-25a634dacc23}`
nrptBase = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\`
nrptOverrideDNS = 0x8 // bitmask value for "use the provided override DNS resolvers"
// This is the legacy rule ID that previous versions used when we supported
// only a single rule. Now that we support multiple rules are required, we
// generate their GUIDs and store them under the Tailscale registry key.
nrptSingleRuleID = `{5abe529b-675b-4486-8459-25a634dacc23}`
// Apparently NRPT rules cannot handle > 50 domains.
nrptMaxDomainsPerRule = 50
// This is the name of the registry value we use to save Rule IDs under
// the Tailscale registry key.
nrptRuleIDValueName = `NRPTRuleIDs`
versionKey = `SOFTWARE\Microsoft\Windows NT\CurrentVersion`
)
@@ -44,6 +53,15 @@ type windowsManager struct {
wslManager *wslManager
}
func loadRuleSubkeyNames() []string {
result := winutil.GetRegStrings(nrptRuleIDValueName, nil)
if result == nil {
// Use the legacy rule ID if none are specified in our registry key
result = []string{nrptSingleRuleID}
}
return result
}
func NewOSConfigurator(logf logger.Logf, interfaceName string) (OSConfigurator, error) {
ret := windowsManager{
logf: logf,
@@ -59,7 +77,7 @@ func NewOSConfigurator(logf logger.Logf, interfaceName string) (OSConfigurator,
// boot up. The bootstrap resolver logic will save us, but it
// slows down start-up a bunch.
if ret.nrptWorks {
ret.delKey(nrptBase)
ret.delAllRuleKeys()
}
// Log WSL status once at startup.
@@ -90,10 +108,26 @@ func (m windowsManager) ifPath(basePath string) string {
return fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid)
}
func (m windowsManager) delKey(path string) error {
if err := registry.DeleteKey(registry.LOCAL_MACHINE, path); err != nil && err != registry.ErrNotExist {
func (m windowsManager) delAllRuleKeys() error {
nrptRuleIDs := loadRuleSubkeyNames()
if err := m.delRuleKeys(nrptRuleIDs); err != nil {
return err
}
if err := winutil.DeleteRegValue(nrptRuleIDValueName); err != nil {
m.logf("Error deleting registry value %q: %v", nrptRuleIDValueName, err)
return err
}
return nil
}
func (m windowsManager) delRuleKeys(nrptRuleIDs []string) error {
for _, rid := range nrptRuleIDs {
keyName := nrptBase + rid
if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyName); err != nil && err != registry.ErrNotExist {
m.logf("Error deleting NRPT rule key %q: %v", keyName, err)
return err
}
}
return nil
}
@@ -104,31 +138,91 @@ func delValue(key registry.Key, name string) error {
return nil
}
// setSplitDNS configures an NRPT (Name Resolution Policy Table) rule
// setSplitDNS configures one or more NRPT (Name Resolution Policy Table) rules
// to resolve queries for domains using resolvers, rather than the
// system's "primary" resolver.
//
// If no resolvers are provided, the Tailscale NRPT rule is deleted.
// If no resolvers are provided, the Tailscale NRPT rules are deleted.
func (m windowsManager) setSplitDNS(resolvers []netaddr.IP, domains []dnsname.FQDN) error {
if len(resolvers) == 0 {
return m.delKey(nrptBase)
return m.delAllRuleKeys()
}
servers := make([]string, 0, len(resolvers))
for _, resolver := range resolvers {
servers = append(servers, resolver.String())
}
doms := make([]string, 0, len(domains))
for _, domain := range domains {
// NRPT rules must have a leading dot, which is not usual for
// DNS search paths.
doms = append(doms, "."+domain.WithoutTrailingDot())
// NRPT has an undocumented restriction that each rule may only be associated
// with a maximum of 50 domains. If we are setting rules for more domains
// than that, we need to split domains into chunks and write out a rule per chunk.
dq := len(domains) / nrptMaxDomainsPerRule
dr := len(domains) % nrptMaxDomainsPerRule
domainRulesLen := dq
if dr > 0 {
domainRulesLen++
}
nrptRuleIDs := loadRuleSubkeyNames()
for len(nrptRuleIDs) < domainRulesLen {
guid, err := windows.GenerateGUID()
if err != nil {
return err
}
nrptRuleIDs = append(nrptRuleIDs, guid.String())
}
// Remove any surplus rules that are no longer needed.
ruleIDsToRemove := nrptRuleIDs[domainRulesLen:]
m.delRuleKeys(ruleIDsToRemove)
// We need to save the list of rule IDs to our Tailscale registry key so that
// we know which rules are ours during subsequent modifications to NRPT rules.
ruleIDsToWrite := nrptRuleIDs[:domainRulesLen]
if len(ruleIDsToWrite) > 0 {
if err := winutil.SetRegStrings(nrptRuleIDValueName, ruleIDsToWrite); err != nil {
return err
}
} else {
if err := winutil.DeleteRegValue(nrptRuleIDValueName); err != nil {
return err
}
}
doms := make([]string, 0, nrptMaxDomainsPerRule)
for i := 0; i < domainRulesLen; i++ {
// Each iteration consumes nrptMaxDomainsPerRule domains...
curLen := nrptMaxDomainsPerRule
// Except for the final iteration: when we have a remainder, use that instead.
if i == domainRulesLen-1 && dr > 0 {
curLen = dr
}
// Obtain the slice of domains to consume within the current iteration.
start := i * nrptMaxDomainsPerRule
end := start + curLen
for _, domain := range domains[start:end] {
// NRPT rules must have a leading dot, which is not usual for
// DNS search paths.
doms = append(doms, "."+domain.WithoutTrailingDot())
}
if err := writeNRPTRule(nrptRuleIDs[i], doms, servers); err != nil {
return err
}
doms = doms[:0]
}
return nil
}
func writeNRPTRule(ruleID string, doms, servers []string) error {
// CreateKey is actually open-or-create, which suits us fine.
key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, nrptBase, registry.SET_VALUE)
key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, nrptBase+ruleID, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("opening %s: %w", nrptBase, err)
return fmt.Errorf("opening %s: %w", nrptBase+ruleID, err)
}
defer key.Close()
if err := key.SetDWordValue("Version", 1); err != nil {

View File

@@ -0,0 +1,160 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dns
import (
"math/rand"
"strings"
"testing"
"time"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"inet.af/netaddr"
"tailscale.com/util/dnsname"
"tailscale.com/util/winutil"
)
func TestManagerWindows(t *testing.T) {
if !winutil.IsCurrentProcessElevated() {
t.Skipf("test requires running as elevated user")
}
logf := func(format string, args ...any) {
t.Logf(format, args...)
}
fakeInterface, err := windows.GenerateGUID()
if err != nil {
t.Fatalf("windows.GenerateGUID: %v\n", err)
}
cfg, err := NewOSConfigurator(logf, fakeInterface.String())
if err != nil {
t.Fatalf("NewOSConfigurator: %v\n", err)
}
mgr := cfg.(windowsManager)
// Upon initialization of cfg, we should not have any NRPT rules
ensureNoRules(t)
resolvers := []netaddr.IP{netaddr.MustParseIP("1.1.1.1")}
domains := make([]dnsname.FQDN, 0, 2*nrptMaxDomainsPerRule+1)
r := rand.New(rand.NewSource(time.Now().UnixNano()))
const charset = "abcdefghijklmnopqrstuvwxyz"
// Just generate a bunch of random subdomains
for len(domains) < cap(domains) {
l := r.Intn(19) + 1
b := make([]byte, l)
for i, _ := range b {
b[i] = charset[r.Intn(len(charset))]
}
d := string(b) + ".example.com"
fqdn, err := dnsname.ToFQDN(d)
if err != nil {
t.Fatalf("dnsname.ToFQDN: %v\n", err)
}
domains = append(domains, fqdn)
}
cases := []int{
1,
50,
51,
100,
101,
100,
50,
1,
51,
}
for _, n := range cases {
t.Logf("Test case: %d domains\n", n)
caseDomains := domains[:n]
err := mgr.setSplitDNS(resolvers, caseDomains)
if err != nil {
t.Fatalf("setSplitDNS: %v\n", err)
}
validateRegistry(t, caseDomains)
}
t.Logf("Test case: nil resolver\n")
err = mgr.setSplitDNS(nil, domains)
if err != nil {
t.Fatalf("setSplitDNS: %v\n", err)
}
ensureNoRules(t)
}
func ensureNoRules(t *testing.T) {
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
if ruleIDs != nil {
t.Errorf("%s: %v, want nil\n", nrptRuleIDValueName, ruleIDs)
}
legacyKeyPath := nrptBase + nrptSingleRuleID
key, err := registry.OpenKey(registry.LOCAL_MACHINE, legacyKeyPath, registry.READ)
if err == nil {
key.Close()
}
if err != registry.ErrNotExist {
t.Errorf("%s: %q, want %q\n", legacyKeyPath, err, registry.ErrNotExist)
}
}
func validateRegistry(t *testing.T, domains []dnsname.FQDN) {
q := len(domains) / nrptMaxDomainsPerRule
r := len(domains) % nrptMaxDomainsPerRule
numRules := q
if r > 0 {
numRules++
}
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
if ruleIDs == nil {
ruleIDs = []string{nrptSingleRuleID}
} else if len(ruleIDs) != numRules {
t.Errorf("%s for %d domains: %d, want %d\n", nrptRuleIDValueName, len(domains), len(ruleIDs), numRules)
}
for i, ruleID := range ruleIDs {
savedDomains, err := getSavedDomainsForRule(ruleID)
if err != nil {
t.Fatalf("getSavedDomainsForRule(%q): %v\n", ruleID, err)
}
start := i * nrptMaxDomainsPerRule
end := start + nrptMaxDomainsPerRule
if i == len(ruleIDs)-1 && r > 0 {
end = start + r
}
checkDomains := domains[start:end]
if len(checkDomains) != len(savedDomains) {
t.Errorf("len(checkDomains) != len(savedDomains): %d, want %d\n", len(savedDomains), len(checkDomains))
}
for j, cd := range checkDomains {
sd := strings.TrimPrefix(savedDomains[j], ".")
if string(cd.WithoutTrailingDot()) != sd {
t.Errorf("checkDomain differs savedDomain: %s, want %s\n", sd, cd.WithoutTrailingDot())
}
}
}
}
func getSavedDomainsForRule(ruleID string) ([]string, error) {
keyPath := nrptBase + ruleID
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.READ)
if err != nil {
return nil, err
}
defer key.Close()
result, _, err := key.GetStringsValue("Name")
return result, err
}