From bc0cd512ee112d31643ad9e326099e92139aa301 Mon Sep 17 00:00:00 2001 From: Nick Khyl Date: Mon, 10 Feb 2025 17:47:10 -0600 Subject: [PATCH] ipn/desktop: add a new package for managing desktop sessions on Windows This PR adds a new package, ipn/desktop, which provides a platform-agnostic interface for enumerating desktop sessions and registering session callbacks. Currently, it is implemented only for Windows. Updates #14823 Signed-off-by: Nick Khyl --- ipn/desktop/doc.go | 6 + ipn/desktop/mksyscall.go | 24 ++ ipn/desktop/session.go | 58 +++ ipn/desktop/sessions.go | 60 +++ ipn/desktop/sessions_notwindows.go | 15 + ipn/desktop/sessions_windows.go | 672 +++++++++++++++++++++++++++++ ipn/desktop/zsyscall_windows.go | 159 +++++++ 7 files changed, 994 insertions(+) create mode 100644 ipn/desktop/doc.go create mode 100644 ipn/desktop/mksyscall.go create mode 100644 ipn/desktop/session.go create mode 100644 ipn/desktop/sessions.go create mode 100644 ipn/desktop/sessions_notwindows.go create mode 100644 ipn/desktop/sessions_windows.go create mode 100644 ipn/desktop/zsyscall_windows.go diff --git a/ipn/desktop/doc.go b/ipn/desktop/doc.go new file mode 100644 index 000000000..64a332792 --- /dev/null +++ b/ipn/desktop/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package desktop facilitates interaction with the desktop environment +// and user sessions. As of 2025-02-06, it is only implemented for Windows. +package desktop diff --git a/ipn/desktop/mksyscall.go b/ipn/desktop/mksyscall.go new file mode 100644 index 000000000..305138468 --- /dev/null +++ b/ipn/desktop/mksyscall.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys setLastError(dwErrorCode uint32) = kernel32.SetLastError + +//sys registerClassEx(windowClass *_WNDCLASSEX) (atom uint16, err error) [atom==0] = user32.RegisterClassExW +//sys createWindowEx(dwExStyle uint32, lpClassName *uint16, lpWindowName *uint16, dwStyle uint32, x int32, y int32, nWidth int32, nHeight int32, hWndParent windows.HWND, hMenu windows.Handle, hInstance windows.Handle, lpParam unsafe.Pointer) (hWnd windows.HWND, err error) [hWnd==0] = user32.CreateWindowExW +//sys defWindowProc(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) = user32.DefWindowProcW +//sys setWindowLongPtr(hwnd windows.HWND, index int32, newLong uintptr) (res uintptr, err error) [res==0 && e1!=0] = user32.SetWindowLongPtrW +//sys getWindowLongPtr(hwnd windows.HWND, index int32) (res uintptr, err error) [res==0 && e1!=0] = user32.GetWindowLongPtrW +//sys sendMessage(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) = user32.SendMessageW +//sys getMessage(lpMsg *_MSG, hwnd windows.HWND, msgMin uint32, msgMax uint32) (ret int32) = user32.GetMessageW +//sys translateMessage(lpMsg *_MSG) (res bool) = user32.TranslateMessage +//sys dispatchMessage(lpMsg *_MSG) (res uintptr) = user32.DispatchMessageW +//sys destroyWindow(hwnd windows.HWND) (err error) [int32(failretval)==0] = user32.DestroyWindow +//sys postQuitMessage(exitCode int32) = user32.PostQuitMessage + +//sys registerSessionNotification(hServer windows.Handle, hwnd windows.HWND, flags uint32) (err error) [int32(failretval)==0] = wtsapi32.WTSRegisterSessionNotificationEx +//sys unregisterSessionNotification(hServer windows.Handle, hwnd windows.HWND) (err error) [int32(failretval)==0] = wtsapi32.WTSUnRegisterSessionNotificationEx diff --git a/ipn/desktop/session.go b/ipn/desktop/session.go new file mode 100644 index 000000000..c95378914 --- /dev/null +++ b/ipn/desktop/session.go @@ -0,0 +1,58 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +import ( + "fmt" + + "tailscale.com/ipn/ipnauth" +) + +// SessionID is a unique identifier of a desktop session. +type SessionID uint + +// SessionStatus is the status of a desktop session. +type SessionStatus int + +const ( + // ClosedSession is a session that does not exist, is not yet initialized by the OS, + // or has been terminated. + ClosedSession SessionStatus = iota + // ForegroundSession is a session that a user can interact with, + // such as when attached to a physical console or an active, + // unlocked RDP connection. + ForegroundSession + // BackgroundSession indicates that the session is locked, disconnected, + // or otherwise running without user presence or interaction. + BackgroundSession +) + +// String implements [fmt.Stringer]. +func (s SessionStatus) String() string { + switch s { + case ClosedSession: + return "Closed" + case ForegroundSession: + return "Foreground" + case BackgroundSession: + return "Background" + default: + panic("unreachable") + } +} + +// Session is a state of a desktop session at a given point in time. +type Session struct { + ID SessionID // Identifier of the session; can be reused after the session is closed. + Status SessionStatus // The status of the session, such as foreground or background. + User ipnauth.Actor // User logged into the session. +} + +// Description returns a human-readable description of the session. +func (s *Session) Description() string { + if maybeUsername, _ := s.User.Username(); maybeUsername != "" { // best effort + return fmt.Sprintf("Session %d - %q (%s)", s.ID, maybeUsername, s.Status) + } + return fmt.Sprintf("Session %d (%s)", s.ID, s.Status) +} diff --git a/ipn/desktop/sessions.go b/ipn/desktop/sessions.go new file mode 100644 index 000000000..8bf7a75e2 --- /dev/null +++ b/ipn/desktop/sessions.go @@ -0,0 +1,60 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +import ( + "errors" + "runtime" +) + +// ErrNotImplemented is returned by [NewSessionManager] when it is not +// implemented for the current GOOS. +var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) + +// SessionInitCallback is a function that is called once per [Session]. +// It returns an optional cleanup function that is called when the session +// is about to be destroyed, or nil if no cleanup is needed. +// It is not safe to call SessionManager methods from within the callback. +type SessionInitCallback func(session *Session) (cleanup func()) + +// SessionStateCallback is a function that reports the initial or updated +// state of a [Session], such as when it transitions between foreground and background. +// It is guaranteed to be called after all registered [SessionInitCallback] functions +// have completed, and before any cleanup functions are called for the same session. +// It is not safe to call SessionManager methods from within the callback. +type SessionStateCallback func(session *Session) + +// SessionManager is an interface that provides access to desktop sessions on the current platform. +// It is safe for concurrent use. +type SessionManager interface { + // Init explicitly initializes the receiver. + // Unless the receiver is explicitly initialized, it will be lazily initialized + // on the first call to any other method. + // It is safe to call Init multiple times. + Init() error + + // Sessions returns a session snapshot taken at the time of the call. + // Since sessions can be created or destroyed at any time, it may become + // outdated as soon as it is returned. + // + // It is primarily intended for logging and debugging. + // Prefer registering a [SessionInitCallback] or [SessionStateCallback] + // in contexts requiring stronger guarantees. + Sessions() (map[SessionID]*Session, error) + + // RegisterInitCallback registers a [SessionInitCallback] that is called for each existing session + // and for each new session that is created, until the returned unregister function is called. + // If the specified [SessionInitCallback] returns a cleanup function, it is called when the session + // is about to be destroyed. The callback function is guaranteed to be called once and only once + // for each existing and new session. + RegisterInitCallback(cb SessionInitCallback) (unregister func(), err error) + + // RegisterStateCallback registers a [SessionStateCallback] that is called for each existing session + // and every time the state of a session changes, until the returned unregister function is called. + RegisterStateCallback(cb SessionStateCallback) (unregister func(), err error) + + // Close waits for all registered callbacks to complete + // and releases resources associated with the receiver. + Close() error +} diff --git a/ipn/desktop/sessions_notwindows.go b/ipn/desktop/sessions_notwindows.go new file mode 100644 index 000000000..da3230a45 --- /dev/null +++ b/ipn/desktop/sessions_notwindows.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package desktop + +import "tailscale.com/types/logger" + +// NewSessionManager returns a new [SessionManager] for the current platform, +// [ErrNotImplemented] if the platform is not supported, or an error if the +// session manager could not be created. +func NewSessionManager(logger.Logf) (SessionManager, error) { + return nil, ErrNotImplemented +} diff --git a/ipn/desktop/sessions_windows.go b/ipn/desktop/sessions_windows.go new file mode 100644 index 000000000..f1b88d573 --- /dev/null +++ b/ipn/desktop/sessions_windows.go @@ -0,0 +1,672 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +import ( + "context" + "errors" + "fmt" + "runtime" + "sync" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" + "tailscale.com/ipn/ipnauth" + "tailscale.com/types/logger" + "tailscale.com/util/must" + "tailscale.com/util/set" +) + +// wtsManager is a [SessionManager] implementation for Windows. +type wtsManager struct { + logf logger.Logf + ctx context.Context // cancelled when the manager is closed + ctxCancel context.CancelFunc + + initOnce func() error + watcher *sessionWatcher + + mu sync.Mutex + sessions map[SessionID]*wtsSession + initCbs set.HandleSet[SessionInitCallback] + stateCbs set.HandleSet[SessionStateCallback] +} + +// NewSessionManager returns a new [SessionManager] for the current platform, +func NewSessionManager(logf logger.Logf) (SessionManager, error) { + ctx, ctxCancel := context.WithCancel(context.Background()) + m := &wtsManager{ + logf: logf, + ctx: ctx, + ctxCancel: ctxCancel, + sessions: make(map[SessionID]*wtsSession), + } + m.watcher = newSessionWatcher(m.ctx, m.logf, m.sessionEventHandler) + + m.initOnce = sync.OnceValue(func() error { + if err := waitUntilWTSReady(m.ctx); err != nil { + return fmt.Errorf("WTS is not ready: %w", err) + } + + m.mu.Lock() + defer m.mu.Unlock() + if err := m.watcher.Start(); err != nil { + return fmt.Errorf("failed to start session watcher: %w", err) + } + + var err error + m.sessions, err = enumerateSessions() + return err // may be nil or non-nil + }) + return m, nil +} + +// Init implements [SessionManager]. +func (m *wtsManager) Init() error { + return m.initOnce() +} + +// Sessions implements [SessionManager]. +func (m *wtsManager) Sessions() (map[SessionID]*Session, error) { + if err := m.initOnce(); err != nil { + return nil, err + } + + m.mu.Lock() + defer m.mu.Unlock() + sessions := make(map[SessionID]*Session, len(m.sessions)) + for _, s := range m.sessions { + sessions[s.id] = s.AsSession() + } + return sessions, nil +} + +// RegisterInitCallback implements [SessionManager]. +func (m *wtsManager) RegisterInitCallback(cb SessionInitCallback) (unregister func(), err error) { + if err := m.initOnce(); err != nil { + return nil, err + } + if cb == nil { + return nil, errors.New("nil callback") + } + + m.mu.Lock() + defer m.mu.Unlock() + handle := m.initCbs.Add(cb) + + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, s := range m.sessions { + if cleanup := cb(s.AsSession()); cleanup != nil { + s.cleanup = append(s.cleanup, cleanup) + } + } + + return func() { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.initCbs, handle) + }, nil +} + +// RegisterStateCallback implements [SessionManager]. +func (m *wtsManager) RegisterStateCallback(cb SessionStateCallback) (unregister func(), err error) { + if err := m.initOnce(); err != nil { + return nil, err + } + if cb == nil { + return nil, errors.New("nil callback") + } + + m.mu.Lock() + defer m.mu.Unlock() + handle := m.stateCbs.Add(cb) + + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, s := range m.sessions { + cb(s.AsSession()) + } + + return func() { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.stateCbs, handle) + }, nil +} + +func (m *wtsManager) sessionEventHandler(id SessionID, event uint32) { + m.mu.Lock() + defer m.mu.Unlock() + switch event { + case windows.WTS_SESSION_LOGON: + // The session may have been created after we started watching, + // but before the initial enumeration was performed. + // Do not create a new session if it already exists. + if _, _, err := m.getOrCreateSessionLocked(id); err != nil { + m.logf("[unexpected] getOrCreateSessionLocked(%d): %v", id, err) + } + case windows.WTS_SESSION_LOCK: + if err := m.setSessionStatusLocked(id, BackgroundSession); err != nil { + m.logf("[unexpected] setSessionStatusLocked(%d, BackgroundSession): %v", id, err) + } + case windows.WTS_SESSION_UNLOCK: + if err := m.setSessionStatusLocked(id, ForegroundSession); err != nil { + m.logf("[unexpected] setSessionStatusLocked(%d, ForegroundSession): %v", id, err) + } + case windows.WTS_SESSION_LOGOFF: + if err := m.deleteSessionLocked(id); err != nil { + m.logf("[unexpected] deleteSessionLocked(%d): %v", id, err) + } + } +} + +func (m *wtsManager) getOrCreateSessionLocked(id SessionID) (_ *wtsSession, created bool, err error) { + if s, ok := m.sessions[id]; ok { + return s, false, nil + } + + s, err := newWTSSession(id, ForegroundSession) + if err != nil { + return nil, false, err + } + m.sessions[id] = s + + session := s.AsSession() + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, cb := range m.initCbs { + if cleanup := cb(session); cleanup != nil { + s.cleanup = append(s.cleanup, cleanup) + } + } + for _, cb := range m.stateCbs { + cb(session) + } + + return s, true, err +} + +func (m *wtsManager) setSessionStatusLocked(id SessionID, status SessionStatus) error { + s, _, err := m.getOrCreateSessionLocked(id) + if err != nil { + return err + } + if s.status == status { + return nil + } + + s.status = status + session := s.AsSession() + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, cb := range m.stateCbs { + cb(session) + } + return nil +} + +func (m *wtsManager) deleteSessionLocked(id SessionID) error { + s, ok := m.sessions[id] + if !ok { + return nil + } + + s.status = ClosedSession + session := s.AsSession() + // TODO(nickkhyl): enqueue callbacks (and [wtsSession.close]!) in a separate goroutine? + for _, cb := range m.stateCbs { + cb(session) + } + + delete(m.sessions, id) + return s.close() +} + +func (m *wtsManager) Close() error { + m.ctxCancel() + + if m.watcher != nil { + err := m.watcher.Stop() + if err != nil { + return err + } + m.watcher = nil + } + + m.mu.Lock() + defer m.mu.Unlock() + m.initCbs = nil + m.stateCbs = nil + errs := make([]error, 0, len(m.sessions)) + for _, s := range m.sessions { + errs = append(errs, s.close()) + } + m.sessions = nil + return errors.Join(errs...) +} + +type wtsSession struct { + id SessionID + user *ipnauth.WindowsActor + + status SessionStatus + + cleanup []func() +} + +func newWTSSession(id SessionID, status SessionStatus) (*wtsSession, error) { + var token windows.Token + if err := windows.WTSQueryUserToken(uint32(id), &token); err != nil { + return nil, err + } + user, err := ipnauth.NewWindowsActorWithToken(token) + if err != nil { + return nil, err + } + return &wtsSession{id, user, status, nil}, nil +} + +// enumerateSessions returns a map of all active WTS sessions. +func enumerateSessions() (map[SessionID]*wtsSession, error) { + const reserved, version uint32 = 0, 1 + var numSessions uint32 + var sessionInfos *windows.WTS_SESSION_INFO + if err := windows.WTSEnumerateSessions(_WTS_CURRENT_SERVER_HANDLE, reserved, version, &sessionInfos, &numSessions); err != nil { + return nil, fmt.Errorf("WTSEnumerateSessions failed: %w", err) + } + defer windows.WTSFreeMemory(uintptr(unsafe.Pointer(sessionInfos))) + + sessions := make(map[SessionID]*wtsSession, numSessions) + for _, si := range unsafe.Slice(sessionInfos, numSessions) { + status := _WTS_CONNECTSTATE_CLASS(si.State).ToSessionStatus() + if status == ClosedSession { + // The session does not exist as far as we're concerned. + // It may be in the process of being created or destroyed, + // or be a special "listener" session, etc. + continue + } + id := SessionID(si.SessionID) + session, err := newWTSSession(id, status) + if err != nil { + continue + } + sessions[id] = session + } + return sessions, nil +} + +func (s *wtsSession) AsSession() *Session { + return &Session{ + ID: s.id, + Status: s.status, + // wtsSession owns the user; don't let the caller close it + User: ipnauth.WithoutClose(s.user), + } +} + +func (m *wtsSession) close() error { + for _, cleanup := range m.cleanup { + cleanup() + } + m.cleanup = nil + + if m.user != nil { + if err := m.user.Close(); err != nil { + return err + } + m.user = nil + } + return nil +} + +type sessionEventHandler func(id SessionID, event uint32) + +// TODO(nickkhyl): implement a sessionWatcher that does not use the message queue. +// One possible approach is to have the tailscaled service register a HandlerEx function +// and stream SERVICE_CONTROL_SESSIONCHANGE events to the tailscaled subprocess +// (the actual tailscaled backend), exposing these events via [sessionWatcher]/[wtsManager]. +// +// See tailscale/corp#26477 for details and tracking. +type sessionWatcher struct { + logf logger.Logf + ctx context.Context // canceled to stop the watcher + ctxCancel context.CancelFunc // cancels the watcher + hWnd windows.HWND // window handle for receiving session change notifications + handler sessionEventHandler // called on session events + + mu sync.Mutex + doneCh chan error // written to when the watcher exits; nil if not started +} + +func newSessionWatcher(ctx context.Context, logf logger.Logf, handler sessionEventHandler) *sessionWatcher { + ctx, cancel := context.WithCancel(ctx) + return &sessionWatcher{logf: logf, ctx: ctx, ctxCancel: cancel, handler: handler} +} + +func (sw *sessionWatcher) Start() error { + sw.mu.Lock() + defer sw.mu.Unlock() + + select { + case <-sw.ctx.Done(): + return fmt.Errorf("sessionWatcher already stopped: %w", sw.ctx.Err()) + default: + } + + if sw.doneCh != nil { + // Already started. + return nil + } + sw.doneCh = make(chan error, 1) + + startedCh := make(chan error, 1) + go sw.run(startedCh) + if err := <-startedCh; err != nil { + return err + } + + // Signal the window to unsubscribe from session notifications + // and shut down gracefully when the sessionWatcher is stopped. + context.AfterFunc(sw.ctx, func() { + sendMessage(sw.hWnd, _WM_CLOSE, 0, 0) + }) + return nil +} + +func (sw *sessionWatcher) run(started chan<- error) { + runtime.LockOSThread() + defer func() { + runtime.UnlockOSThread() + close(sw.doneCh) + }() + err := sw.createMessageWindow() + started <- err + if err != nil { + return + } + pumpThreadMessages() +} + +// Stop stops the session watcher and waits for it to exit. +func (sw *sessionWatcher) Stop() error { + sw.ctxCancel() + + sw.mu.Lock() + doneCh := sw.doneCh + sw.doneCh = nil + sw.mu.Unlock() + + if doneCh != nil { + return <-doneCh + } + return nil +} + +const watcherWindowClassName = "Tailscale-SessionManager" + +var watcherWindowClassName16 = sync.OnceValue(func() *uint16 { + return must.Get(syscall.UTF16PtrFromString(watcherWindowClassName)) +}) + +var registerSessionManagerWindowClass = sync.OnceValue(func() error { + var hInst windows.Handle + if err := windows.GetModuleHandleEx(0, nil, &hInst); err != nil { + return fmt.Errorf("GetModuleHandle: %w", err) + } + wc := _WNDCLASSEX{ + CbSize: uint32(unsafe.Sizeof(_WNDCLASSEX{})), + HInstance: hInst, + LpfnWndProc: syscall.NewCallback(sessionWatcherWndProc), + LpszClassName: watcherWindowClassName16(), + } + if _, err := registerClassEx(&wc); err != nil { + return fmt.Errorf("RegisterClassEx(%q): %w", watcherWindowClassName, err) + } + return nil +}) + +func (sw *sessionWatcher) createMessageWindow() error { + if err := registerSessionManagerWindowClass(); err != nil { + return err + } + _, err := createWindowEx( + 0, // dwExStyle + watcherWindowClassName16(), // lpClassName + nil, // lpWindowName + 0, // dwStyle + 0, // x + 0, // y + 0, // nWidth + 0, // nHeight + _HWND_MESSAGE, // hWndParent; message-only window + 0, // hMenu + 0, // hInstance + unsafe.Pointer(sw), // lpParam + ) + if err != nil { + return fmt.Errorf("CreateWindowEx: %w", err) + } + return nil +} + +func (sw *sessionWatcher) wndProc(hWnd windows.HWND, msg uint32, wParam, lParam uintptr) (result uintptr) { + switch msg { + case _WM_CREATE: + err := registerSessionNotification(_WTS_CURRENT_SERVER_HANDLE, hWnd, _NOTIFY_FOR_ALL_SESSIONS) + if err != nil { + sw.logf("[unexpected] failed to register for session notifications: %v", err) + return ^uintptr(0) + } + sw.logf("registered for session notifications") + case _WM_WTSSESSION_CHANGE: + sw.handler(SessionID(lParam), uint32(wParam)) + return 0 + case _WM_CLOSE: + if err := destroyWindow(hWnd); err != nil { + sw.logf("[unexpected] failed to destroy window: %v", err) + } + return 0 + case _WM_DESTROY: + err := unregisterSessionNotification(_WTS_CURRENT_SERVER_HANDLE, hWnd) + if err != nil { + sw.logf("[unexpected] failed to unregister session notifications callback: %v", err) + } + sw.logf("unregistered from session notifications") + return 0 + case _WM_NCDESTROY: + sw.hWnd = 0 + postQuitMessage(0) // quit the message loop for this thread + } + return defWindowProc(hWnd, msg, wParam, lParam) +} + +func (sw *sessionWatcher) setHandle(hwnd windows.HWND) error { + sw.hWnd = hwnd + setLastError(0) + _, err := setWindowLongPtr(sw.hWnd, _GWLP_USERDATA, uintptr(unsafe.Pointer(sw))) + return err // may be nil or non-nil +} + +func sessionWatcherByHandle(hwnd windows.HWND) *sessionWatcher { + val, _ := getWindowLongPtr(hwnd, _GWLP_USERDATA) + return (*sessionWatcher)(unsafe.Pointer(val)) +} + +func sessionWatcherWndProc(hWnd windows.HWND, msg uint32, wParam, lParam uintptr) (result uintptr) { + if msg == _WM_NCCREATE { + cs := (*_CREATESTRUCT)(unsafe.Pointer(lParam)) + sw := (*sessionWatcher)(unsafe.Pointer(cs.CreateParams)) + if sw == nil { + return 0 + } + if err := sw.setHandle(hWnd); err != nil { + return 0 + } + return defWindowProc(hWnd, msg, wParam, lParam) + } + if sw := sessionWatcherByHandle(hWnd); sw != nil { + return sw.wndProc(hWnd, msg, wParam, lParam) + } + return defWindowProc(hWnd, msg, wParam, lParam) +} + +func pumpThreadMessages() { + var msg _MSG + for getMessage(&msg, 0, 0, 0) != 0 { + translateMessage(&msg) + dispatchMessage(&msg) + } +} + +// waitUntilWTSReady waits until the Windows Terminal Services (WTS) is ready. +// This is necessary because the WTS API functions may fail if called before +// the WTS is ready. +// +// https://web.archive.org/web/20250207011738/https://learn.microsoft.com/en-us/windows/win32/api/wtsapi32/nf-wtsapi32-wtsregistersessionnotificationex +func waitUntilWTSReady(ctx context.Context) error { + eventName16, err := windows.UTF16PtrFromString(`Global\TermSrvReadyEvent`) + if err != nil { + return err + } + event, err := windows.OpenEvent(windows.SYNCHRONIZE, false, eventName16) + if err != nil { + return err + } + defer windows.CloseHandle(event) + return waitForContextOrHandle(ctx, event) +} + +// waitForContextOrHandle waits for either the context to be done or a handle to be signaled. +func waitForContextOrHandle(ctx context.Context, handle windows.Handle) error { + contextDoneEvent, cleanup, err := channelToEvent(ctx.Done()) + if err != nil { + return err + } + defer cleanup() + + handles := []windows.Handle{contextDoneEvent, handle} + waitCode, err := windows.WaitForMultipleObjects(handles, false, windows.INFINITE) + if err != nil { + return err + } + + waitCode -= windows.WAIT_OBJECT_0 + if waitCode == 0 { // contextDoneEvent + return ctx.Err() + } + return nil +} + +// channelToEvent returns an auto-reset event that is set when the channel +// becomes receivable, including when the channel is closed. +func channelToEvent[T any](c <-chan T) (evt windows.Handle, cleanup func(), err error) { + evt, err = windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return 0, nil, err + } + + cancel := make(chan struct{}) + + go func() { + select { + case <-cancel: + return + case <-c: + } + windows.SetEvent(evt) + }() + + cleanup = func() { + close(cancel) + windows.CloseHandle(evt) + } + + return evt, cleanup, nil +} + +type _WNDCLASSEX struct { + CbSize uint32 + Style uint32 + LpfnWndProc uintptr + CbClsExtra int32 + CbWndExtra int32 + HInstance windows.Handle + HIcon windows.Handle + HCursor windows.Handle + HbrBackground windows.Handle + LpszMenuName *uint16 + LpszClassName *uint16 + HIconSm windows.Handle +} + +type _CREATESTRUCT struct { + CreateParams uintptr + Instance windows.Handle + Menu windows.Handle + Parent windows.HWND + Cy int32 + Cx int32 + Y int32 + X int32 + Style int32 + Name *uint16 + ClassName *uint16 + ExStyle uint32 +} + +type _POINT struct { + X, Y int32 +} + +type _MSG struct { + HWnd windows.HWND + Message uint32 + WParam uintptr + LParam uintptr + Time uint32 + Pt _POINT +} + +const ( + _WM_CREATE = 1 + _WM_DESTROY = 2 + _WM_CLOSE = 16 + _WM_NCCREATE = 129 + _WM_QUIT = 18 + _WM_NCDESTROY = 130 + + // _WM_WTSSESSION_CHANGE is a message sent to windows that have registered + // for session change notifications, informing them of changes in session state. + // + // https://web.archive.org/web/20250207012421/https://learn.microsoft.com/en-us/windows/win32/termserv/wm-wtssession-change + _WM_WTSSESSION_CHANGE = 0x02B1 +) + +const _GWLP_USERDATA = -21 + +const _HWND_MESSAGE = ^windows.HWND(2) + +// _NOTIFY_FOR_ALL_SESSIONS indicates that the window should receive +// session change notifications for all sessions on the specified server. +const _NOTIFY_FOR_ALL_SESSIONS = 1 + +// _WTS_CURRENT_SERVER_HANDLE indicates that the window should receive +// session change notifications for the host itself rather than a remote server. +const _WTS_CURRENT_SERVER_HANDLE = windows.Handle(0) + +// _WTS_CONNECTSTATE_CLASS represents the connection state of a session. +// +// https://web.archive.org/web/20250206082427/https://learn.microsoft.com/en-us/windows/win32/api/wtsapi32/ne-wtsapi32-wts_connectstate_class +type _WTS_CONNECTSTATE_CLASS int32 + +// ToSessionStatus converts cs to a [SessionStatus]. +func (cs _WTS_CONNECTSTATE_CLASS) ToSessionStatus() SessionStatus { + switch cs { + case windows.WTSActive: + return ForegroundSession + case windows.WTSDisconnected: + return BackgroundSession + default: + // The session does not exist as far as we're concerned. + return ClosedSession + } +} diff --git a/ipn/desktop/zsyscall_windows.go b/ipn/desktop/zsyscall_windows.go new file mode 100644 index 000000000..222ab49e5 --- /dev/null +++ b/ipn/desktop/zsyscall_windows.go @@ -0,0 +1,159 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package desktop + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + moduser32 = windows.NewLazySystemDLL("user32.dll") + modwtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll") + + procSetLastError = modkernel32.NewProc("SetLastError") + procCreateWindowExW = moduser32.NewProc("CreateWindowExW") + procDefWindowProcW = moduser32.NewProc("DefWindowProcW") + procDestroyWindow = moduser32.NewProc("DestroyWindow") + procDispatchMessageW = moduser32.NewProc("DispatchMessageW") + procGetMessageW = moduser32.NewProc("GetMessageW") + procGetWindowLongPtrW = moduser32.NewProc("GetWindowLongPtrW") + procPostQuitMessage = moduser32.NewProc("PostQuitMessage") + procRegisterClassExW = moduser32.NewProc("RegisterClassExW") + procSendMessageW = moduser32.NewProc("SendMessageW") + procSetWindowLongPtrW = moduser32.NewProc("SetWindowLongPtrW") + procTranslateMessage = moduser32.NewProc("TranslateMessage") + procWTSRegisterSessionNotificationEx = modwtsapi32.NewProc("WTSRegisterSessionNotificationEx") + procWTSUnRegisterSessionNotificationEx = modwtsapi32.NewProc("WTSUnRegisterSessionNotificationEx") +) + +func setLastError(dwErrorCode uint32) { + syscall.Syscall(procSetLastError.Addr(), 1, uintptr(dwErrorCode), 0, 0) + return +} + +func createWindowEx(dwExStyle uint32, lpClassName *uint16, lpWindowName *uint16, dwStyle uint32, x int32, y int32, nWidth int32, nHeight int32, hWndParent windows.HWND, hMenu windows.Handle, hInstance windows.Handle, lpParam unsafe.Pointer) (hWnd windows.HWND, err error) { + r0, _, e1 := syscall.Syscall12(procCreateWindowExW.Addr(), 12, uintptr(dwExStyle), uintptr(unsafe.Pointer(lpClassName)), uintptr(unsafe.Pointer(lpWindowName)), uintptr(dwStyle), uintptr(x), uintptr(y), uintptr(nWidth), uintptr(nHeight), uintptr(hWndParent), uintptr(hMenu), uintptr(hInstance), uintptr(lpParam)) + hWnd = windows.HWND(r0) + if hWnd == 0 { + err = errnoErr(e1) + } + return +} + +func defWindowProc(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) { + r0, _, _ := syscall.Syscall6(procDefWindowProcW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0) + res = uintptr(r0) + return +} + +func destroyWindow(hwnd windows.HWND) (err error) { + r1, _, e1 := syscall.Syscall(procDestroyWindow.Addr(), 1, uintptr(hwnd), 0, 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func dispatchMessage(lpMsg *_MSG) (res uintptr) { + r0, _, _ := syscall.Syscall(procDispatchMessageW.Addr(), 1, uintptr(unsafe.Pointer(lpMsg)), 0, 0) + res = uintptr(r0) + return +} + +func getMessage(lpMsg *_MSG, hwnd windows.HWND, msgMin uint32, msgMax uint32) (ret int32) { + r0, _, _ := syscall.Syscall6(procGetMessageW.Addr(), 4, uintptr(unsafe.Pointer(lpMsg)), uintptr(hwnd), uintptr(msgMin), uintptr(msgMax), 0, 0) + ret = int32(r0) + return +} + +func getWindowLongPtr(hwnd windows.HWND, index int32) (res uintptr, err error) { + r0, _, e1 := syscall.Syscall(procGetWindowLongPtrW.Addr(), 2, uintptr(hwnd), uintptr(index), 0) + res = uintptr(r0) + if res == 0 && e1 != 0 { + err = errnoErr(e1) + } + return +} + +func postQuitMessage(exitCode int32) { + syscall.Syscall(procPostQuitMessage.Addr(), 1, uintptr(exitCode), 0, 0) + return +} + +func registerClassEx(windowClass *_WNDCLASSEX) (atom uint16, err error) { + r0, _, e1 := syscall.Syscall(procRegisterClassExW.Addr(), 1, uintptr(unsafe.Pointer(windowClass)), 0, 0) + atom = uint16(r0) + if atom == 0 { + err = errnoErr(e1) + } + return +} + +func sendMessage(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) { + r0, _, _ := syscall.Syscall6(procSendMessageW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0) + res = uintptr(r0) + return +} + +func setWindowLongPtr(hwnd windows.HWND, index int32, newLong uintptr) (res uintptr, err error) { + r0, _, e1 := syscall.Syscall(procSetWindowLongPtrW.Addr(), 3, uintptr(hwnd), uintptr(index), uintptr(newLong)) + res = uintptr(r0) + if res == 0 && e1 != 0 { + err = errnoErr(e1) + } + return +} + +func translateMessage(lpMsg *_MSG) (res bool) { + r0, _, _ := syscall.Syscall(procTranslateMessage.Addr(), 1, uintptr(unsafe.Pointer(lpMsg)), 0, 0) + res = r0 != 0 + return +} + +func registerSessionNotification(hServer windows.Handle, hwnd windows.HWND, flags uint32) (err error) { + r1, _, e1 := syscall.Syscall(procWTSRegisterSessionNotificationEx.Addr(), 3, uintptr(hServer), uintptr(hwnd), uintptr(flags)) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func unregisterSessionNotification(hServer windows.Handle, hwnd windows.HWND) (err error) { + r1, _, e1 := syscall.Syscall(procWTSUnRegisterSessionNotificationEx.Addr(), 2, uintptr(hServer), uintptr(hwnd), 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +}