util/winutil: ensure domain controller address is used when retrieving remote profile information

We cannot directly pass a flat domain name into NetUserGetInfo; we must
resolve the address of a domain controller first.

This PR implements the appropriate resolution mechanisms to do that, and
also exposes a couple of new utility APIs for future needs.

Fixes #12627

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
Aaron Klotz 2024-06-26 12:08:38 -06:00
parent 0323dd01b2
commit 5f177090e3
5 changed files with 226 additions and 4 deletions

View File

@ -6,9 +6,11 @@
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
//sys dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) = netapi32.DsGetDcNameW
//sys expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) [int32(failretval)==0] = userenv.ExpandEnvironmentStringsForUserW
//sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings
//sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW
//sys netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) = netapi32.NetValidateName
//sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W
//sys registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) = kernel32.RegisterApplicationRestart
//sys rmEndSession(session _RMHANDLE) (ret error) = rstrtmgr.RmEndSession

View File

@ -135,9 +135,36 @@ func (up *UserProfile) Close() error {
}
func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, userName *uint16) (path *uint16, err error) {
// logf is for debugging/testing.
if logf == nil {
logf = logger.Discard
// logf is for debugging/testing. While we would normally replace a nil logf
// with logger.Discard, we're using explicit checks within this func so that
// we don't waste time allocating and converting UTF-16 strings unnecessarily.
var comp string
if logf != nil {
comp = windows.UTF16PtrToString(computerName)
user := windows.UTF16PtrToString(userName)
logf("BEGIN getRoamingProfilePath(%q, %q)", comp, user)
defer logf("END getRoamingProfilePath(%q, %q)", comp, user)
}
isDomainName, err := isDomainName(computerName)
if err != nil {
return nil, err
}
if isDomainName {
if logf != nil {
logf("computerName %q is a domain, resolving...", comp)
}
dcInfo, err := resolveDomainController(computerName, nil)
if err != nil {
return nil, err
}
defer dcInfo.Close()
computerName = dcInfo.DomainControllerName
if logf != nil {
dom := windows.UTF16PtrToString(computerName)
logf("%q resolved to %q", comp, dom)
}
}
var pbuf *byte
@ -147,7 +174,9 @@ func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName,
defer windows.NetApiBufferFree(pbuf)
ui4 := (*_USER_INFO_4)(unsafe.Pointer(pbuf))
logf("getRoamingProfilePath: got %#v", *ui4)
if logf != nil {
logf("getRoamingProfilePath: got %#v", *ui4)
}
profilePath := ui4.Profile
if profilePath == nil {
return nil, nil
@ -162,6 +191,10 @@ func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName,
return nil, err
}
if logf != nil {
logf("returning %q", windows.UTF16ToString(expanded[:]))
}
// This buffer is only used briefly, so we don't bother copying it into a shorter slice.
return &expanded[0], nil
}

View File

@ -0,0 +1,24 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package winutil
import (
"testing"
"golang.org/x/sys/windows"
)
func TestGetRoamingProfilePath(t *testing.T) {
token := windows.GetCurrentProcessToken()
computerName, userName, err := getComputerAndUserName(token, nil)
if err != nil {
t.Fatal(err)
}
if _, err := getRoamingProfilePath(t.Logf, token, computerName, userName); err != nil {
t.Error(err)
}
// TODO(aaron): Flesh out better once can run tests under domain accounts.
}

View File

@ -784,3 +784,147 @@ func SetNTString[NTS NTStr, BU BufUnit](nts *NTS, buf []BU) {
panic("unknown type")
}
}
type domainControllerAddressType uint32
const (
//lint:ignore U1000 maps to a win32 API
_DS_INET_ADDRESS domainControllerAddressType = 1
_DS_NETBIOS_ADDRESS domainControllerAddressType = 2
)
type domainControllerFlag uint32
const (
//lint:ignore U1000 maps to a win32 API
_DS_PDC_FLAG domainControllerFlag = 0x00000001
_DS_GC_FLAG domainControllerFlag = 0x00000004
_DS_LDAP_FLAG domainControllerFlag = 0x00000008
_DS_DS_FLAG domainControllerFlag = 0x00000010
_DS_KDC_FLAG domainControllerFlag = 0x00000020
_DS_TIMESERV_FLAG domainControllerFlag = 0x00000040
_DS_CLOSEST_FLAG domainControllerFlag = 0x00000080
_DS_WRITABLE_FLAG domainControllerFlag = 0x00000100
_DS_GOOD_TIMESERV_FLAG domainControllerFlag = 0x00000200
_DS_NDNC_FLAG domainControllerFlag = 0x00000400
_DS_SELECT_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00000800
_DS_FULL_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00001000
_DS_WS_FLAG domainControllerFlag = 0x00002000
_DS_DS_8_FLAG domainControllerFlag = 0x00004000
_DS_DS_9_FLAG domainControllerFlag = 0x00008000
_DS_DS_10_FLAG domainControllerFlag = 0x00010000
_DS_KEY_LIST_FLAG domainControllerFlag = 0x00020000
_DS_PING_FLAGS domainControllerFlag = 0x000FFFFF
_DS_DNS_CONTROLLER_FLAG domainControllerFlag = 0x20000000
_DS_DNS_DOMAIN_FLAG domainControllerFlag = 0x40000000
_DS_DNS_FOREST_FLAG domainControllerFlag = 0x80000000
)
type _DOMAIN_CONTROLLER_INFO struct {
DomainControllerName *uint16
DomainControllerAddress *uint16
DomainControllerAddressType domainControllerAddressType
DomainGuid windows.GUID
DomainName *uint16
DnsForestName *uint16
Flags domainControllerFlag
DcSiteName *uint16
ClientSiteName *uint16
}
func (dci *_DOMAIN_CONTROLLER_INFO) Close() error {
if dci == nil {
return nil
}
return windows.NetApiBufferFree((*byte)(unsafe.Pointer(dci)))
}
type dsGetDcNameFlag uint32
const (
//lint:ignore U1000 maps to a win32 API
_DS_FORCE_REDISCOVERY dsGetDcNameFlag = 0x00000001
_DS_DIRECTORY_SERVICE_REQUIRED dsGetDcNameFlag = 0x00000010
_DS_DIRECTORY_SERVICE_PREFERRED dsGetDcNameFlag = 0x00000020
_DS_GC_SERVER_REQUIRED dsGetDcNameFlag = 0x00000040
_DS_PDC_REQUIRED dsGetDcNameFlag = 0x00000080
_DS_BACKGROUND_ONLY dsGetDcNameFlag = 0x00000100
_DS_IP_REQUIRED dsGetDcNameFlag = 0x00000200
_DS_KDC_REQUIRED dsGetDcNameFlag = 0x00000400
_DS_TIMESERV_REQUIRED dsGetDcNameFlag = 0x00000800
_DS_WRITABLE_REQUIRED dsGetDcNameFlag = 0x00001000
_DS_GOOD_TIMESERV_PREFERRED dsGetDcNameFlag = 0x00002000
_DS_AVOID_SELF dsGetDcNameFlag = 0x00004000
_DS_ONLY_LDAP_NEEDED dsGetDcNameFlag = 0x00008000
_DS_IS_FLAT_NAME dsGetDcNameFlag = 0x00010000
_DS_IS_DNS_NAME dsGetDcNameFlag = 0x00020000
_DS_TRY_NEXTCLOSEST_SITE dsGetDcNameFlag = 0x00040000
_DS_DIRECTORY_SERVICE_6_REQUIRED dsGetDcNameFlag = 0x00080000
_DS_WEB_SERVICE_REQUIRED dsGetDcNameFlag = 0x00100000
_DS_DIRECTORY_SERVICE_8_REQUIRED dsGetDcNameFlag = 0x00200000
_DS_DIRECTORY_SERVICE_9_REQUIRED dsGetDcNameFlag = 0x00400000
_DS_DIRECTORY_SERVICE_10_REQUIRED dsGetDcNameFlag = 0x00800000
_DS_KEY_LIST_SUPPORT_REQUIRED dsGetDcNameFlag = 0x01000000
_DS_RETURN_DNS_NAME dsGetDcNameFlag = 0x40000000
_DS_RETURN_FLAT_NAME dsGetDcNameFlag = 0x80000000
)
func resolveDomainController(domainName *uint16, domainGUID *windows.GUID) (*_DOMAIN_CONTROLLER_INFO, error) {
const flags = _DS_DIRECTORY_SERVICE_REQUIRED | _DS_IS_FLAT_NAME | _DS_RETURN_DNS_NAME
var dcInfo *_DOMAIN_CONTROLLER_INFO
if err := dsGetDcName(nil, domainName, domainGUID, nil, flags, &dcInfo); err != nil {
return nil, err
}
return dcInfo, nil
}
// ResolveDomainController resolves the DNS name of the nearest available
// domain controller for the domain specified by domainName.
func ResolveDomainController(domainName string) (string, error) {
domainName16, err := windows.UTF16PtrFromString(domainName)
if err != nil {
return "", err
}
dcInfo, err := resolveDomainController(domainName16, nil)
if err != nil {
return "", err
}
defer dcInfo.Close()
return windows.UTF16PtrToString(dcInfo.DomainControllerName), nil
}
type _NETSETUP_NAME_TYPE int32
const (
_NetSetupUnknown _NETSETUP_NAME_TYPE = 0
_NetSetupMachine _NETSETUP_NAME_TYPE = 1
_NetSetupWorkgroup _NETSETUP_NAME_TYPE = 2
_NetSetupDomain _NETSETUP_NAME_TYPE = 3
_NetSetupNonExistentDomain _NETSETUP_NAME_TYPE = 4
_NetSetupDnsMachine _NETSETUP_NAME_TYPE = 5
)
func isDomainName(name *uint16) (bool, error) {
err := netValidateName(nil, name, nil, nil, _NetSetupDomain)
switch err {
case nil:
return true, nil
case windows.ERROR_NO_SUCH_DOMAIN:
return false, nil
default:
return false, err
}
}
// IsDomainName checks whether name represents an existing domain reachable by
// the current machine.
func IsDomainName(name string) (bool, error) {
name16, err := windows.UTF16PtrFromString(name)
if err != nil {
return false, err
}
return isDomainName(name16)
}

View File

@ -42,12 +42,15 @@ func errnoErr(e syscall.Errno) error {
var (
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
modnetapi32 = windows.NewLazySystemDLL("netapi32.dll")
modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll")
moduserenv = windows.NewLazySystemDLL("userenv.dll")
procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W")
procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings")
procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart")
procDsGetDcNameW = modnetapi32.NewProc("DsGetDcNameW")
procNetValidateName = modnetapi32.NewProc("NetValidateName")
procRmEndSession = modrstrtmgr.NewProc("RmEndSession")
procRmGetList = modrstrtmgr.NewProc("RmGetList")
procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession")
@ -78,6 +81,22 @@ func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret w
return
}
func dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) {
r0, _, _ := syscall.Syscall6(procDsGetDcNameW.Addr(), 6, uintptr(unsafe.Pointer(computerName)), uintptr(unsafe.Pointer(domainName)), uintptr(unsafe.Pointer(domainGuid)), uintptr(unsafe.Pointer(siteName)), uintptr(flags), uintptr(unsafe.Pointer(dcInfo)))
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) {
r0, _, _ := syscall.Syscall6(procNetValidateName.Addr(), 5, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(account)), uintptr(unsafe.Pointer(password)), uintptr(nameType), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func rmEndSession(session _RMHANDLE) (ret error) {
r0, _, _ := syscall.Syscall(procRmEndSession.Addr(), 1, uintptr(session), 0, 0)
if r0 != 0 {