util/winutil: update UserProfile to ensure any environment variables in the roaming profile path are expanded

Updates #12383

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
Aaron Klotz 2024-06-14 12:50:28 -06:00
parent a8ee83e2c5
commit 7354547bd8
3 changed files with 31 additions and 22 deletions

View File

@ -6,6 +6,7 @@
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go
//sys expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) [int32(failretval)==0] = userenv.ExpandEnvironmentStringsForUserW
//sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings
//sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW
//sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W

View File

@ -80,7 +80,7 @@ func LoadUserProfile(token windows.Token, u *user.User) (up *UserProfile, err er
var roamingProfilePath *uint16
if winenv.IsDomainJoined() {
roamingProfilePath, err = getRoamingProfilePath(nil, computerName, userName)
roamingProfilePath, err = getRoamingProfilePath(nil, token, computerName, userName)
if err != nil {
return nil, err
}
@ -134,7 +134,7 @@ func (up *UserProfile) Close() error {
return nil
}
func getRoamingProfilePath(logf logger.Logf, computerName, userName *uint16) (path *uint16, err error) {
func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, userName *uint16) (path *uint16, err error) {
// logf is for debugging/testing.
if logf == nil {
logf = logger.Discard
@ -152,19 +152,18 @@ func getRoamingProfilePath(logf logger.Logf, computerName, userName *uint16) (pa
if profilePath == nil {
return nil, nil
}
var sz int
for ptr := unsafe.Pointer(profilePath); *(*uint16)(ptr) != 0; sz++ {
ptr = unsafe.Pointer(uintptr(ptr) + unsafe.Sizeof(*profilePath))
}
if sz == 0 {
if *profilePath == 0 {
// Empty string
return nil, nil
}
buf := unsafe.Slice(profilePath, sz+1)
cp := append([]uint16{}, buf...)
return unsafe.SliceData(cp), nil
var expanded [windows.MAX_PATH + 1]uint16
if err := expandEnvironmentStringsForUser(token, profilePath, &expanded[0], uint32(len(expanded))); err != nil {
return nil, err
}
// This buffer is only used briefly, so we don't bother copying it into a shorter slice.
return &expanded[0], nil
}
func getComputerAndUserName(token windows.Token, u *user.User) (computerName *uint16, userName *uint16, err error) {

View File

@ -45,16 +45,17 @@ func errnoErr(e syscall.Errno) error {
modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll")
moduserenv = windows.NewLazySystemDLL("userenv.dll")
procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W")
procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings")
procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart")
procRmEndSession = modrstrtmgr.NewProc("RmEndSession")
procRmGetList = modrstrtmgr.NewProc("RmGetList")
procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession")
procRmRegisterResources = modrstrtmgr.NewProc("RmRegisterResources")
procRmStartSession = modrstrtmgr.NewProc("RmStartSession")
procLoadUserProfileW = moduserenv.NewProc("LoadUserProfileW")
procUnloadUserProfile = moduserenv.NewProc("UnloadUserProfile")
procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W")
procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings")
procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart")
procRmEndSession = modrstrtmgr.NewProc("RmEndSession")
procRmGetList = modrstrtmgr.NewProc("RmGetList")
procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession")
procRmRegisterResources = modrstrtmgr.NewProc("RmRegisterResources")
procRmStartSession = modrstrtmgr.NewProc("RmStartSession")
procExpandEnvironmentStringsForUserW = moduserenv.NewProc("ExpandEnvironmentStringsForUserW")
procLoadUserProfileW = moduserenv.NewProc("LoadUserProfileW")
procUnloadUserProfile = moduserenv.NewProc("UnloadUserProfile")
)
func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) {
@ -117,6 +118,14 @@ func rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret
return
}
func expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procExpandEnvironmentStringsForUserW.Addr(), 4, uintptr(token), uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(dstLen), 0, 0)
if int32(r1) == 0 {
err = errnoErr(e1)
}
return
}
func loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) {
r1, _, e1 := syscall.Syscall(procLoadUserProfileW.Addr(), 2, uintptr(token), uintptr(unsafe.Pointer(profileInfo)), 0)
if int32(r1) == 0 {