diff --git a/net/dns/manager_windows.go b/net/dns/manager_windows.go index df62185b3..b71abd330 100644 --- a/net/dns/manager_windows.go +++ b/net/dns/manager_windows.go @@ -20,12 +20,10 @@ "tailscale.com/envknob" "tailscale.com/types/logger" "tailscale.com/util/dnsname" + "tailscale.com/util/winutil" ) const ( - ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters` - ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters` - versionKey = `SOFTWARE\Microsoft\Windows NT\CurrentVersion` ) @@ -59,24 +57,15 @@ func NewOSConfigurator(logf logger.Logf, interfaceName string) (OSConfigurator, return ret, nil } -// keyOpenTimeout is how long we wait for a registry key to -// appear. For some reason, registry keys tied to ephemeral interfaces -// can take a long while to appear after interface creation, and we -// can end up racing with that. -const keyOpenTimeout = 20 * time.Second - -func (m windowsManager) openKey(path string) (registry.Key, error) { - key, err := openKeyWait(registry.LOCAL_MACHINE, path, registry.SET_VALUE, keyOpenTimeout) +func (m windowsManager) openInterfaceKey(pfx winutil.RegistryPathPrefix) (registry.Key, error) { + path := pfx.WithSuffix(m.guid) + key, err := winutil.OpenKeyWait(registry.LOCAL_MACHINE, path, registry.SET_VALUE) if err != nil { return 0, fmt.Errorf("opening %s: %w", path, err) } return key, nil } -func (m windowsManager) ifPath(basePath string) string { - return fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid) -} - func delValue(key registry.Key, name string) error { if err := key.DeleteValue(name); err != nil && err != registry.ErrNotExist { return err @@ -134,7 +123,7 @@ func (m windowsManager) setPrimaryDNS(resolvers []netip.Addr, domains []dnsname. domStrs = append(domStrs, dom.WithoutTrailingDot()) } - key4, err := m.openKey(m.ifPath(ipv4RegBase)) + key4, err := m.openInterfaceKey(winutil.IPv4TCPIPInterfacePrefix) if err != nil { return err } @@ -156,7 +145,7 @@ func (m windowsManager) setPrimaryDNS(resolvers []netip.Addr, domains []dnsname. return err } - key6, err := m.openKey(m.ifPath(ipv6RegBase)) + key6, err := m.openInterfaceKey(winutil.IPv6TCPIPInterfacePrefix) if err != nil { return err } @@ -308,25 +297,26 @@ func (m windowsManager) Close() error { // Windows DHCP client from sending dynamic DNS updates for our interface to // AD domain controllers. func (m windowsManager) disableDynamicUpdates() error { - setRegValue := func(regBase string) error { - key, err := m.openKey(m.ifPath(regBase)) - if err != nil { - return err - } - defer key.Close() - - return key.SetDWordValue("DisableDynamicUpdate", 1) + if err := m.setSingleDWORD(winutil.IPv4TCPIPInterfacePrefix, "EnableDNSUpdate", 0); err != nil { + return err } - - for _, regBase := range []string{ipv4RegBase, ipv6RegBase} { - if err := setRegValue(regBase); err != nil { - return err - } + if err := m.setSingleDWORD(winutil.IPv6TCPIPInterfacePrefix, "EnableDNSUpdate", 0); err != nil { + return err } - return nil } +// setSingleDWORD opens the Registry Key in HKLM for the interface associated +// with the windowsManager and sets the "keyPrefix\value" to data. +func (m windowsManager) setSingleDWORD(prefix winutil.RegistryPathPrefix, value string, data uint32) error { + k, err := m.openInterfaceKey(prefix) + if err != nil { + return err + } + defer k.Close() + return k.SetDWordValue(value, data) +} + func (m windowsManager) GetBaseConfig() (OSConfig, error) { resolvers, err := m.getBasePrimaryResolver() if err != nil { diff --git a/net/dns/manager_windows_test.go b/net/dns/manager_windows_test.go index d8efdbabd..7ec502cbe 100644 --- a/net/dns/manager_windows_test.go +++ b/net/dns/manager_windows_test.go @@ -339,11 +339,12 @@ func deleteFakeGPKey(t *testing.T) { } func createFakeInterfaceKey(t *testing.T, guid windows.GUID) (func(), error) { - basePaths := []string{ipv4RegBase, ipv6RegBase} + basePaths := []winutil.RegistryPathPrefix{winutil.IPv4TCPIPInterfacePrefix, winutil.IPv6TCPIPInterfacePrefix} keyPaths := make([]string, 0, len(basePaths)) + guidStr := guid.String() for _, basePath := range basePaths { - keyPath := fmt.Sprintf(`%s\Interfaces\%s`, basePath, guid) + keyPath := string(basePath.WithSuffix(guidStr)) key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyPath, registry.SET_VALUE) if err != nil { return nil, err diff --git a/net/dns/registry_windows.go b/net/dns/registry_windows.go deleted file mode 100644 index f8e1f514a..000000000 --- a/net/dns/registry_windows.go +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. -// -// The code in this file originates from https://git.zx2c4.com/wireguard-go: -// Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. -// Copying license: https://git.zx2c4.com/wireguard-go/tree/COPYING - -package dns - -import ( - "fmt" - "runtime" - "strings" - "time" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" -) - -func openKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - deadline := time.Now().Add(timeout) - pathSpl := strings.Split(path, "\\") - for i := 0; ; i++ { - keyName := pathSpl[i] - isLast := i+1 == len(pathSpl) - - event, err := windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - return 0, fmt.Errorf("windows.CreateEvent: %v", err) - } - defer windows.CloseHandle(event) - - var key registry.Key - for { - err = windows.RegNotifyChangeKeyValue(windows.Handle(k), false, windows.REG_NOTIFY_CHANGE_NAME, event, true) - if err != nil { - return 0, fmt.Errorf("windows.RegNotifyChangeKeyValue: %v", err) - } - - var accessFlags uint32 - if isLast { - accessFlags = access - } else { - accessFlags = registry.NOTIFY - } - key, err = registry.OpenKey(k, keyName, accessFlags) - if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND { - timeout := time.Until(deadline) / time.Millisecond - if timeout < 0 { - timeout = 0 - } - s, err := windows.WaitForSingleObject(event, uint32(timeout)) - if err != nil { - return 0, fmt.Errorf("windows.WaitForSingleObject: %v", err) - } - if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows - return 0, fmt.Errorf("timeout waiting for registry key") - } - } else if err != nil { - return 0, fmt.Errorf("registry.OpenKey(%v): %v", path, err) - } else { - if isLast { - return key, nil - } - defer key.Close() - break - } - } - - k = key - } -} diff --git a/util/winutil/winutil_windows.go b/util/winutil/winutil_windows.go index 08f212225..210d7ef02 100644 --- a/util/winutil/winutil_windows.go +++ b/util/winutil/winutil_windows.go @@ -10,7 +10,9 @@ "log" "os/exec" "runtime" + "strings" "syscall" + "time" "unsafe" "golang.org/x/sys/windows" @@ -391,3 +393,93 @@ func IsCurrentProcessElevated() bool { return token.IsElevated() } + +// keyOpenTimeout is how long we wait for a registry key to appear. For some +// reason, registry keys tied to ephemeral interfaces can take a long while to +// appear after interface creation, and we can end up racing with that. +const keyOpenTimeout = 20 * time.Second + +// RegistryPath represents a path inside a root registry.Key. +type RegistryPath string + +// RegistryPathPrefix specifies a RegistryPath prefix that must be suffixed with +// another RegistryPath to make a valid RegistryPath. +type RegistryPathPrefix string + +// WithSuffix returns a RegistryPath with the given suffix appended. +func (p RegistryPathPrefix) WithSuffix(suf string) RegistryPath { + return RegistryPath(string(p) + suf) +} + +const ( + IPv4TCPIPBase RegistryPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters` + IPv6TCPIPBase RegistryPath = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters` + NetBTBase RegistryPath = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters` + + IPv4TCPIPInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + IPv6TCPIPInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` + NetBTInterfacePrefix RegistryPathPrefix = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces\Tcpip_` +) + +// ErrKeyWaitTimeout is returned by OpenKeyWait when calls timeout. +var ErrKeyWaitTimeout = errors.New("timeout waiting for registry key") + +// OpenKeyWait opens a registry key, waiting for it to appear if necessary. It +// returns the opened key, or ErrKeyWaitTimeout if the key does not appear +// within 20s. The caller must call Close on the returned key. +func OpenKeyWait(k registry.Key, path RegistryPath, access uint32) (registry.Key, error) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + deadline := time.Now().Add(keyOpenTimeout) + pathSpl := strings.Split(string(path), "\\") + for i := 0; ; i++ { + keyName := pathSpl[i] + isLast := i+1 == len(pathSpl) + + event, err := windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return 0, fmt.Errorf("windows.CreateEvent: %w", err) + } + defer windows.CloseHandle(event) + + var key registry.Key + for { + err = windows.RegNotifyChangeKeyValue(windows.Handle(k), false, windows.REG_NOTIFY_CHANGE_NAME, event, true) + if err != nil { + return 0, fmt.Errorf("windows.RegNotifyChangeKeyValue: %w", err) + } + + var accessFlags uint32 + if isLast { + accessFlags = access + } else { + accessFlags = registry.NOTIFY + } + key, err = registry.OpenKey(k, keyName, accessFlags) + if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND { + timeout := time.Until(deadline) / time.Millisecond + if timeout < 0 { + timeout = 0 + } + s, err := windows.WaitForSingleObject(event, uint32(timeout)) + if err != nil { + return 0, fmt.Errorf("windows.WaitForSingleObject: %w", err) + } + if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows + return 0, ErrKeyWaitTimeout + } + } else if err != nil { + return 0, fmt.Errorf("registry.OpenKey(%v): %w", path, err) + } else { + if isLast { + return key, nil + } + defer key.Close() + break + } + } + + k = key + } +}