cmd/tailscaled: conditionally flush Windows DNS cache on SessionChange

For the service, all we need to do is handle the `svc.SessionChange` command.
Upon receipt of a `windows.WTS_SESSION_UNLOCK` event, we fire off a goroutine to flush the DNS cache.
(Windows expects responses to service requests to be quick, so we don't want to do that synchronously.)

This is gated on an integral registry value named `FlushDNSOnSessionUnlock`,
whose value we obtain during service initialization.

(See [this link](https://docs.microsoft.com/en-us/windows/win32/api/winsvc/nc-winsvc-lphandler_function_ex) for information re: handling `SERVICE_CONTROL_SESSIONCHANGE`.)

Fixes #2956

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
Aaron Klotz 2021-09-28 16:33:08 -06:00
parent 3386a86fe5
commit e016eaf410
2 changed files with 23 additions and 0 deletions

View File

@ -34,6 +34,7 @@
"tailscale.com/net/dns"
"tailscale.com/net/tstun"
"tailscale.com/types/logger"
"tailscale.com/util/winutil"
"tailscale.com/version"
"tailscale.com/wf"
"tailscale.com/wgengine"
@ -43,6 +44,8 @@
const serviceName = "Tailscale"
var flushDNSOnSessionUnlock bool
func isWindowsService() bool {
v, err := svc.IsWindowsService()
if err != nil {
@ -63,6 +66,7 @@ type ipnService struct {
func (service *ipnService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) {
changes <- svc.Status{State: svc.StartPending}
flushDNSOnSessionUnlock = winutil.GetRegInteger("FlushDNSOnSessionUnlock", 0) != 0
ctx, cancel := context.WithCancel(context.Background())
doneCh := make(chan struct{})
go func() {
@ -82,6 +86,9 @@ func (service *ipnService) Execute(args []string, r <-chan svc.ChangeRequest, ch
cancel()
case svc.Interrogate:
changes <- cmd.CurrentStatus
case svc.SessionChange:
handleSessionChange(cmd)
changes <- cmd.CurrentStatus
}
}
}
@ -264,6 +271,21 @@ type engineOrError struct {
return err
}
func handleSessionChange(chgRequest svc.ChangeRequest) {
if chgRequest.Cmd != svc.SessionChange || chgRequest.EventType != windows.WTS_SESSION_UNLOCK ||
!flushDNSOnSessionUnlock {
return
}
log.Printf("Received WTS_SESSION_UNLOCK event, initiating DNS flush.")
go func() {
err := dns.Flush()
if err != nil {
log.Printf("Error flushing DNS on session unlock: %v", err)
}
}()
}
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
getTickCount64Proc = kernel32.NewProc("GetTickCount64")

View File

@ -57,6 +57,7 @@
_ "tailscale.com/types/key"
_ "tailscale.com/types/logger"
_ "tailscale.com/util/osshare"
_ "tailscale.com/util/winutil"
_ "tailscale.com/version"
_ "tailscale.com/version/distro"
_ "tailscale.com/wf"