tailscale/util/winutil/svcdiag_windows.go
Aaron Klotz 5812093d31 util/winutil: publicize existing functions for opening read-only connections to the Windows Service Control Manager
We're going to need to access these from code outside winutil.

Updates #10215

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
2023-12-22 10:52:50 -08:00

304 lines
8.5 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package winutil
import (
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
"tailscale.com/types/logger"
"tailscale.com/util/set"
)
// LogSvcState obtains the state of the Windows service named rootSvcName and
// all of its dependencies, and then emits that state to logf.
func LogSvcState(logf logger.Logf, rootSvcName string) {
logEntries := []svcStateLogEntry{}
walkFn := func(svc *mgr.Service, config mgr.Config) {
status, err := svc.Query()
if err != nil {
logf("Failed retrieving Status for service %q: %v", svc.Name, err)
}
logEntries = append(logEntries, makeLogEntry(svc, status, config))
}
err := walkServices(rootSvcName, walkFn)
if err != nil {
logf("LogSvcState error: %v", err)
return
}
json, err := json.MarshalIndent(logEntries, "", " ")
if err != nil {
logf("Error marshaling service log entries: %v", err)
return
}
var builder strings.Builder
builder.WriteString("State of service ")
fmt.Fprintf(&builder, "%q", rootSvcName)
builder.WriteString(" and its dependencies:")
builder.WriteString("\n")
builder.Write(json)
builder.WriteString("\n")
logf(builder.String())
}
// walkSvcFunc is type of the callback function invoked by WalkServices.
type walkSvcFunc func(*mgr.Service, mgr.Config)
// walkServices opens the service named rootSvcName and walks its dependency
// graph, invoking callback for each service (including the root itself).
func walkServices(rootSvcName string, callback walkSvcFunc) error {
scm, err := ConnectToLocalSCMForRead()
if err != nil {
return fmt.Errorf("connecting to Service Control Manager: %w", err)
}
defer scm.Disconnect()
rootSvc, err := OpenServiceForRead(scm, rootSvcName)
if err != nil {
return fmt.Errorf("opening service %q: %w", rootSvcName, err)
}
deps := []*mgr.Service{rootSvc}
defer func() {
// Any service still in deps when we return is open and must be closed.
for _, dep := range deps {
dep.Close()
}
}()
seen := set.Set[string]{}
for err == nil && len(deps) > 0 {
err = func() error {
curSvc := deps[len(deps)-1]
defer curSvc.Close()
deps = deps[:len(deps)-1]
seen.Add(curSvc.Name)
curCfg, err := curSvc.Config()
if err != nil {
return fmt.Errorf("retrieving Config for service %q: %w", curSvc.Name, err)
}
callback(curSvc, curCfg)
for _, depName := range curCfg.Dependencies {
if seen.Contains(depName) {
continue
}
depSvc, err := OpenServiceForRead(scm, depName)
if err != nil {
return fmt.Errorf("opening service %q: %w", depName, err)
}
deps = append(deps, depSvc)
}
return nil
}()
}
return err
}
type svcStateLogEntry struct {
ServiceName string `json:"serviceName"`
ServiceType string `json:"serviceType"`
State string `json:"state"`
StartupType string `json:"startupType"`
Triggers *_SERVICE_TRIGGER_INFO `json:"triggers,omitempty"`
TriggersError error `json:"triggersError,omitempty"`
}
type _SERVICE_TRIGGER_SPECIFIC_DATA_ITEM struct {
dataType uint32
cbData uint32
data *byte
}
type serviceTriggerSpecificDataItemJSONMarshal struct {
DataType uint32 `json:"dataType"`
Data string `json:"data,omitempty"`
}
func (tsdi *_SERVICE_TRIGGER_SPECIFIC_DATA_ITEM) MarshalJSON() ([]byte, error) {
m := serviceTriggerSpecificDataItemJSONMarshal{DataType: tsdi.dataType}
const maxDataLen = 128
data := unsafe.Slice(tsdi.data, tsdi.cbData)
if len(data) > maxDataLen {
// Only output the first maxDataLen bytes.
m.Data = fmt.Sprintf("%s... (truncated %d bytes)", hex.EncodeToString(data[:maxDataLen]), len(data)-maxDataLen)
} else {
m.Data = hex.EncodeToString(data)
}
return json.Marshal(m)
}
type _SERVICE_TRIGGER struct {
triggerType uint32
action uint32
triggerSubtype *windows.GUID
cDataItems uint32
pDataItems *_SERVICE_TRIGGER_SPECIFIC_DATA_ITEM
}
type serviceTriggerJSONMarshal struct {
TriggerType uint32 `json:"triggerType"`
Action uint32 `json:"action"`
TriggerSubtype string `json:"triggerSubtype,omitempty"`
DataItems []_SERVICE_TRIGGER_SPECIFIC_DATA_ITEM `json:"dataItems"`
}
func (ti *_SERVICE_TRIGGER) MarshalJSON() ([]byte, error) {
m := serviceTriggerJSONMarshal{
TriggerType: ti.triggerType,
Action: ti.action,
DataItems: unsafe.Slice(ti.pDataItems, ti.cDataItems),
}
if ti.triggerSubtype != nil {
m.TriggerSubtype = ti.triggerSubtype.String()
}
return json.Marshal(m)
}
type _SERVICE_TRIGGER_INFO struct {
cTriggers uint32
pTriggers *_SERVICE_TRIGGER
_ *byte // pReserved
}
func (sti *_SERVICE_TRIGGER_INFO) MarshalJSON() ([]byte, error) {
triggers := unsafe.Slice(sti.pTriggers, sti.cTriggers)
return json.Marshal(triggers)
}
// getSvcTriggerInfo obtains information about any system events that may be
// used to start svc. Only relevant for demand-start (aka manual) services.
func getSvcTriggerInfo(svc *mgr.Service) (*_SERVICE_TRIGGER_INFO, error) {
var desiredLen uint32
err := queryServiceConfig2(svc.Handle, windows.SERVICE_CONFIG_TRIGGER_INFO,
nil, 0, &desiredLen)
if err != windows.ERROR_INSUFFICIENT_BUFFER {
return nil, err
}
buf := make([]byte, desiredLen)
err = queryServiceConfig2(svc.Handle, windows.SERVICE_CONFIG_TRIGGER_INFO,
&buf[0], desiredLen, &desiredLen)
if err != nil {
return nil, err
}
return (*_SERVICE_TRIGGER_INFO)(unsafe.Pointer(&buf[0])), nil
}
// makeLogEntry consolidates relevant service information into a svcStateLogEntry.
// We record the values of various service configuration constants as strings
// so the the log entries are easy to interpret at a glance by humans.
func makeLogEntry(svc *mgr.Service, status svc.Status, cfg mgr.Config) (entry svcStateLogEntry) {
entry.ServiceName = svc.Name
switch status.State {
case windows.SERVICE_STOPPED:
entry.State = "STOPPED"
case windows.SERVICE_START_PENDING:
entry.State = "START_PENDING"
case windows.SERVICE_STOP_PENDING:
entry.State = "STOP_PENDING"
case windows.SERVICE_RUNNING:
entry.State = "RUNNING"
case windows.SERVICE_CONTINUE_PENDING:
entry.State = "CONTINUE_PENDING"
case windows.SERVICE_PAUSE_PENDING:
entry.State = "PAUSE_PENDING"
case windows.SERVICE_PAUSED:
entry.State = "PAUSED"
case windows.SERVICE_NO_CHANGE:
entry.State = "NO_CHANGE"
default:
entry.State = fmt.Sprintf("Unknown constant %d", status.State)
}
switch cfg.ServiceType {
case windows.SERVICE_FILE_SYSTEM_DRIVER:
entry.ServiceType = "FILE_SYSTEM_DRIVER"
case windows.SERVICE_KERNEL_DRIVER:
entry.ServiceType = "KERNEL_DRIVER"
case windows.SERVICE_WIN32_OWN_PROCESS, windows.SERVICE_WIN32_SHARE_PROCESS:
entry.ServiceType = "WIN32"
default:
entry.ServiceType = fmt.Sprintf("Unknown constant %d", cfg.ServiceType)
}
switch cfg.StartType {
case windows.SERVICE_BOOT_START:
entry.StartupType = "BOOT_START"
case windows.SERVICE_SYSTEM_START:
entry.StartupType = "SYSTEM_START"
case windows.SERVICE_AUTO_START:
if cfg.DelayedAutoStart {
entry.StartupType = "DELAYED_AUTO_START"
} else {
entry.StartupType = "AUTO_START"
}
case windows.SERVICE_DEMAND_START:
entry.StartupType = "DEMAND_START"
triggerInfo, err := getSvcTriggerInfo(svc)
if err == nil {
entry.Triggers = triggerInfo
} else {
entry.TriggersError = err
}
case windows.SERVICE_DISABLED:
entry.StartupType = "DISABLED"
default:
entry.StartupType = fmt.Sprintf("Unknown constant %d", cfg.StartType)
}
return entry
}
// ConnectToLocalSCMForRead connects to the Windows Service Control Manager with
// read-only access. x/sys/windows/svc/mgr/Connect requests read+write access,
// which requires Administrative access rights.
func ConnectToLocalSCMForRead() (*mgr.Mgr, error) {
h, err := windows.OpenSCManager(nil, nil, windows.GENERIC_READ)
if err != nil {
return nil, err
}
return &mgr.Mgr{Handle: h}, nil
}
// OpenServiceForRead opens a service with read-only access.
// x/sys/windows/svc/mgr/(*Mgr).OpenService requests read+write access,
// which requires Administrative access rights.
func OpenServiceForRead(scm *mgr.Mgr, svcName string) (*mgr.Service, error) {
svcNamePtr, err := windows.UTF16PtrFromString(svcName)
if err != nil {
return nil, err
}
h, err := windows.OpenService(scm.Handle, svcNamePtr, windows.GENERIC_READ)
if err != nil {
return nil, err
}
return &mgr.Service{Name: svcName, Handle: h}, nil
}