mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-05 23:07:44 +00:00
util/winutil: ensure domain controller address is used when retrieving remote profile information
We cannot directly pass a flat domain name into NetUserGetInfo; we must resolve the address of a domain controller first. This PR implements the appropriate resolution mechanisms to do that, and also exposes a couple of new utility APIs for future needs. Fixes #12627 Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
parent
0323dd01b2
commit
5f177090e3
@ -6,9 +6,11 @@
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
|
||||
//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
|
||||
|
||||
//sys dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) = netapi32.DsGetDcNameW
|
||||
//sys expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) [int32(failretval)==0] = userenv.ExpandEnvironmentStringsForUserW
|
||||
//sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings
|
||||
//sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW
|
||||
//sys netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) = netapi32.NetValidateName
|
||||
//sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W
|
||||
//sys registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) = kernel32.RegisterApplicationRestart
|
||||
//sys rmEndSession(session _RMHANDLE) (ret error) = rstrtmgr.RmEndSession
|
||||
|
@ -135,9 +135,36 @@ func (up *UserProfile) Close() error {
|
||||
}
|
||||
|
||||
func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, userName *uint16) (path *uint16, err error) {
|
||||
// logf is for debugging/testing.
|
||||
if logf == nil {
|
||||
logf = logger.Discard
|
||||
// logf is for debugging/testing. While we would normally replace a nil logf
|
||||
// with logger.Discard, we're using explicit checks within this func so that
|
||||
// we don't waste time allocating and converting UTF-16 strings unnecessarily.
|
||||
var comp string
|
||||
if logf != nil {
|
||||
comp = windows.UTF16PtrToString(computerName)
|
||||
user := windows.UTF16PtrToString(userName)
|
||||
logf("BEGIN getRoamingProfilePath(%q, %q)", comp, user)
|
||||
defer logf("END getRoamingProfilePath(%q, %q)", comp, user)
|
||||
}
|
||||
|
||||
isDomainName, err := isDomainName(computerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isDomainName {
|
||||
if logf != nil {
|
||||
logf("computerName %q is a domain, resolving...", comp)
|
||||
}
|
||||
dcInfo, err := resolveDomainController(computerName, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer dcInfo.Close()
|
||||
|
||||
computerName = dcInfo.DomainControllerName
|
||||
if logf != nil {
|
||||
dom := windows.UTF16PtrToString(computerName)
|
||||
logf("%q resolved to %q", comp, dom)
|
||||
}
|
||||
}
|
||||
|
||||
var pbuf *byte
|
||||
@ -147,7 +174,9 @@ func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName,
|
||||
defer windows.NetApiBufferFree(pbuf)
|
||||
|
||||
ui4 := (*_USER_INFO_4)(unsafe.Pointer(pbuf))
|
||||
logf("getRoamingProfilePath: got %#v", *ui4)
|
||||
if logf != nil {
|
||||
logf("getRoamingProfilePath: got %#v", *ui4)
|
||||
}
|
||||
profilePath := ui4.Profile
|
||||
if profilePath == nil {
|
||||
return nil, nil
|
||||
@ -162,6 +191,10 @@ func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if logf != nil {
|
||||
logf("returning %q", windows.UTF16ToString(expanded[:]))
|
||||
}
|
||||
|
||||
// This buffer is only used briefly, so we don't bother copying it into a shorter slice.
|
||||
return &expanded[0], nil
|
||||
}
|
||||
|
24
util/winutil/userprofile_windows_test.go
Normal file
24
util/winutil/userprofile_windows_test.go
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package winutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestGetRoamingProfilePath(t *testing.T) {
|
||||
token := windows.GetCurrentProcessToken()
|
||||
computerName, userName, err := getComputerAndUserName(token, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := getRoamingProfilePath(t.Logf, token, computerName, userName); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// TODO(aaron): Flesh out better once can run tests under domain accounts.
|
||||
}
|
@ -784,3 +784,147 @@ func SetNTString[NTS NTStr, BU BufUnit](nts *NTS, buf []BU) {
|
||||
panic("unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
type domainControllerAddressType uint32
|
||||
|
||||
const (
|
||||
//lint:ignore U1000 maps to a win32 API
|
||||
_DS_INET_ADDRESS domainControllerAddressType = 1
|
||||
_DS_NETBIOS_ADDRESS domainControllerAddressType = 2
|
||||
)
|
||||
|
||||
type domainControllerFlag uint32
|
||||
|
||||
const (
|
||||
//lint:ignore U1000 maps to a win32 API
|
||||
_DS_PDC_FLAG domainControllerFlag = 0x00000001
|
||||
_DS_GC_FLAG domainControllerFlag = 0x00000004
|
||||
_DS_LDAP_FLAG domainControllerFlag = 0x00000008
|
||||
_DS_DS_FLAG domainControllerFlag = 0x00000010
|
||||
_DS_KDC_FLAG domainControllerFlag = 0x00000020
|
||||
_DS_TIMESERV_FLAG domainControllerFlag = 0x00000040
|
||||
_DS_CLOSEST_FLAG domainControllerFlag = 0x00000080
|
||||
_DS_WRITABLE_FLAG domainControllerFlag = 0x00000100
|
||||
_DS_GOOD_TIMESERV_FLAG domainControllerFlag = 0x00000200
|
||||
_DS_NDNC_FLAG domainControllerFlag = 0x00000400
|
||||
_DS_SELECT_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00000800
|
||||
_DS_FULL_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00001000
|
||||
_DS_WS_FLAG domainControllerFlag = 0x00002000
|
||||
_DS_DS_8_FLAG domainControllerFlag = 0x00004000
|
||||
_DS_DS_9_FLAG domainControllerFlag = 0x00008000
|
||||
_DS_DS_10_FLAG domainControllerFlag = 0x00010000
|
||||
_DS_KEY_LIST_FLAG domainControllerFlag = 0x00020000
|
||||
_DS_PING_FLAGS domainControllerFlag = 0x000FFFFF
|
||||
_DS_DNS_CONTROLLER_FLAG domainControllerFlag = 0x20000000
|
||||
_DS_DNS_DOMAIN_FLAG domainControllerFlag = 0x40000000
|
||||
_DS_DNS_FOREST_FLAG domainControllerFlag = 0x80000000
|
||||
)
|
||||
|
||||
type _DOMAIN_CONTROLLER_INFO struct {
|
||||
DomainControllerName *uint16
|
||||
DomainControllerAddress *uint16
|
||||
DomainControllerAddressType domainControllerAddressType
|
||||
DomainGuid windows.GUID
|
||||
DomainName *uint16
|
||||
DnsForestName *uint16
|
||||
Flags domainControllerFlag
|
||||
DcSiteName *uint16
|
||||
ClientSiteName *uint16
|
||||
}
|
||||
|
||||
func (dci *_DOMAIN_CONTROLLER_INFO) Close() error {
|
||||
if dci == nil {
|
||||
return nil
|
||||
}
|
||||
return windows.NetApiBufferFree((*byte)(unsafe.Pointer(dci)))
|
||||
}
|
||||
|
||||
type dsGetDcNameFlag uint32
|
||||
|
||||
const (
|
||||
//lint:ignore U1000 maps to a win32 API
|
||||
_DS_FORCE_REDISCOVERY dsGetDcNameFlag = 0x00000001
|
||||
_DS_DIRECTORY_SERVICE_REQUIRED dsGetDcNameFlag = 0x00000010
|
||||
_DS_DIRECTORY_SERVICE_PREFERRED dsGetDcNameFlag = 0x00000020
|
||||
_DS_GC_SERVER_REQUIRED dsGetDcNameFlag = 0x00000040
|
||||
_DS_PDC_REQUIRED dsGetDcNameFlag = 0x00000080
|
||||
_DS_BACKGROUND_ONLY dsGetDcNameFlag = 0x00000100
|
||||
_DS_IP_REQUIRED dsGetDcNameFlag = 0x00000200
|
||||
_DS_KDC_REQUIRED dsGetDcNameFlag = 0x00000400
|
||||
_DS_TIMESERV_REQUIRED dsGetDcNameFlag = 0x00000800
|
||||
_DS_WRITABLE_REQUIRED dsGetDcNameFlag = 0x00001000
|
||||
_DS_GOOD_TIMESERV_PREFERRED dsGetDcNameFlag = 0x00002000
|
||||
_DS_AVOID_SELF dsGetDcNameFlag = 0x00004000
|
||||
_DS_ONLY_LDAP_NEEDED dsGetDcNameFlag = 0x00008000
|
||||
_DS_IS_FLAT_NAME dsGetDcNameFlag = 0x00010000
|
||||
_DS_IS_DNS_NAME dsGetDcNameFlag = 0x00020000
|
||||
_DS_TRY_NEXTCLOSEST_SITE dsGetDcNameFlag = 0x00040000
|
||||
_DS_DIRECTORY_SERVICE_6_REQUIRED dsGetDcNameFlag = 0x00080000
|
||||
_DS_WEB_SERVICE_REQUIRED dsGetDcNameFlag = 0x00100000
|
||||
_DS_DIRECTORY_SERVICE_8_REQUIRED dsGetDcNameFlag = 0x00200000
|
||||
_DS_DIRECTORY_SERVICE_9_REQUIRED dsGetDcNameFlag = 0x00400000
|
||||
_DS_DIRECTORY_SERVICE_10_REQUIRED dsGetDcNameFlag = 0x00800000
|
||||
_DS_KEY_LIST_SUPPORT_REQUIRED dsGetDcNameFlag = 0x01000000
|
||||
_DS_RETURN_DNS_NAME dsGetDcNameFlag = 0x40000000
|
||||
_DS_RETURN_FLAT_NAME dsGetDcNameFlag = 0x80000000
|
||||
)
|
||||
|
||||
func resolveDomainController(domainName *uint16, domainGUID *windows.GUID) (*_DOMAIN_CONTROLLER_INFO, error) {
|
||||
const flags = _DS_DIRECTORY_SERVICE_REQUIRED | _DS_IS_FLAT_NAME | _DS_RETURN_DNS_NAME
|
||||
var dcInfo *_DOMAIN_CONTROLLER_INFO
|
||||
if err := dsGetDcName(nil, domainName, domainGUID, nil, flags, &dcInfo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dcInfo, nil
|
||||
}
|
||||
|
||||
// ResolveDomainController resolves the DNS name of the nearest available
|
||||
// domain controller for the domain specified by domainName.
|
||||
func ResolveDomainController(domainName string) (string, error) {
|
||||
domainName16, err := windows.UTF16PtrFromString(domainName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
dcInfo, err := resolveDomainController(domainName16, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer dcInfo.Close()
|
||||
|
||||
return windows.UTF16PtrToString(dcInfo.DomainControllerName), nil
|
||||
}
|
||||
|
||||
type _NETSETUP_NAME_TYPE int32
|
||||
|
||||
const (
|
||||
_NetSetupUnknown _NETSETUP_NAME_TYPE = 0
|
||||
_NetSetupMachine _NETSETUP_NAME_TYPE = 1
|
||||
_NetSetupWorkgroup _NETSETUP_NAME_TYPE = 2
|
||||
_NetSetupDomain _NETSETUP_NAME_TYPE = 3
|
||||
_NetSetupNonExistentDomain _NETSETUP_NAME_TYPE = 4
|
||||
_NetSetupDnsMachine _NETSETUP_NAME_TYPE = 5
|
||||
)
|
||||
|
||||
func isDomainName(name *uint16) (bool, error) {
|
||||
err := netValidateName(nil, name, nil, nil, _NetSetupDomain)
|
||||
switch err {
|
||||
case nil:
|
||||
return true, nil
|
||||
case windows.ERROR_NO_SUCH_DOMAIN:
|
||||
return false, nil
|
||||
default:
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
// IsDomainName checks whether name represents an existing domain reachable by
|
||||
// the current machine.
|
||||
func IsDomainName(name string) (bool, error) {
|
||||
name16, err := windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return isDomainName(name16)
|
||||
}
|
||||
|
@ -42,12 +42,15 @@ func errnoErr(e syscall.Errno) error {
|
||||
var (
|
||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
modnetapi32 = windows.NewLazySystemDLL("netapi32.dll")
|
||||
modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll")
|
||||
moduserenv = windows.NewLazySystemDLL("userenv.dll")
|
||||
|
||||
procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W")
|
||||
procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings")
|
||||
procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart")
|
||||
procDsGetDcNameW = modnetapi32.NewProc("DsGetDcNameW")
|
||||
procNetValidateName = modnetapi32.NewProc("NetValidateName")
|
||||
procRmEndSession = modrstrtmgr.NewProc("RmEndSession")
|
||||
procRmGetList = modrstrtmgr.NewProc("RmGetList")
|
||||
procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession")
|
||||
@ -78,6 +81,22 @@ func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret w
|
||||
return
|
||||
}
|
||||
|
||||
func dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) {
|
||||
r0, _, _ := syscall.Syscall6(procDsGetDcNameW.Addr(), 6, uintptr(unsafe.Pointer(computerName)), uintptr(unsafe.Pointer(domainName)), uintptr(unsafe.Pointer(domainGuid)), uintptr(unsafe.Pointer(siteName)), uintptr(flags), uintptr(unsafe.Pointer(dcInfo)))
|
||||
if r0 != 0 {
|
||||
ret = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) {
|
||||
r0, _, _ := syscall.Syscall6(procNetValidateName.Addr(), 5, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(account)), uintptr(unsafe.Pointer(password)), uintptr(nameType), 0)
|
||||
if r0 != 0 {
|
||||
ret = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func rmEndSession(session _RMHANDLE) (ret error) {
|
||||
r0, _, _ := syscall.Syscall(procRmEndSession.Addr(), 1, uintptr(session), 0, 0)
|
||||
if r0 != 0 {
|
||||
|
Loading…
x
Reference in New Issue
Block a user