mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-26 03:25:35 +00:00
a583e498b0
Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
156 lines
3.9 KiB
Go
156 lines
3.9 KiB
Go
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package dns
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/tailscale/wireguard-go/tun"
|
|
"golang.org/x/sys/windows/registry"
|
|
"tailscale.com/types/logger"
|
|
)
|
|
|
|
const (
|
|
ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters`
|
|
ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters`
|
|
tsRegBase = `SOFTWARE\Tailscale IPN`
|
|
)
|
|
|
|
type windowsManager struct {
|
|
logf logger.Logf
|
|
guid string
|
|
}
|
|
|
|
func newManager(mconfig ManagerConfig) managerImpl {
|
|
return windowsManager{
|
|
logf: mconfig.Logf,
|
|
guid: tun.WintunGUID,
|
|
}
|
|
}
|
|
|
|
func setRegistryString(path, name, value string) error {
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.SET_VALUE)
|
|
if err != nil {
|
|
return fmt.Errorf("opening %s: %w", path, err)
|
|
}
|
|
defer key.Close()
|
|
|
|
err = key.SetStringValue(name, value)
|
|
if err != nil {
|
|
return fmt.Errorf("setting %s[%s]: %w", path, name, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getRegistryString(path, name string) (string, error) {
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ)
|
|
if err != nil {
|
|
return "", fmt.Errorf("opening %s: %w", path, err)
|
|
}
|
|
defer key.Close()
|
|
|
|
value, _, err := key.GetStringValue(name)
|
|
if err != nil {
|
|
return "", fmt.Errorf("getting %s[%s]: %w", path, name, err)
|
|
}
|
|
return value, nil
|
|
}
|
|
|
|
func (m windowsManager) setNameservers(basePath string, nameservers []string) error {
|
|
path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid)
|
|
value := strings.Join(nameservers, ",")
|
|
return setRegistryString(path, "NameServer", value)
|
|
}
|
|
|
|
func (m windowsManager) setDomains(path string, oldDomains, newDomains []string) error {
|
|
// We reimplement setRegistryString to ensure that we hold the key for the whole operation.
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ|registry.SET_VALUE)
|
|
if err != nil {
|
|
return fmt.Errorf("opening %s: %w", path, err)
|
|
}
|
|
defer key.Close()
|
|
|
|
searchList, _, err := key.GetStringValue("SearchList")
|
|
if err != nil && err != registry.ErrNotExist {
|
|
return fmt.Errorf("getting %s[SearchList]: %w", path, err)
|
|
}
|
|
currentDomains := strings.Split(searchList, ",")
|
|
|
|
var domainsToSet []string
|
|
for _, domain := range currentDomains {
|
|
inOld, inNew := false, false
|
|
|
|
// The number of domains should be small,
|
|
// so this is probaly faster than constructing a map.
|
|
for _, oldDomain := range oldDomains {
|
|
if domain == oldDomain {
|
|
inOld = true
|
|
}
|
|
}
|
|
for _, newDomain := range newDomains {
|
|
if domain == newDomain {
|
|
inNew = true
|
|
}
|
|
}
|
|
|
|
if !inNew && !inOld {
|
|
domainsToSet = append(domainsToSet, domain)
|
|
}
|
|
}
|
|
domainsToSet = append(domainsToSet, newDomains...)
|
|
|
|
searchList = strings.Join(domainsToSet, ",")
|
|
if err := key.SetStringValue("SearchList", searchList); err != nil {
|
|
return fmt.Errorf("setting %s[SearchList]: %w", path, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m windowsManager) Up(config Config) error {
|
|
var ipsv4 []string
|
|
var ipsv6 []string
|
|
|
|
for _, ip := range config.Nameservers {
|
|
if ip.Is4() {
|
|
ipsv4 = append(ipsv4, ip.String())
|
|
} else {
|
|
ipsv6 = append(ipsv6, ip.String())
|
|
}
|
|
}
|
|
|
|
lastSearchList, err := getRegistryString(tsRegBase, "SearchList")
|
|
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
|
return err
|
|
}
|
|
lastDomains := strings.Split(lastSearchList, ",")
|
|
|
|
if err := m.setNameservers(ipv4RegBase, ipsv4); err != nil {
|
|
return err
|
|
}
|
|
if err := m.setDomains(ipv4RegBase, lastDomains, config.Domains); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := m.setNameservers(ipv6RegBase, ipsv6); err != nil {
|
|
return err
|
|
}
|
|
if err := m.setDomains(ipv6RegBase, lastDomains, config.Domains); err != nil {
|
|
return err
|
|
}
|
|
|
|
newSearchList := strings.Join(config.Domains, ",")
|
|
if err := setRegistryString(tsRegBase, "SearchList", newSearchList); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m windowsManager) Down() error {
|
|
return m.Up(Config{Nameservers: nil, Domains: nil})
|
|
}
|