From 17abf201a3f1196467f4eb33d648147fd3c65e01 Mon Sep 17 00:00:00 2001 From: Aaron Klotz Date: Tue, 7 Dec 2021 17:09:51 -0700 Subject: [PATCH] DNS Queries on Windows using DnsQueryEx and asynchronous procedure calls Signed-off-by: Aaron Klotz --- cmd/dnsapc/dnsapc.go | 46 +++++++ util/winutil/apcthread_windows.go | 213 ++++++++++++++++++++++++++++++ util/winutil/dnsapc_windows.go | 111 ++++++++++++++++ util/winutil/dnsq_windows.go | 161 ++++++++++++++++++++++ util/winutil/mksyscall.go | 11 ++ util/winutil/zsyscall_windows.go | 61 +++++++++ 6 files changed, 603 insertions(+) create mode 100644 cmd/dnsapc/dnsapc.go create mode 100644 util/winutil/apcthread_windows.go create mode 100644 util/winutil/dnsapc_windows.go create mode 100644 util/winutil/dnsq_windows.go create mode 100644 util/winutil/mksyscall.go create mode 100644 util/winutil/zsyscall_windows.go diff --git a/cmd/dnsapc/dnsapc.go b/cmd/dnsapc/dnsapc.go new file mode 100644 index 000000000..5f2ccbc1b --- /dev/null +++ b/cmd/dnsapc/dnsapc.go @@ -0,0 +1,46 @@ +package main + +import ( + "fmt" + "os" + "unsafe" + + "golang.org/x/sys/windows" + "tailscale.com/util/winutil" +) + +func main() { + result, err := winutil.DnsQuery("www.tailscale.com", windows.DNS_TYPE_A, winutil.DNS_QUERY_STANDARD, nil, 0) + if err != nil { + fmt.Fprintln(os.Stderr, err) + return + } + + finalStatus := result.QueryStatus + fmt.Printf("Query status: %v", finalStatus) + if finalStatus != 0 { + fmt.Printf(" (%v)\n", windows.Errno(finalStatus)) + } else { + fmt.Printf("\n") + } + + count := 0 + for rec := result.QueryRecords; rec != nil; rec = rec.Next { + name := windows.UTF16PtrToString(rec.Name) + fmt.Printf("Record %d: %s, type %v", count, name, rec.Type) + switch rec.Type { + case windows.DNS_TYPE_A: + rd := (*winutil.DNSAData)(unsafe.Pointer(&rec.Data[0])) + a := rd.IPv4Address + fmt.Printf(" (A): %v.%v.%v.%v\n", a[0], a[1], a[2], a[3]) + case windows.DNS_TYPE_CNAME: + rd := (*windows.DNSPTRData)(unsafe.Pointer(&rec.Data[0])) + fmt.Printf(" (CNAME): %s\n", windows.UTF16PtrToString(rd.Host)) + default: + fmt.Printf("\n") + } + count++ + } + + result.Close() +} diff --git a/util/winutil/apcthread_windows.go b/util/winutil/apcthread_windows.go new file mode 100644 index 000000000..0652d1006 --- /dev/null +++ b/util/winutil/apcthread_windows.go @@ -0,0 +1,213 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package winutil + +import ( + "fmt" + "reflect" + "runtime" + "sync" + + "golang.org/x/sys/windows" +) + +var ( + k32 = windows.NewLazySystemDLL("kernel32.dll") + procWaitForSingleObjectEx = k32.NewProc("WaitForSingleObjectEx") +) + +func waitForSingleObjectEx(handle windows.Handle, timeout uint32, alertable bool) (uint32, error) { + var ua uintptr + if alertable { + ua = 1 + } + code, _, err := procWaitForSingleObjectEx.Call(uintptr(handle), uintptr(timeout), ua) + code32 := uint32(code) + if code32 == windows.WAIT_FAILED { + return code32, err + } + return code32, nil +} + +const ( + noPendingIdleTimeout = 30000 + reqChanBufSize = 64 +) + +type APCChannel chan []reflect.Value + +func MakeAPCChannel() APCChannel { + return make(APCChannel) +} + +type APCRequest interface { + Begin() *APCChannel +} + +type APCChannelResolver interface { + GetChannel([]reflect.Value) *APCChannel +} + +type APCCallbackInfo struct { + Function reflect.Type + Resolver APCChannelResolver +} + +type apcThread struct { + once sync.Once + event windows.Handle + reqChan chan APCRequest + pending map[*APCChannel]struct{} + mu sync.RWMutex // protects cbinfo + cbinfo map[APCCallbackInfo]uintptr +} + +func (t *apcThread) init() { + var err error + // Auto-reset event for signaling that new requests are present + t.event, err = windows.CreateEvent(nil, 1, 0, nil) + if err != nil { + panic(fmt.Sprintf("Creating apcThread event: %v", err)) + } + t.reqChan = make(chan APCRequest, reqChanBufSize) + t.pending = make(map[*APCChannel]struct{}) + t.cbinfo = make(map[APCCallbackInfo]uintptr) +} + +func (t *apcThread) submitWork(req APCRequest) error { + // Lazily start the goroutine + t.once.Do(func() { + go t.run() + }) + t.reqChan <- req + // We need to set an event to poke the APC thread into checking t.reqChan. + return windows.SetEvent(t.event) +} + +var thd apcThread + +func init() { + thd.init() +} + +func (t *apcThread) nextWaitTimeout() uint32 { + if len(t.pending) > 0 { + return windows.INFINITE + } else { + return noPendingIdleTimeout + } +} + +func (t *apcThread) beginRequest(req APCRequest) { + // Lock the OS thread before calling Begin, which will initiate the APC + // request on the current OS thread. + runtime.LockOSThread() + apcctx := req.Begin() + if apcctx == nil { + // Request failed, we don't need to lock anymore. + runtime.UnlockOSThread() + } else { + // Save the context so it doesn't get GC'd and we can track pending requests + t.pending[apcctx] = struct{}{} + } +} + +// run is the goroutine that executes APC requests. It is started lazily, but +// once it is running, it remains so for the remainder of the process's lifetime. +// Note that it only locks the OS thread while requests are in-flight; once +// all requests have been processed, it blocks on t.reqChan without consuming +// an OS thread. +// (Hi Brad! When this goroutine is 100% idle, it does not lock an OS thread. +// Is this acceptable, or do we want additional magic to make the +// entire goroutine shut down after an extended period of disuse?) +func (t *apcThread) run() { + for { + select { + case req := <-t.reqChan: + t.beginRequest(req) + continue + default: + // If nothing is pending, we can safely block indefinitely on the request channel. + // Otherwise we need to fall through into blocking on t.event so that we may process APCs. + if len(t.pending) == 0 { + req := <-t.reqChan + t.beginRequest(req) + continue + } + } + + var waitCode uint32 + var err error + for waitCode, err = waitForSingleObjectEx(t.event, t.nextWaitTimeout(), true); waitCode == windows.WAIT_IO_COMPLETION; { + // Drain queued APCs + } + switch waitCode { + case uint32(windows.WAIT_TIMEOUT): + // There are no more requests pending, we can just block on t.reqChan now + continue + case uint32(windows.WAIT_FAILED): + panic(fmt.Sprintf("apcThread waitForSingleObjectEx failed: %v", err)) + default: + // There are new requests in the channel. + windows.ResetEvent(t.event) + continue + } + } +} + +type apcHandler func([]reflect.Value) []reflect.Value + +// makeAPCHandler creates a handler function that an APC will invoke to complete +// its request. args contains the APC's arguments, which are then sent to the +// channel for processing by the API consumer. +func (t *apcThread) makeAPCHandler(resolver APCChannelResolver) apcHandler { + return func(args []reflect.Value) []reflect.Value { + apcchan := resolver.GetChannel(args) + delete(t.pending, apcchan) + runtime.UnlockOSThread() + *apcchan <- args + // APCs don't use return values, but we need to return this to satisfy + // Go's callback requirements. + return []reflect.Value{reflect.ValueOf(uintptr(0))} + } +} + +func (t *apcThread) registerCallback(cb APCCallbackInfo) uintptr { + // Common path: Callback is already registered + t.mu.RLock() + cookie, ok := t.cbinfo[cb] + t.mu.RUnlock() + if ok { + return cookie + } + + // Slower path: We need to register + t.mu.Lock() + // Check again to make sure we didn't lose a race + cookie, ok = t.cbinfo[cb] + if ok { + t.mu.Unlock() + return cookie + } + + handler := t.makeAPCHandler(cb.Resolver) + outer := reflect.MakeFunc(cb.Function, handler) + cbptr := windows.NewCallback(outer.Interface()) + t.cbinfo[cb] = cbptr + t.mu.Unlock() + return cbptr +} + +// RegisterAPCCallback must be called any time a new type of APC is going to be +// submitted. Ideally this would be called only once for each type (via sync.Once). +func RegisterAPCCallback(cb APCCallbackInfo) uintptr { + return thd.registerCallback(cb) +} + +// SubmitAPCWork is the main entry point for submitting work for APC processing. +// The APC type must have been previously registered via RegisterAPICallback. +func SubmitAPCWork(req APCRequest) error { + return thd.submitWork(req) +} diff --git a/util/winutil/dnsapc_windows.go b/util/winutil/dnsapc_windows.go new file mode 100644 index 000000000..32245c530 --- /dev/null +++ b/util/winutil/dnsapc_windows.go @@ -0,0 +1,111 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package winutil + +import ( + "fmt" + "reflect" + "sync" + "unsafe" + + "golang.org/x/sys/windows" +) + +// Technically this function returns void, but we need it to return uintptr for NewCallback to work +func dnsQueryExApc(param uintptr, results *DNSQueryResult) uintptr { + return 0 +} + +type resolver struct{} + +func (r resolver) GetChannel(args []reflect.Value) *APCChannel { + return &((*invoker)(unsafe.Pointer(uintptr(args[0].Uint())))).done +} + +type invoker struct { + done APCChannel + req DNSQueryRequest + result DNSQueryResult + cancel DNSQueryCancel +} + +type DNSServerList struct { + Family uint16 + List []DNSAddr +} + +var ( + once sync.Once + apcCallback uintptr +) + +func newDnsInvoker(qname string, qtype uint16, qoptions uint64, srvList *DNSServerList, ifaceIdx uint32) (*invoker, error) { + once.Do(func() { + cbInfo := APCCallbackInfo{reflect.TypeOf(dnsQueryExApc), resolver{}} + apcCallback = RegisterAPCCallback(cbInfo) + }) + + var name *uint16 + var err error + if len(qname) > 0 { + name, err = windows.UTF16PtrFromString(qname) + if err != nil { + return nil, err + } + } + + var serverList *DNSAddrArray + if srvList != nil { + serverList = NewDNSAddrArray(srvList.Family, srvList.List) + } + + inv := &invoker{done: MakeAPCChannel(), + req: DNSQueryRequest{ + Version: DNS_QUERY_REQUEST_VERSION1, + QueryName: name, + QueryType: qtype, + QueryOptions: qoptions, + DNSServerList: serverList, + InterfaceIndex: ifaceIdx, + QueryCompletionCallback: apcCallback}, + result: DNSQueryResult{Version: DNS_QUERY_RESULTS_VERSION1}} + inv.req.QueryContext = uintptr(unsafe.Pointer(inv)) + + return inv, nil +} + +func (i *invoker) Begin() *APCChannel { + err := DnsQueryEx(&i.req, &i.result, &i.cancel) + if err != DNS_REQUEST_PENDING { + i.result.QueryStatus = DNSStatus(uintptr(err.(windows.Errno))) + close(i.done) + return nil + } + + return &i.done +} + +func (i *invoker) Wait() { + <-i.done +} + +func (i *invoker) Cancel() error { + return DnsCancelQuery(&i.cancel) +} + +func DnsQuery(qname string, qtype uint16, qoptions uint64, srvList *DNSServerList, interfaceIdx uint32) (*DNSQueryResult, error) { + inv, err := newDnsInvoker(qname, qtype, qoptions, srvList, interfaceIdx) + if err != nil { + return nil, fmt.Errorf("Failed creating DNS invoker: %w", err) + } + + err = SubmitAPCWork(inv) + if err != nil { + return nil, fmt.Errorf("Failed submitting work: %w", err) + } + + inv.Wait() + return &inv.result, nil +} diff --git a/util/winutil/dnsq_windows.go b/util/winutil/dnsq_windows.go new file mode 100644 index 000000000..d35e295ba --- /dev/null +++ b/util/winutil/dnsq_windows.go @@ -0,0 +1,161 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package winutil + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +type DNSAddr struct { + maxSa [32]byte /* DNS_ADDR_MAX_SOCKADDR_LENGTH */ + dnsAddrUserDword [8]uint32 +} + +func (a *DNSAddr) AsInet4() *windows.SockaddrInet4 { + return (*windows.SockaddrInet4)(unsafe.Pointer(&a.maxSa[0])) +} + +func (a *DNSAddr) AsInet6() *windows.SockaddrInet6 { + return (*windows.SockaddrInet6)(unsafe.Pointer(&a.maxSa[0])) +} + +type DNSAData struct { + IPv4Address [4]byte +} + +type DNSAddrArray struct { + MaxCount uint32 + AddrCount uint32 + tag uint32 + Family uint16 + wreserved uint16 + flags uint32 + matchFlag uint32 + reserved1 uint32 + reserved2 uint32 + AddrArray [1]DNSAddr +} + +// TODO: We can probably make this more efficient +func NewDNSAddrArray(family uint16, addrs []DNSAddr) *DNSAddrArray { + numBytes := unsafe.Sizeof(DNSAddrArray{}) + count := len(addrs) + if count > 1 { + numBytes += (uintptr(count) - 1) * unsafe.Sizeof(DNSAddr{}) + } + + buf := make([]byte, numBytes) + result := (*DNSAddrArray)(unsafe.Pointer(&buf[0])) + result.MaxCount = uint32(count) + result.AddrCount = uint32(count) + result.Family = family + + dstAddrs := unsafe.Slice(&result.AddrArray[0], count) + copy(dstAddrs, addrs) + + return result +} + +const ( + DNS_QUERY_REQUEST_VERSION1 = 1 + DNS_QUERY_REQUEST_VERSION3 = 3 +) + +const ( + DNS_QUERY_RESULTS_VERSION1 = 1 +) + +const ( + DNS_QUERY_STANDARD = 0x00000000 + DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE = 0x00000001 + DNS_QUERY_USE_TCP_ONLY = 0x00000002 + DNS_QUERY_NO_RECURSION = 0x00000004 + DNS_QUERY_BYPASS_CACHE = 0x00000008 + DNS_QUERY_NO_WIRE_QUERY = 0x00000010 + DNS_QUERY_NO_LOCAL_NAME = 0x00000020 + DNS_QUERY_NO_HOSTS_FILE = 0x00000040 + DNS_QUERY_NO_NETBT = 0x00000080 + DNS_QUERY_WIRE_ONLY = 0x00000100 + DNS_QUERY_RETURN_MESSAGE = 0x00000200 + DNS_QUERY_MULTICAST_ONLY = 0x00000400 + DNS_QUERY_NO_MULTICAST = 0x00000800 + DNS_QUERY_TREAT_AS_FQDN = 0x00001000 + DNS_QUERY_ADDRCONFIG = 0x00002000 + DNS_QUERY_DUAL_ADDR = 0x00004000 + DNS_QUERY_DONT_RESET_TTL_VALUES = 0x00100000 + DNS_QUERY_DISABLE_IDN_ENCODING = 0x00200000 + DNS_QUERY_APPEND_MULTILABEL = 0x00800000 + DNS_QUERY_DNSSEC_OK = 0x01000000 + DNS_QUERY_DNSSEC_CHECKING_DISABLED = 0x02000000 + DNS_QUERY_RESERVED = 0xf0000000 +) + +type DNSQueryRequest struct { + Version uint32 + QueryName *uint16 + QueryType uint16 + QueryOptions uint64 + DNSServerList *DNSAddrArray + InterfaceIndex uint32 + QueryCompletionCallback uintptr + QueryContext uintptr +} + +type DNSCustomServer struct { + ServerType uint32 + Flags uint64 + Template *uint16 + MaxSa [32]byte /* DNS_ADDR_MAX_SOCKADDR_LENGTH */ +} + +type DNSQueryRequest3 struct { + DNSQueryRequest + IsNetworkQueryRequired int32 /* BOOL */ + RequiredNetworkIndex uint32 + CCustomServers uint32 + PCustomServers *DNSCustomServer +} + +const ( + DNS_CUSTOM_SERVER_TYPE_UDP = 0x1 + DNS_CUSTOM_SERVER_TYPE_DOH = 0x2 +) + +const ( + DNS_CUSTOM_SERVER_UDP_FALLBACK = 0x1 +) + +var ( + DNS_REQUEST_PENDING windows.Errno = 0x00002522 +) + +type DNSStatus int32 + +type DNSQueryResult struct { + Version uint32 + QueryStatus DNSStatus + QueryOptions uint64 + QueryRecords *windows.DNSRecord + reserved uintptr +} + +const ( + DNSFreeFlat = 0 + DNSFreeRecordList = 1 + DNSFreeParsedMessageFields = 2 +) + +func (qr *DNSQueryResult) Close() error { + windows.DnsRecordListFree(qr.QueryRecords, DNSFreeRecordList) + qr.QueryRecords = nil + return nil +} + +type DNSQueryCancel struct { + // This is defined in C as 32 bytes, but it is also declared with 8 byte alignment + reserved [4]uint64 +} diff --git a/util/winutil/mksyscall.go b/util/winutil/mksyscall.go new file mode 100644 index 000000000..779460b7f --- /dev/null +++ b/util/winutil/mksyscall.go @@ -0,0 +1,11 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package winutil + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go + +// Note: DO NOT use DnsQueryExW! It *is* exported from dnsapi.dll but is an internal function! +//sys DnsQueryEx(request *DNSQueryRequest, result *DNSQueryResult, cancelHandle *DNSQueryCancel) (status error) = dnsapi.DnsQueryEx +//sys DnsCancelQuery(cancelHandle *DNSQueryCancel) (status error) = dnsapi.DnsCancelQuery diff --git a/util/winutil/zsyscall_windows.go b/util/winutil/zsyscall_windows.go new file mode 100644 index 000000000..2ab67b03c --- /dev/null +++ b/util/winutil/zsyscall_windows.go @@ -0,0 +1,61 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package winutil + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + moddnsapi = windows.NewLazySystemDLL("dnsapi.dll") + + procDnsCancelQuery = moddnsapi.NewProc("DnsCancelQuery") + procDnsQueryEx = moddnsapi.NewProc("DnsQueryEx") +) + +func DnsCancelQuery(cancelHandle *DNSQueryCancel) (status error) { + r0, _, _ := syscall.Syscall(procDnsCancelQuery.Addr(), 1, uintptr(unsafe.Pointer(cancelHandle)), 0, 0) + if r0 != 0 { + status = syscall.Errno(r0) + } + return +} + +func DnsQueryEx(request *DNSQueryRequest, result *DNSQueryResult, cancelHandle *DNSQueryCancel) (status error) { + r0, _, _ := syscall.Syscall(procDnsQueryEx.Addr(), 3, uintptr(unsafe.Pointer(request)), uintptr(unsafe.Pointer(result)), uintptr(unsafe.Pointer(cancelHandle))) + if r0 != 0 { + status = syscall.Errno(r0) + } + return +}