mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-07 08:07:42 +00:00
71029cea2d
This updates all source files to use a new standard header for copyright and license declaration. Notably, copyright no longer includes a date, and we now use the standard SPDX-License-Identifier header. This commit was done almost entirely mechanically with perl, and then some minimal manual fixes. Updates #6865 Signed-off-by: Will Norris <will@tailscale.com>
679 lines
17 KiB
Go
679 lines
17 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package dns
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"net/netip"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/sys/windows"
|
|
"golang.org/x/sys/windows/registry"
|
|
"tailscale.com/util/dnsname"
|
|
"tailscale.com/util/winutil"
|
|
)
|
|
|
|
const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}"
|
|
|
|
func TestHostFileNewLines(t *testing.T) {
|
|
in := []byte("#foo\r\n#bar\n#baz\n")
|
|
want := []byte("#foo\r\n#bar\r\n#baz\r\n")
|
|
|
|
got, err := setTailscaleHosts(in, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !bytes.Equal(got, want) {
|
|
t.Errorf("got %q, want %q\n", got, want)
|
|
}
|
|
}
|
|
|
|
func TestManagerWindowsLocal(t *testing.T) {
|
|
if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
|
|
t.Skipf("test requires running as elevated user on Windows 10+")
|
|
}
|
|
|
|
runTest(t, true)
|
|
}
|
|
|
|
func TestManagerWindowsGP(t *testing.T) {
|
|
if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
|
|
t.Skipf("test requires running as elevated user on Windows 10+")
|
|
}
|
|
|
|
checkGPNotificationsWork(t)
|
|
|
|
// Make sure group policy is refreshed before this test exits but after we've
|
|
// cleaned everything else up.
|
|
defer procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE))
|
|
|
|
err := createFakeGPKey()
|
|
if err != nil {
|
|
t.Fatalf("Creating fake GP key: %v\n", err)
|
|
}
|
|
defer deleteFakeGPKey(t)
|
|
|
|
runTest(t, false)
|
|
}
|
|
|
|
func TestManagerWindowsGPMove(t *testing.T) {
|
|
if !isWindows10OrBetter() || !winutil.IsCurrentProcessElevated() {
|
|
t.Skipf("test requires running as elevated user on Windows 10+")
|
|
}
|
|
|
|
checkGPNotificationsWork(t)
|
|
|
|
logf := func(format string, args ...any) {
|
|
t.Logf(format, args...)
|
|
}
|
|
|
|
fakeInterface, err := windows.GenerateGUID()
|
|
if err != nil {
|
|
t.Fatalf("windows.GenerateGUID: %v\n", err)
|
|
}
|
|
|
|
delIfKey, err := createFakeInterfaceKey(t, fakeInterface)
|
|
if err != nil {
|
|
t.Fatalf("createFakeInterfaceKey: %v\n", err)
|
|
}
|
|
defer delIfKey()
|
|
|
|
cfg, err := NewOSConfigurator(logf, fakeInterface.String())
|
|
if err != nil {
|
|
t.Fatalf("NewOSConfigurator: %v\n", err)
|
|
}
|
|
mgr := cfg.(*windowsManager)
|
|
defer mgr.Close()
|
|
|
|
usingGP := mgr.nrptDB.writeAsGP
|
|
if usingGP {
|
|
t.Fatalf("usingGP %v, want %v\n", usingGP, false)
|
|
}
|
|
|
|
regWatcher, err := newRegKeyWatcher()
|
|
if err != nil {
|
|
t.Fatalf("newRegKeyWatcher error %v\n", err)
|
|
}
|
|
|
|
// Upon initialization of cfg, we should not have any NRPT rules
|
|
ensureNoRules(t)
|
|
|
|
resolvers := []netip.Addr{netip.MustParseAddr("1.1.1.1")}
|
|
domains := genRandomSubdomains(t, 1)
|
|
|
|
// 1. Populate local NRPT
|
|
err = mgr.setSplitDNS(resolvers, domains)
|
|
if err != nil {
|
|
t.Fatalf("setSplitDNS: %v\n", err)
|
|
}
|
|
|
|
t.Logf("Validating that local NRPT is populated...\n")
|
|
validateRegistry(t, nrptBaseLocal, domains)
|
|
ensureNoRulesInSubkey(t, nrptBaseGP)
|
|
|
|
// 2. Create fake GP key and refresh
|
|
t.Logf("Creating fake group policy key and refreshing...\n")
|
|
err = createFakeGPKey()
|
|
if err != nil {
|
|
t.Fatalf("createFakeGPKey: %v\n", err)
|
|
}
|
|
|
|
err = regWatcher.watch()
|
|
if err != nil {
|
|
t.Fatalf("regWatcher.watch: %v\n", err)
|
|
}
|
|
|
|
err = testDoRefresh()
|
|
if err != nil {
|
|
t.Fatalf("testDoRefresh: %v\n", err)
|
|
}
|
|
|
|
err = regWatcher.wait()
|
|
if err != nil {
|
|
t.Fatalf("regWatcher.wait: %v\n", err)
|
|
}
|
|
|
|
// 3. Check that local NRPT is empty and GP is populated
|
|
t.Logf("Validating that group policy NRPT is populated...\n")
|
|
validateRegistry(t, nrptBaseGP, domains)
|
|
ensureNoRulesInSubkey(t, nrptBaseLocal)
|
|
|
|
// 4. Delete fake GP key and refresh
|
|
t.Logf("Deleting fake group policy key and refreshing...\n")
|
|
deleteFakeGPKey(t)
|
|
|
|
err = regWatcher.watch()
|
|
if err != nil {
|
|
t.Fatalf("regWatcher.watch: %v\n", err)
|
|
}
|
|
|
|
err = testDoRefresh()
|
|
if err != nil {
|
|
t.Fatalf("testDoRefresh: %v\n", err)
|
|
}
|
|
|
|
err = regWatcher.wait()
|
|
if err != nil {
|
|
t.Fatalf("regWatcher.wait: %v\n", err)
|
|
}
|
|
|
|
// 5. Check that local NRPT is populated and GP is empty
|
|
t.Logf("Validating that local NRPT is populated...\n")
|
|
validateRegistry(t, nrptBaseLocal, domains)
|
|
ensureNoRulesInSubkey(t, nrptBaseGP)
|
|
|
|
// 6. Cleanup
|
|
t.Logf("Cleaning up...\n")
|
|
err = mgr.setSplitDNS(nil, domains)
|
|
if err != nil {
|
|
t.Fatalf("setSplitDNS: %v\n", err)
|
|
}
|
|
ensureNoRules(t)
|
|
}
|
|
|
|
func checkGPNotificationsWork(t *testing.T) {
|
|
// Test to ensure that RegisterGPNotification work on this machine,
|
|
// otherwise this test will fail.
|
|
trk, err := newGPNotificationTracker()
|
|
if err != nil {
|
|
t.Skipf("newGPNotificationTracker error: %v\n", err)
|
|
}
|
|
defer trk.Close()
|
|
|
|
r, _, err := procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE))
|
|
if r == 0 {
|
|
t.Fatalf("RefreshPolicyEx error: %v\n", err)
|
|
}
|
|
|
|
timeout := uint32(10000) // Milliseconds
|
|
if !trk.DidRefreshTimeout(timeout) {
|
|
t.Skipf("GP notifications are not working on this machine\n")
|
|
}
|
|
}
|
|
|
|
func runTest(t *testing.T, isLocal bool) {
|
|
logf := func(format string, args ...any) {
|
|
t.Logf(format, args...)
|
|
}
|
|
|
|
fakeInterface, err := windows.GenerateGUID()
|
|
if err != nil {
|
|
t.Fatalf("windows.GenerateGUID: %v\n", err)
|
|
}
|
|
|
|
delIfKey, err := createFakeInterfaceKey(t, fakeInterface)
|
|
if err != nil {
|
|
t.Fatalf("createFakeInterfaceKey: %v\n", err)
|
|
}
|
|
defer delIfKey()
|
|
|
|
cfg, err := NewOSConfigurator(logf, fakeInterface.String())
|
|
if err != nil {
|
|
t.Fatalf("NewOSConfigurator: %v\n", err)
|
|
}
|
|
mgr := cfg.(*windowsManager)
|
|
defer mgr.Close()
|
|
|
|
usingGP := mgr.nrptDB.writeAsGP
|
|
if isLocal == usingGP {
|
|
t.Fatalf("usingGP %v, want %v\n", usingGP, !usingGP)
|
|
}
|
|
|
|
// Upon initialization of cfg, we should not have any NRPT rules
|
|
ensureNoRules(t)
|
|
|
|
resolvers := []netip.Addr{netip.MustParseAddr("1.1.1.1")}
|
|
|
|
domains := genRandomSubdomains(t, 2*nrptMaxDomainsPerRule+1)
|
|
|
|
cases := []int{
|
|
1,
|
|
50,
|
|
51,
|
|
100,
|
|
101,
|
|
100,
|
|
50,
|
|
1,
|
|
51,
|
|
}
|
|
|
|
var regBaseValidate string
|
|
var regBaseEnsure string
|
|
if isLocal {
|
|
regBaseValidate = nrptBaseLocal
|
|
regBaseEnsure = nrptBaseGP
|
|
} else {
|
|
regBaseValidate = nrptBaseGP
|
|
regBaseEnsure = nrptBaseLocal
|
|
}
|
|
|
|
var trk *gpNotificationTracker
|
|
if isLocal {
|
|
// (dblohm7) When isLocal == true, we keep trk active through the entire
|
|
// sequence of test cases, and then we verify that no policy notifications
|
|
// occurred. Because policy notifications are scoped to the entire computer,
|
|
// this check could potentially fail if another process concurrently modifies
|
|
// group policies while this test is running. I don't expect this to be an
|
|
// issue on any computer on which we run this test, but something to keep in
|
|
// mind if we start seeing flakiness around these GP notifications.
|
|
trk, err = newGPNotificationTracker()
|
|
if err != nil {
|
|
t.Fatalf("newGPNotificationTracker: %v\n", err)
|
|
}
|
|
defer trk.Close()
|
|
}
|
|
|
|
runCase := func(n int) {
|
|
t.Logf("Test case: %d domains\n", n)
|
|
if !isLocal {
|
|
// When !isLocal, we want to check that a GP notification occurred for
|
|
// every single test case.
|
|
trk, err = newGPNotificationTracker()
|
|
if err != nil {
|
|
t.Fatalf("newGPNotificationTracker: %v\n", err)
|
|
}
|
|
defer trk.Close()
|
|
}
|
|
caseDomains := domains[:n]
|
|
err = mgr.setSplitDNS(resolvers, caseDomains)
|
|
if err != nil {
|
|
t.Fatalf("setSplitDNS: %v\n", err)
|
|
}
|
|
validateRegistry(t, regBaseValidate, caseDomains)
|
|
ensureNoRulesInSubkey(t, regBaseEnsure)
|
|
if !isLocal && !trk.DidRefresh(true) {
|
|
t.Fatalf("DidRefresh false, want true\n")
|
|
}
|
|
}
|
|
|
|
for _, n := range cases {
|
|
runCase(n)
|
|
}
|
|
|
|
if isLocal && trk.DidRefresh(false) {
|
|
t.Errorf("DidRefresh true, want false\n")
|
|
}
|
|
|
|
t.Logf("Test case: nil resolver\n")
|
|
err = mgr.setSplitDNS(nil, domains)
|
|
if err != nil {
|
|
t.Fatalf("setSplitDNS: %v\n", err)
|
|
}
|
|
ensureNoRules(t)
|
|
}
|
|
|
|
func createFakeGPKey() error {
|
|
keyStr := nrptBaseGP + `\` + testGPRuleID
|
|
key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyStr, registry.SET_VALUE)
|
|
if err != nil {
|
|
return fmt.Errorf("opening %s: %w", keyStr, err)
|
|
}
|
|
defer key.Close()
|
|
if err := key.SetDWordValue("Version", 1); err != nil {
|
|
return err
|
|
}
|
|
if err := key.SetStringsValue("Name", []string{"._setbygp_.example.com"}); err != nil {
|
|
return err
|
|
}
|
|
if err := key.SetStringValue("GenericDNSServers", "1.1.1.1"); err != nil {
|
|
return err
|
|
}
|
|
if err := key.SetDWordValue("ConfigOptions", nrptOverrideDNS); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func deleteFakeGPKey(t *testing.T) {
|
|
keyName := nrptBaseGP + `\` + testGPRuleID
|
|
if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyName); err != nil && err != registry.ErrNotExist {
|
|
t.Fatalf("Error deleting NRPT rule key %q: %v\n", keyName, err)
|
|
}
|
|
|
|
isEmpty, err := isPolicyConfigSubkeyEmpty()
|
|
if err != nil {
|
|
t.Fatalf("isPolicyConfigSubkeyEmpty: %v", err)
|
|
}
|
|
|
|
if !isEmpty {
|
|
return
|
|
}
|
|
|
|
if err := registry.DeleteKey(registry.LOCAL_MACHINE, nrptBaseGP); err != nil {
|
|
t.Fatalf("Deleting DnsPolicyKey Subkey: %v", err)
|
|
}
|
|
}
|
|
|
|
func createFakeInterfaceKey(t *testing.T, guid windows.GUID) (func(), error) {
|
|
basePaths := []winutil.RegistryPathPrefix{winutil.IPv4TCPIPInterfacePrefix, winutil.IPv6TCPIPInterfacePrefix}
|
|
keyPaths := make([]string, 0, len(basePaths))
|
|
|
|
guidStr := guid.String()
|
|
for _, basePath := range basePaths {
|
|
keyPath := string(basePath.WithSuffix(guidStr))
|
|
key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
key.Close()
|
|
|
|
keyPaths = append(keyPaths, keyPath)
|
|
}
|
|
|
|
result := func() {
|
|
for _, keyPath := range keyPaths {
|
|
if err := registry.DeleteKey(registry.LOCAL_MACHINE, keyPath); err != nil {
|
|
t.Fatalf("deleting fake interface key \"%s\": %v\n", keyPath, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func ensureNoRules(t *testing.T) {
|
|
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
|
|
if ruleIDs != nil {
|
|
t.Errorf("%s: %v, want nil\n", nrptRuleIDValueName, ruleIDs)
|
|
}
|
|
|
|
for _, base := range []string{nrptBaseLocal, nrptBaseGP} {
|
|
ensureNoSingleRule(t, base)
|
|
}
|
|
}
|
|
|
|
func ensureNoRulesInSubkey(t *testing.T, base string) {
|
|
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
|
|
if ruleIDs == nil {
|
|
for _, base := range []string{nrptBaseLocal, nrptBaseGP} {
|
|
ensureNoSingleRule(t, base)
|
|
}
|
|
return
|
|
}
|
|
|
|
for _, ruleID := range ruleIDs {
|
|
keyName := base + `\` + ruleID
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyName, registry.READ)
|
|
if err == nil {
|
|
key.Close()
|
|
} else if err != registry.ErrNotExist {
|
|
t.Fatalf("%s: %q, want %q\n", keyName, err, registry.ErrNotExist)
|
|
}
|
|
}
|
|
|
|
if base == nrptBaseGP {
|
|
// When dealing with the group policy subkey, we want the base key to
|
|
// also be absent.
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, base, registry.READ)
|
|
if err == nil {
|
|
key.Close()
|
|
|
|
isEmpty, err := isPolicyConfigSubkeyEmpty()
|
|
if err != nil {
|
|
t.Fatalf("isPolicyConfigSubkeyEmpty: %v", err)
|
|
}
|
|
if isEmpty {
|
|
t.Errorf("Unexpectedly found group policy key\n")
|
|
}
|
|
} else if err != registry.ErrNotExist {
|
|
t.Errorf("Group policy key error: %q, want %q\n", err, registry.ErrNotExist)
|
|
}
|
|
}
|
|
}
|
|
|
|
func ensureNoSingleRule(t *testing.T, base string) {
|
|
singleKeyPath := base + `\` + nrptSingleRuleID
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, singleKeyPath, registry.READ)
|
|
if err == nil {
|
|
key.Close()
|
|
}
|
|
if err != registry.ErrNotExist {
|
|
t.Fatalf("%s: %q, want %q\n", singleKeyPath, err, registry.ErrNotExist)
|
|
}
|
|
}
|
|
|
|
func validateRegistry(t *testing.T, nrptBase string, domains []dnsname.FQDN) {
|
|
q := len(domains) / nrptMaxDomainsPerRule
|
|
r := len(domains) % nrptMaxDomainsPerRule
|
|
numRules := q
|
|
if r > 0 {
|
|
numRules++
|
|
}
|
|
|
|
ruleIDs := winutil.GetRegStrings(nrptRuleIDValueName, nil)
|
|
if ruleIDs == nil {
|
|
ruleIDs = []string{nrptSingleRuleID}
|
|
} else if len(ruleIDs) != numRules {
|
|
t.Errorf("%s for %d domains: %d, want %d\n", nrptRuleIDValueName, len(domains), len(ruleIDs), numRules)
|
|
}
|
|
|
|
for i, ruleID := range ruleIDs {
|
|
savedDomains, err := getSavedDomainsForRule(nrptBase, ruleID)
|
|
if err != nil {
|
|
t.Fatalf("getSavedDomainsForRule(%q, %q): %v\n", nrptBase, ruleID, err)
|
|
}
|
|
|
|
start := i * nrptMaxDomainsPerRule
|
|
end := start + nrptMaxDomainsPerRule
|
|
if i == len(ruleIDs)-1 && r > 0 {
|
|
end = start + r
|
|
}
|
|
|
|
checkDomains := domains[start:end]
|
|
if len(checkDomains) != len(savedDomains) {
|
|
t.Errorf("len(checkDomains) != len(savedDomains): %d, want %d\n", len(savedDomains), len(checkDomains))
|
|
}
|
|
for j, cd := range checkDomains {
|
|
sd := strings.TrimPrefix(savedDomains[j], ".")
|
|
if string(cd.WithoutTrailingDot()) != sd {
|
|
t.Errorf("checkDomain differs savedDomain: %s, want %s\n", sd, cd.WithoutTrailingDot())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func getSavedDomainsForRule(base, ruleID string) ([]string, error) {
|
|
keyPath := base + `\` + ruleID
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.READ)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer key.Close()
|
|
result, _, err := key.GetStringsValue("Name")
|
|
return result, err
|
|
}
|
|
|
|
func genRandomSubdomains(t *testing.T, n int) []dnsname.FQDN {
|
|
domains := make([]dnsname.FQDN, 0, n)
|
|
|
|
seed := time.Now().UnixNano()
|
|
t.Logf("genRandomSubdomains(%d) seed: %v\n", n, seed)
|
|
|
|
r := rand.New(rand.NewSource(seed))
|
|
const charset = "abcdefghijklmnopqrstuvwxyz"
|
|
|
|
for len(domains) < cap(domains) {
|
|
l := r.Intn(19) + 1
|
|
b := make([]byte, l)
|
|
for i, _ := range b {
|
|
b[i] = charset[r.Intn(len(charset))]
|
|
}
|
|
d := string(b) + ".example.com"
|
|
fqdn, err := dnsname.ToFQDN(d)
|
|
if err != nil {
|
|
t.Fatalf("dnsname.ToFQDN: %v\n", err)
|
|
}
|
|
domains = append(domains, fqdn)
|
|
}
|
|
|
|
return domains
|
|
}
|
|
|
|
func testDoRefresh() (err error) {
|
|
r, _, e := procRefreshPolicyEx.Call(uintptr(1), uintptr(_RP_FORCE))
|
|
if r == 0 {
|
|
err = e
|
|
}
|
|
return err
|
|
}
|
|
|
|
// gpNotificationTracker registers with the Windows policy engine and receives
|
|
// notifications when policy refreshes occur.
|
|
type gpNotificationTracker struct {
|
|
event windows.Handle
|
|
}
|
|
|
|
func newGPNotificationTracker() (*gpNotificationTracker, error) {
|
|
var err error
|
|
evt, err := windows.CreateEvent(nil, 0, 0, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
windows.CloseHandle(evt)
|
|
}
|
|
}()
|
|
|
|
ok, _, e := procRegisterGPNotification.Call(
|
|
uintptr(evt),
|
|
uintptr(1), // We want computer policy changes, not user policy changes.
|
|
)
|
|
if ok == 0 {
|
|
err = e
|
|
return nil, err
|
|
}
|
|
|
|
return &gpNotificationTracker{evt}, nil
|
|
}
|
|
|
|
func (trk *gpNotificationTracker) DidRefresh(isExpected bool) bool {
|
|
// If we're not expecting a refresh event, then we need to use a timeout.
|
|
timeout := uint32(1000) // 1 second (in milliseconds)
|
|
if isExpected {
|
|
// Otherwise, since it is imperative that we see an event, we wait infinitely.
|
|
timeout = windows.INFINITE
|
|
}
|
|
|
|
return trk.DidRefreshTimeout(timeout)
|
|
}
|
|
|
|
func (trk *gpNotificationTracker) DidRefreshTimeout(timeout uint32) bool {
|
|
waitCode, _ := windows.WaitForSingleObject(trk.event, timeout)
|
|
return waitCode == windows.WAIT_OBJECT_0
|
|
}
|
|
|
|
func (trk *gpNotificationTracker) Close() error {
|
|
procUnregisterGPNotification.Call(uintptr(trk.event))
|
|
windows.CloseHandle(trk.event)
|
|
trk.event = 0
|
|
return nil
|
|
}
|
|
|
|
type regKeyWatcher struct {
|
|
keyLocal registry.Key
|
|
keyGP registry.Key
|
|
evtLocal windows.Handle
|
|
evtGP windows.Handle
|
|
}
|
|
|
|
func newRegKeyWatcher() (*regKeyWatcher, error) {
|
|
var err error
|
|
|
|
keyLocal, _, err := registry.CreateKey(registry.LOCAL_MACHINE, nrptBaseLocal, registry.READ)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
keyLocal.Close()
|
|
}
|
|
}()
|
|
|
|
// Monitor dnsBaseGP instead of nrptBaseGP, since the latter will be
|
|
// repeatedly created and destroyed throughout the course of the test.
|
|
keyGP, _, err := registry.CreateKey(registry.LOCAL_MACHINE, dnsBaseGP, registry.READ)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
keyGP.Close()
|
|
}
|
|
}()
|
|
|
|
evtLocal, err := windows.CreateEvent(nil, 0, 0, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
windows.CloseHandle(evtLocal)
|
|
}
|
|
}()
|
|
|
|
evtGP, err := windows.CreateEvent(nil, 0, 0, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := ®KeyWatcher{
|
|
keyLocal: keyLocal,
|
|
keyGP: keyGP,
|
|
evtLocal: evtLocal,
|
|
evtGP: evtGP,
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (rw *regKeyWatcher) watch() error {
|
|
// We can make these waits thread-agnostic because the tests that use this code must already run on Windows 10+
|
|
err := windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyLocal), true,
|
|
windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtLocal, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return windows.RegNotifyChangeKeyValue(windows.Handle(rw.keyGP), true,
|
|
windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_THREAD_AGNOSTIC, rw.evtGP, true)
|
|
}
|
|
|
|
func (rw *regKeyWatcher) wait() error {
|
|
handles := []windows.Handle{
|
|
rw.evtLocal,
|
|
rw.evtGP,
|
|
}
|
|
|
|
waitCode, err := windows.WaitForMultipleObjects(
|
|
handles,
|
|
true, // Wait for both events to signal before resuming.
|
|
10000, // 10 seconds (as milliseconds)
|
|
)
|
|
|
|
const WAIT_TIMEOUT = 0x102
|
|
switch waitCode {
|
|
case WAIT_TIMEOUT:
|
|
return context.DeadlineExceeded
|
|
case windows.WAIT_FAILED:
|
|
return err
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (rw *regKeyWatcher) Close() error {
|
|
rw.keyLocal.Close()
|
|
rw.keyGP.Close()
|
|
windows.CloseHandle(rw.evtLocal)
|
|
windows.CloseHandle(rw.evtGP)
|
|
return nil
|
|
}
|