mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-07 16:17:41 +00:00
util/winutil: ensure domain controller address is used when retrieving remote profile information
We cannot directly pass a flat domain name into NetUserGetInfo; we must resolve the address of a domain controller first. This PR implements the appropriate resolution mechanisms to do that, and also exposes a couple of new utility APIs for future needs. Fixes #12627 Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
parent
0323dd01b2
commit
5f177090e3
@ -6,9 +6,11 @@
|
|||||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
|
||||||
//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
|
//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
|
||||||
|
|
||||||
|
//sys dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) = netapi32.DsGetDcNameW
|
||||||
//sys expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) [int32(failretval)==0] = userenv.ExpandEnvironmentStringsForUserW
|
//sys expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) [int32(failretval)==0] = userenv.ExpandEnvironmentStringsForUserW
|
||||||
//sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings
|
//sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings
|
||||||
//sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW
|
//sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW
|
||||||
|
//sys netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) = netapi32.NetValidateName
|
||||||
//sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W
|
//sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W
|
||||||
//sys registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) = kernel32.RegisterApplicationRestart
|
//sys registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) = kernel32.RegisterApplicationRestart
|
||||||
//sys rmEndSession(session _RMHANDLE) (ret error) = rstrtmgr.RmEndSession
|
//sys rmEndSession(session _RMHANDLE) (ret error) = rstrtmgr.RmEndSession
|
||||||
|
@ -135,9 +135,36 @@ func (up *UserProfile) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, userName *uint16) (path *uint16, err error) {
|
func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, userName *uint16) (path *uint16, err error) {
|
||||||
// logf is for debugging/testing.
|
// logf is for debugging/testing. While we would normally replace a nil logf
|
||||||
if logf == nil {
|
// with logger.Discard, we're using explicit checks within this func so that
|
||||||
logf = logger.Discard
|
// we don't waste time allocating and converting UTF-16 strings unnecessarily.
|
||||||
|
var comp string
|
||||||
|
if logf != nil {
|
||||||
|
comp = windows.UTF16PtrToString(computerName)
|
||||||
|
user := windows.UTF16PtrToString(userName)
|
||||||
|
logf("BEGIN getRoamingProfilePath(%q, %q)", comp, user)
|
||||||
|
defer logf("END getRoamingProfilePath(%q, %q)", comp, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
isDomainName, err := isDomainName(computerName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if isDomainName {
|
||||||
|
if logf != nil {
|
||||||
|
logf("computerName %q is a domain, resolving...", comp)
|
||||||
|
}
|
||||||
|
dcInfo, err := resolveDomainController(computerName, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer dcInfo.Close()
|
||||||
|
|
||||||
|
computerName = dcInfo.DomainControllerName
|
||||||
|
if logf != nil {
|
||||||
|
dom := windows.UTF16PtrToString(computerName)
|
||||||
|
logf("%q resolved to %q", comp, dom)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var pbuf *byte
|
var pbuf *byte
|
||||||
@ -147,7 +174,9 @@ func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName,
|
|||||||
defer windows.NetApiBufferFree(pbuf)
|
defer windows.NetApiBufferFree(pbuf)
|
||||||
|
|
||||||
ui4 := (*_USER_INFO_4)(unsafe.Pointer(pbuf))
|
ui4 := (*_USER_INFO_4)(unsafe.Pointer(pbuf))
|
||||||
logf("getRoamingProfilePath: got %#v", *ui4)
|
if logf != nil {
|
||||||
|
logf("getRoamingProfilePath: got %#v", *ui4)
|
||||||
|
}
|
||||||
profilePath := ui4.Profile
|
profilePath := ui4.Profile
|
||||||
if profilePath == nil {
|
if profilePath == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -162,6 +191,10 @@ func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if logf != nil {
|
||||||
|
logf("returning %q", windows.UTF16ToString(expanded[:]))
|
||||||
|
}
|
||||||
|
|
||||||
// This buffer is only used briefly, so we don't bother copying it into a shorter slice.
|
// This buffer is only used briefly, so we don't bother copying it into a shorter slice.
|
||||||
return &expanded[0], nil
|
return &expanded[0], nil
|
||||||
}
|
}
|
||||||
|
24
util/winutil/userprofile_windows_test.go
Normal file
24
util/winutil/userprofile_windows_test.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package winutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetRoamingProfilePath(t *testing.T) {
|
||||||
|
token := windows.GetCurrentProcessToken()
|
||||||
|
computerName, userName, err := getComputerAndUserName(token, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := getRoamingProfilePath(t.Logf, token, computerName, userName); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(aaron): Flesh out better once can run tests under domain accounts.
|
||||||
|
}
|
@ -784,3 +784,147 @@ func SetNTString[NTS NTStr, BU BufUnit](nts *NTS, buf []BU) {
|
|||||||
panic("unknown type")
|
panic("unknown type")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type domainControllerAddressType uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
//lint:ignore U1000 maps to a win32 API
|
||||||
|
_DS_INET_ADDRESS domainControllerAddressType = 1
|
||||||
|
_DS_NETBIOS_ADDRESS domainControllerAddressType = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type domainControllerFlag uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
//lint:ignore U1000 maps to a win32 API
|
||||||
|
_DS_PDC_FLAG domainControllerFlag = 0x00000001
|
||||||
|
_DS_GC_FLAG domainControllerFlag = 0x00000004
|
||||||
|
_DS_LDAP_FLAG domainControllerFlag = 0x00000008
|
||||||
|
_DS_DS_FLAG domainControllerFlag = 0x00000010
|
||||||
|
_DS_KDC_FLAG domainControllerFlag = 0x00000020
|
||||||
|
_DS_TIMESERV_FLAG domainControllerFlag = 0x00000040
|
||||||
|
_DS_CLOSEST_FLAG domainControllerFlag = 0x00000080
|
||||||
|
_DS_WRITABLE_FLAG domainControllerFlag = 0x00000100
|
||||||
|
_DS_GOOD_TIMESERV_FLAG domainControllerFlag = 0x00000200
|
||||||
|
_DS_NDNC_FLAG domainControllerFlag = 0x00000400
|
||||||
|
_DS_SELECT_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00000800
|
||||||
|
_DS_FULL_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00001000
|
||||||
|
_DS_WS_FLAG domainControllerFlag = 0x00002000
|
||||||
|
_DS_DS_8_FLAG domainControllerFlag = 0x00004000
|
||||||
|
_DS_DS_9_FLAG domainControllerFlag = 0x00008000
|
||||||
|
_DS_DS_10_FLAG domainControllerFlag = 0x00010000
|
||||||
|
_DS_KEY_LIST_FLAG domainControllerFlag = 0x00020000
|
||||||
|
_DS_PING_FLAGS domainControllerFlag = 0x000FFFFF
|
||||||
|
_DS_DNS_CONTROLLER_FLAG domainControllerFlag = 0x20000000
|
||||||
|
_DS_DNS_DOMAIN_FLAG domainControllerFlag = 0x40000000
|
||||||
|
_DS_DNS_FOREST_FLAG domainControllerFlag = 0x80000000
|
||||||
|
)
|
||||||
|
|
||||||
|
type _DOMAIN_CONTROLLER_INFO struct {
|
||||||
|
DomainControllerName *uint16
|
||||||
|
DomainControllerAddress *uint16
|
||||||
|
DomainControllerAddressType domainControllerAddressType
|
||||||
|
DomainGuid windows.GUID
|
||||||
|
DomainName *uint16
|
||||||
|
DnsForestName *uint16
|
||||||
|
Flags domainControllerFlag
|
||||||
|
DcSiteName *uint16
|
||||||
|
ClientSiteName *uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dci *_DOMAIN_CONTROLLER_INFO) Close() error {
|
||||||
|
if dci == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return windows.NetApiBufferFree((*byte)(unsafe.Pointer(dci)))
|
||||||
|
}
|
||||||
|
|
||||||
|
type dsGetDcNameFlag uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
//lint:ignore U1000 maps to a win32 API
|
||||||
|
_DS_FORCE_REDISCOVERY dsGetDcNameFlag = 0x00000001
|
||||||
|
_DS_DIRECTORY_SERVICE_REQUIRED dsGetDcNameFlag = 0x00000010
|
||||||
|
_DS_DIRECTORY_SERVICE_PREFERRED dsGetDcNameFlag = 0x00000020
|
||||||
|
_DS_GC_SERVER_REQUIRED dsGetDcNameFlag = 0x00000040
|
||||||
|
_DS_PDC_REQUIRED dsGetDcNameFlag = 0x00000080
|
||||||
|
_DS_BACKGROUND_ONLY dsGetDcNameFlag = 0x00000100
|
||||||
|
_DS_IP_REQUIRED dsGetDcNameFlag = 0x00000200
|
||||||
|
_DS_KDC_REQUIRED dsGetDcNameFlag = 0x00000400
|
||||||
|
_DS_TIMESERV_REQUIRED dsGetDcNameFlag = 0x00000800
|
||||||
|
_DS_WRITABLE_REQUIRED dsGetDcNameFlag = 0x00001000
|
||||||
|
_DS_GOOD_TIMESERV_PREFERRED dsGetDcNameFlag = 0x00002000
|
||||||
|
_DS_AVOID_SELF dsGetDcNameFlag = 0x00004000
|
||||||
|
_DS_ONLY_LDAP_NEEDED dsGetDcNameFlag = 0x00008000
|
||||||
|
_DS_IS_FLAT_NAME dsGetDcNameFlag = 0x00010000
|
||||||
|
_DS_IS_DNS_NAME dsGetDcNameFlag = 0x00020000
|
||||||
|
_DS_TRY_NEXTCLOSEST_SITE dsGetDcNameFlag = 0x00040000
|
||||||
|
_DS_DIRECTORY_SERVICE_6_REQUIRED dsGetDcNameFlag = 0x00080000
|
||||||
|
_DS_WEB_SERVICE_REQUIRED dsGetDcNameFlag = 0x00100000
|
||||||
|
_DS_DIRECTORY_SERVICE_8_REQUIRED dsGetDcNameFlag = 0x00200000
|
||||||
|
_DS_DIRECTORY_SERVICE_9_REQUIRED dsGetDcNameFlag = 0x00400000
|
||||||
|
_DS_DIRECTORY_SERVICE_10_REQUIRED dsGetDcNameFlag = 0x00800000
|
||||||
|
_DS_KEY_LIST_SUPPORT_REQUIRED dsGetDcNameFlag = 0x01000000
|
||||||
|
_DS_RETURN_DNS_NAME dsGetDcNameFlag = 0x40000000
|
||||||
|
_DS_RETURN_FLAT_NAME dsGetDcNameFlag = 0x80000000
|
||||||
|
)
|
||||||
|
|
||||||
|
func resolveDomainController(domainName *uint16, domainGUID *windows.GUID) (*_DOMAIN_CONTROLLER_INFO, error) {
|
||||||
|
const flags = _DS_DIRECTORY_SERVICE_REQUIRED | _DS_IS_FLAT_NAME | _DS_RETURN_DNS_NAME
|
||||||
|
var dcInfo *_DOMAIN_CONTROLLER_INFO
|
||||||
|
if err := dsGetDcName(nil, domainName, domainGUID, nil, flags, &dcInfo); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dcInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveDomainController resolves the DNS name of the nearest available
|
||||||
|
// domain controller for the domain specified by domainName.
|
||||||
|
func ResolveDomainController(domainName string) (string, error) {
|
||||||
|
domainName16, err := windows.UTF16PtrFromString(domainName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
dcInfo, err := resolveDomainController(domainName16, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer dcInfo.Close()
|
||||||
|
|
||||||
|
return windows.UTF16PtrToString(dcInfo.DomainControllerName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type _NETSETUP_NAME_TYPE int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
_NetSetupUnknown _NETSETUP_NAME_TYPE = 0
|
||||||
|
_NetSetupMachine _NETSETUP_NAME_TYPE = 1
|
||||||
|
_NetSetupWorkgroup _NETSETUP_NAME_TYPE = 2
|
||||||
|
_NetSetupDomain _NETSETUP_NAME_TYPE = 3
|
||||||
|
_NetSetupNonExistentDomain _NETSETUP_NAME_TYPE = 4
|
||||||
|
_NetSetupDnsMachine _NETSETUP_NAME_TYPE = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
func isDomainName(name *uint16) (bool, error) {
|
||||||
|
err := netValidateName(nil, name, nil, nil, _NetSetupDomain)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return true, nil
|
||||||
|
case windows.ERROR_NO_SUCH_DOMAIN:
|
||||||
|
return false, nil
|
||||||
|
default:
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsDomainName checks whether name represents an existing domain reachable by
|
||||||
|
// the current machine.
|
||||||
|
func IsDomainName(name string) (bool, error) {
|
||||||
|
name16, err := windows.UTF16PtrFromString(name)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return isDomainName(name16)
|
||||||
|
}
|
||||||
|
@ -42,12 +42,15 @@ func errnoErr(e syscall.Errno) error {
|
|||||||
var (
|
var (
|
||||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||||
|
modnetapi32 = windows.NewLazySystemDLL("netapi32.dll")
|
||||||
modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll")
|
modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll")
|
||||||
moduserenv = windows.NewLazySystemDLL("userenv.dll")
|
moduserenv = windows.NewLazySystemDLL("userenv.dll")
|
||||||
|
|
||||||
procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W")
|
procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W")
|
||||||
procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings")
|
procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings")
|
||||||
procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart")
|
procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart")
|
||||||
|
procDsGetDcNameW = modnetapi32.NewProc("DsGetDcNameW")
|
||||||
|
procNetValidateName = modnetapi32.NewProc("NetValidateName")
|
||||||
procRmEndSession = modrstrtmgr.NewProc("RmEndSession")
|
procRmEndSession = modrstrtmgr.NewProc("RmEndSession")
|
||||||
procRmGetList = modrstrtmgr.NewProc("RmGetList")
|
procRmGetList = modrstrtmgr.NewProc("RmGetList")
|
||||||
procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession")
|
procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession")
|
||||||
@ -78,6 +81,22 @@ func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret w
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) {
|
||||||
|
r0, _, _ := syscall.Syscall6(procDsGetDcNameW.Addr(), 6, uintptr(unsafe.Pointer(computerName)), uintptr(unsafe.Pointer(domainName)), uintptr(unsafe.Pointer(domainGuid)), uintptr(unsafe.Pointer(siteName)), uintptr(flags), uintptr(unsafe.Pointer(dcInfo)))
|
||||||
|
if r0 != 0 {
|
||||||
|
ret = syscall.Errno(r0)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) {
|
||||||
|
r0, _, _ := syscall.Syscall6(procNetValidateName.Addr(), 5, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(account)), uintptr(unsafe.Pointer(password)), uintptr(nameType), 0)
|
||||||
|
if r0 != 0 {
|
||||||
|
ret = syscall.Errno(r0)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func rmEndSession(session _RMHANDLE) (ret error) {
|
func rmEndSession(session _RMHANDLE) (ret error) {
|
||||||
r0, _, _ := syscall.Syscall(procRmEndSession.Addr(), 1, uintptr(session), 0, 0)
|
r0, _, _ := syscall.Syscall(procRmEndSession.Addr(), 1, uintptr(session), 0, 0)
|
||||||
if r0 != 0 {
|
if r0 != 0 {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user