DNS Queries on Windows using DnsQueryEx and asynchronous procedure calls

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
Aaron Klotz 2021-12-07 17:09:51 -07:00
parent c0701b130d
commit 17abf201a3
6 changed files with 603 additions and 0 deletions

46
cmd/dnsapc/dnsapc.go Normal file
View File

@ -0,0 +1,46 @@
package main
import (
"fmt"
"os"
"unsafe"
"golang.org/x/sys/windows"
"tailscale.com/util/winutil"
)
func main() {
result, err := winutil.DnsQuery("www.tailscale.com", windows.DNS_TYPE_A, winutil.DNS_QUERY_STANDARD, nil, 0)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return
}
finalStatus := result.QueryStatus
fmt.Printf("Query status: %v", finalStatus)
if finalStatus != 0 {
fmt.Printf(" (%v)\n", windows.Errno(finalStatus))
} else {
fmt.Printf("\n")
}
count := 0
for rec := result.QueryRecords; rec != nil; rec = rec.Next {
name := windows.UTF16PtrToString(rec.Name)
fmt.Printf("Record %d: %s, type %v", count, name, rec.Type)
switch rec.Type {
case windows.DNS_TYPE_A:
rd := (*winutil.DNSAData)(unsafe.Pointer(&rec.Data[0]))
a := rd.IPv4Address
fmt.Printf(" (A): %v.%v.%v.%v\n", a[0], a[1], a[2], a[3])
case windows.DNS_TYPE_CNAME:
rd := (*windows.DNSPTRData)(unsafe.Pointer(&rec.Data[0]))
fmt.Printf(" (CNAME): %s\n", windows.UTF16PtrToString(rd.Host))
default:
fmt.Printf("\n")
}
count++
}
result.Close()
}

View File

@ -0,0 +1,213 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package winutil
import (
"fmt"
"reflect"
"runtime"
"sync"
"golang.org/x/sys/windows"
)
var (
k32 = windows.NewLazySystemDLL("kernel32.dll")
procWaitForSingleObjectEx = k32.NewProc("WaitForSingleObjectEx")
)
func waitForSingleObjectEx(handle windows.Handle, timeout uint32, alertable bool) (uint32, error) {
var ua uintptr
if alertable {
ua = 1
}
code, _, err := procWaitForSingleObjectEx.Call(uintptr(handle), uintptr(timeout), ua)
code32 := uint32(code)
if code32 == windows.WAIT_FAILED {
return code32, err
}
return code32, nil
}
const (
noPendingIdleTimeout = 30000
reqChanBufSize = 64
)
type APCChannel chan []reflect.Value
func MakeAPCChannel() APCChannel {
return make(APCChannel)
}
type APCRequest interface {
Begin() *APCChannel
}
type APCChannelResolver interface {
GetChannel([]reflect.Value) *APCChannel
}
type APCCallbackInfo struct {
Function reflect.Type
Resolver APCChannelResolver
}
type apcThread struct {
once sync.Once
event windows.Handle
reqChan chan APCRequest
pending map[*APCChannel]struct{}
mu sync.RWMutex // protects cbinfo
cbinfo map[APCCallbackInfo]uintptr
}
func (t *apcThread) init() {
var err error
// Auto-reset event for signaling that new requests are present
t.event, err = windows.CreateEvent(nil, 1, 0, nil)
if err != nil {
panic(fmt.Sprintf("Creating apcThread event: %v", err))
}
t.reqChan = make(chan APCRequest, reqChanBufSize)
t.pending = make(map[*APCChannel]struct{})
t.cbinfo = make(map[APCCallbackInfo]uintptr)
}
func (t *apcThread) submitWork(req APCRequest) error {
// Lazily start the goroutine
t.once.Do(func() {
go t.run()
})
t.reqChan <- req
// We need to set an event to poke the APC thread into checking t.reqChan.
return windows.SetEvent(t.event)
}
var thd apcThread
func init() {
thd.init()
}
func (t *apcThread) nextWaitTimeout() uint32 {
if len(t.pending) > 0 {
return windows.INFINITE
} else {
return noPendingIdleTimeout
}
}
func (t *apcThread) beginRequest(req APCRequest) {
// Lock the OS thread before calling Begin, which will initiate the APC
// request on the current OS thread.
runtime.LockOSThread()
apcctx := req.Begin()
if apcctx == nil {
// Request failed, we don't need to lock anymore.
runtime.UnlockOSThread()
} else {
// Save the context so it doesn't get GC'd and we can track pending requests
t.pending[apcctx] = struct{}{}
}
}
// run is the goroutine that executes APC requests. It is started lazily, but
// once it is running, it remains so for the remainder of the process's lifetime.
// Note that it only locks the OS thread while requests are in-flight; once
// all requests have been processed, it blocks on t.reqChan without consuming
// an OS thread.
// (Hi Brad! When this goroutine is 100% idle, it does not lock an OS thread.
// Is this acceptable, or do we want additional magic to make the
// entire goroutine shut down after an extended period of disuse?)
func (t *apcThread) run() {
for {
select {
case req := <-t.reqChan:
t.beginRequest(req)
continue
default:
// If nothing is pending, we can safely block indefinitely on the request channel.
// Otherwise we need to fall through into blocking on t.event so that we may process APCs.
if len(t.pending) == 0 {
req := <-t.reqChan
t.beginRequest(req)
continue
}
}
var waitCode uint32
var err error
for waitCode, err = waitForSingleObjectEx(t.event, t.nextWaitTimeout(), true); waitCode == windows.WAIT_IO_COMPLETION; {
// Drain queued APCs
}
switch waitCode {
case uint32(windows.WAIT_TIMEOUT):
// There are no more requests pending, we can just block on t.reqChan now
continue
case uint32(windows.WAIT_FAILED):
panic(fmt.Sprintf("apcThread waitForSingleObjectEx failed: %v", err))
default:
// There are new requests in the channel.
windows.ResetEvent(t.event)
continue
}
}
}
type apcHandler func([]reflect.Value) []reflect.Value
// makeAPCHandler creates a handler function that an APC will invoke to complete
// its request. args contains the APC's arguments, which are then sent to the
// channel for processing by the API consumer.
func (t *apcThread) makeAPCHandler(resolver APCChannelResolver) apcHandler {
return func(args []reflect.Value) []reflect.Value {
apcchan := resolver.GetChannel(args)
delete(t.pending, apcchan)
runtime.UnlockOSThread()
*apcchan <- args
// APCs don't use return values, but we need to return this to satisfy
// Go's callback requirements.
return []reflect.Value{reflect.ValueOf(uintptr(0))}
}
}
func (t *apcThread) registerCallback(cb APCCallbackInfo) uintptr {
// Common path: Callback is already registered
t.mu.RLock()
cookie, ok := t.cbinfo[cb]
t.mu.RUnlock()
if ok {
return cookie
}
// Slower path: We need to register
t.mu.Lock()
// Check again to make sure we didn't lose a race
cookie, ok = t.cbinfo[cb]
if ok {
t.mu.Unlock()
return cookie
}
handler := t.makeAPCHandler(cb.Resolver)
outer := reflect.MakeFunc(cb.Function, handler)
cbptr := windows.NewCallback(outer.Interface())
t.cbinfo[cb] = cbptr
t.mu.Unlock()
return cbptr
}
// RegisterAPCCallback must be called any time a new type of APC is going to be
// submitted. Ideally this would be called only once for each type (via sync.Once).
func RegisterAPCCallback(cb APCCallbackInfo) uintptr {
return thd.registerCallback(cb)
}
// SubmitAPCWork is the main entry point for submitting work for APC processing.
// The APC type must have been previously registered via RegisterAPICallback.
func SubmitAPCWork(req APCRequest) error {
return thd.submitWork(req)
}

View File

@ -0,0 +1,111 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package winutil
import (
"fmt"
"reflect"
"sync"
"unsafe"
"golang.org/x/sys/windows"
)
// Technically this function returns void, but we need it to return uintptr for NewCallback to work
func dnsQueryExApc(param uintptr, results *DNSQueryResult) uintptr {
return 0
}
type resolver struct{}
func (r resolver) GetChannel(args []reflect.Value) *APCChannel {
return &((*invoker)(unsafe.Pointer(uintptr(args[0].Uint())))).done
}
type invoker struct {
done APCChannel
req DNSQueryRequest
result DNSQueryResult
cancel DNSQueryCancel
}
type DNSServerList struct {
Family uint16
List []DNSAddr
}
var (
once sync.Once
apcCallback uintptr
)
func newDnsInvoker(qname string, qtype uint16, qoptions uint64, srvList *DNSServerList, ifaceIdx uint32) (*invoker, error) {
once.Do(func() {
cbInfo := APCCallbackInfo{reflect.TypeOf(dnsQueryExApc), resolver{}}
apcCallback = RegisterAPCCallback(cbInfo)
})
var name *uint16
var err error
if len(qname) > 0 {
name, err = windows.UTF16PtrFromString(qname)
if err != nil {
return nil, err
}
}
var serverList *DNSAddrArray
if srvList != nil {
serverList = NewDNSAddrArray(srvList.Family, srvList.List)
}
inv := &invoker{done: MakeAPCChannel(),
req: DNSQueryRequest{
Version: DNS_QUERY_REQUEST_VERSION1,
QueryName: name,
QueryType: qtype,
QueryOptions: qoptions,
DNSServerList: serverList,
InterfaceIndex: ifaceIdx,
QueryCompletionCallback: apcCallback},
result: DNSQueryResult{Version: DNS_QUERY_RESULTS_VERSION1}}
inv.req.QueryContext = uintptr(unsafe.Pointer(inv))
return inv, nil
}
func (i *invoker) Begin() *APCChannel {
err := DnsQueryEx(&i.req, &i.result, &i.cancel)
if err != DNS_REQUEST_PENDING {
i.result.QueryStatus = DNSStatus(uintptr(err.(windows.Errno)))
close(i.done)
return nil
}
return &i.done
}
func (i *invoker) Wait() {
<-i.done
}
func (i *invoker) Cancel() error {
return DnsCancelQuery(&i.cancel)
}
func DnsQuery(qname string, qtype uint16, qoptions uint64, srvList *DNSServerList, interfaceIdx uint32) (*DNSQueryResult, error) {
inv, err := newDnsInvoker(qname, qtype, qoptions, srvList, interfaceIdx)
if err != nil {
return nil, fmt.Errorf("Failed creating DNS invoker: %w", err)
}
err = SubmitAPCWork(inv)
if err != nil {
return nil, fmt.Errorf("Failed submitting work: %w", err)
}
inv.Wait()
return &inv.result, nil
}

View File

@ -0,0 +1,161 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package winutil
import (
"unsafe"
"golang.org/x/sys/windows"
)
type DNSAddr struct {
maxSa [32]byte /* DNS_ADDR_MAX_SOCKADDR_LENGTH */
dnsAddrUserDword [8]uint32
}
func (a *DNSAddr) AsInet4() *windows.SockaddrInet4 {
return (*windows.SockaddrInet4)(unsafe.Pointer(&a.maxSa[0]))
}
func (a *DNSAddr) AsInet6() *windows.SockaddrInet6 {
return (*windows.SockaddrInet6)(unsafe.Pointer(&a.maxSa[0]))
}
type DNSAData struct {
IPv4Address [4]byte
}
type DNSAddrArray struct {
MaxCount uint32
AddrCount uint32
tag uint32
Family uint16
wreserved uint16
flags uint32
matchFlag uint32
reserved1 uint32
reserved2 uint32
AddrArray [1]DNSAddr
}
// TODO: We can probably make this more efficient
func NewDNSAddrArray(family uint16, addrs []DNSAddr) *DNSAddrArray {
numBytes := unsafe.Sizeof(DNSAddrArray{})
count := len(addrs)
if count > 1 {
numBytes += (uintptr(count) - 1) * unsafe.Sizeof(DNSAddr{})
}
buf := make([]byte, numBytes)
result := (*DNSAddrArray)(unsafe.Pointer(&buf[0]))
result.MaxCount = uint32(count)
result.AddrCount = uint32(count)
result.Family = family
dstAddrs := unsafe.Slice(&result.AddrArray[0], count)
copy(dstAddrs, addrs)
return result
}
const (
DNS_QUERY_REQUEST_VERSION1 = 1
DNS_QUERY_REQUEST_VERSION3 = 3
)
const (
DNS_QUERY_RESULTS_VERSION1 = 1
)
const (
DNS_QUERY_STANDARD = 0x00000000
DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE = 0x00000001
DNS_QUERY_USE_TCP_ONLY = 0x00000002
DNS_QUERY_NO_RECURSION = 0x00000004
DNS_QUERY_BYPASS_CACHE = 0x00000008
DNS_QUERY_NO_WIRE_QUERY = 0x00000010
DNS_QUERY_NO_LOCAL_NAME = 0x00000020
DNS_QUERY_NO_HOSTS_FILE = 0x00000040
DNS_QUERY_NO_NETBT = 0x00000080
DNS_QUERY_WIRE_ONLY = 0x00000100
DNS_QUERY_RETURN_MESSAGE = 0x00000200
DNS_QUERY_MULTICAST_ONLY = 0x00000400
DNS_QUERY_NO_MULTICAST = 0x00000800
DNS_QUERY_TREAT_AS_FQDN = 0x00001000
DNS_QUERY_ADDRCONFIG = 0x00002000
DNS_QUERY_DUAL_ADDR = 0x00004000
DNS_QUERY_DONT_RESET_TTL_VALUES = 0x00100000
DNS_QUERY_DISABLE_IDN_ENCODING = 0x00200000
DNS_QUERY_APPEND_MULTILABEL = 0x00800000
DNS_QUERY_DNSSEC_OK = 0x01000000
DNS_QUERY_DNSSEC_CHECKING_DISABLED = 0x02000000
DNS_QUERY_RESERVED = 0xf0000000
)
type DNSQueryRequest struct {
Version uint32
QueryName *uint16
QueryType uint16
QueryOptions uint64
DNSServerList *DNSAddrArray
InterfaceIndex uint32
QueryCompletionCallback uintptr
QueryContext uintptr
}
type DNSCustomServer struct {
ServerType uint32
Flags uint64
Template *uint16
MaxSa [32]byte /* DNS_ADDR_MAX_SOCKADDR_LENGTH */
}
type DNSQueryRequest3 struct {
DNSQueryRequest
IsNetworkQueryRequired int32 /* BOOL */
RequiredNetworkIndex uint32
CCustomServers uint32
PCustomServers *DNSCustomServer
}
const (
DNS_CUSTOM_SERVER_TYPE_UDP = 0x1
DNS_CUSTOM_SERVER_TYPE_DOH = 0x2
)
const (
DNS_CUSTOM_SERVER_UDP_FALLBACK = 0x1
)
var (
DNS_REQUEST_PENDING windows.Errno = 0x00002522
)
type DNSStatus int32
type DNSQueryResult struct {
Version uint32
QueryStatus DNSStatus
QueryOptions uint64
QueryRecords *windows.DNSRecord
reserved uintptr
}
const (
DNSFreeFlat = 0
DNSFreeRecordList = 1
DNSFreeParsedMessageFields = 2
)
func (qr *DNSQueryResult) Close() error {
windows.DnsRecordListFree(qr.QueryRecords, DNSFreeRecordList)
qr.QueryRecords = nil
return nil
}
type DNSQueryCancel struct {
// This is defined in C as 32 bytes, but it is also declared with 8 byte alignment
reserved [4]uint64
}

11
util/winutil/mksyscall.go Normal file
View File

@ -0,0 +1,11 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package winutil
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go
// Note: DO NOT use DnsQueryExW! It *is* exported from dnsapi.dll but is an internal function!
//sys DnsQueryEx(request *DNSQueryRequest, result *DNSQueryResult, cancelHandle *DNSQueryCancel) (status error) = dnsapi.DnsQueryEx
//sys DnsCancelQuery(cancelHandle *DNSQueryCancel) (status error) = dnsapi.DnsCancelQuery

View File

@ -0,0 +1,61 @@
// Code generated by 'go generate'; DO NOT EDIT.
package winutil
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
moddnsapi = windows.NewLazySystemDLL("dnsapi.dll")
procDnsCancelQuery = moddnsapi.NewProc("DnsCancelQuery")
procDnsQueryEx = moddnsapi.NewProc("DnsQueryEx")
)
func DnsCancelQuery(cancelHandle *DNSQueryCancel) (status error) {
r0, _, _ := syscall.Syscall(procDnsCancelQuery.Addr(), 1, uintptr(unsafe.Pointer(cancelHandle)), 0, 0)
if r0 != 0 {
status = syscall.Errno(r0)
}
return
}
func DnsQueryEx(request *DNSQueryRequest, result *DNSQueryResult, cancelHandle *DNSQueryCancel) (status error) {
r0, _, _ := syscall.Syscall(procDnsQueryEx.Addr(), 3, uintptr(unsafe.Pointer(request)), uintptr(unsafe.Pointer(result)), uintptr(unsafe.Pointer(cancelHandle)))
if r0 != 0 {
status = syscall.Errno(r0)
}
return
}