mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-11 13:18:53 +00:00
appc,cmd/sniproxy,ipn/ipnlocal: split sniproxy configuration code out of appc
The design changed during integration and testing, resulting in the earlier implementation growing in the appc package to be intended now only for the sniproxy implementation. That code is moved to it's final location, and the current App Connector code is now renamed. Updates tailscale/corp#15437 Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:

committed by
James Tucker

parent
6c0ac8bef3
commit
f27b2cf569
104
cmd/sniproxy/handlers.go
Normal file
104
cmd/sniproxy/handlers.go
Normal file
@@ -0,0 +1,104 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"inet.af/tcpproxy"
|
||||
"tailscale.com/net/netutil"
|
||||
)
|
||||
|
||||
type tcpRoundRobinHandler struct {
|
||||
// To is a list of destination addresses to forward to.
|
||||
// An entry may be either an IP address or a DNS name.
|
||||
To []string
|
||||
|
||||
// DialContext is used to make the outgoing TCP connection.
|
||||
DialContext func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
// ReachableIPs enumerates the IP addresses this handler is reachable on.
|
||||
ReachableIPs []netip.Addr
|
||||
}
|
||||
|
||||
// ReachableOn returns the IP addresses this handler is reachable on.
|
||||
func (h *tcpRoundRobinHandler) ReachableOn() []netip.Addr {
|
||||
return h.ReachableIPs
|
||||
}
|
||||
|
||||
func (h *tcpRoundRobinHandler) Handle(c net.Conn) {
|
||||
addrPortStr := c.LocalAddr().String()
|
||||
_, port, err := net.SplitHostPort(addrPortStr)
|
||||
if err != nil {
|
||||
log.Printf("tcpRoundRobinHandler.Handle: bogus addrPort %q", addrPortStr)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var p tcpproxy.Proxy
|
||||
p.ListenFunc = func(net, laddr string) (net.Listener, error) {
|
||||
return netutil.NewOneConnListener(c, nil), nil
|
||||
}
|
||||
|
||||
dest := h.To[rand.Intn(len(h.To))]
|
||||
dial := &tcpproxy.DialProxy{
|
||||
Addr: fmt.Sprintf("%s:%s", dest, port),
|
||||
DialContext: h.DialContext,
|
||||
}
|
||||
|
||||
p.AddRoute(addrPortStr, dial)
|
||||
p.Start()
|
||||
}
|
||||
|
||||
type tcpSNIHandler struct {
|
||||
// Allowlist enumerates the FQDNs which may be proxied via SNI. An
|
||||
// empty slice means all domains are permitted.
|
||||
Allowlist []string
|
||||
|
||||
// DialContext is used to make the outgoing TCP connection.
|
||||
DialContext func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
// ReachableIPs enumerates the IP addresses this handler is reachable on.
|
||||
ReachableIPs []netip.Addr
|
||||
}
|
||||
|
||||
// ReachableOn returns the IP addresses this handler is reachable on.
|
||||
func (h *tcpSNIHandler) ReachableOn() []netip.Addr {
|
||||
return h.ReachableIPs
|
||||
}
|
||||
|
||||
func (h *tcpSNIHandler) Handle(c net.Conn) {
|
||||
addrPortStr := c.LocalAddr().String()
|
||||
_, port, err := net.SplitHostPort(addrPortStr)
|
||||
if err != nil {
|
||||
log.Printf("tcpSNIHandler.Handle: bogus addrPort %q", addrPortStr)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var p tcpproxy.Proxy
|
||||
p.ListenFunc = func(net, laddr string) (net.Listener, error) {
|
||||
return netutil.NewOneConnListener(c, nil), nil
|
||||
}
|
||||
p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) {
|
||||
if len(h.Allowlist) > 0 {
|
||||
// TODO(tom): handle subdomains
|
||||
if slices.Index(h.Allowlist, sniName) < 0 {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return &tcpproxy.DialProxy{
|
||||
Addr: net.JoinHostPort(sniName, port),
|
||||
DialContext: h.DialContext,
|
||||
}, true
|
||||
})
|
||||
p.Start()
|
||||
}
|
159
cmd/sniproxy/handlers_test.go
Normal file
159
cmd/sniproxy/handlers_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/net/memnet"
|
||||
)
|
||||
|
||||
func echoConnOnce(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
b := make([]byte, 256)
|
||||
n, err := conn.Read(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := conn.Write(b[:n]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPRoundRobinHandler(t *testing.T) {
|
||||
h := tcpRoundRobinHandler{
|
||||
To: []string{"yeet.com"},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if network != "tcp" {
|
||||
t.Errorf("network = %s, want %s", network, "tcp")
|
||||
}
|
||||
if addr != "yeet.com:22" {
|
||||
t.Errorf("addr = %s, want %s", addr, "yeet.com:22")
|
||||
}
|
||||
|
||||
c, s := memnet.NewConn("outbound", 1024)
|
||||
go echoConnOnce(s)
|
||||
return c, nil
|
||||
},
|
||||
}
|
||||
|
||||
cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:22"), 1024)
|
||||
h.Handle(sSock)
|
||||
|
||||
// Test data write and read, the other end will echo back
|
||||
// a single stanza
|
||||
want := "hello"
|
||||
if _, err := io.WriteString(cSock, want); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := make([]byte, len(want))
|
||||
if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(got) != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
|
||||
// The other end closed the socket after the first echo, so
|
||||
// any following read should error.
|
||||
io.WriteString(cSock, "deadass heres some data on god fr")
|
||||
if _, err := io.ReadAtLeast(cSock, got, len(got)); err == nil {
|
||||
t.Error("read succeeded on closed socket")
|
||||
}
|
||||
}
|
||||
|
||||
// Capture of first TCP data segment for a connection to https://pkgs.tailscale.com
|
||||
const tlsStart = `45000239ff1840004006f9f5c0a801f2
|
||||
c726b5efcf9e01bbe803b21394e3b752
|
||||
801801f641dc00000101080ade3474f2
|
||||
2fb93ee71603010200010001fc030303
|
||||
c3acbd19d2624765bb19af4bce03365e
|
||||
1d197f5bb939cdadeff26b0f8e7a0620
|
||||
295b04127b82bae46aac4ff58cffef25
|
||||
eba75a4b7a6de729532c411bd9dd0d2c
|
||||
00203a3a130113021303c02bc02fc02c
|
||||
c030cca9cca8c013c014009c009d002f
|
||||
003501000193caca0000000a000a0008
|
||||
1a1a001d001700180010000e000c0268
|
||||
3208687474702f312e31002b0007062a
|
||||
2a03040303ff01000100000d00120010
|
||||
04030804040105030805050108060601
|
||||
000b00020100002300000033002b0029
|
||||
1a1a000100001d0020d3c76bef062979
|
||||
a812ce935cfb4dbe6b3a84dc5ba9226f
|
||||
23b0f34af9d1d03b4a001b0003020002
|
||||
00120000446900050003026832000000
|
||||
170015000012706b67732e7461696c73
|
||||
63616c652e636f6d002d000201010005
|
||||
00050100000000001700003a3a000100
|
||||
0015002d000000000000000000000000
|
||||
00000000000000000000000000000000
|
||||
00000000000000000000000000000000
|
||||
0000290094006f0069e76f2016f963ad
|
||||
38c8632d1f240cd75e00e25fdef295d4
|
||||
7042b26f3a9a543b1c7dc74939d77803
|
||||
20527d423ff996997bda2c6383a14f49
|
||||
219eeef8a053e90a32228df37ddbe126
|
||||
eccf6b085c93890d08341d819aea6111
|
||||
0d909f4cd6b071d9ea40618e74588a33
|
||||
90d494bbb5c3002120d5a164a16c9724
|
||||
c9ef5e540d8d6f007789a7acf9f5f16f
|
||||
bf6a1907a6782ed02b`
|
||||
|
||||
func fakeSNIHeader() []byte {
|
||||
b, err := hex.DecodeString(strings.Replace(tlsStart, "\n", "", -1))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b[0x34:] // trim IP + TCP header
|
||||
}
|
||||
|
||||
func TestTCPSNIHandler(t *testing.T) {
|
||||
h := tcpSNIHandler{
|
||||
Allowlist: []string{"pkgs.tailscale.com"},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if network != "tcp" {
|
||||
t.Errorf("network = %s, want %s", network, "tcp")
|
||||
}
|
||||
if addr != "pkgs.tailscale.com:443" {
|
||||
t.Errorf("addr = %s, want %s", addr, "pkgs.tailscale.com:443")
|
||||
}
|
||||
|
||||
c, s := memnet.NewConn("outbound", 1024)
|
||||
go echoConnOnce(s)
|
||||
return c, nil
|
||||
},
|
||||
}
|
||||
|
||||
cSock, sSock := memnet.NewTCPConn(netip.MustParseAddrPort("10.64.1.2:22"), netip.MustParseAddrPort("10.64.1.2:443"), 1024)
|
||||
h.Handle(sSock)
|
||||
|
||||
// Fake a TLS handshake record with an SNI in it.
|
||||
if _, err := cSock.Write(fakeSNIHeader()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Test read, the other end will echo back
|
||||
// a single stanza, which is at least the beginning of the SNI header.
|
||||
want := fakeSNIHeader()[:5]
|
||||
if _, err := cSock.Write(want); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := make([]byte, len(want))
|
||||
if _, err := io.ReadAtLeast(cSock, got, len(got)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
327
cmd/sniproxy/server.go
Normal file
327
cmd/sniproxy/server.go
Normal file
@@ -0,0 +1,327 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"expvar"
|
||||
"log"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"tailscale.com/metrics"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/appctype"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/nettype"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/mak"
|
||||
)
|
||||
|
||||
var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
|
||||
|
||||
// target describes the predicates which route some inbound
|
||||
// traffic to the app connector to a specific handler.
|
||||
type target struct {
|
||||
Dest netip.Prefix
|
||||
Matching tailcfg.ProtoPortRange
|
||||
}
|
||||
|
||||
// Server implements an App Connector as expressed in sniproxy.
|
||||
type Server struct {
|
||||
mu sync.RWMutex // mu guards following fields
|
||||
connectors map[appctype.ConfigID]connector
|
||||
}
|
||||
|
||||
type appcMetrics struct {
|
||||
dnsResponses expvar.Int
|
||||
dnsFailures expvar.Int
|
||||
tcpConns expvar.Int
|
||||
sniConns expvar.Int
|
||||
unhandledConns expvar.Int
|
||||
}
|
||||
|
||||
var getMetrics = sync.OnceValue[*appcMetrics](func() *appcMetrics {
|
||||
m := appcMetrics{}
|
||||
|
||||
stats := new(metrics.Set)
|
||||
stats.Set("tls_sessions", &m.sniConns)
|
||||
clientmetric.NewCounterFunc("sniproxy_tls_sessions", m.sniConns.Value)
|
||||
stats.Set("tcp_sessions", &m.tcpConns)
|
||||
clientmetric.NewCounterFunc("sniproxy_tcp_sessions", m.tcpConns.Value)
|
||||
stats.Set("dns_responses", &m.dnsResponses)
|
||||
clientmetric.NewCounterFunc("sniproxy_dns_responses", m.dnsResponses.Value)
|
||||
stats.Set("dns_failed", &m.dnsFailures)
|
||||
clientmetric.NewCounterFunc("sniproxy_dns_failed", m.dnsFailures.Value)
|
||||
expvar.Publish("sniproxy", stats)
|
||||
|
||||
return &m
|
||||
})
|
||||
|
||||
// Configure applies the provided configuration to the app connector.
|
||||
func (s *Server) Configure(cfg *appctype.AppConnectorConfig) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.connectors = makeConnectorsFromConfig(cfg)
|
||||
log.Printf("installed app connector config: %+v", s.connectors)
|
||||
}
|
||||
|
||||
// HandleTCPFlow implements tsnet.FallbackTCPHandler.
|
||||
func (s *Server) HandleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
|
||||
m := getMetrics()
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
for _, c := range s.connectors {
|
||||
if handler, intercept := c.handleTCPFlow(src, dst, m); intercept {
|
||||
return handler, intercept
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// HandleDNS handles a DNS request to the app connector.
|
||||
func (s *Server) HandleDNS(c nettype.ConnPacketConn) {
|
||||
defer c.Close()
|
||||
c.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
m := getMetrics()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
log.Printf("HandleDNS: read failed: %v\n ", err)
|
||||
m.dnsFailures.Add(1)
|
||||
return
|
||||
}
|
||||
|
||||
addrPortStr := c.LocalAddr().String()
|
||||
host, _, err := net.SplitHostPort(addrPortStr)
|
||||
if err != nil {
|
||||
log.Printf("HandleDNS: bogus addrPort %q", addrPortStr)
|
||||
m.dnsFailures.Add(1)
|
||||
return
|
||||
}
|
||||
localAddr, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
log.Printf("HandleDNS: bogus local address %q", host)
|
||||
m.dnsFailures.Add(1)
|
||||
return
|
||||
}
|
||||
|
||||
var msg dnsmessage.Message
|
||||
err = msg.Unpack(buf[:n])
|
||||
if err != nil {
|
||||
log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err)
|
||||
m.dnsFailures.Add(1)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, connector := range s.connectors {
|
||||
resp, err := connector.handleDNS(&msg, localAddr)
|
||||
if err != nil {
|
||||
log.Printf("HandleDNS: connector handling failed: %v\n", err)
|
||||
m.dnsFailures.Add(1)
|
||||
return
|
||||
}
|
||||
if len(resp) > 0 {
|
||||
// This connector handled the DNS request
|
||||
_, err = c.Write(resp)
|
||||
if err != nil {
|
||||
log.Printf("HandleDNS: write failed: %v\n", err)
|
||||
m.dnsFailures.Add(1)
|
||||
return
|
||||
}
|
||||
|
||||
m.dnsResponses.Add(1)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// connector describes a logical collection of
|
||||
// services which need to be proxied.
|
||||
type connector struct {
|
||||
Handlers map[target]handler
|
||||
}
|
||||
|
||||
// handleTCPFlow implements tsnet.FallbackTCPHandler.
|
||||
func (c *connector) handleTCPFlow(src, dst netip.AddrPort, m *appcMetrics) (handler func(net.Conn), intercept bool) {
|
||||
for t, h := range c.Handlers {
|
||||
if t.Matching.Proto != 0 && t.Matching.Proto != int(ipproto.TCP) {
|
||||
continue
|
||||
}
|
||||
if !t.Dest.Contains(dst.Addr()) {
|
||||
continue
|
||||
}
|
||||
if !t.Matching.Ports.Contains(dst.Port()) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch h.(type) {
|
||||
case *tcpSNIHandler:
|
||||
m.sniConns.Add(1)
|
||||
case *tcpRoundRobinHandler:
|
||||
m.tcpConns.Add(1)
|
||||
default:
|
||||
log.Printf("handleTCPFlow: unhandled handler type %T", h)
|
||||
}
|
||||
|
||||
return h.Handle, true
|
||||
}
|
||||
|
||||
m.unhandledConns.Add(1)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// handleDNS returns the DNS response to the given query. If this
|
||||
// connector is unable to handle the request, nil is returned.
|
||||
func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (response []byte, err error) {
|
||||
for t, h := range c.Handlers {
|
||||
if t.Dest.Contains(localAddr) {
|
||||
return makeDNSResponse(req, h.ReachableOn())
|
||||
}
|
||||
}
|
||||
|
||||
// Did not match, signal 'not handled' to caller
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) {
|
||||
resp := dnsmessage.NewBuilder(response,
|
||||
dnsmessage.Header{
|
||||
ID: req.Header.ID,
|
||||
Response: true,
|
||||
Authoritative: true,
|
||||
})
|
||||
resp.EnableCompression()
|
||||
|
||||
if len(req.Questions) == 0 {
|
||||
response, _ = resp.Finish()
|
||||
return response, nil
|
||||
}
|
||||
q := req.Questions[0]
|
||||
err = resp.StartQuestions()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp.Question(q)
|
||||
|
||||
err = resp.StartAnswers()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch q.Type {
|
||||
case dnsmessage.TypeAAAA:
|
||||
for _, ip := range reachableIPs {
|
||||
if ip.Is6() {
|
||||
err = resp.AAAAResource(
|
||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||
dnsmessage.AAAAResource{AAAA: ip.As16()},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
case dnsmessage.TypeA:
|
||||
for _, ip := range reachableIPs {
|
||||
if ip.Is4() {
|
||||
err = resp.AResource(
|
||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||
dnsmessage.AResource{A: ip.As4()},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
case dnsmessage.TypeSOA:
|
||||
err = resp.SOAResource(
|
||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||
dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
|
||||
Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
|
||||
)
|
||||
case dnsmessage.TypeNS:
|
||||
err = resp.NSResource(
|
||||
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
|
||||
dnsmessage.NSResource{NS: tsMBox},
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Finish()
|
||||
}
|
||||
|
||||
type handler interface {
|
||||
// Handle handles the given socket.
|
||||
Handle(c net.Conn)
|
||||
|
||||
// ReachableOn returns the IP addresses this handler is reachable on.
|
||||
ReachableOn() []netip.Addr
|
||||
}
|
||||
|
||||
func installDNATHandler(d *appctype.DNATConfig, out *connector) {
|
||||
// These handlers don't actually do DNAT, they just
|
||||
// proxy the data over the connection.
|
||||
var dialer net.Dialer
|
||||
dialer.Timeout = 5 * time.Second
|
||||
h := tcpRoundRobinHandler{
|
||||
To: d.To,
|
||||
DialContext: dialer.DialContext,
|
||||
ReachableIPs: d.Addrs,
|
||||
}
|
||||
|
||||
for _, addr := range d.Addrs {
|
||||
for _, protoPort := range d.IP {
|
||||
t := target{
|
||||
Dest: netip.PrefixFrom(addr, addr.BitLen()),
|
||||
Matching: protoPort,
|
||||
}
|
||||
|
||||
mak.Set(&out.Handlers, t, handler(&h))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func installSNIHandler(c *appctype.SNIProxyConfig, out *connector) {
|
||||
var dialer net.Dialer
|
||||
dialer.Timeout = 5 * time.Second
|
||||
h := tcpSNIHandler{
|
||||
Allowlist: c.AllowedDomains,
|
||||
DialContext: dialer.DialContext,
|
||||
ReachableIPs: c.Addrs,
|
||||
}
|
||||
|
||||
for _, addr := range c.Addrs {
|
||||
for _, protoPort := range c.IP {
|
||||
t := target{
|
||||
Dest: netip.PrefixFrom(addr, addr.BitLen()),
|
||||
Matching: protoPort,
|
||||
}
|
||||
|
||||
mak.Set(&out.Handlers, t, handler(&h))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeConnectorsFromConfig(cfg *appctype.AppConnectorConfig) map[appctype.ConfigID]connector {
|
||||
var connectors map[appctype.ConfigID]connector
|
||||
|
||||
for cID, d := range cfg.DNAT {
|
||||
c := connectors[cID]
|
||||
installDNATHandler(&d, &c)
|
||||
mak.Set(&connectors, cID, c)
|
||||
}
|
||||
for cID, d := range cfg.SNIProxy {
|
||||
c := connectors[cID]
|
||||
installSNIHandler(&d, &c)
|
||||
mak.Set(&connectors, cID, c)
|
||||
}
|
||||
|
||||
return connectors
|
||||
}
|
95
cmd/sniproxy/server_test.go
Normal file
95
cmd/sniproxy/server_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/appctype"
|
||||
)
|
||||
|
||||
func TestMakeConnectorsFromConfig(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
input *appctype.AppConnectorConfig
|
||||
want map[appctype.ConfigID]connector
|
||||
}{
|
||||
{
|
||||
"empty",
|
||||
&appctype.AppConnectorConfig{},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"DNAT",
|
||||
&appctype.AppConnectorConfig{
|
||||
DNAT: map[appctype.ConfigID]appctype.DNATConfig{
|
||||
"swiggity_swooty": {
|
||||
Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")},
|
||||
To: []string{"example.org"},
|
||||
IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[appctype.ConfigID]connector{
|
||||
"swiggity_swooty": {
|
||||
Handlers: map[target]handler{
|
||||
{
|
||||
Dest: netip.MustParsePrefix("100.64.0.1/32"),
|
||||
Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
|
||||
}: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
|
||||
{
|
||||
Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||
Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
|
||||
}: &tcpRoundRobinHandler{To: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"SNIProxy",
|
||||
&appctype.AppConnectorConfig{
|
||||
SNIProxy: map[appctype.ConfigID]appctype.SNIProxyConfig{
|
||||
"swiggity_swooty": {
|
||||
Addrs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")},
|
||||
AllowedDomains: []string{"example.org"},
|
||||
IP: []tailcfg.ProtoPortRange{{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}}},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[appctype.ConfigID]connector{
|
||||
"swiggity_swooty": {
|
||||
Handlers: map[target]handler{
|
||||
{
|
||||
Dest: netip.MustParsePrefix("100.64.0.1/32"),
|
||||
Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
|
||||
}: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
|
||||
{
|
||||
Dest: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||
Matching: tailcfg.ProtoPortRange{Proto: 0, Ports: tailcfg.PortRange{First: 0, Last: 65535}},
|
||||
}: &tcpSNIHandler{Allowlist: []string{"example.org"}, ReachableIPs: []netip.Addr{netip.MustParseAddr("100.64.0.1"), netip.MustParseAddr("fd7a:115c:a1e0::1")}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
connectors := makeConnectorsFromConfig(tc.input)
|
||||
|
||||
if diff := cmp.Diff(connectors, tc.want,
|
||||
cmpopts.IgnoreFields(tcpRoundRobinHandler{}, "DialContext"),
|
||||
cmpopts.IgnoreFields(tcpSNIHandler{}, "DialContext"),
|
||||
cmp.Comparer(func(x, y netip.Addr) bool {
|
||||
return x == y
|
||||
})); diff != "" {
|
||||
t.Fatalf("mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -22,8 +22,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/peterbourgon/ff/v3"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"tailscale.com/appc"
|
||||
"tailscale.com/client/tailscale"
|
||||
"tailscale.com/hostinfo"
|
||||
"tailscale.com/ipn"
|
||||
@@ -38,8 +36,6 @@ import (
|
||||
|
||||
const configCapKey = "tailscale.com/sniproxy"
|
||||
|
||||
var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
|
||||
|
||||
// portForward is the state for a single port forwarding entry, as passed to the --forward flag.
|
||||
type portForward struct {
|
||||
Port int
|
||||
@@ -99,7 +95,7 @@ func main() {
|
||||
func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) {
|
||||
// Wire up Tailscale node + app connector server
|
||||
hostinfo.SetApp("sniproxy")
|
||||
var s server
|
||||
var s sniproxy
|
||||
s.ts = ts
|
||||
|
||||
s.ts.Port = uint16(wgPort)
|
||||
@@ -110,7 +106,7 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro
|
||||
log.Fatalf("LocalClient() failed: %v", err)
|
||||
}
|
||||
s.lc = lc
|
||||
s.ts.RegisterFallbackTCPHandler(s.appc.HandleTCPFlow)
|
||||
s.ts.RegisterFallbackTCPHandler(s.srv.HandleTCPFlow)
|
||||
|
||||
// Start special-purpose listeners: dns, http promotion, debug server
|
||||
ln, err := s.ts.Listen("udp", ":53")
|
||||
@@ -181,18 +177,18 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro
|
||||
// on the command line. This is intentionally done after we advertise any routes
|
||||
// because its never correct to advertise the nodes native IP addresses.
|
||||
s.mergeConfigFromFlags(&c, ports, forwards)
|
||||
s.appc.Configure(&c)
|
||||
s.srv.Configure(&c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type server struct {
|
||||
appc appc.Server
|
||||
ts *tsnet.Server
|
||||
lc *tailscale.LocalClient
|
||||
type sniproxy struct {
|
||||
srv Server
|
||||
ts *tsnet.Server
|
||||
lc *tailscale.LocalClient
|
||||
}
|
||||
|
||||
func (s *server) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error {
|
||||
func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error {
|
||||
// Collect the set of addresses to advertise, using a map
|
||||
// to avoid duplicate entries.
|
||||
addrs := map[netip.Addr]struct{}{}
|
||||
@@ -224,7 +220,7 @@ func (s *server) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppC
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *server) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) {
|
||||
func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) {
|
||||
ip4, ip6 := s.ts.TailscaleIPs()
|
||||
|
||||
sniConfigFromFlags := appctype.SNIProxyConfig{
|
||||
@@ -276,18 +272,18 @@ func (s *server) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, f
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) serveDNS(ln net.Listener) {
|
||||
func (s *sniproxy) serveDNS(ln net.Listener) {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
log.Printf("serveDNS accept: %v", err)
|
||||
return
|
||||
}
|
||||
go s.appc.HandleDNS(c.(nettype.ConnPacketConn))
|
||||
go s.srv.HandleDNS(c.(nettype.ConnPacketConn))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) promoteHTTPS(ln net.Listener) {
|
||||
func (s *sniproxy) promoteHTTPS(ln net.Listener) {
|
||||
err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound)
|
||||
}))
|
||||
|
Reference in New Issue
Block a user