From 7354547bd849352acb2bdb15ea99c5088d635568 Mon Sep 17 00:00:00 2001 From: Aaron Klotz Date: Fri, 14 Jun 2024 12:50:28 -0600 Subject: [PATCH] util/winutil: update UserProfile to ensure any environment variables in the roaming profile path are expanded Updates #12383 Signed-off-by: Aaron Klotz --- util/winutil/mksyscall.go | 1 + util/winutil/userprofile_windows.go | 23 +++++++++++------------ util/winutil/zsyscall_windows.go | 29 +++++++++++++++++++---------- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/util/winutil/mksyscall.go b/util/winutil/mksyscall.go index 1bfdffa1a..5fb915b41 100644 --- a/util/winutil/mksyscall.go +++ b/util/winutil/mksyscall.go @@ -6,6 +6,7 @@ //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go //go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go +//sys expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) [int32(failretval)==0] = userenv.ExpandEnvironmentStringsForUserW //sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings //sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW //sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W diff --git a/util/winutil/userprofile_windows.go b/util/winutil/userprofile_windows.go index 99fb99d22..6bedf420b 100644 --- a/util/winutil/userprofile_windows.go +++ b/util/winutil/userprofile_windows.go @@ -80,7 +80,7 @@ func LoadUserProfile(token windows.Token, u *user.User) (up *UserProfile, err er var roamingProfilePath *uint16 if winenv.IsDomainJoined() { - roamingProfilePath, err = getRoamingProfilePath(nil, computerName, userName) + roamingProfilePath, err = getRoamingProfilePath(nil, token, computerName, userName) if err != nil { return nil, err } @@ -134,7 +134,7 @@ func (up *UserProfile) Close() error { return nil } -func getRoamingProfilePath(logf logger.Logf, computerName, userName *uint16) (path *uint16, err error) { +func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, userName *uint16) (path *uint16, err error) { // logf is for debugging/testing. if logf == nil { logf = logger.Discard @@ -152,19 +152,18 @@ func getRoamingProfilePath(logf logger.Logf, computerName, userName *uint16) (pa if profilePath == nil { return nil, nil } - - var sz int - for ptr := unsafe.Pointer(profilePath); *(*uint16)(ptr) != 0; sz++ { - ptr = unsafe.Pointer(uintptr(ptr) + unsafe.Sizeof(*profilePath)) - } - - if sz == 0 { + if *profilePath == 0 { + // Empty string return nil, nil } - buf := unsafe.Slice(profilePath, sz+1) - cp := append([]uint16{}, buf...) - return unsafe.SliceData(cp), nil + var expanded [windows.MAX_PATH + 1]uint16 + if err := expandEnvironmentStringsForUser(token, profilePath, &expanded[0], uint32(len(expanded))); err != nil { + return nil, err + } + + // This buffer is only used briefly, so we don't bother copying it into a shorter slice. + return &expanded[0], nil } func getComputerAndUserName(token windows.Token, u *user.User) (computerName *uint16, userName *uint16, err error) { diff --git a/util/winutil/zsyscall_windows.go b/util/winutil/zsyscall_windows.go index 8bb0091f7..d5d2d8721 100644 --- a/util/winutil/zsyscall_windows.go +++ b/util/winutil/zsyscall_windows.go @@ -45,16 +45,17 @@ func errnoErr(e syscall.Errno) error { modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll") moduserenv = windows.NewLazySystemDLL("userenv.dll") - procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W") - procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings") - procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart") - procRmEndSession = modrstrtmgr.NewProc("RmEndSession") - procRmGetList = modrstrtmgr.NewProc("RmGetList") - procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession") - procRmRegisterResources = modrstrtmgr.NewProc("RmRegisterResources") - procRmStartSession = modrstrtmgr.NewProc("RmStartSession") - procLoadUserProfileW = moduserenv.NewProc("LoadUserProfileW") - procUnloadUserProfile = moduserenv.NewProc("UnloadUserProfile") + procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W") + procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings") + procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart") + procRmEndSession = modrstrtmgr.NewProc("RmEndSession") + procRmGetList = modrstrtmgr.NewProc("RmGetList") + procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession") + procRmRegisterResources = modrstrtmgr.NewProc("RmRegisterResources") + procRmStartSession = modrstrtmgr.NewProc("RmStartSession") + procExpandEnvironmentStringsForUserW = moduserenv.NewProc("ExpandEnvironmentStringsForUserW") + procLoadUserProfileW = moduserenv.NewProc("LoadUserProfileW") + procUnloadUserProfile = moduserenv.NewProc("UnloadUserProfile") ) func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) { @@ -117,6 +118,14 @@ func rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret return } +func expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) { + r1, _, e1 := syscall.Syscall6(procExpandEnvironmentStringsForUserW.Addr(), 4, uintptr(token), uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(dstLen), 0, 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + func loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) { r1, _, e1 := syscall.Syscall(procLoadUserProfileW.Addr(), 2, uintptr(token), uintptr(unsafe.Pointer(profileInfo)), 0) if int32(r1) == 0 {