diff --git a/cmd/dnsapc/dnsapc.go b/cmd/dnsapc/dnsapc.go index 5f2ccbc1b..1c77c8672 100644 --- a/cmd/dnsapc/dnsapc.go +++ b/cmd/dnsapc/dnsapc.go @@ -10,13 +10,14 @@ ) func main() { - result, err := winutil.DnsQuery("www.tailscale.com", windows.DNS_TYPE_A, winutil.DNS_QUERY_STANDARD, nil, 0) + result, err := winutil.DNSQuery("www.tailscale.com", windows.DNS_TYPE_A, winutil.DNS_QUERY_STANDARD, nil, 0) if err != nil { fmt.Fprintln(os.Stderr, err) return } - finalStatus := result.QueryStatus + qresult := result.Wait() + finalStatus := qresult.QueryStatus fmt.Printf("Query status: %v", finalStatus) if finalStatus != 0 { fmt.Printf(" (%v)\n", windows.Errno(finalStatus)) @@ -25,7 +26,7 @@ func main() { } count := 0 - for rec := result.QueryRecords; rec != nil; rec = rec.Next { + for rec := qresult.QueryRecords; rec != nil; rec = rec.Next { name := windows.UTF16PtrToString(rec.Name) fmt.Printf("Record %d: %s, type %v", count, name, rec.Type) switch rec.Type { @@ -42,5 +43,5 @@ func main() { count++ } - result.Close() + qresult.Close() } diff --git a/util/winutil/dnsapc_windows.go b/util/winutil/dnsapc_windows.go index 32245c530..257a72cf0 100644 --- a/util/winutil/dnsapc_windows.go +++ b/util/winutil/dnsapc_windows.go @@ -41,7 +41,7 @@ type DNSServerList struct { apcCallback uintptr ) -func newDnsInvoker(qname string, qtype uint16, qoptions uint64, srvList *DNSServerList, ifaceIdx uint32) (*invoker, error) { +func newDNSInvoker(qname string, qtype uint16, qoptions uint64, srvList *DNSServerList, ifaceIdx uint32) (*invoker, error) { once.Do(func() { cbInfo := APCCallbackInfo{reflect.TypeOf(dnsQueryExApc), resolver{}} apcCallback = RegisterAPCCallback(cbInfo) @@ -87,16 +87,22 @@ func (i *invoker) Begin() *APCChannel { return &i.done } -func (i *invoker) Wait() { +func (i *invoker) Wait() *DNSQueryResult { <-i.done + return &i.result } func (i *invoker) Cancel() error { return DnsCancelQuery(&i.cancel) } -func DnsQuery(qname string, qtype uint16, qoptions uint64, srvList *DNSServerList, interfaceIdx uint32) (*DNSQueryResult, error) { - inv, err := newDnsInvoker(qname, qtype, qoptions, srvList, interfaceIdx) +type DNSResult interface { + Wait() *DNSQueryResult + Cancel() error +} + +func DNSQuery(qname string, qtype uint16, qoptions uint64, srvList *DNSServerList, interfaceIdx uint32) (DNSResult, error) { + inv, err := newDNSInvoker(qname, qtype, qoptions, srvList, interfaceIdx) if err != nil { return nil, fmt.Errorf("Failed creating DNS invoker: %w", err) } @@ -106,6 +112,5 @@ func DnsQuery(qname string, qtype uint16, qoptions uint64, srvList *DNSServerLis return nil, fmt.Errorf("Failed submitting work: %w", err) } - inv.Wait() - return &inv.result, nil + return inv, nil }