mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-25 19:15:34 +00:00
tsdns: initial implementation of a Tailscale DNS resolver (#396)
Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
This commit is contained in:
parent
5e1ee4be53
commit
511840b1f6
31
ipn/local.go
31
ipn/local.go
@ -26,6 +26,7 @@
|
|||||||
"tailscale.com/wgengine"
|
"tailscale.com/wgengine"
|
||||||
"tailscale.com/wgengine/filter"
|
"tailscale.com/wgengine/filter"
|
||||||
"tailscale.com/wgengine/router"
|
"tailscale.com/wgengine/router"
|
||||||
|
"tailscale.com/wgengine/tsdns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LocalBackend is the glue between the major pieces of the Tailscale
|
// LocalBackend is the glue between the major pieces of the Tailscale
|
||||||
@ -311,6 +312,7 @@ func (b *LocalBackend) Start(opts Options) error {
|
|||||||
|
|
||||||
b.send(Notify{NetMap: newSt.NetMap})
|
b.send(Notify{NetMap: newSt.NetMap})
|
||||||
b.updateFilter(newSt.NetMap)
|
b.updateFilter(newSt.NetMap)
|
||||||
|
b.updateDNSMap(newSt.NetMap)
|
||||||
if disableDERP {
|
if disableDERP {
|
||||||
b.e.SetDERPMap(nil)
|
b.e.SetDERPMap(nil)
|
||||||
} else {
|
} else {
|
||||||
@ -427,6 +429,27 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap) {
|
|||||||
b.e.SetFilter(filter.New(netMap.PacketFilter, localNets, b.e.GetFilter(), b.logf))
|
b.e.SetFilter(filter.New(netMap.PacketFilter, localNets, b.e.GetFilter(), b.logf))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateDNSMap updates the domain map in the DNS resolver in wgengine
|
||||||
|
// based on the given netMap and user preferences.
|
||||||
|
func (b *LocalBackend) updateDNSMap(netMap *controlclient.NetworkMap) {
|
||||||
|
if netMap == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dnsMap := &tsdns.Map{DomainToIP: make(map[string]netaddr.IP)}
|
||||||
|
for _, peer := range netMap.Peers {
|
||||||
|
if len(peer.Addresses) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
domain := peer.Hostinfo.Hostname
|
||||||
|
// Like PeerStatus.SimpleHostName()
|
||||||
|
domain = strings.TrimSuffix(domain, ".local")
|
||||||
|
domain = strings.TrimSuffix(domain, ".localdomain")
|
||||||
|
domain = domain + ".ipn.dev"
|
||||||
|
dnsMap.DomainToIP[domain] = netaddr.IPFrom16(peer.Addresses[0].IP.Addr)
|
||||||
|
}
|
||||||
|
b.e.SetDNSMap(dnsMap)
|
||||||
|
}
|
||||||
|
|
||||||
// readPoller is a goroutine that receives service lists from
|
// readPoller is a goroutine that receives service lists from
|
||||||
// b.portpoll and propagates them into the controlclient's HostInfo.
|
// b.portpoll and propagates them into the controlclient's HostInfo.
|
||||||
func (b *LocalBackend) readPoller() {
|
func (b *LocalBackend) readPoller() {
|
||||||
@ -667,6 +690,7 @@ func (b *LocalBackend) SetPrefs(new *Prefs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
b.updateFilter(b.netMapCache)
|
b.updateFilter(b.netMapCache)
|
||||||
|
b.updateDNSMap(b.netMapCache)
|
||||||
|
|
||||||
if old.WantRunning != new.WantRunning {
|
if old.WantRunning != new.WantRunning {
|
||||||
b.stateMachine()
|
b.stateMachine()
|
||||||
@ -799,6 +823,13 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs, dnsDomains []string) *router.
|
|||||||
rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...)
|
rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The Tailscale DNS IP.
|
||||||
|
// TODO(dmytro): make this configurable.
|
||||||
|
rs.Routes = append(rs.Routes, netaddr.IPPrefix{
|
||||||
|
IP: netaddr.IPv4(100, 100, 100, 100),
|
||||||
|
Bits: 32,
|
||||||
|
})
|
||||||
|
|
||||||
return rs
|
return rs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
package filter
|
package filter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -137,7 +136,7 @@ func maybeHexdump(flag RunFlags, b []byte) string {
|
|||||||
var acceptBucket = rate.NewLimiter(rate.Every(10*time.Second), 3)
|
var acceptBucket = rate.NewLimiter(rate.Every(10*time.Second), 3)
|
||||||
var dropBucket = rate.NewLimiter(rate.Every(5*time.Second), 10)
|
var dropBucket = rate.NewLimiter(rate.Every(5*time.Second), 10)
|
||||||
|
|
||||||
func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.ParsedPacket, r Response, why string) {
|
func (f *Filter) logRateLimit(runflags RunFlags, q *packet.ParsedPacket, r Response, why string) {
|
||||||
var verdict string
|
var verdict string
|
||||||
|
|
||||||
if r == Drop && (runflags&LogDrops) != 0 && dropBucket.Allow() {
|
if r == Drop && (runflags&LogDrops) != 0 && dropBucket.Allow() {
|
||||||
@ -151,36 +150,33 @@ func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.ParsedPacke
|
|||||||
// Note: it is crucial that q.String() be called only if {accept,drop}Bucket.Allow() passes,
|
// Note: it is crucial that q.String() be called only if {accept,drop}Bucket.Allow() passes,
|
||||||
// since it causes an allocation.
|
// since it causes an allocation.
|
||||||
if verdict != "" {
|
if verdict != "" {
|
||||||
var qs string
|
b := q.Buffer()
|
||||||
if q == nil {
|
f.logf("%s: %s %d %s\n%s", verdict, q.String(), len(b), why, maybeHexdump(runflags, b))
|
||||||
qs = fmt.Sprintf("(%d bytes)", len(b))
|
|
||||||
} else {
|
|
||||||
qs = q.String()
|
|
||||||
}
|
|
||||||
f.logf("%s: %s %d %s\n%s", verdict, qs, len(b), why, maybeHexdump(runflags, b))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filter) RunIn(b []byte, q *packet.ParsedPacket, rf RunFlags) Response {
|
// RunIn determines whether this node is allowed to receive q from a Tailscale peer.
|
||||||
r := f.pre(b, q, rf)
|
func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response {
|
||||||
|
r := f.pre(q, rf)
|
||||||
if r == Accept || r == Drop {
|
if r == Accept || r == Drop {
|
||||||
// already logged
|
// already logged
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
r, why := f.runIn(q)
|
r, why := f.runIn(q)
|
||||||
f.logRateLimit(rf, b, q, r, why)
|
f.logRateLimit(rf, q, r, why)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filter) RunOut(b []byte, q *packet.ParsedPacket, rf RunFlags) Response {
|
// RunOut determines whether this node is allowed to send q to a Tailscale peer.
|
||||||
r := f.pre(b, q, rf)
|
func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response {
|
||||||
|
r := f.pre(q, rf)
|
||||||
if r == Drop || r == Accept {
|
if r == Drop || r == Accept {
|
||||||
// already logged
|
// already logged
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
r, why := f.runOut(q)
|
r, why := f.runOut(q)
|
||||||
f.logRateLimit(rf, b, q, r, why)
|
f.logRateLimit(rf, q, r, why)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,29 +247,28 @@ func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) {
|
|||||||
return Accept, "ok out"
|
return Accept, "ok out"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Filter) pre(b []byte, q *packet.ParsedPacket, rf RunFlags) Response {
|
func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags) Response {
|
||||||
if len(b) == 0 {
|
if len(q.Buffer()) == 0 {
|
||||||
// wireguard keepalive packet, always permit.
|
// wireguard keepalive packet, always permit.
|
||||||
return Accept
|
return Accept
|
||||||
}
|
}
|
||||||
if len(b) < 20 {
|
if len(q.Buffer()) < 20 {
|
||||||
f.logRateLimit(rf, b, nil, Drop, "too short")
|
f.logRateLimit(rf, q, Drop, "too short")
|
||||||
return Drop
|
return Drop
|
||||||
}
|
}
|
||||||
q.Decode(b)
|
|
||||||
|
|
||||||
switch q.IPProto {
|
switch q.IPProto {
|
||||||
case packet.Unknown:
|
case packet.Unknown:
|
||||||
// Unknown packets are dangerous; always drop them.
|
// Unknown packets are dangerous; always drop them.
|
||||||
f.logRateLimit(rf, b, q, Drop, "unknown")
|
f.logRateLimit(rf, q, Drop, "unknown")
|
||||||
return Drop
|
return Drop
|
||||||
case packet.IPv6:
|
case packet.IPv6:
|
||||||
f.logRateLimit(rf, b, q, Drop, "ipv6")
|
f.logRateLimit(rf, q, Drop, "ipv6")
|
||||||
return Drop
|
return Drop
|
||||||
case packet.Fragment:
|
case packet.Fragment:
|
||||||
// Fragments after the first always need to be passed through.
|
// Fragments after the first always need to be passed through.
|
||||||
// Very small fragments are considered Junk by ParsedPacket.
|
// Very small fragments are considered Junk by ParsedPacket.
|
||||||
f.logRateLimit(rf, b, q, Accept, "fragment")
|
f.logRateLimit(rf, q, Accept, "fragment")
|
||||||
return Accept
|
return Accept
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,11 +144,12 @@ func TestNoAllocs(t *testing.T) {
|
|||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
got := int(testing.AllocsPerRun(1000, func() {
|
got := int(testing.AllocsPerRun(1000, func() {
|
||||||
var q ParsedPacket
|
q := &ParsedPacket{}
|
||||||
|
q.Decode(test.packet)
|
||||||
if test.in {
|
if test.in {
|
||||||
acl.RunIn(test.packet, &q, 0)
|
acl.RunIn(q, 0)
|
||||||
} else {
|
} else {
|
||||||
acl.RunOut(test.packet, &q, 0)
|
acl.RunOut(q, 0)
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -187,12 +188,13 @@ func BenchmarkFilter(b *testing.B) {
|
|||||||
for _, bench := range benches {
|
for _, bench := range benches {
|
||||||
b.Run(bench.name, func(b *testing.B) {
|
b.Run(bench.name, func(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
var q ParsedPacket
|
q := &ParsedPacket{}
|
||||||
|
q.Decode(bench.packet)
|
||||||
// This branch seems to have no measurable impact on performance.
|
// This branch seems to have no measurable impact on performance.
|
||||||
if bench.in {
|
if bench.in {
|
||||||
acl.RunIn(bench.packet, &q, 0)
|
acl.RunIn(q, 0)
|
||||||
} else {
|
} else {
|
||||||
acl.RunOut(bench.packet, &q, 0)
|
acl.RunOut(q, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -215,7 +217,9 @@ func TestPreFilter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
f := NewAllowNone(t.Logf)
|
f := NewAllowNone(t.Logf)
|
||||||
for _, testPacket := range packets {
|
for _, testPacket := range packets {
|
||||||
got := f.pre([]byte(testPacket.b), &ParsedPacket{}, LogDrops|LogAccepts)
|
p := &ParsedPacket{}
|
||||||
|
p.Decode(testPacket.b)
|
||||||
|
got := f.pre(p, LogDrops|LogAccepts)
|
||||||
if got != testPacket.want {
|
if got != testPacket.want {
|
||||||
t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b))
|
t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b))
|
||||||
}
|
}
|
||||||
|
@ -102,7 +102,7 @@ func ipChecksum(b []byte) uint16 {
|
|||||||
// It extracts only the subprotocol id, IP addresses, and (if any) ports,
|
// It extracts only the subprotocol id, IP addresses, and (if any) ports,
|
||||||
// and shouldn't need any memory allocation.
|
// and shouldn't need any memory allocation.
|
||||||
func (q *ParsedPacket) Decode(b []byte) {
|
func (q *ParsedPacket) Decode(b []byte) {
|
||||||
q.b = nil
|
q.b = b
|
||||||
|
|
||||||
if len(b) < ipHeaderLength {
|
if len(b) < ipHeaderLength {
|
||||||
q.IPProto = Unknown
|
q.IPProto = Unknown
|
||||||
@ -170,7 +170,6 @@ func (q *ParsedPacket) Decode(b []byte) {
|
|||||||
}
|
}
|
||||||
q.SrcPort = 0
|
q.SrcPort = 0
|
||||||
q.DstPort = 0
|
q.DstPort = 0
|
||||||
q.b = b
|
|
||||||
q.dataofs = q.subofs + icmpHeaderLength
|
q.dataofs = q.subofs + icmpHeaderLength
|
||||||
return
|
return
|
||||||
case TCP:
|
case TCP:
|
||||||
@ -181,7 +180,6 @@ func (q *ParsedPacket) Decode(b []byte) {
|
|||||||
q.SrcPort = get16(sub[0:2])
|
q.SrcPort = get16(sub[0:2])
|
||||||
q.DstPort = get16(sub[2:4])
|
q.DstPort = get16(sub[2:4])
|
||||||
q.TCPFlags = sub[13] & 0x3F
|
q.TCPFlags = sub[13] & 0x3F
|
||||||
q.b = b
|
|
||||||
headerLength := (sub[12] & 0xF0) >> 2
|
headerLength := (sub[12] & 0xF0) >> 2
|
||||||
q.dataofs = q.subofs + int(headerLength)
|
q.dataofs = q.subofs + int(headerLength)
|
||||||
return
|
return
|
||||||
@ -192,7 +190,6 @@ func (q *ParsedPacket) Decode(b []byte) {
|
|||||||
}
|
}
|
||||||
q.SrcPort = get16(sub[0:2])
|
q.SrcPort = get16(sub[0:2])
|
||||||
q.DstPort = get16(sub[2:4])
|
q.DstPort = get16(sub[2:4])
|
||||||
q.b = b
|
|
||||||
q.dataofs = q.subofs + udpHeaderLength
|
q.dataofs = q.subofs + udpHeaderLength
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
@ -244,6 +241,11 @@ func (q *ParsedPacket) UDPHeader() UDPHeader {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Buffer returns the entire packet buffer.
|
||||||
|
func (q *ParsedPacket) Buffer() []byte {
|
||||||
|
return q.b
|
||||||
|
}
|
||||||
|
|
||||||
// Sub returns the IP subprotocol section.
|
// Sub returns the IP subprotocol section.
|
||||||
func (q *ParsedPacket) Sub(begin, n int) []byte {
|
func (q *ParsedPacket) Sub(begin, n int) []byte {
|
||||||
return q.b[q.subofs+begin : q.subofs+begin+n]
|
return q.b[q.subofs+begin : q.subofs+begin+n]
|
||||||
|
@ -90,6 +90,7 @@ func TestIPString(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var ipv6PacketDecode = ParsedPacket{
|
var ipv6PacketDecode = ParsedPacket{
|
||||||
|
b: ipv6PacketBuffer,
|
||||||
IPProto: IPv6,
|
IPProto: IPv6,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,6 +101,7 @@ func TestIPString(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var unknownPacketDecode = ParsedPacket{
|
var unknownPacketDecode = ParsedPacket{
|
||||||
|
b: unknownPacketBuffer,
|
||||||
IPProto: Unknown,
|
IPProto: Unknown,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
274
wgengine/tsdns/tsdns.go
Normal file
274
wgengine/tsdns/tsdns.go
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package tsdns provides a Resolver struct capable of resolving
|
||||||
|
// domains on a Tailscale network.
|
||||||
|
package tsdns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
dns "golang.org/x/net/dns/dnsmessage"
|
||||||
|
"inet.af/netaddr"
|
||||||
|
"tailscale.com/types/logger"
|
||||||
|
"tailscale.com/wgengine/packet"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultTTL is the TTL in seconds of all responses from Resolver.
|
||||||
|
const defaultTTL = 600
|
||||||
|
|
||||||
|
var (
|
||||||
|
errMapNotSet = errors.New("domain map not set")
|
||||||
|
errNoSuchDomain = errors.New("domain does not exist")
|
||||||
|
errNotImplemented = errors.New("query type not implemented")
|
||||||
|
errNotOurName = errors.New("not an *.ipn.dev domain")
|
||||||
|
errNotQuery = errors.New("not a DNS query")
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultIP = packet.IP(binary.BigEndian.Uint32([]byte{100, 100, 100, 100}))
|
||||||
|
defaultPort = uint16(53)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Map is all the data Resolver needs to resolve DNS queries.
|
||||||
|
type Map struct {
|
||||||
|
// DomainToIP is a mapping of Tailscale domains to their IP addresses.
|
||||||
|
// For example, monitoring.ipn.dev -> 100.64.0.1.
|
||||||
|
DomainToIP map[string]netaddr.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolver is a DNS resolver for domain names of the form *.ipn.dev
|
||||||
|
// It is intended
|
||||||
|
type Resolver struct {
|
||||||
|
logf logger.Logf
|
||||||
|
|
||||||
|
// ip is the IP on which the resolver is listening.
|
||||||
|
ip packet.IP
|
||||||
|
// port is the port on which the resolver is listening.
|
||||||
|
port uint16
|
||||||
|
|
||||||
|
// mu guards the following fields from being updated while used.
|
||||||
|
mu sync.Mutex
|
||||||
|
// dnsMap is the map most recently received from the control server.
|
||||||
|
dnsMap *Map
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResolver constructs a resolver with default parameters.
|
||||||
|
func NewResolver(logf logger.Logf) *Resolver {
|
||||||
|
r := &Resolver{
|
||||||
|
logf: logf,
|
||||||
|
ip: defaultIP,
|
||||||
|
port: defaultPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcceptsPacket determines if the given packet is
|
||||||
|
// directed to this resolver (by ip and port).
|
||||||
|
// We also require that UDP be used to simplify things for now.
|
||||||
|
func (r *Resolver) AcceptsPacket(in *packet.ParsedPacket) bool {
|
||||||
|
return in.DstIP == r.ip && in.DstPort == r.port && in.IPProto == packet.UDP
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMap sets the resolver's DNS map.
|
||||||
|
func (r *Resolver) SetMap(m *Map) {
|
||||||
|
r.mu.Lock()
|
||||||
|
r.dnsMap = m
|
||||||
|
r.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve maps a given domain name to the IP address of the host that owns it.
|
||||||
|
func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) {
|
||||||
|
// If not a subdomain of ipn.dev, then we must refuse this query.
|
||||||
|
// We do this before checking the map to distinguish beween nonexistent domains
|
||||||
|
// and misdirected queries.
|
||||||
|
if !strings.HasSuffix(domain, ".ipn.dev") {
|
||||||
|
return netaddr.IP{}, dns.RCodeRefused, errNotOurName
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
if r.dnsMap == nil {
|
||||||
|
r.mu.Unlock()
|
||||||
|
return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet
|
||||||
|
}
|
||||||
|
addr, found := r.dnsMap.DomainToIP[domain]
|
||||||
|
r.mu.Unlock()
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
return netaddr.IP{}, dns.RCodeNameError, errNoSuchDomain
|
||||||
|
}
|
||||||
|
return addr, dns.RCodeSuccess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type response struct {
|
||||||
|
Header dns.Header
|
||||||
|
ResourceHeader dns.ResourceHeader
|
||||||
|
Question dns.Question
|
||||||
|
// TODO(dmytro): support IPv6.
|
||||||
|
IP netaddr.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseQuery parses the query in given packet into a response struct.
|
||||||
|
func (r *Resolver) parseQuery(query *packet.ParsedPacket, resp *response) error {
|
||||||
|
var parser dns.Parser
|
||||||
|
var err error
|
||||||
|
|
||||||
|
resp.Header, err = parser.Start(query.Payload())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Header.Response {
|
||||||
|
return errNotQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Question, err = parser.Question()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeResponse resolves the question stored in resp and sets the answer fields.
|
||||||
|
func (r *Resolver) makeResponse(resp *response) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
name := resp.Question.Name.String()
|
||||||
|
if len(name) > 0 {
|
||||||
|
name = name[:len(name)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Question.Type == dns.TypeA {
|
||||||
|
// Remove final dot from name: *.ipn.dev. -> *.ipn.dev
|
||||||
|
resp.IP, resp.Header.RCode, err = r.Resolve(name)
|
||||||
|
} else {
|
||||||
|
resp.Header.RCode = dns.RCodeNotImplemented
|
||||||
|
err = errNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshalAnswer serializes the answer record into an active builder.
|
||||||
|
func marshalAnswer(resp *response, builder *dns.Builder) error {
|
||||||
|
var answer dns.AResource
|
||||||
|
|
||||||
|
err := builder.StartAnswers()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
answerHeader := dns.ResourceHeader{
|
||||||
|
Name: resp.Question.Name,
|
||||||
|
Type: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
TTL: defaultTTL,
|
||||||
|
}
|
||||||
|
ip := resp.IP.As16()
|
||||||
|
copy(answer.A[:], ip[12:])
|
||||||
|
return builder.AResource(answerHeader, answer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshalResponse serializes the DNS response into an active builder.
|
||||||
|
func marshalResponse(resp *response, builder *dns.Builder) ([]byte, error) {
|
||||||
|
resp.Header.Response = true
|
||||||
|
resp.Header.Authoritative = true
|
||||||
|
if resp.Header.RecursionDesired {
|
||||||
|
resp.Header.RecursionAvailable = true
|
||||||
|
}
|
||||||
|
|
||||||
|
err := builder.StartQuestions()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = builder.Question(resp.Question)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Header.RCode == dns.RCodeSuccess {
|
||||||
|
err = marshalAnswer(resp, builder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.Finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalResponsePacket(query *packet.ParsedPacket, resp *response, buf []byte) ([]byte, error) {
|
||||||
|
udpHeader := query.UDPHeader()
|
||||||
|
udpHeader.ToResponse()
|
||||||
|
offset := udpHeader.Len()
|
||||||
|
|
||||||
|
// dns.Builder appends to the passed buffer (without reallocation when possible),
|
||||||
|
// so we pass in a zero-length slice starting at the point it should start writing.
|
||||||
|
builder := dns.NewBuilder(buf[offset:offset], resp.Header)
|
||||||
|
|
||||||
|
// rbuf is the response slice with the correct length starting at offset.
|
||||||
|
rbuf, err := marshalResponse(resp, &builder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
end := offset + len(rbuf)
|
||||||
|
err = udpHeader.Marshal(buf[:end])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf[:end], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Respond writes a response to query into buf and returns buf trimmed to the response length.
|
||||||
|
// It is assumed that r.AcceptsPacket(query) is true.
|
||||||
|
func (r *Resolver) Respond(query *packet.ParsedPacket, buf []byte) ([]byte, error) {
|
||||||
|
var resp response
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 0. Verify that contract is upheld.
|
||||||
|
if !r.AcceptsPacket(query) {
|
||||||
|
r.logf("[unexpected] tsdns: Respond called on query not for this resolver")
|
||||||
|
resp.Header.RCode = dns.RCodeServerFailure
|
||||||
|
return marshalResponsePacket(query, &resp, buf)
|
||||||
|
}
|
||||||
|
// A DNS response is at least as long as the query
|
||||||
|
if len(buf) < len(query.Buffer()) {
|
||||||
|
r.logf("[unexpected] tsdns: response buffer is too small")
|
||||||
|
resp.Header.RCode = dns.RCodeServerFailure
|
||||||
|
return marshalResponsePacket(query, &resp, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Parse query packet.
|
||||||
|
err = r.parseQuery(query, &resp)
|
||||||
|
// We will not return this error: it is the sender's fault.
|
||||||
|
if err != nil {
|
||||||
|
r.logf("tsdns: error during query parsing: %v", err)
|
||||||
|
resp.Header.RCode = dns.RCodeFormatError
|
||||||
|
return marshalResponsePacket(query, &resp, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Service the query.
|
||||||
|
err = r.makeResponse(&resp)
|
||||||
|
// We will not return this error: it is the sender's fault.
|
||||||
|
if err != nil {
|
||||||
|
r.logf("tsdns: error during name resolution: %v", err)
|
||||||
|
return marshalResponsePacket(query, &resp, buf)
|
||||||
|
}
|
||||||
|
// For now, we require IPv4 in all cases.
|
||||||
|
// If we somehow came up with a non-IPv4 address, it's our fault.
|
||||||
|
if !resp.IP.Is4() {
|
||||||
|
resp.Header.RCode = dns.RCodeServerFailure
|
||||||
|
r.logf("tsdns: error during name resolution: IPv6 address: %v", resp.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Serialize the response.
|
||||||
|
return marshalResponsePacket(query, &resp, buf)
|
||||||
|
}
|
@ -10,6 +10,7 @@
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/tailscale/wireguard-go/device"
|
"github.com/tailscale/wireguard-go/device"
|
||||||
@ -19,10 +20,12 @@
|
|||||||
"tailscale.com/wgengine/packet"
|
"tailscale.com/wgengine/packet"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const maxBufferSize = device.MaxMessageSize
|
||||||
readMaxSize = device.MaxMessageSize
|
|
||||||
readOffset = device.MessageTransportHeaderSize
|
// PacketStartOffset is the minimal amount of leading space that must exist
|
||||||
)
|
// before &packet[offset] in a packet passed to Read, Write, or InjectInboundDirect.
|
||||||
|
// This is necessary to avoid reallocation in wireguard-go internals.
|
||||||
|
const PacketStartOffset = device.MessageTransportHeaderSize
|
||||||
|
|
||||||
// MaxPacketSize is the maximum size (in bytes)
|
// MaxPacketSize is the maximum size (in bytes)
|
||||||
// of a packet that can be injected into a tstun.TUN.
|
// of a packet that can be injected into a tstun.TUN.
|
||||||
@ -35,7 +38,15 @@
|
|||||||
ErrFiltered = errors.New("packet dropped by filter")
|
ErrFiltered = errors.New("packet dropped by filter")
|
||||||
)
|
)
|
||||||
|
|
||||||
var errPacketTooBig = errors.New("packet too big")
|
var (
|
||||||
|
errPacketTooBig = errors.New("packet too big")
|
||||||
|
errOffsetTooBig = errors.New("offset larger than buffer length")
|
||||||
|
errOffsetTooSmall = errors.New("offset smaller than PacketStartOffset")
|
||||||
|
)
|
||||||
|
|
||||||
|
// FilterFunc is a packet-filtering function with access to the TUN device.
|
||||||
|
// It must not hold onto the packet struct, as its backing storage will be reused.
|
||||||
|
type FilterFunc func(*packet.ParsedPacket, *TUN) filter.Response
|
||||||
|
|
||||||
// TUN wraps a tun.Device from wireguard-go,
|
// TUN wraps a tun.Device from wireguard-go,
|
||||||
// augmenting it with filtering and packet injection.
|
// augmenting it with filtering and packet injection.
|
||||||
@ -47,10 +58,14 @@ type TUN struct {
|
|||||||
tdev tun.Device
|
tdev tun.Device
|
||||||
|
|
||||||
// buffer stores the oldest unconsumed packet from tdev.
|
// buffer stores the oldest unconsumed packet from tdev.
|
||||||
// It is made a static buffer in order to avoid graticious allocation.
|
// It is made a static buffer in order to avoid allocations.
|
||||||
buffer [readMaxSize]byte
|
buffer [maxBufferSize]byte
|
||||||
// bufferConsumed synchronizes access to buffer (shared by Read and poll).
|
// bufferConsumed synchronizes access to buffer (shared by Read and poll).
|
||||||
bufferConsumed chan struct{}
|
bufferConsumed chan struct{}
|
||||||
|
// parsedPacketPool holds a pool of ParsedPacket structs for use in filtering.
|
||||||
|
// This is needed because escape analysis cannot see that parsed packets
|
||||||
|
// do not escape through {Pre,Post}Filter{In,Out}.
|
||||||
|
parsedPacketPool sync.Pool // of *packet.ParsedPacket
|
||||||
|
|
||||||
// closed signals poll (by closing) when the device is closed.
|
// closed signals poll (by closing) when the device is closed.
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
@ -73,8 +88,19 @@ type TUN struct {
|
|||||||
// filterFlags control the verbosity of logging packet drops/accepts.
|
// filterFlags control the verbosity of logging packet drops/accepts.
|
||||||
filterFlags filter.RunFlags
|
filterFlags filter.RunFlags
|
||||||
|
|
||||||
// insecure disables all filtering when set. This is useful in tests.
|
// PreFilterIn is the inbound filter function that runs before the main filter
|
||||||
insecure bool
|
// and therefore sees the packets that may be later dropped by it.
|
||||||
|
PreFilterIn FilterFunc
|
||||||
|
// PostFilterIn is the inbound filter function that runs after the main filter.
|
||||||
|
PostFilterIn FilterFunc
|
||||||
|
// PreFilterOut is the outbound filter function that runs before the main filter
|
||||||
|
// and therefore sees the packets that may be later dropped by it.
|
||||||
|
PreFilterOut FilterFunc
|
||||||
|
// PostFilterOut is the outbound filter function that runs after the main filter.
|
||||||
|
PostFilterOut FilterFunc
|
||||||
|
|
||||||
|
// disableFilter disables all filtering when set. This should only be used in tests.
|
||||||
|
disableFilter bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
|
func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
|
||||||
@ -87,8 +113,14 @@ func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
|
|||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
errors: make(chan error),
|
errors: make(chan error),
|
||||||
outbound: make(chan []byte),
|
outbound: make(chan []byte),
|
||||||
filterFlags: filter.LogAccepts | filter.LogDrops,
|
// TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets.
|
||||||
|
filterFlags: filter.LogAccepts | filter.LogDrops,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tun.parsedPacketPool.New = func() interface{} {
|
||||||
|
return new(packet.ParsedPacket)
|
||||||
|
}
|
||||||
|
|
||||||
go tun.poll()
|
go tun.poll()
|
||||||
// The buffer starts out consumed.
|
// The buffer starts out consumed.
|
||||||
tun.bufferConsumed <- struct{}{}
|
tun.bufferConsumed <- struct{}{}
|
||||||
@ -140,10 +172,10 @@ func (t *TUN) poll() {
|
|||||||
// continue
|
// continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read may use memory in t.buffer before readOffset for mandatory headers.
|
// Read may use memory in t.buffer before PacketStartOffset for mandatory headers.
|
||||||
// This is the rationale behind the tun.TUN.{Read,Write} interfaces
|
// This is the rationale behind the tun.TUN.{Read,Write} interfaces
|
||||||
// and the reason t.buffer has size MaxMessageSize and not MaxContentSize.
|
// and the reason t.buffer has size MaxMessageSize and not MaxContentSize.
|
||||||
n, err := t.tdev.Read(t.buffer[:], readOffset)
|
n, err := t.tdev.Read(t.buffer[:], PacketStartOffset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
select {
|
select {
|
||||||
case <-t.closed:
|
case <-t.closed:
|
||||||
@ -165,26 +197,41 @@ func (t *TUN) poll() {
|
|||||||
select {
|
select {
|
||||||
case <-t.closed:
|
case <-t.closed:
|
||||||
return
|
return
|
||||||
case t.outbound <- t.buffer[readOffset : readOffset+n]:
|
case t.outbound <- t.buffer[PacketStartOffset : PacketStartOffset+n]:
|
||||||
// continue
|
// continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TUN) filterOut(buf []byte) filter.Response {
|
func (t *TUN) filterOut(buf []byte) filter.Response {
|
||||||
|
p := t.parsedPacketPool.Get().(*packet.ParsedPacket)
|
||||||
|
defer t.parsedPacketPool.Put(p)
|
||||||
|
p.Decode(buf)
|
||||||
|
|
||||||
|
if t.PreFilterOut != nil {
|
||||||
|
if t.PreFilterOut(p, t) == filter.Drop {
|
||||||
|
return filter.Drop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
filt, _ := t.filter.Load().(*filter.Filter)
|
filt, _ := t.filter.Load().(*filter.Filter)
|
||||||
|
|
||||||
if filt == nil {
|
if filt == nil {
|
||||||
t.logf("Warning: you forgot to use SetFilter()! Packet dropped.")
|
t.logf("tstun: warning: you forgot to use SetFilter()! Packet dropped.")
|
||||||
return filter.Drop
|
return filter.Drop
|
||||||
}
|
}
|
||||||
|
|
||||||
var p packet.ParsedPacket
|
if filt.RunOut(p, t.filterFlags) != filter.Accept {
|
||||||
if filt.RunOut(buf, &p, t.filterFlags) == filter.Accept {
|
return filter.Drop
|
||||||
return filter.Accept
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return filter.Drop
|
if t.PostFilterOut != nil {
|
||||||
|
if t.PostFilterOut(p, t) == filter.Drop {
|
||||||
|
return filter.Drop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.Accept
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TUN) Read(buf []byte, offset int) (int, error) {
|
func (t *TUN) Read(buf []byte, offset int) (int, error) {
|
||||||
@ -200,12 +247,16 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) {
|
|||||||
// t.buffer has a fixed location in memory,
|
// t.buffer has a fixed location in memory,
|
||||||
// so this is the easiest way to tell when it has been consumed.
|
// so this is the easiest way to tell when it has been consumed.
|
||||||
// &packet[0] can be used because empty packets do not reach t.outbound.
|
// &packet[0] can be used because empty packets do not reach t.outbound.
|
||||||
if &packet[0] == &t.buffer[readOffset] {
|
if &packet[0] == &t.buffer[PacketStartOffset] {
|
||||||
t.bufferConsumed <- struct{}{}
|
t.bufferConsumed <- struct{}{}
|
||||||
|
} else {
|
||||||
|
// If the packet is not from t.buffer, then it is an injected packet.
|
||||||
|
// In this case, we return eary to bypass filtering
|
||||||
|
return n, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !t.insecure {
|
if !t.disableFilter {
|
||||||
response := t.filterOut(buf[offset : offset+n])
|
response := t.filterOut(buf[offset : offset+n])
|
||||||
if response != filter.Accept {
|
if response != filter.Accept {
|
||||||
// Wireguard considers read errors fatal; pretend nothing was read
|
// Wireguard considers read errors fatal; pretend nothing was read
|
||||||
@ -217,35 +268,38 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *TUN) filterIn(buf []byte) filter.Response {
|
func (t *TUN) filterIn(buf []byte) filter.Response {
|
||||||
|
p := t.parsedPacketPool.Get().(*packet.ParsedPacket)
|
||||||
|
defer t.parsedPacketPool.Put(p)
|
||||||
|
p.Decode(buf)
|
||||||
|
|
||||||
|
if t.PreFilterIn != nil {
|
||||||
|
if t.PreFilterIn(p, t) == filter.Drop {
|
||||||
|
return filter.Drop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
filt, _ := t.filter.Load().(*filter.Filter)
|
filt, _ := t.filter.Load().(*filter.Filter)
|
||||||
|
|
||||||
if filt == nil {
|
if filt == nil {
|
||||||
t.logf("Warning: you forgot to use SetFilter()! Packet dropped.")
|
t.logf("tstun: warning: you forgot to use SetFilter()! Packet dropped.")
|
||||||
return filter.Drop
|
return filter.Drop
|
||||||
}
|
}
|
||||||
|
|
||||||
var p packet.ParsedPacket
|
if filt.RunIn(p, t.filterFlags) != filter.Accept {
|
||||||
if filt.RunIn(buf, &p, t.filterFlags) == filter.Accept {
|
return filter.Drop
|
||||||
// Only in fake mode, answer any incoming pings.
|
|
||||||
if p.IsEchoRequest() {
|
|
||||||
ft, ok := t.tdev.(*fakeTUN)
|
|
||||||
if ok {
|
|
||||||
header := p.ICMPHeader()
|
|
||||||
header.ToResponse()
|
|
||||||
packet := packet.Generate(&header, p.Payload())
|
|
||||||
ft.Write(packet, 0)
|
|
||||||
// We already handled it, stop.
|
|
||||||
return filter.Drop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return filter.Accept
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return filter.Drop
|
if t.PostFilterIn != nil {
|
||||||
|
if t.PostFilterIn(p, t) == filter.Drop {
|
||||||
|
return filter.Drop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.Accept
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TUN) Write(buf []byte, offset int) (int, error) {
|
func (t *TUN) Write(buf []byte, offset int) (int, error) {
|
||||||
if !t.insecure {
|
if !t.disableFilter {
|
||||||
response := t.filterIn(buf[offset:])
|
response := t.filterIn(buf[offset:])
|
||||||
if response != filter.Accept {
|
if response != filter.Accept {
|
||||||
return 0, ErrFiltered
|
return 0, ErrFiltered
|
||||||
@ -264,24 +318,53 @@ func (t *TUN) SetFilter(filt *filter.Filter) {
|
|||||||
t.filter.Store(filt)
|
t.filter.Store(filt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InjectInbound makes the TUN device behave as if a packet
|
// InjectInboundDirect makes the TUN device behave as if a packet
|
||||||
// with the given contents was received from the network.
|
// with the given contents was received from the network.
|
||||||
// It blocks and does not take ownership of the packet.
|
// It blocks and does not take ownership of the packet.
|
||||||
// Injecting an empty packet is a no-op.
|
// The injected packet will not pass through inbound filters.
|
||||||
func (t *TUN) InjectInbound(packet []byte) error {
|
//
|
||||||
|
// The packet contents are to start at &buf[offset].
|
||||||
|
// offset must be greater or equal to PacketStartOffset.
|
||||||
|
// The space before &buf[offset] will be used by Wireguard.
|
||||||
|
func (t *TUN) InjectInboundDirect(buf []byte, offset int) error {
|
||||||
|
if len(buf) > MaxPacketSize {
|
||||||
|
return errPacketTooBig
|
||||||
|
}
|
||||||
|
if len(buf) < offset {
|
||||||
|
return errOffsetTooBig
|
||||||
|
}
|
||||||
|
if offset < PacketStartOffset {
|
||||||
|
return errOffsetTooSmall
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write to the underlying device to skip filters.
|
||||||
|
_, err := t.tdev.Write(buf, offset)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectInboundCopy takes a packet without leading space,
|
||||||
|
// reallocates it to conform to the InjectInbondDirect interface
|
||||||
|
// and calls InjectInboundDirect on it. Injecting a nil packet is a no-op.
|
||||||
|
func (t *TUN) InjectInboundCopy(packet []byte) error {
|
||||||
|
// We duplicate this check from InjectInboundDirect here
|
||||||
|
// to avoid wasting an allocation on an oversized packet.
|
||||||
if len(packet) > MaxPacketSize {
|
if len(packet) > MaxPacketSize {
|
||||||
return errPacketTooBig
|
return errPacketTooBig
|
||||||
}
|
}
|
||||||
if len(packet) == 0 {
|
if len(packet) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
_, err := t.Write(packet, 0)
|
|
||||||
return err
|
buf := make([]byte, PacketStartOffset+len(packet))
|
||||||
|
copy(buf[PacketStartOffset:], packet)
|
||||||
|
|
||||||
|
return t.InjectInboundDirect(buf, PacketStartOffset)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InjectOutbound makes the TUN device behave as if a packet
|
// InjectOutbound makes the TUN device behave as if a packet
|
||||||
// with the given contents was sent to the network.
|
// with the given contents was sent to the network.
|
||||||
// It does not block, but takes ownership of the packet.
|
// It does not block, but takes ownership of the packet.
|
||||||
|
// The injected packet will not pass through outbound filters.
|
||||||
// Injecting an empty packet is a no-op.
|
// Injecting an empty packet is a no-op.
|
||||||
func (t *TUN) InjectOutbound(packet []byte) error {
|
func (t *TUN) InjectOutbound(packet []byte) error {
|
||||||
if len(packet) > MaxPacketSize {
|
if len(packet) > MaxPacketSize {
|
||||||
|
@ -58,7 +58,7 @@ func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *TUN) {
|
|||||||
if secure {
|
if secure {
|
||||||
setfilter(logf, tun)
|
setfilter(logf, tun)
|
||||||
} else {
|
} else {
|
||||||
tun.insecure = true
|
tun.disableFilter = true
|
||||||
}
|
}
|
||||||
return chtun, tun
|
return chtun, tun
|
||||||
}
|
}
|
||||||
@ -69,7 +69,7 @@ func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *TUN) {
|
|||||||
if secure {
|
if secure {
|
||||||
setfilter(logf, tun)
|
setfilter(logf, tun)
|
||||||
} else {
|
} else {
|
||||||
tun.insecure = true
|
tun.disableFilter = true
|
||||||
}
|
}
|
||||||
return ftun.(*fakeTUN), tun
|
return ftun.(*fakeTUN), tun
|
||||||
}
|
}
|
||||||
@ -151,7 +151,7 @@ func TestWriteAndInject(t *testing.T) {
|
|||||||
for _, packet := range injected {
|
for _, packet := range injected {
|
||||||
go func(packet string) {
|
go func(packet string) {
|
||||||
payload := []byte(packet)
|
payload := []byte(packet)
|
||||||
err := tun.InjectInbound(payload)
|
err := tun.InjectInboundCopy(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%s: error: %v", packet, err)
|
t.Errorf("%s: error: %v", packet, err)
|
||||||
}
|
}
|
||||||
|
@ -34,6 +34,7 @@
|
|||||||
"tailscale.com/wgengine/monitor"
|
"tailscale.com/wgengine/monitor"
|
||||||
"tailscale.com/wgengine/packet"
|
"tailscale.com/wgengine/packet"
|
||||||
"tailscale.com/wgengine/router"
|
"tailscale.com/wgengine/router"
|
||||||
|
"tailscale.com/wgengine/tsdns"
|
||||||
"tailscale.com/wgengine/tstun"
|
"tailscale.com/wgengine/tstun"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -54,6 +55,7 @@ type userspaceEngine struct {
|
|||||||
tundev *tstun.TUN
|
tundev *tstun.TUN
|
||||||
wgdev *device.Device
|
wgdev *device.Device
|
||||||
router router.Router
|
router router.Router
|
||||||
|
resolver *tsdns.Resolver
|
||||||
magicConn *magicsock.Conn
|
magicConn *magicsock.Conn
|
||||||
linkMon *monitor.Mon
|
linkMon *monitor.Mon
|
||||||
|
|
||||||
@ -73,6 +75,28 @@ type userspaceEngine struct {
|
|||||||
// Lock ordering: wgLock, then mu.
|
// Lock ordering: wgLock, then mu.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RouterGen is the signature for a function that creates a
|
||||||
|
// router.Router.
|
||||||
|
type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error)
|
||||||
|
|
||||||
|
type EngineConfig struct {
|
||||||
|
// Logf is the logging function used by the engine.
|
||||||
|
Logf logger.Logf
|
||||||
|
// TUN is the tun device used by the engine.
|
||||||
|
TUN tun.Device
|
||||||
|
// RouterGen is the function used to instantiate the router.
|
||||||
|
RouterGen RouterGen
|
||||||
|
// ListenPort is the port on which the engine will listen.
|
||||||
|
ListenPort uint16
|
||||||
|
// EchoRespondToAll determines whether ICMP Echo requests incoming from Tailscale peers
|
||||||
|
// will be intercepted and responded to, regardless of the source host.
|
||||||
|
EchoRespondToAll bool
|
||||||
|
// UseTailscaleDNS determines whether DNS requests for names of the form *.ipn.dev
|
||||||
|
// directed to the designated Taislcale DNS address (see wgengine/tsdns)
|
||||||
|
// will be intercepted and resolved by a tsdns.Resolver.
|
||||||
|
UseTailscaleDNS bool
|
||||||
|
}
|
||||||
|
|
||||||
type Loggify struct {
|
type Loggify struct {
|
||||||
f logger.Logf
|
f logger.Logf
|
||||||
}
|
}
|
||||||
@ -84,8 +108,14 @@ func (l *Loggify) Write(b []byte) (int, error) {
|
|||||||
|
|
||||||
func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16) (Engine, error) {
|
func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16) (Engine, error) {
|
||||||
logf("Starting userspace wireguard engine (FAKE tuntap device).")
|
logf("Starting userspace wireguard engine (FAKE tuntap device).")
|
||||||
tundev := tstun.WrapTUN(logf, tstun.NewFakeTUN())
|
conf := EngineConfig{
|
||||||
return NewUserspaceEngineAdvanced(logf, tundev, router.NewFake, listenPort)
|
Logf: logf,
|
||||||
|
TUN: tstun.NewFakeTUN(),
|
||||||
|
RouterGen: router.NewFake,
|
||||||
|
ListenPort: listenPort,
|
||||||
|
EchoRespondToAll: true,
|
||||||
|
}
|
||||||
|
return NewUserspaceEngineAdvanced(conf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUserspaceEngine creates the named tun device and returns a
|
// NewUserspaceEngine creates the named tun device and returns a
|
||||||
@ -104,38 +134,53 @@ func NewUserspaceEngine(logf logger.Logf, tunname string, listenPort uint16) (En
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
logf("CreateTUN ok.")
|
logf("CreateTUN ok.")
|
||||||
tundev := tstun.WrapTUN(logf, tun)
|
|
||||||
|
|
||||||
e, err := NewUserspaceEngineAdvanced(logf, tundev, router.New, listenPort)
|
conf := EngineConfig{
|
||||||
|
Logf: logf,
|
||||||
|
TUN: tun,
|
||||||
|
RouterGen: router.New,
|
||||||
|
ListenPort: listenPort,
|
||||||
|
// TODO(dmytro): plumb this down.
|
||||||
|
UseTailscaleDNS: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
e, err := NewUserspaceEngineAdvanced(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return e, err
|
return e, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// RouterGen is the signature for a function that creates a
|
// NewUserspaceEngineAdvanced is like NewUserspaceEngine
|
||||||
// router.Router.
|
// but provides control over all config fields.
|
||||||
type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error)
|
func NewUserspaceEngineAdvanced(conf EngineConfig) (Engine, error) {
|
||||||
|
return newUserspaceEngineAdvanced(conf)
|
||||||
// NewUserspaceEngineAdvanced is like NewUserspaceEngine but takes a pre-created TUN device and allows specifing
|
|
||||||
// a custom router constructor and listening port.
|
|
||||||
func NewUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (Engine, error) {
|
|
||||||
return newUserspaceEngineAdvanced(logf, tundev, routerGen, listenPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (_ Engine, reterr error) {
|
func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) {
|
||||||
|
logf := conf.Logf
|
||||||
|
|
||||||
e := &userspaceEngine{
|
e := &userspaceEngine{
|
||||||
logf: logf,
|
logf: logf,
|
||||||
reqCh: make(chan struct{}, 1),
|
reqCh: make(chan struct{}, 1),
|
||||||
waitCh: make(chan struct{}),
|
waitCh: make(chan struct{}),
|
||||||
tundev: tundev,
|
tundev: tstun.WrapTUN(logf, conf.TUN),
|
||||||
pingers: make(map[wgcfg.Key]*pinger),
|
resolver: tsdns.NewResolver(logf),
|
||||||
|
pingers: make(map[wgcfg.Key]*pinger),
|
||||||
}
|
}
|
||||||
e.linkState, _ = getLinkState()
|
e.linkState, _ = getLinkState()
|
||||||
|
|
||||||
|
// Respond to all pings only in fake mode.
|
||||||
|
if conf.EchoRespondToAll {
|
||||||
|
e.tundev.PostFilterIn = echoRespondToAll
|
||||||
|
}
|
||||||
|
if conf.UseTailscaleDNS {
|
||||||
|
e.tundev.PreFilterOut = e.handleDNS
|
||||||
|
}
|
||||||
|
|
||||||
mon, err := monitor.New(logf, func() { e.LinkChange(false) })
|
mon, err := monitor.New(logf, func() { e.LinkChange(false) })
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tundev.Close()
|
e.tundev.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
e.linkMon = mon
|
e.linkMon = mon
|
||||||
@ -149,12 +194,12 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
|
|||||||
}
|
}
|
||||||
magicsockOpts := magicsock.Options{
|
magicsockOpts := magicsock.Options{
|
||||||
Logf: logf,
|
Logf: logf,
|
||||||
Port: listenPort,
|
Port: conf.ListenPort,
|
||||||
EndpointsFunc: endpointsFn,
|
EndpointsFunc: endpointsFn,
|
||||||
}
|
}
|
||||||
e.magicConn, err = magicsock.NewConn(magicsockOpts)
|
e.magicConn, err = magicsock.NewConn(magicsockOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tundev.Close()
|
e.tundev.Close()
|
||||||
return nil, fmt.Errorf("wgengine: %v", err)
|
return nil, fmt.Errorf("wgengine: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,7 +256,7 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
|
|||||||
|
|
||||||
// Pass the underlying tun.(*NativeDevice) to the router:
|
// Pass the underlying tun.(*NativeDevice) to the router:
|
||||||
// routers do not Read or Write, but do access native interfaces.
|
// routers do not Read or Write, but do access native interfaces.
|
||||||
e.router, err = routerGen(logf, e.wgdev, e.tundev.Unwrap())
|
e.router, err = conf.RouterGen(logf, e.wgdev, e.tundev.Unwrap())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.magicConn.Close()
|
e.magicConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -256,6 +301,37 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
|
|||||||
return e, nil
|
return e, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// echoRespondToAll is an inbound post-filter responding to all echo requests.
|
||||||
|
func echoRespondToAll(p *packet.ParsedPacket, t *tstun.TUN) filter.Response {
|
||||||
|
if p.IsEchoRequest() {
|
||||||
|
header := p.ICMPHeader()
|
||||||
|
header.ToResponse()
|
||||||
|
packet := packet.Generate(&header, p.Payload())
|
||||||
|
t.InjectOutbound(packet)
|
||||||
|
// We already handled it, stop.
|
||||||
|
return filter.Drop
|
||||||
|
}
|
||||||
|
return filter.Accept
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDNS is an outbound pre-filter resolving Tailscale domains.
|
||||||
|
func (e *userspaceEngine) handleDNS(p *packet.ParsedPacket, t *tstun.TUN) filter.Response {
|
||||||
|
if e.resolver.AcceptsPacket(p) {
|
||||||
|
// TODO(dmytro): avoid this allocation without having tsdns know tstun quirks.
|
||||||
|
buf := make([]byte, tstun.MaxPacketSize)
|
||||||
|
offset := tstun.PacketStartOffset
|
||||||
|
response, err := e.resolver.Respond(p, buf[offset:])
|
||||||
|
if err != nil {
|
||||||
|
e.logf("DNS resolver error: %v", err)
|
||||||
|
} else {
|
||||||
|
t.InjectInboundDirect(buf[:offset+len(response)], offset)
|
||||||
|
}
|
||||||
|
// We already handled it, stop.
|
||||||
|
return filter.Drop
|
||||||
|
}
|
||||||
|
return filter.Accept
|
||||||
|
}
|
||||||
|
|
||||||
// pinger sends ping packets for a few seconds.
|
// pinger sends ping packets for a few seconds.
|
||||||
//
|
//
|
||||||
// These generated packets are used to ensure we trigger the spray logic in
|
// These generated packets are used to ensure we trigger the spray logic in
|
||||||
@ -447,6 +523,10 @@ func (e *userspaceEngine) SetFilter(filt *filter.Filter) {
|
|||||||
e.tundev.SetFilter(filt)
|
e.tundev.SetFilter(filt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *userspaceEngine) SetDNSMap(dm *tsdns.Map) {
|
||||||
|
e.resolver.SetMap(dm)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) {
|
func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) {
|
||||||
e.mu.Lock()
|
e.mu.Lock()
|
||||||
defer e.mu.Unlock()
|
defer e.mu.Unlock()
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/wgengine/filter"
|
"tailscale.com/wgengine/filter"
|
||||||
"tailscale.com/wgengine/router"
|
"tailscale.com/wgengine/router"
|
||||||
|
"tailscale.com/wgengine/tsdns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWatchdog wraps an Engine and makes sure that all methods complete
|
// NewWatchdog wraps an Engine and makes sure that all methods complete
|
||||||
@ -74,6 +75,9 @@ func (e *watchdogEngine) GetFilter() *filter.Filter {
|
|||||||
func (e *watchdogEngine) SetFilter(filt *filter.Filter) {
|
func (e *watchdogEngine) SetFilter(filt *filter.Filter) {
|
||||||
e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) })
|
e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) })
|
||||||
}
|
}
|
||||||
|
func (e *watchdogEngine) SetDNSMap(dm *tsdns.Map) {
|
||||||
|
e.watchdog("SetDNSMap", func() { e.wrap.SetDNSMap(dm) })
|
||||||
|
}
|
||||||
func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) {
|
func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) {
|
||||||
e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) })
|
e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) })
|
||||||
}
|
}
|
||||||
|
@ -10,9 +10,6 @@
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"tailscale.com/wgengine/router"
|
|
||||||
"tailscale.com/wgengine/tstun"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWatchdog(t *testing.T) {
|
func TestWatchdog(t *testing.T) {
|
||||||
@ -20,8 +17,7 @@ func TestWatchdog(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("default watchdog does not fire", func(t *testing.T) {
|
t.Run("default watchdog does not fire", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN())
|
e, err := NewFakeUserspaceEngine(t.Logf, 0)
|
||||||
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -37,8 +33,7 @@ func TestWatchdog(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("watchdog fires on blocked getStatus", func(t *testing.T) {
|
t.Run("watchdog fires on blocked getStatus", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN())
|
e, err := NewFakeUserspaceEngine(t.Logf, 0)
|
||||||
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/wgengine/filter"
|
"tailscale.com/wgengine/filter"
|
||||||
"tailscale.com/wgengine/router"
|
"tailscale.com/wgengine/router"
|
||||||
|
"tailscale.com/wgengine/tsdns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ByteCount is the number of bytes that have been sent or received.
|
// ByteCount is the number of bytes that have been sent or received.
|
||||||
@ -65,6 +66,9 @@ type Engine interface {
|
|||||||
// SetFilter updates the packet filter.
|
// SetFilter updates the packet filter.
|
||||||
SetFilter(*filter.Filter)
|
SetFilter(*filter.Filter)
|
||||||
|
|
||||||
|
// SetDNSMap updates the DNS map.
|
||||||
|
SetDNSMap(*tsdns.Map)
|
||||||
|
|
||||||
// SetStatusCallback sets the function to call when the
|
// SetStatusCallback sets the function to call when the
|
||||||
// WireGuard status changes.
|
// WireGuard status changes.
|
||||||
SetStatusCallback(StatusCallback)
|
SetStatusCallback(StatusCallback)
|
||||||
|
Loading…
Reference in New Issue
Block a user