mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-16 03:31:39 +00:00
cmd/sniproxy: add port forwarding and prometheus metrics
1. Add TCP port forwarding. For example: ./sniproxy -forwards=tcp/22/github.com will forward SSH to github. % ssh -i ~/.ssh/id_ecdsa.pem -T git@github.com Hi GitHubUser! You've successfully authenticated, but GitHub does not provide shell access. % ssh -i ~/.ssh/id_ecdsa.pem -T git@100.65.x.y Hi GitHubUser! You've successfully authenticated, but GitHub does not provide shell access. 2. Additionally export clientmetrics as prometheus metrics for local scraping over the tailnet: http://sniproxy-hostname:8080/debug/varz Updates https://github.com/tailscale/tailscale/issues/1748 Signed-off-by: Denton Gentry <dgentry@tailscale.com>
This commit is contained in:
parent
98a5116434
commit
24d41e4ae7
1
cmd/sniproxy/.gitignore
vendored
Normal file
1
cmd/sniproxy/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
sniproxy
|
@ -3,15 +3,20 @@
|
|||||||
|
|
||||||
// The sniproxy is an outbound SNI proxy. It receives TLS connections over
|
// The sniproxy is an outbound SNI proxy. It receives TLS connections over
|
||||||
// Tailscale on one or more TCP ports and sends them out to the same SNI
|
// Tailscale on one or more TCP ports and sends them out to the same SNI
|
||||||
// hostname & port on the internet. It only does TCP.
|
// hostname & port on the internet. It can optionally forward one or more
|
||||||
|
// TCP ports to a specific destination. It only does TCP.
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"expvar"
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -19,27 +24,54 @@ import (
|
|||||||
"inet.af/tcpproxy"
|
"inet.af/tcpproxy"
|
||||||
"tailscale.com/client/tailscale"
|
"tailscale.com/client/tailscale"
|
||||||
"tailscale.com/hostinfo"
|
"tailscale.com/hostinfo"
|
||||||
|
"tailscale.com/metrics"
|
||||||
"tailscale.com/net/netutil"
|
"tailscale.com/net/netutil"
|
||||||
"tailscale.com/tsnet"
|
"tailscale.com/tsnet"
|
||||||
|
"tailscale.com/tsweb"
|
||||||
"tailscale.com/types/nettype"
|
"tailscale.com/types/nettype"
|
||||||
"tailscale.com/util/clientmetric"
|
"tailscale.com/util/clientmetric"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ports = flag.String("ports", "443", "comma-separated list of ports to proxy")
|
ports = flag.String("ports", "443", "comma-separated list of ports to proxy")
|
||||||
|
forwards = flag.String("forwards", "", "comma-separated list of ports to transparently forward, protocol/number/destination. For example, --forwards=tcp/22/github.com,tcp/5432/sql.example.com")
|
||||||
wgPort = flag.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select")
|
wgPort = flag.Int("wg-listen-port", 0, "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select")
|
||||||
promoteHTTPS = flag.Bool("promote-https", true, "promote HTTP to HTTPS")
|
promoteHTTPS = flag.Bool("promote-https", true, "promote HTTP to HTTPS")
|
||||||
|
debugPort = flag.Int("debug-port", 8080, "Listening port for debug/metrics endpoint")
|
||||||
)
|
)
|
||||||
|
|
||||||
var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
|
var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
|
||||||
|
|
||||||
var (
|
// portForward is the state for a single port forwarding entry, as passed to the --forward flag.
|
||||||
numSessions = clientmetric.NewCounter("sniproxy_sessions")
|
type portForward struct {
|
||||||
numBadAddrPort = clientmetric.NewCounter("sniproxy_bad_addrport")
|
Port int
|
||||||
dnsResponses = clientmetric.NewCounter("sniproxy_dns_responses")
|
Proto string
|
||||||
dnsFailures = clientmetric.NewCounter("sniproxy_dns_failed")
|
Destination string
|
||||||
httpPromoted = clientmetric.NewCounter("sniproxy_http_promoted")
|
}
|
||||||
)
|
|
||||||
|
// parseForward takes a proto/port/destination tuple as an input, as would be passed
|
||||||
|
// to the --forward command line flag, and returns a *portForward struct of those parameters.
|
||||||
|
func parseForward(value string) (*portForward, error) {
|
||||||
|
parts := strings.Split(value, "/")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil, errors.New("cannot parse: " + value)
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := parts[0]
|
||||||
|
if proto != "tcp" {
|
||||||
|
return nil, errors.New("unsupported forwarding protocol: " + proto)
|
||||||
|
}
|
||||||
|
port, err := strconv.ParseUint(parts[1], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("bad forwarding port: " + parts[1])
|
||||||
|
}
|
||||||
|
host := parts[2]
|
||||||
|
if host == "" {
|
||||||
|
return nil, errors.New("bad destination: " + value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &portForward{Port: int(port), Proto: proto, Destination: host}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
@ -58,6 +90,7 @@ func main() {
|
|||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
s.lc = lc
|
s.lc = lc
|
||||||
|
s.initMetrics()
|
||||||
|
|
||||||
for _, portStr := range strings.Split(*ports, ",") {
|
for _, portStr := range strings.Split(*ports, ",") {
|
||||||
ln, err := s.ts.Listen("tcp", ":"+portStr)
|
ln, err := s.ts.Listen("tcp", ":"+portStr)
|
||||||
@ -68,6 +101,34 @@ func main() {
|
|||||||
go s.serve(ln)
|
go s.serve(ln)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, forwStr := range strings.Split(*forwards, ",") {
|
||||||
|
if forwStr == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
forw, err := parseForward(forwStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := s.ts.Listen("tcp", ":"+strconv.Itoa(forw.Port))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
log.Printf("Serving on port %d to %s...", forw.Port, forw.Destination)
|
||||||
|
|
||||||
|
// Add an entry to the expvar LabelMap for Prometheus metrics,
|
||||||
|
// and create a clientmetric to report that same value.
|
||||||
|
service := portNumberToName(forw)
|
||||||
|
s.numTCPsessions.SetInt64(service, 0)
|
||||||
|
metric := fmt.Sprintf("sniproxy_tcp_sessions_%s", service)
|
||||||
|
clientmetric.NewCounterFunc(metric, func() int64 {
|
||||||
|
return s.numTCPsessions.Get(service).Value()
|
||||||
|
})
|
||||||
|
|
||||||
|
go s.forward(ln, forw)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
ln, err := s.ts.Listen("udp", ":53")
|
ln, err := s.ts.Listen("udp", ":53")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
@ -83,12 +144,31 @@ func main() {
|
|||||||
go s.promoteHTTPS(ln)
|
go s.promoteHTTPS(ln)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if *debugPort != 0 {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
tsweb.Debugger(mux)
|
||||||
|
dln, err := s.ts.Listen("tcp", fmt.Sprintf(":%d", *debugPort))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
log.Fatal(http.Serve(dln, mux))
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
select {}
|
select {}
|
||||||
}
|
}
|
||||||
|
|
||||||
type server struct {
|
type server struct {
|
||||||
ts tsnet.Server
|
ts tsnet.Server
|
||||||
lc *tailscale.LocalClient
|
lc *tailscale.LocalClient
|
||||||
|
|
||||||
|
numTLSsessions expvar.Int
|
||||||
|
numTCPsessions *metrics.LabelMap
|
||||||
|
numBadAddrPort expvar.Int
|
||||||
|
dnsResponses expvar.Int
|
||||||
|
dnsFailures expvar.Int
|
||||||
|
httpPromoted expvar.Int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) serve(ln net.Listener) {
|
func (s *server) serve(ln net.Listener) {
|
||||||
@ -101,6 +181,16 @@ func (s *server) serve(ln net.Listener) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *server) forward(ln net.Listener, forw *portForward) {
|
||||||
|
for {
|
||||||
|
c, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
go s.forwardConn(c, forw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *server) serveDNS(ln net.Listener) {
|
func (s *server) serveDNS(ln net.Listener) {
|
||||||
for {
|
for {
|
||||||
c, err := ln.Accept()
|
c, err := ln.Accept()
|
||||||
@ -118,7 +208,7 @@ func (s *server) serveDNSConn(c nettype.ConnPacketConn) {
|
|||||||
n, err := c.Read(buf)
|
n, err := c.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("c.Read failed: %v\n ", err)
|
log.Printf("c.Read failed: %v\n ", err)
|
||||||
dnsFailures.Add(1)
|
s.dnsFailures.Add(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -126,25 +216,25 @@ func (s *server) serveDNSConn(c nettype.ConnPacketConn) {
|
|||||||
err = msg.Unpack(buf[:n])
|
err = msg.Unpack(buf[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("dnsmessage unpack failed: %v\n ", err)
|
log.Printf("dnsmessage unpack failed: %v\n ", err)
|
||||||
dnsFailures.Add(1)
|
s.dnsFailures.Add(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
buf, err = s.dnsResponse(&msg)
|
buf, err = s.dnsResponse(&msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("s.dnsResponse failed: %v\n", err)
|
log.Printf("s.dnsResponse failed: %v\n", err)
|
||||||
dnsFailures.Add(1)
|
s.dnsFailures.Add(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = c.Write(buf)
|
_, err = c.Write(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("c.Write failed: %v\n", err)
|
log.Printf("c.Write failed: %v\n", err)
|
||||||
dnsFailures.Add(1)
|
s.dnsFailures.Add(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsResponses.Add(1)
|
s.dnsResponses.Add(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) serveConn(c net.Conn) {
|
func (s *server) serveConn(c net.Conn) {
|
||||||
@ -152,7 +242,7 @@ func (s *server) serveConn(c net.Conn) {
|
|||||||
_, port, err := net.SplitHostPort(addrPortStr)
|
_, port, err := net.SplitHostPort(addrPortStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("bogus addrPort %q", addrPortStr)
|
log.Printf("bogus addrPort %q", addrPortStr)
|
||||||
numBadAddrPort.Add(1)
|
s.numBadAddrPort.Add(1)
|
||||||
c.Close()
|
c.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -165,7 +255,7 @@ func (s *server) serveConn(c net.Conn) {
|
|||||||
return netutil.NewOneConnListener(c, nil), nil
|
return netutil.NewOneConnListener(c, nil), nil
|
||||||
}
|
}
|
||||||
p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) {
|
p.AddSNIRouteFunc(addrPortStr, func(ctx context.Context, sniName string) (t tcpproxy.Target, ok bool) {
|
||||||
numSessions.Add(1)
|
s.numTLSsessions.Add(1)
|
||||||
return &tcpproxy.DialProxy{
|
return &tcpproxy.DialProxy{
|
||||||
Addr: net.JoinHostPort(sniName, port),
|
Addr: net.JoinHostPort(sniName, port),
|
||||||
DialContext: dialer.DialContext,
|
DialContext: dialer.DialContext,
|
||||||
@ -174,6 +264,49 @@ func (s *server) serveConn(c net.Conn) {
|
|||||||
p.Start()
|
p.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// portNumberToName returns a human-readable name for several port numbers commonly forwarded,
|
||||||
|
// and "tcp###" for everything else. It is used for metric label names.
|
||||||
|
func portNumberToName(forw *portForward) string {
|
||||||
|
switch forw.Port {
|
||||||
|
case 22:
|
||||||
|
return "ssh"
|
||||||
|
case 1433:
|
||||||
|
return "sqlserver"
|
||||||
|
case 3306:
|
||||||
|
return "mysql"
|
||||||
|
case 3389:
|
||||||
|
return "rdp"
|
||||||
|
case 5432:
|
||||||
|
return "postgres"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%s%d", forw.Proto, forw.Port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// forwardConn sets up a forwarder for a TCP connection. It does not inspect of the data
|
||||||
|
// like the SNI forwarding does, it merely forwards all data to the destination specified
|
||||||
|
// in the --forward=tcp/22/github.com argument.
|
||||||
|
func (s *server) forwardConn(c net.Conn, forw *portForward) {
|
||||||
|
addrPortStr := c.LocalAddr().String()
|
||||||
|
|
||||||
|
var dialer net.Dialer
|
||||||
|
dialer.Timeout = 30 * time.Second
|
||||||
|
|
||||||
|
var p tcpproxy.Proxy
|
||||||
|
p.ListenFunc = func(net, laddr string) (net.Listener, error) {
|
||||||
|
return netutil.NewOneConnListener(c, nil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dial := &tcpproxy.DialProxy{
|
||||||
|
Addr: fmt.Sprintf("%s:%d", forw.Destination, forw.Port),
|
||||||
|
DialContext: dialer.DialContext,
|
||||||
|
}
|
||||||
|
|
||||||
|
p.AddRoute(addrPortStr, dial)
|
||||||
|
s.numTCPsessions.Add(portNumberToName(forw), 1)
|
||||||
|
p.Start()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) {
|
func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) {
|
||||||
resp := dnsmessage.NewBuilder(buf,
|
resp := dnsmessage.NewBuilder(buf,
|
||||||
dnsmessage.Header{
|
dnsmessage.Header{
|
||||||
@ -235,8 +368,36 @@ func (s *server) dnsResponse(req *dnsmessage.Message) (buf []byte, err error) {
|
|||||||
|
|
||||||
func (s *server) promoteHTTPS(ln net.Listener) {
|
func (s *server) promoteHTTPS(ln net.Listener) {
|
||||||
err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
err := http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
httpPromoted.Add(1)
|
s.httpPromoted.Add(1)
|
||||||
http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound)
|
http.Redirect(w, r, "https://"+r.Host+r.RequestURI, http.StatusFound)
|
||||||
}))
|
}))
|
||||||
log.Fatalf("promoteHTTPS http.Serve: %v", err)
|
log.Fatalf("promoteHTTPS http.Serve: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initMetrics sets up local prometheus metrics, and creates clientmetrics to report those
|
||||||
|
// same counters.
|
||||||
|
func (s *server) initMetrics() {
|
||||||
|
stats := new(metrics.Set)
|
||||||
|
|
||||||
|
stats.Set("tls_sessions", &s.numTLSsessions)
|
||||||
|
clientmetric.NewCounterFunc("sniproxy_tls_sessions", s.numTLSsessions.Value)
|
||||||
|
|
||||||
|
s.numTCPsessions = &metrics.LabelMap{Label: "proto"}
|
||||||
|
stats.Set("tcp_sessions", s.numTCPsessions)
|
||||||
|
// clientmetric doesn't have a good way to implement a Map type.
|
||||||
|
// We create clientmetrics dynamically when parsing the --forwards argument
|
||||||
|
|
||||||
|
stats.Set("bad_addrport", &s.numBadAddrPort)
|
||||||
|
clientmetric.NewCounterFunc("sniproxy_bad_addrport", s.numBadAddrPort.Value)
|
||||||
|
|
||||||
|
stats.Set("dns_responses", &s.dnsResponses)
|
||||||
|
clientmetric.NewCounterFunc("sniproxy_dns_responses", s.dnsResponses.Value)
|
||||||
|
|
||||||
|
stats.Set("dns_failed", &s.dnsFailures)
|
||||||
|
clientmetric.NewCounterFunc("sniproxy_dns_failed", s.dnsFailures.Value)
|
||||||
|
|
||||||
|
stats.Set("http_promoted", &s.httpPromoted)
|
||||||
|
clientmetric.NewCounterFunc("sniproxy_http_promoted", s.httpPromoted.Value)
|
||||||
|
|
||||||
|
expvar.Publish("sniproxy", stats)
|
||||||
|
}
|
||||||
|
37
cmd/sniproxy/sniproxy_test.go
Normal file
37
cmd/sniproxy/sniproxy_test.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPortForwardingArguments(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
in string
|
||||||
|
wanterr string
|
||||||
|
want *portForward
|
||||||
|
}{
|
||||||
|
{"", "", nil},
|
||||||
|
{"bad port specifier", "cannot parse", nil},
|
||||||
|
{"tcp/xyz/example.com", "bad forwarding port", nil},
|
||||||
|
{"tcp//example.com", "bad forwarding port", nil},
|
||||||
|
{"tcp/2112/", "bad destination", nil},
|
||||||
|
{"udp/53/example.com", "unsupported forwarding protocol", nil},
|
||||||
|
{"tcp/22/github.com", "", &portForward{Proto: "tcp", Port: 22, Destination: "github.com"}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
got, goterr := parseForward(tt.in)
|
||||||
|
if tt.wanterr != "" {
|
||||||
|
if !strings.Contains(goterr.Error(), tt.wanterr) {
|
||||||
|
t.Errorf("f(%q).err = %v; want %v", tt.in, goterr, tt.wanterr)
|
||||||
|
}
|
||||||
|
} else if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
|
t.Errorf("Parsed forward (-got, +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user