diff --git a/cmd/sniproxy/.gitignore b/cmd/sniproxy/.gitignore new file mode 100644 index 000000000..b1399c881 --- /dev/null +++ b/cmd/sniproxy/.gitignore @@ -0,0 +1 @@ +sniproxy diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go index 1fbf9a1d2..04af9cd1b 100644 --- a/cmd/sniproxy/sniproxy.go +++ b/cmd/sniproxy/sniproxy.go @@ -3,15 +3,20 @@ // The sniproxy is an outbound SNI proxy. It receives TLS connections over // Tailscale on one or more TCP ports and sends them out to the same SNI -// hostname & port on the internet. It only does TCP. +// hostname & port on the internet. It can optionally forward one or more +// TCP ports to a specific destination. It only does TCP. package main import ( "context" + "errors" + "expvar" "flag" + "fmt" "log" "net" "net/http" + "strconv" "strings" "time" @@ -19,27 +24,54 @@ "inet.af/tcpproxy" "tailscale.com/client/tailscale" "tailscale.com/hostinfo" + "tailscale.com/metrics" "tailscale.com/net/netutil" "tailscale.com/tsnet" + "tailscale.com/tsweb" "tailscale.com/types/nettype" "tailscale.com/util/clientmetric" ) var ( ports = flag.String("ports", "443", "comma-separated list of ports to proxy") + forwards = flag.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com") wgPort = flag.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select") promoteHTTPS = flag.Bool("promote-https", true, "promote HTTP to HTTPS") + debugPort = flag.Int("debug-port", 8080, "Listening port for debug/metrics endpoint") ) var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") -var ( - numSessions = clientmetric.NewCounter("sniproxy_sessions") - numBadAddrPort = clientmetric.NewCounter("sniproxy_bad_addrport") - dnsResponses = clientmetric.NewCounter("sniproxy_dns_responses") - dnsFailures = clientmetric.NewCounter("sniproxy_dns_failed") - httpPromoted = clientmetric.NewCounter("sniproxy_http_promoted") -) +// portForward is the state for a single port forwarding entry, as passed to the --forward flag. +type portForward struct { + Port int + Proto string + Destination string +} + +// parseForward takes a proto/port/destination tuple as an input, as would be passed +// to the --forward command line flag, and returns a *portForward struct of those parameters. +func parseForward(value string) (*portForward, error) { + parts := strings.Split(value, "/") + if len(parts) != 3 { + return nil, errors.New("cannot parse: " + value) + } + + proto := parts[0] + if proto != "tcp" { + return nil, errors.New("unsupported forwarding protocol: " + proto) + } + port, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return nil, errors.New("bad forwarding port: " + parts[1]) + } + host := parts[2] + if host == "" { + return nil, errors.New("bad destination: " + value) + } + + return &portForward{Port: int(port), Proto: proto, Destination: host}, nil +} func main() { flag.Parse() @@ -58,6 +90,7 @@ func main() { log.Fatal(err) } s.lc = lc + s.initMetrics() for _, portStr := range strings.Split(*ports, ",") { ln, err := s.ts.Listen("tcp", ":"+portStr) @@ -68,6 +101,34 @@ func main() { go s.serve(ln) } + for _, forwStr := range strings.Split(*forwards, ",") { + if forwStr == "" { + continue + } + forw, err := parseForward(forwStr) + if err != nil { + log.Fatal(err) + } + + ln, err := s.ts.Listen("tcp", ":"+strconv.Itoa(forw.Port)) + if err != nil { + log.Fatal(err) + } + log.Printf("Serving on port %d to %s...", forw.Port, forw.Destination) + + // Add an entry to the expvar LabelMap for Prometheus metrics, + // and create a clientmetric to report that same value. + service := portNumberToName(forw) + s.numTCPsessions.SetInt64(service, 0) + metric := fmt.Sprintf("sniproxy_tcp_sessions_%s", service) + clientmetric.NewCounterFunc(metric, func() int64 { + return s.numTCPsessions.Get(service).Value() + }) + + go s.forward(ln, forw) + + } + ln, err := s.ts.Listen("udp", ":53") if err != nil { log.Fatal(err) @@ -83,12 +144,31 @@ func main() { go s.promoteHTTPS(ln) } + if *debugPort != 0 { + mux := http.NewServeMux() + tsweb.Debugger(mux) + dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", *debugPort)) + if err != nil { + log.Fatal(err) + } + go func() { + log.Fatal(http.Serve(dln, mux)) + }() + } + select {} } type server struct { ts tsnet.Server lc *tailscale.LocalClient + + numTLSsessions expvar.Int + numTCPsessions *metrics.LabelMap + numBadAddrPort expvar.Int + dnsResponses expvar.Int + dnsFailures expvar.Int + httpPromoted expvar.Int } func (s *server) serve(ln net.Listener) { @@ -101,6 +181,16 @@ func (s *server) serve(ln net.Listener) { } } +func (s *server) forward(ln net.Listener, forw *portForward) { + for { + c, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + go s.forwardConn(c, forw) + } +} + func (s *server) serveDNS(ln net.Listener) { for { c, err := ln.Accept() @@ -118,7 +208,7 @@ func (s *server) serveDNSConn(c nettype.ConnPacketConn) { n, err := c.Read(buf) if err != nil { log.Printf("c.Read failed: %v\n ", err) - dnsFailures.Add(1) + s.dnsFailures.Add(1) return } @@ -126,25 +216,25 @@ func (s *server) serveDNSConn(c nettype.ConnPacketConn) { err = msg.Unpack(buf[:n]) if err != nil { log.Printf("dnsmessage unpack failed: %v\n ", err) - dnsFailures.Add(1) + s.dnsFailures.Add(1) return } buf, err = s.dnsResponse(&msg) if err != nil { log.Printf("s.dnsResponse failed: %v\n", err) - dnsFailures.Add(1) + s.dnsFailures.Add(1) return } _, err = c.Write(buf) if err != nil { log.Printf("c.Write failed: %v\n", err) - dnsFailures.Add(1) + s.dnsFailures.Add(1) return } - dnsResponses.Add(1) + s.dnsResponses.Add(1) } func (s *server) serveConn(c net.Conn) { @@ -152,7 +242,7 @@ func (s *server) serveConn(c net.Conn) { _, port, err := net.SplitHostPort(addrPortStr) if err != nil { log.Printf("bogus addrPort %q", addrPortStr) - numBadAddrPort.Add(1) + s.numBadAddrPort.Add(1) c.Close() return } @@ -165,7 +255,7 @@ func (s *server) serveConn(c net.Conn) { return netutil.NewOneConnListener(c, nil), nil } p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) { - numSessions.Add(1) + s.numTLSsessions.Add(1) return &tcpproxy.DialProxy{ Addr: net.JoinHostPort(sniName, port), DialContext: dialer.DialContext, @@ -174,6 +264,49 @@ func (s *server) serveConn(c net.Conn) { p.Start() } +// portNumberToName returns a human-readable name for several port numbers commonly forwarded, +// and "tcp###" for everything else. It is used for metric label names. +func portNumberToName(forw *portForward) string { + switch forw.Port { + case 22: + return "ssh" + case 1433: + return "sqlserver" + case 3306: + return "mysql" + case 3389: + return "rdp" + case 5432: + return "postgres" + default: + return fmt.Sprintf("%s%d", forw.Proto, forw.Port) + } +} + +// forwardConn sets up a forwarder for a TCP connection. It does not inspect of the data +// like the SNI forwarding does, it merely forwards all data to the destination specified +// in the --forward=tcp/22/github.com argument. +func (s *server) forwardConn(c net.Conn, forw *portForward) { + addrPortStr := c.LocalAddr().String() + + var dialer net.Dialer + dialer.Timeout = 30 * time.Second + + var p tcpproxy.Proxy + p.ListenFunc = func(net, laddr string) (net.Listener, error) { + return netutil.NewOneConnListener(c, nil), nil + } + + dial := &tcpproxy.DialProxy{ + Addr: fmt.Sprintf("%s:%d", forw.Destination, forw.Port), + DialContext: dialer.DialContext, + } + + p.AddRoute(addrPortStr, dial) + s.numTCPsessions.Add(portNumberToName(forw), 1) + p.Start() +} + func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) { resp := dnsmessage.NewBuilder(buf, dnsmessage.Header{ @@ -235,8 +368,36 @@ func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) { func (s *server) promoteHTTPS(ln net.Listener) { err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - httpPromoted.Add(1) + s.httpPromoted.Add(1) http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound) })) log.Fatalf("promoteHTTPS http.Serve: %v", err) } + +// initMetrics sets up local prometheus metrics, and creates clientmetrics to report those +// same counters. +func (s *server) initMetrics() { + stats := new(metrics.Set) + + stats.Set("tls_sessions", &s.numTLSsessions) + clientmetric.NewCounterFunc("sniproxy_tls_sessions", s.numTLSsessions.Value) + + s.numTCPsessions = &metrics.LabelMap{Label: "proto"} + stats.Set("tcp_sessions", s.numTCPsessions) + // clientmetric doesn't have a good way to implement a Map type. + // We create clientmetrics dynamically when parsing the --forwards argument + + stats.Set("bad_addrport", &s.numBadAddrPort) + clientmetric.NewCounterFunc("sniproxy_bad_addrport", s.numBadAddrPort.Value) + + stats.Set("dns_responses", &s.dnsResponses) + clientmetric.NewCounterFunc("sniproxy_dns_responses", s.dnsResponses.Value) + + stats.Set("dns_failed", &s.dnsFailures) + clientmetric.NewCounterFunc("sniproxy_dns_failed", s.dnsFailures.Value) + + stats.Set("http_promoted", &s.httpPromoted) + clientmetric.NewCounterFunc("sniproxy_http_promoted", s.httpPromoted.Value) + + expvar.Publish("sniproxy", stats) +} diff --git a/cmd/sniproxy/sniproxy_test.go b/cmd/sniproxy/sniproxy_test.go new file mode 100644 index 000000000..15cc2ec21 --- /dev/null +++ b/cmd/sniproxy/sniproxy_test.go @@ -0,0 +1,37 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestPortForwardingArguments(t *testing.T) { + tests := []struct { + in string + wanterr string + want *portForward + }{ + {"", "", nil}, + {"bad port specifier", "cannot parse", nil}, + {"tcp/xyz/example.com", "bad forwarding port", nil}, + {"tcp//example.com", "bad forwarding port", nil}, + {"tcp/2112/", "bad destination", nil}, + {"udp/53/example.com", "unsupported forwarding protocol", nil}, + {"tcp/22/github.com", "", &portForward{Proto: "tcp", Port: 22, Destination: "github.com"}}, + } + for _, tt := range tests { + got, goterr := parseForward(tt.in) + if tt.wanterr != "" { + if !strings.Contains(goterr.Error(), tt.wanterr) { + t.Errorf("f(%q).err = %v; want %v", tt.in, goterr, tt.wanterr) + } + } else if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("Parsed forward (-got, +want):\n%s", diff) + } + } +}