cmd/sniproxy: implement support for control configuration, multiple addresses

* Implement missing tests for sniproxy
 * Wire sniproxy to new appc package
 * Add support to tsnet for routing subnet router traffic into netstack, so it can be handled

Updates: https://github.com/tailscale/corp/issues/15038
Signed-off-by: Tom DNetto <tom@tailscale.com>
This commit is contained in:
Tom DNetto 2023-10-19 17:07:07 -07:00 committed by Tom
parent 0d86eb9da5
commit a7c80c332a
4 changed files with 378 additions and 295 deletions

View File

@ -67,6 +67,7 @@ func (s *Server) Configure(cfg *appctype.AppConnectorConfig) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.connectors = makeConnectorsFromConfig(cfg) s.connectors = makeConnectorsFromConfig(cfg)
log.Printf("installed app connector config: %+v", s.connectors)
} }
// HandleTCPFlow implements tsnet.FallbackTCPHandler. // HandleTCPFlow implements tsnet.FallbackTCPHandler.
@ -193,8 +194,7 @@ func (c *connector) handleDNS(req *dnsmessage.Message, localAddr netip.Addr) (re
} }
func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) { func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (response []byte, err error) {
buf := make([]byte, 1500) resp := dnsmessage.NewBuilder(response,
resp := dnsmessage.NewBuilder(buf,
dnsmessage.Header{ dnsmessage.Header{
ID: req.Header.ID, ID: req.Header.ID,
Response: true, Response: true,
@ -203,8 +203,8 @@ func makeDNSResponse(req *dnsmessage.Message, reachableIPs []netip.Addr) (respon
resp.EnableCompression() resp.EnableCompression()
if len(req.Questions) == 0 { if len(req.Questions) == 0 {
buf, _ = resp.Finish() response, _ = resp.Finish()
return buf, nil return response, nil
} }
q := req.Questions[0] q := req.Questions[0]
err = resp.StartQuestions() err = resp.StartQuestions()

View File

@ -10,30 +10,34 @@
import ( import (
"context" "context"
"errors" "errors"
"expvar"
"flag" "flag"
"fmt" "fmt"
"log" "log"
"net" "net"
"net/http" "net/http"
"net/netip"
"os" "os"
"sort"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/peterbourgon/ff/v3" "github.com/peterbourgon/ff/v3"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
"inet.af/tcpproxy" "tailscale.com/appc"
"tailscale.com/client/tailscale" "tailscale.com/client/tailscale"
"tailscale.com/hostinfo" "tailscale.com/hostinfo"
"tailscale.com/metrics" "tailscale.com/ipn"
"tailscale.com/net/netutil" "tailscale.com/tailcfg"
"tailscale.com/tsnet" "tailscale.com/tsnet"
"tailscale.com/tsweb" "tailscale.com/tsweb"
"tailscale.com/types/appctype"
"tailscale.com/types/ipproto"
"tailscale.com/types/nettype" "tailscale.com/types/nettype"
"tailscale.com/util/clientmetric" "tailscale.com/util/mak"
) )
const configCapKey = "tailscale.com/sniproxy"
var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
// portForward is the state for a single port forwarding entry, as passed to the --forward flag. // portForward is the state for a single port forwarding entry, as passed to the --forward flag.
@ -68,6 +72,7 @@ func parseForward(value string) (*portForward, error) {
} }
func main() { func main() {
// Parse flags
fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError) fs := flag.NewFlagSet("sniproxy", flag.ContinueOnError)
var ( var (
ports = fs.String("ports", "443", "comma-separated list of ports to proxy") ports = fs.String("ports", "443", "comma-separated list of ports to proxy")
@ -77,124 +82,197 @@ func main() {
debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint") debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint")
hostname = fs.String("hostname", "", "Hostname to register the service under") hostname = fs.String("hostname", "", "Hostname to register the service under")
) )
err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC")) err := ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_APPC"))
if err != nil { if err != nil {
log.Fatal("ff.Parse") log.Fatal("ff.Parse")
} }
if *ports == "" {
log.Fatal("no ports")
}
var ts tsnet.Server
defer ts.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
run(ctx, &ts, *wgPort, *hostname, *promoteHTTPS, *debugPort, *ports, *forwards)
}
// run actually runs the sniproxy. Its separate from main() to assist in testing.
func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, promoteHTTPS bool, debugPort int, ports, forwards string) {
// Wire up Tailscale node + app connector server
hostinfo.SetApp("sniproxy") hostinfo.SetApp("sniproxy")
var s server var s server
s.ts.Port = uint16(*wgPort) s.ts = ts
s.ts.Hostname = *hostname
defer s.ts.Close() s.ts.Port = uint16(wgPort)
s.ts.Hostname = hostname
lc, err := s.ts.LocalClient() lc, err := s.ts.LocalClient()
if err != nil { if err != nil {
log.Fatal(err) log.Fatalf("LocalClient() failed: %v", err)
} }
s.lc = lc s.lc = lc
s.initMetrics() s.ts.RegisterFallbackTCPHandler(s.appc.HandleTCPFlow)
for _, portStr := range strings.Split(*ports, ",") { // Start special-purpose listeners: dns, http promotion, debug server
ln, err := s.ts.Listen("tcp", ":"+portStr) ln, err := s.ts.Listen("udp", ":53")
if err != nil {
log.Fatalf("failed listening on port 53: %v", err)
}
defer ln.Close()
go s.serveDNS(ln)
if promoteHTTPS {
ln, err := s.ts.Listen("tcp", ":80")
if err != nil { if err != nil {
log.Fatal(err) log.Fatalf("failed listening on port 80: %v", err)
} }
log.Printf("Serving on port %v ...", portStr) defer ln.Close()
go s.serve(ln) log.Printf("Promoting HTTP to HTTPS ...")
go s.promoteHTTPS(ln)
}
if debugPort != 0 {
mux := http.NewServeMux()
tsweb.Debugger(mux)
dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", debugPort))
if err != nil {
log.Fatalf("failed listening on debug port: %v", err)
}
defer dln.Close()
go func() {
log.Fatalf("debug serve: %v", http.Serve(dln, mux))
}()
} }
for _, forwStr := range strings.Split(*forwards, ",") { // Finally, start mainloop to configure app connector based on information
// in the netmap.
// We set the NotifyInitialNetMap flag so we will always get woken with the
// current netmap, before only being woken on changes.
bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys)
if err != nil {
log.Fatalf("watching IPN bus: %v", err)
}
defer bus.Close()
for {
msg, err := bus.Next()
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
log.Fatalf("reading IPN bus: %v", err)
}
// NetMap contains app-connector configuration
if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() {
sn := nm.SelfNode.AsStruct()
var c appctype.AppConnectorConfig
nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey)
if err != nil {
log.Printf("failed to read app connector configuration from coordination server: %v", err)
} else if len(nmConf) > 0 {
c = nmConf[0]
}
if c.AdvertiseRoutes {
if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil {
log.Printf("failed to advertise routes: %v", err)
}
}
// Backwards compatibility: combine any configuration from control with flags specified
// on the command line. This is intentionally done after we advertise any routes
// because its never correct to advertise the nodes native IP addresses.
s.mergeConfigFromFlags(&c, ports, forwards)
s.appc.Configure(&c)
}
}
}
type server struct {
appc appc.Server
ts *tsnet.Server
lc *tailscale.LocalClient
}
func (s *server) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error {
// Collect the set of addresses to advertise, using a map
// to avoid duplicate entries.
addrs := map[netip.Addr]struct{}{}
for _, c := range c.SNIProxy {
for _, ip := range c.Addrs {
addrs[ip] = struct{}{}
}
}
for _, c := range c.DNAT {
for _, ip := range c.Addrs {
addrs[ip] = struct{}{}
}
}
var routes []netip.Prefix
for a := range addrs {
routes = append(routes, netip.PrefixFrom(a, a.BitLen()))
}
sort.SliceStable(routes, func(i, j int) bool {
return routes[i].Addr().Less(routes[j].Addr()) // determinism r us
})
_, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{
Prefs: ipn.Prefs{
AdvertiseRoutes: routes,
},
AdvertiseRoutesSet: true,
})
return err
}
func (s *server) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, forwards string) {
ip4, ip6 := s.ts.TailscaleIPs()
sniConfigFromFlags := appctype.SNIProxyConfig{
Addrs: []netip.Addr{ip4, ip6},
}
if ports != "" {
for _, portStr := range strings.Split(ports, ",") {
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
log.Fatalf("invalid port: %s", portStr)
}
sniConfigFromFlags.IP = append(sniConfigFromFlags.IP, tailcfg.ProtoPortRange{
Proto: int(ipproto.TCP),
Ports: tailcfg.PortRange{First: uint16(port), Last: uint16(port)},
})
}
}
var forwardConfigFromFlags []appctype.DNATConfig
for _, forwStr := range strings.Split(forwards, ",") {
if forwStr == "" { if forwStr == "" {
continue continue
} }
forw, err := parseForward(forwStr) forw, err := parseForward(forwStr)
if err != nil { if err != nil {
log.Fatal(err) log.Printf("invalid forwarding spec: %v", err)
continue
} }
ln, err := s.ts.Listen("tcp", ":"+strconv.Itoa(forw.Port)) forwardConfigFromFlags = append(forwardConfigFromFlags, appctype.DNATConfig{
if err != nil { Addrs: []netip.Addr{ip4, ip6},
log.Fatal(err) To: []string{forw.Destination},
} IP: []tailcfg.ProtoPortRange{
log.Printf("Serving on port %d to %s...", forw.Port, forw.Destination) {
Proto: int(ipproto.TCP),
// Add an entry to the expvar LabelMap for Prometheus metrics, Ports: tailcfg.PortRange{First: uint16(forw.Port), Last: uint16(forw.Port)},
// and create a clientmetric to report that same value. },
service := portNumberToName(forw) },
s.numTCPsessions.SetInt64(service, 0)
metric := fmt.Sprintf("sniproxy_tcp_sessions_%s", service)
clientmetric.NewCounterFunc(metric, func() int64 {
return s.numTCPsessions.Get(service).Value()
}) })
go s.forward(ln, forw)
} }
ln, err := s.ts.Listen("udp", ":53") if len(forwardConfigFromFlags) == 0 && len(sniConfigFromFlags.IP) == 0 {
if err != nil { return // no config specified on the command line
log.Fatal(err)
}
go s.serveDNS(ln)
if *promoteHTTPS {
ln, err := s.ts.Listen("tcp", ":80")
if err != nil {
log.Fatal(err)
}
log.Printf("Promoting HTTP to HTTPS ...")
go s.promoteHTTPS(ln)
} }
if *debugPort != 0 { mak.Set(&out.SNIProxy, "flags", sniConfigFromFlags)
mux := http.NewServeMux() for i, forward := range forwardConfigFromFlags {
tsweb.Debugger(mux) mak.Set(&out.DNAT, appctype.ConfigID(fmt.Sprintf("flags_%d", i)), forward)
dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", *debugPort))
if err != nil {
log.Fatal(err)
}
go func() {
log.Fatal(http.Serve(dln, mux))
}()
}
select {}
}
type server struct {
ts tsnet.Server
lc *tailscale.LocalClient
numTLSsessions expvar.Int
numTCPsessions *metrics.LabelMap
numBadAddrPort expvar.Int
dnsResponses expvar.Int
dnsFailures expvar.Int
httpPromoted expvar.Int
}
func (s *server) serve(ln net.Listener) {
for {
c, err := ln.Accept()
if err != nil {
log.Fatal(err)
}
go s.serveConn(c)
}
}
func (s *server) forward(ln net.Listener, forw *portForward) {
for {
c, err := ln.Accept()
if err != nil {
log.Fatal(err)
}
go s.forwardConn(c, forw)
} }
} }
@ -202,209 +280,16 @@ func (s *server) serveDNS(ln net.Listener) {
for { for {
c, err := ln.Accept() c, err := ln.Accept()
if err != nil { if err != nil {
log.Fatal(err) log.Printf("serveDNS accept: %v", err)
return
} }
go s.serveDNSConn(c.(nettype.ConnPacketConn)) go s.appc.HandleDNS(c.(nettype.ConnPacketConn))
} }
} }
func (s *server) serveDNSConn(c nettype.ConnPacketConn) {
defer c.Close()
c.SetReadDeadline(time.Now().Add(5 * time.Second))
buf := make([]byte, 1500)
n, err := c.Read(buf)
if err != nil {
log.Printf("c.Read failed: %v\n ", err)
s.dnsFailures.Add(1)
return
}
var msg dnsmessage.Message
err = msg.Unpack(buf[:n])
if err != nil {
log.Printf("dnsmessage unpack failed: %v\n ", err)
s.dnsFailures.Add(1)
return
}
buf, err = s.dnsResponse(&msg)
if err != nil {
log.Printf("s.dnsResponse failed: %v\n", err)
s.dnsFailures.Add(1)
return
}
_, err = c.Write(buf)
if err != nil {
log.Printf("c.Write failed: %v\n", err)
s.dnsFailures.Add(1)
return
}
s.dnsResponses.Add(1)
}
func (s *server) serveConn(c net.Conn) {
addrPortStr := c.LocalAddr().String()
_, port, err := net.SplitHostPort(addrPortStr)
if err != nil {
log.Printf("bogus addrPort %q", addrPortStr)
s.numBadAddrPort.Add(1)
c.Close()
return
}
var dialer net.Dialer
dialer.Timeout = 5 * time.Second
var p tcpproxy.Proxy
p.ListenFunc = func(net, laddr string) (net.Listener, error) {
return netutil.NewOneConnListener(c, nil), nil
}
p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) {
s.numTLSsessions.Add(1)
return &tcpproxy.DialProxy{
Addr: net.JoinHostPort(sniName, port),
DialContext: dialer.DialContext,
}, true
})
p.Start()
}
// portNumberToName returns a human-readable name for several port numbers commonly forwarded,
// and "tcp###" for everything else. It is used for metric label names.
func portNumberToName(forw *portForward) string {
switch forw.Port {
case 22:
return "ssh"
case 1433:
return "sqlserver"
case 3306:
return "mysql"
case 3389:
return "rdp"
case 5432:
return "postgres"
default:
return fmt.Sprintf("%s%d", forw.Proto, forw.Port)
}
}
// forwardConn sets up a forwarder for a TCP connection. It does not inspect of the data
// like the SNI forwarding does, it merely forwards all data to the destination specified
// in the --forward=tcp/22/github.com argument.
func (s *server) forwardConn(c net.Conn, forw *portForward) {
addrPortStr := c.LocalAddr().String()
var dialer net.Dialer
dialer.Timeout = 30 * time.Second
var p tcpproxy.Proxy
p.ListenFunc = func(net, laddr string) (net.Listener, error) {
return netutil.NewOneConnListener(c, nil), nil
}
dial := &tcpproxy.DialProxy{
Addr: fmt.Sprintf("%s:%d", forw.Destination, forw.Port),
DialContext: dialer.DialContext,
}
p.AddRoute(addrPortStr, dial)
s.numTCPsessions.Add(portNumberToName(forw), 1)
p.Start()
}
func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) {
resp := dnsmessage.NewBuilder(buf,
dnsmessage.Header{
ID: req.Header.ID,
Response: true,
Authoritative: true,
})
resp.EnableCompression()
if len(req.Questions) == 0 {
buf, _ = resp.Finish()
return
}
q := req.Questions[0]
err = resp.StartQuestions()
if err != nil {
return
}
resp.Question(q)
ip4, ip6 := s.ts.TailscaleIPs()
err = resp.StartAnswers()
if err != nil {
return
}
switch q.Type {
case dnsmessage.TypeAAAA:
err = resp.AAAAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.AAAAResource{AAAA: ip6.As16()},
)
case dnsmessage.TypeA:
err = resp.AResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.AResource{A: ip4.As4()},
)
case dnsmessage.TypeSOA:
err = resp.SOAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
)
case dnsmessage.TypeNS:
err = resp.NSResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.NSResource{NS: tsMBox},
)
}
if err != nil {
return
}
return resp.Finish()
}
func (s *server) promoteHTTPS(ln net.Listener) { func (s *server) promoteHTTPS(ln net.Listener) {
err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.httpPromoted.Add(1)
http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound) http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound)
})) }))
log.Fatalf("promoteHTTPS http.Serve: %v", err) log.Fatalf("promoteHTTPS http.Serve: %v", err)
} }
// initMetrics sets up local prometheus metrics, and creates clientmetrics to report those
// same counters.
func (s *server) initMetrics() {
stats := new(metrics.Set)
stats.Set("tls_sessions", &s.numTLSsessions)
clientmetric.NewCounterFunc("sniproxy_tls_sessions", s.numTLSsessions.Value)
s.numTCPsessions = &metrics.LabelMap{Label: "proto"}
stats.Set("tcp_sessions", s.numTCPsessions)
// clientmetric doesn't have a good way to implement a Map type.
// We create clientmetrics dynamically when parsing the --forwards argument
stats.Set("bad_addrport", &s.numBadAddrPort)
clientmetric.NewCounterFunc("sniproxy_bad_addrport", s.numBadAddrPort.Value)
stats.Set("dns_responses", &s.dnsResponses)
clientmetric.NewCounterFunc("sniproxy_dns_responses", s.dnsResponses.Value)
stats.Set("dns_failed", &s.dnsFailures)
clientmetric.NewCounterFunc("sniproxy_dns_failed", s.dnsFailures.Value)
stats.Set("http_promoted", &s.httpPromoted)
clientmetric.NewCounterFunc("sniproxy_http_promoted", s.httpPromoted.Value)
expvar.Publish("sniproxy", stats)
}

View File

@ -4,10 +4,30 @@
package main package main
import ( import (
"context"
"encoding/json"
"flag"
"fmt"
"net"
"net/http/httptest"
"net/netip"
"os"
"path/filepath"
"strings" "strings"
"testing" "testing"
"time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"tailscale.com/ipn/store/mem"
"tailscale.com/net/netns"
"tailscale.com/tailcfg"
"tailscale.com/tsnet"
"tailscale.com/tstest/integration"
"tailscale.com/tstest/integration/testcontrol"
"tailscale.com/types/appctype"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/logger"
) )
func TestPortForwardingArguments(t *testing.T) { func TestPortForwardingArguments(t *testing.T) {
@ -35,3 +55,169 @@ func TestPortForwardingArguments(t *testing.T) {
} }
} }
} }
var verboseDERP = flag.Bool("verbose-derp", false, "if set, print DERP and STUN logs")
var verboseNodes = flag.Bool("verbose-nodes", false, "if set, print tsnet.Server logs")
func startControl(t *testing.T) (control *testcontrol.Server, controlURL string) {
// Corp#4520: don't use netns for tests.
netns.SetEnabled(false)
t.Cleanup(func() {
netns.SetEnabled(true)
})
derpLogf := logger.Discard
if *verboseDERP {
derpLogf = t.Logf
}
derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1")
control = &testcontrol.Server{
DERPMap: derpMap,
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
MagicDNSDomain: "tail-scale.ts.net",
}
control.HTTPTestServer = httptest.NewUnstartedServer(control)
control.HTTPTestServer.Start()
t.Cleanup(control.HTTPTestServer.Close)
controlURL = control.HTTPTestServer.URL
t.Logf("testcontrol listening on %s", controlURL)
return control, controlURL
}
func startNode(t *testing.T, ctx context.Context, controlURL, hostname string) (*tsnet.Server, key.NodePublic, netip.Addr) {
t.Helper()
tmp := filepath.Join(t.TempDir(), hostname)
os.MkdirAll(tmp, 0755)
s := &tsnet.Server{
Dir: tmp,
ControlURL: controlURL,
Hostname: hostname,
Store: new(mem.Store),
Ephemeral: true,
}
if !*verboseNodes {
s.Logf = logger.Discard
}
t.Cleanup(func() { s.Close() })
status, err := s.Up(ctx)
if err != nil {
t.Fatal(err)
}
return s, status.Self.PublicKey, status.TailscaleIPs[0]
}
func TestSNIProxyWithNetmapConfig(t *testing.T) {
c, controlURL := startControl(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create a listener to proxy connections to.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
// Start sniproxy
sni, nodeKey, ip := startNode(t, ctx, controlURL, "snitest")
go run(ctx, sni, 0, sni.Hostname, false, 0, "", "")
// Configure the mock coordination server to send down app connector config.
config := &appctype.AppConnectorConfig{
DNAT: map[appctype.ConfigID]appctype.DNATConfig{
"nic_test": {
Addrs: []netip.Addr{ip},
To: []string{"127.0.0.1"},
IP: []tailcfg.ProtoPortRange{
{
Proto: int(ipproto.TCP),
Ports: tailcfg.PortRange{First: uint16(ln.Addr().(*net.TCPAddr).Port), Last: uint16(ln.Addr().(*net.TCPAddr).Port)},
},
},
},
},
}
b, err := json.Marshal(config)
if err != nil {
t.Fatal(err)
}
c.SetNodeCapMap(nodeKey, tailcfg.NodeCapMap{
configCapKey: []tailcfg.RawMessage{tailcfg.RawMessage(b)},
})
// Lets spin up a second node (to represent the client).
client, _, _ := startNode(t, ctx, controlURL, "client")
// Make sure that the sni node has received its config.
l, err := sni.LocalClient()
if err != nil {
t.Fatal(err)
}
gotConfigured := false
for i := 0; i < 100; i++ {
s, err := l.StatusWithoutPeers(ctx)
if err != nil {
t.Fatal(err)
}
if len(s.Self.CapMap) > 0 {
gotConfigured = true
break // we got it
}
time.Sleep(10 * time.Millisecond)
}
if !gotConfigured {
t.Error("sni node never received its configuration from the coordination server!")
}
// Lets make the client open a connection to the sniproxy node, and
// make sure it results in a connection to our test listener.
w, err := client.Dial(ctx, "tcp", fmt.Sprintf("%s:%d", ip, ln.Addr().(*net.TCPAddr).Port))
if err != nil {
t.Fatal(err)
}
defer w.Close()
r, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
r.Close()
}
func TestSNIProxyWithFlagConfig(t *testing.T) {
_, controlURL := startControl(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create a listener to proxy connections to.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
// Start sniproxy
sni, _, ip := startNode(t, ctx, controlURL, "snitest")
go run(ctx, sni, 0, sni.Hostname, false, 0, "", fmt.Sprintf("tcp/%d/localhost", ln.Addr().(*net.TCPAddr).Port))
// Lets spin up a second node (to represent the client).
client, _, _ := startNode(t, ctx, controlURL, "client")
// Lets make the client open a connection to the sniproxy node, and
// make sure it results in a connection to our test listener.
w, err := client.Dial(ctx, "tcp", fmt.Sprintf("%s:%d", ip, ln.Addr().(*net.TCPAddr).Port))
if err != nil {
t.Fatal(err)
}
defer w.Close()
r, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
r.Close()
}

View File

@ -75,6 +75,9 @@ type Server struct {
// masquerade address to use for that peer. // masquerade address to use for that peer.
masquerades map[key.NodePublic]map[key.NodePublic]netip.Addr // node => peer => SelfNodeV{4,6}MasqAddrForThisPeer IP masquerades map[key.NodePublic]map[key.NodePublic]netip.Addr // node => peer => SelfNodeV{4,6}MasqAddrForThisPeer IP
// nodeCapMaps overrides the capability map sent down to a client.
nodeCapMaps map[key.NodePublic]tailcfg.NodeCapMap
// suppressAutoMapResponses is the set of nodes that should not be sent // suppressAutoMapResponses is the set of nodes that should not be sent
// automatic map responses from serveMap. (They should only get manually sent ones) // automatic map responses from serveMap. (They should only get manually sent ones)
suppressAutoMapResponses set.Set[key.NodePublic] suppressAutoMapResponses set.Set[key.NodePublic]
@ -369,6 +372,14 @@ func (s *Server) SetMasqueradeAddresses(pairs []MasqueradePair) {
s.updateLocked("SetMasqueradeAddresses", s.nodeIDsLocked(0)) s.updateLocked("SetMasqueradeAddresses", s.nodeIDsLocked(0))
} }
// SetNodeCapMap overrides the capability map the specified client receives.
func (s *Server) SetNodeCapMap(nodeKey key.NodePublic, capMap tailcfg.NodeCapMap) {
s.mu.Lock()
defer s.mu.Unlock()
mak.Set(&s.nodeCapMaps, nodeKey, capMap)
s.updateLocked("SetNodeCapMap", s.nodeIDsLocked(0))
}
// nodeIDsLocked returns the node IDs of all nodes in the server, except // nodeIDsLocked returns the node IDs of all nodes in the server, except
// for the node with the given ID. // for the node with the given ID.
func (s *Server) nodeIDsLocked(except tailcfg.NodeID) []tailcfg.NodeID { func (s *Server) nodeIDsLocked(except tailcfg.NodeID) []tailcfg.NodeID {
@ -881,6 +892,7 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse,
// node key rotated away (once test server supports that) // node key rotated away (once test server supports that)
return nil, nil return nil, nil
} }
node.CapMap = s.nodeCapMaps[nk]
node.Capabilities = append(node.Capabilities, tailcfg.NodeAttrDisableUPnP) node.Capabilities = append(node.Capabilities, tailcfg.NodeAttrDisableUPnP)
user, _ := s.getUser(nk) user, _ := s.getUser(nk)