mirror of
https://github.com/tailscale/tailscale.git
synced 2025-12-03 02:21:58 +00:00
net/dns: update Windows split DNS settings to work alongside other NRPT entries set by group policy.
When there are group policy entries for the NRPT that do not belong to Tailscale, we recognize that we need to add ourselves to group policy and use that registry key instead of the local one. We also refresh the group policy settings as necessary to ensure that our changes take effect immediately. Fixes https://github.com/tailscale/tailscale/issues/4607 Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -17,11 +18,62 @@ import (
|
||||
"tailscale.com/util/winutil"
|
||||
)
|
||||
|
||||
func TestManagerWindows(t *testing.T) {
|
||||
if !winutil.IsCurrentProcessElevated() {
|
||||
t.Skipf("test requires running as elevated user")
|
||||
const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}"
|
||||
|
||||
var (
|
||||
procRegisterGPNotification = libUserenv.NewProc("RegisterGPNotification")
|
||||
procUnregisterGPNotification = libUserenv.NewProc("UnregisterGPNotification")
|
||||
)
|
||||
|
||||
func TestManagerWindowsLocal(t *testing.T) {
|
||||
if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
|
||||
t.Skipf("test requires running as elevated user on Windows 10+")
|
||||
}
|
||||
|
||||
runTest(t, true)
|
||||
}
|
||||
|
||||
func TestManagerWindowsGP(t *testing.T) {
|
||||
if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
|
||||
t.Skipf("test requires running as elevated user on Windows 10+")
|
||||
}
|
||||
|
||||
checkGPNotificationsWork(t)
|
||||
|
||||
// Make sure group policy is refreshed before this test exits but after we've
|
||||
// cleaned everything else up.
|
||||
defer procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE))
|
||||
|
||||
err := createFakeGPKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Creating fake GP key: %v\n", err)
|
||||
}
|
||||
defer deleteFakeGPKey(t)
|
||||
|
||||
runTest(t, false)
|
||||
}
|
||||
|
||||
func checkGPNotificationsWork(t *testing.T) {
|
||||
// Test to ensure that RegisterGPNotification work on this machine,
|
||||
// otherwise this test will fail.
|
||||
trk, err := newGPNotificationTracker()
|
||||
if err != nil {
|
||||
t.Skipf("newGPNotificationTracker error: %v\n", err)
|
||||
}
|
||||
defer trk.Close()
|
||||
|
||||
r, _, err := procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE))
|
||||
if r == 0 {
|
||||
t.Fatalf("RefreshPolicyEx error: %v\n", err)
|
||||
}
|
||||
|
||||
timeout := uint32(10000) // Milliseconds
|
||||
if !trk.DidRefreshTimeout(timeout) {
|
||||
t.Skipf("GP notifications are not working on this machine\n")
|
||||
}
|
||||
}
|
||||
|
||||
func runTest(t *testing.T, isLocal bool) {
|
||||
logf := func(format string, args ...any) {
|
||||
t.Logf(format, args...)
|
||||
}
|
||||
@@ -37,6 +89,11 @@ func TestManagerWindows(t *testing.T) {
|
||||
}
|
||||
mgr := cfg.(windowsManager)
|
||||
|
||||
usingGP := mgr.nrptDB.writeAsGP
|
||||
if isLocal == usingGP {
|
||||
t.Fatalf("usingGP %v, want %v\n", usingGP, !usingGP)
|
||||
}
|
||||
|
||||
// Upon initialization of cfg, we should not have any NRPT rules
|
||||
ensureNoRules(t)
|
||||
|
||||
@@ -74,14 +131,61 @@ func TestManagerWindows(t *testing.T) {
|
||||
51,
|
||||
}
|
||||
|
||||
for _, n := range cases {
|
||||
var regBaseValidate string
|
||||
var regBaseEnsure string
|
||||
if isLocal {
|
||||
regBaseValidate = nrptBaseLocal
|
||||
regBaseEnsure = nrptBaseGP
|
||||
} else {
|
||||
regBaseValidate = nrptBaseGP
|
||||
regBaseEnsure = nrptBaseLocal
|
||||
}
|
||||
|
||||
var trk *gpNotificationTracker
|
||||
if isLocal {
|
||||
// (dblohm7) When isLocal == true, we keep trk active through the entire
|
||||
// sequence of test cases, and then we verify that no policy notifications
|
||||
// occurred. Because policy notifications are scoped to the entire computer,
|
||||
// this check could potentially fail if another process concurrently modifies
|
||||
// group policies while this test is running. I don't expect this to be an
|
||||
// issue on any computer on which we run this test, but something to keep in
|
||||
// mind if we start seeing flakiness around these GP notifications.
|
||||
trk, err = newGPNotificationTracker()
|
||||
if err != nil {
|
||||
t.Fatalf("newGPNotificationTracker: %v\n", err)
|
||||
}
|
||||
defer trk.Close()
|
||||
}
|
||||
|
||||
runCase := func(n int) {
|
||||
t.Logf("Test case: %d domains\n", n)
|
||||
if !isLocal {
|
||||
// When !isLocal, we want to check that a GP notification occured for
|
||||
// every single test case.
|
||||
trk, err = newGPNotificationTracker()
|
||||
if err != nil {
|
||||
t.Fatalf("newGPNotificationTracker: %v\n", err)
|
||||
}
|
||||
defer trk.Close()
|
||||
}
|
||||
caseDomains := domains[:n]
|
||||
err := mgr.setSplitDNS(resolvers, caseDomains)
|
||||
err = mgr.setSplitDNS(resolvers, caseDomains)
|
||||
if err != nil {
|
||||
t.Fatalf("setSplitDNS: %v\n", err)
|
||||
}
|
||||
validateRegistry(t, caseDomains)
|
||||
validateRegistry(t, regBaseValidate, caseDomains)
|
||||
ensureNoRulesInSubkey(t, regBaseEnsure)
|
||||
if !isLocal && !trk.DidRefresh(true) {
|
||||
t.Fatalf("DidRefresh false, want true\n")
|
||||
}
|
||||
}
|
||||
|
||||
for _, n := range cases {
|
||||
runCase(n)
|
||||
}
|
||||
|
||||
if isLocal && trk.DidRefresh(false) {
|
||||
t.Errorf("DidRefresh true, want false\n")
|
||||
}
|
||||
|
||||
t.Logf("Test case: nil resolver\n")
|
||||
@@ -92,23 +196,92 @@ func TestManagerWindows(t *testing.T) {
|
||||
ensureNoRules(t)
|
||||
}
|
||||
|
||||
func createFakeGPKey() error {
|
||||
keyStr := nrptBaseGP + `\` + testGPRuleID
|
||||
key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyStr, registry.SET_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening %s: %w", keyStr, err)
|
||||
}
|
||||
defer key.Close()
|
||||
if err := key.SetDWordValue("Version", 1); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := key.SetStringsValue("Name", []string{"._setbygp_.example.com"}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := key.SetStringValue("GenericDNSServers", "1.1.1.1"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := key.SetDWordValue("ConfigOptions", nrptOverrideDNS); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteFakeGPKey(t *testing.T) {
|
||||
keyName := nrptBaseGP + `\` + testGPRuleID
|
||||
if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyName); err != nil && err != registry.ErrNotExist {
|
||||
t.Fatalf("Error deleting NRPT rule key %q: %v\n", keyName, err)
|
||||
}
|
||||
|
||||
isEmpty, err := isPolicyConfigSubkeyEmpty()
|
||||
if err != nil {
|
||||
t.Fatalf("isPolicyConfigSubkeyEmpty: %v", err)
|
||||
}
|
||||
|
||||
if !isEmpty {
|
||||
return
|
||||
}
|
||||
|
||||
if err := registry.DeleteKey(registry.LOCAL_MACHINE, nrptBaseGP); err != nil {
|
||||
t.Fatalf("Deleting DnsPolicyKey Subkey: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ensureNoRules(t *testing.T) {
|
||||
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
|
||||
if ruleIDs != nil {
|
||||
t.Errorf("%s: %v, want nil\n", nrptRuleIDValueName, ruleIDs)
|
||||
}
|
||||
|
||||
legacyKeyPath := nrptBase + nrptSingleRuleID
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, legacyKeyPath, registry.READ)
|
||||
for _, base := range []string{nrptBaseLocal, nrptBaseGP} {
|
||||
ensureNoSingleRule(t, base)
|
||||
}
|
||||
}
|
||||
|
||||
func ensureNoRulesInSubkey(t *testing.T, base string) {
|
||||
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
|
||||
if ruleIDs == nil {
|
||||
for _, base := range []string{nrptBaseLocal, nrptBaseGP} {
|
||||
ensureNoSingleRule(t, base)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, ruleID := range ruleIDs {
|
||||
keyName := base + `\` + ruleID
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyName, registry.READ)
|
||||
if err == nil {
|
||||
key.Close()
|
||||
}
|
||||
if err != registry.ErrNotExist {
|
||||
t.Fatalf("%s: %q, want %q\n", keyName, err, registry.ErrNotExist)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ensureNoSingleRule(t *testing.T, base string) {
|
||||
singleKeyPath := base + `\` + nrptSingleRuleID
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, singleKeyPath, registry.READ)
|
||||
if err == nil {
|
||||
key.Close()
|
||||
}
|
||||
if err != registry.ErrNotExist {
|
||||
t.Errorf("%s: %q, want %q\n", legacyKeyPath, err, registry.ErrNotExist)
|
||||
t.Fatalf("%s: %q, want %q\n", singleKeyPath, err, registry.ErrNotExist)
|
||||
}
|
||||
}
|
||||
|
||||
func validateRegistry(t *testing.T, domains []dnsname.FQDN) {
|
||||
func validateRegistry(t *testing.T, nrptBase string, domains []dnsname.FQDN) {
|
||||
q := len(domains) / nrptMaxDomainsPerRule
|
||||
r := len(domains) % nrptMaxDomainsPerRule
|
||||
numRules := q
|
||||
@@ -124,9 +297,9 @@ func validateRegistry(t *testing.T, domains []dnsname.FQDN) {
|
||||
}
|
||||
|
||||
for i, ruleID := range ruleIDs {
|
||||
savedDomains, err := getSavedDomainsForRule(ruleID)
|
||||
savedDomains, err := getSavedDomainsForRule(nrptBase, ruleID)
|
||||
if err != nil {
|
||||
t.Fatalf("getSavedDomainsForRule(%q): %v\n", ruleID, err)
|
||||
t.Fatalf("getSavedDomainsForRule(%q, %q): %v\n", nrptBase, ruleID, err)
|
||||
}
|
||||
|
||||
start := i * nrptMaxDomainsPerRule
|
||||
@@ -148,8 +321,8 @@ func validateRegistry(t *testing.T, domains []dnsname.FQDN) {
|
||||
}
|
||||
}
|
||||
|
||||
func getSavedDomainsForRule(ruleID string) ([]string, error) {
|
||||
keyPath := nrptBase + ruleID
|
||||
func getSavedDomainsForRule(base, ruleID string) ([]string, error) {
|
||||
keyPath := base + `\` + ruleID
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.READ)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -158,3 +331,56 @@ func getSavedDomainsForRule(ruleID string) ([]string, error) {
|
||||
result, _, err := key.GetStringsValue("Name")
|
||||
return result, err
|
||||
}
|
||||
|
||||
// gpNotificationTracker registers with the Windows policy engine and receives
|
||||
// notifications when policy refreshes occur.
|
||||
type gpNotificationTracker struct {
|
||||
event windows.Handle
|
||||
}
|
||||
|
||||
func newGPNotificationTracker() (*gpNotificationTracker, error) {
|
||||
var err error
|
||||
evt, err := windows.CreateEvent(nil, 0, 0, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
windows.CloseHandle(evt)
|
||||
}
|
||||
}()
|
||||
|
||||
ok, _, e := procRegisterGPNotification.Call(
|
||||
uintptr(evt),
|
||||
uintptr(1), // We want computer policy changes, not user policy changes.
|
||||
)
|
||||
if ok == 0 {
|
||||
err = e
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &gpNotificationTracker{evt}, nil
|
||||
}
|
||||
|
||||
func (trk *gpNotificationTracker) DidRefresh(isExpected bool) bool {
|
||||
// If we're not expecting a refresh event, then we need to use a timeout.
|
||||
timeout := uint32(1000) // 1 second (in milliseconds)
|
||||
if isExpected {
|
||||
// Otherwise, since it is imperative that we see an event, we wait infinitely.
|
||||
timeout = windows.INFINITE
|
||||
}
|
||||
|
||||
return trk.DidRefreshTimeout(timeout)
|
||||
}
|
||||
|
||||
func (trk *gpNotificationTracker) DidRefreshTimeout(timeout uint32) bool {
|
||||
waitCode, _ := windows.WaitForSingleObject(trk.event, timeout)
|
||||
return waitCode == windows.WAIT_OBJECT_0
|
||||
}
|
||||
|
||||
func (trk *gpNotificationTracker) Close() error {
|
||||
procUnregisterGPNotification.Call(uintptr(trk.event))
|
||||
windows.CloseHandle(trk.event)
|
||||
trk.event = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user