mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-25 10:09:17 +00:00 
			
		
		
		
	 486a423716
			
		
	
	486a423716
	
	
	
		
			
			This adds a new `UserLogf` field to the `Server` struct. When set this any logs generated by Server are logged using `UserLogf` and all spammy backend logs are logged to `Logf`. If it `UserLogf` is unset, we default to `log.Printf` and if `Logf` is unset we discard all the spammy logs. Fixes #12094 Signed-off-by: Maisem Ali <maisem@tailscale.com>
		
			
				
	
	
		
			363 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			363 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| // The pgproxy server is a proxy for the Postgres wire protocol.
 | |
| package main
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"crypto/ecdsa"
 | |
| 	"crypto/elliptic"
 | |
| 	crand "crypto/rand"
 | |
| 	"crypto/tls"
 | |
| 	"crypto/x509"
 | |
| 	"crypto/x509/pkix"
 | |
| 	"expvar"
 | |
| 	"flag"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"log"
 | |
| 	"math/big"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"os"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"tailscale.com/client/tailscale"
 | |
| 	"tailscale.com/metrics"
 | |
| 	"tailscale.com/tsnet"
 | |
| 	"tailscale.com/tsweb"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	hostname     = flag.String("hostname", "", "Tailscale hostname to serve on")
 | |
| 	port         = flag.Int("port", 5432, "Listening port for client connections")
 | |
| 	debugPort    = flag.Int("debug-port", 80, "Listening port for debug/metrics endpoint")
 | |
| 	upstreamAddr = flag.String("upstream-addr", "", "Address of the upstream Postgres server, in host:port format")
 | |
| 	upstreamCA   = flag.String("upstream-ca-file", "", "File containing the PEM-encoded CA certificate for the upstream server")
 | |
| 	tailscaleDir = flag.String("state-dir", "", "Directory in which to store the Tailscale auth state")
 | |
| )
 | |
| 
 | |
| func main() {
 | |
| 	flag.Parse()
 | |
| 	if *hostname == "" {
 | |
| 		log.Fatal("missing --hostname")
 | |
| 	}
 | |
| 	if *upstreamAddr == "" {
 | |
| 		log.Fatal("missing --upstream-addr")
 | |
| 	}
 | |
| 	if *upstreamCA == "" {
 | |
| 		log.Fatal("missing --upstream-ca-file")
 | |
| 	}
 | |
| 	if *tailscaleDir == "" {
 | |
| 		log.Fatal("missing --state-dir")
 | |
| 	}
 | |
| 
 | |
| 	ts := &tsnet.Server{
 | |
| 		Dir:      *tailscaleDir,
 | |
| 		Hostname: *hostname,
 | |
| 	}
 | |
| 
 | |
| 	if os.Getenv("TS_AUTHKEY") == "" {
 | |
| 		log.Print("Note: you need to run this with TS_AUTHKEY=... the first time, to join your tailnet of choice.")
 | |
| 	}
 | |
| 
 | |
| 	tsclient, err := ts.LocalClient()
 | |
| 	if err != nil {
 | |
| 		log.Fatalf("getting tsnet API client: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	p, err := newProxy(*upstreamAddr, *upstreamCA, tsclient)
 | |
| 	if err != nil {
 | |
| 		log.Fatal(err)
 | |
| 	}
 | |
| 	expvar.Publish("pgproxy", p.Expvar())
 | |
| 
 | |
| 	if *debugPort != 0 {
 | |
| 		mux := http.NewServeMux()
 | |
| 		tsweb.Debugger(mux)
 | |
| 		srv := &http.Server{
 | |
| 			Handler: mux,
 | |
| 		}
 | |
| 		dln, err := ts.Listen("tcp", fmt.Sprintf(":%d", *debugPort))
 | |
| 		if err != nil {
 | |
| 			log.Fatal(err)
 | |
| 		}
 | |
| 		go func() {
 | |
| 			log.Fatal(srv.Serve(dln))
 | |
| 		}()
 | |
| 	}
 | |
| 
 | |
| 	ln, err := ts.Listen("tcp", fmt.Sprintf(":%d", *port))
 | |
| 	if err != nil {
 | |
| 		log.Fatal(err)
 | |
| 	}
 | |
| 	log.Printf("serving access to %s on port %d", *upstreamAddr, *port)
 | |
| 	log.Fatal(p.Serve(ln))
 | |
| }
 | |
| 
 | |
| // proxy is a postgres wire protocol proxy, which strictly enforces
 | |
| // the security of the TLS connection to its upstream regardless of
 | |
| // what the client's TLS configuration is.
 | |
| type proxy struct {
 | |
| 	upstreamAddr     string // "my.database.com:5432"
 | |
| 	upstreamHost     string // "my.database.com"
 | |
| 	upstreamCertPool *x509.CertPool
 | |
| 	downstreamCert   []tls.Certificate
 | |
| 	client           *tailscale.LocalClient
 | |
| 
 | |
| 	activeSessions  expvar.Int
 | |
| 	startedSessions expvar.Int
 | |
| 	errors          metrics.LabelMap
 | |
| }
 | |
| 
 | |
| // newProxy returns a proxy that forwards connections to
 | |
| // upstreamAddr. The upstream's TLS session is verified using the CA
 | |
| // cert(s) in upstreamCAPath.
 | |
| func newProxy(upstreamAddr, upstreamCAPath string, client *tailscale.LocalClient) (*proxy, error) {
 | |
| 	bs, err := os.ReadFile(upstreamCAPath)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	upstreamCertPool := x509.NewCertPool()
 | |
| 	if !upstreamCertPool.AppendCertsFromPEM(bs) {
 | |
| 		return nil, fmt.Errorf("invalid CA cert in %q", upstreamCAPath)
 | |
| 	}
 | |
| 
 | |
| 	h, _, err := net.SplitHostPort(upstreamAddr)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	downstreamCert, err := mkSelfSigned(h)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &proxy{
 | |
| 		upstreamAddr:     upstreamAddr,
 | |
| 		upstreamHost:     h,
 | |
| 		upstreamCertPool: upstreamCertPool,
 | |
| 		downstreamCert:   []tls.Certificate{downstreamCert},
 | |
| 		client:           client,
 | |
| 		errors:           metrics.LabelMap{Label: "kind"},
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // Expvar returns p's monitoring metrics.
 | |
| func (p *proxy) Expvar() expvar.Var {
 | |
| 	ret := &metrics.Set{}
 | |
| 	ret.Set("sessions_active", &p.activeSessions)
 | |
| 	ret.Set("sessions_started", &p.startedSessions)
 | |
| 	ret.Set("session_errors", &p.errors)
 | |
| 	return ret
 | |
| }
 | |
| 
 | |
| // Serve accepts postgres client connections on ln and proxies them to
 | |
| // the configured upstream. ln can be any net.Listener, but all client
 | |
| // connections must originate from tailscale IPs that can be verified
 | |
| // with WhoIs.
 | |
| func (p *proxy) Serve(ln net.Listener) error {
 | |
| 	var lastSessionID int64
 | |
| 	for {
 | |
| 		c, err := ln.Accept()
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		id := time.Now().UnixNano()
 | |
| 		if id == lastSessionID {
 | |
| 			// Bluntly enforce SID uniqueness, even if collisions are
 | |
| 			// fantastically unlikely (but OSes vary in how much timer
 | |
| 			// precision they expose to the OS, so id might be rounded
 | |
| 			// e.g. to the same millisecond)
 | |
| 			id++
 | |
| 		}
 | |
| 		lastSessionID = id
 | |
| 		go func(sessionID int64) {
 | |
| 			if err := p.serve(sessionID, c); err != nil {
 | |
| 				log.Printf("%d: session ended with error: %v", sessionID, err)
 | |
| 			}
 | |
| 		}(id)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| var (
 | |
| 	// sslStart is the magic bytes that postgres clients use to indicate
 | |
| 	// that they want to do a TLS handshake. Servers should respond with
 | |
| 	// the single byte "S" before starting a normal TLS handshake.
 | |
| 	sslStart = [8]byte{0, 0, 0, 8, 0x04, 0xd2, 0x16, 0x2f}
 | |
| 	// plaintextStart is the magic bytes that postgres clients use to
 | |
| 	// indicate that they're starting a plaintext authentication
 | |
| 	// handshake.
 | |
| 	plaintextStart = [8]byte{0, 0, 0, 86, 0, 3, 0, 0}
 | |
| )
 | |
| 
 | |
| // serve proxies the postgres client on c to the proxy's upstream,
 | |
| // enforcing strict TLS to the upstream.
 | |
| func (p *proxy) serve(sessionID int64, c net.Conn) error {
 | |
| 	defer c.Close()
 | |
| 
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
 | |
| 	defer cancel()
 | |
| 
 | |
| 	whois, err := p.client.WhoIs(ctx, c.RemoteAddr().String())
 | |
| 	if err != nil {
 | |
| 		p.errors.Add("whois-failed", 1)
 | |
| 		return fmt.Errorf("getting client identity: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Before anything else, log the connection attempt.
 | |
| 	user, machine := "", ""
 | |
| 	if whois.Node != nil {
 | |
| 		if whois.Node.Hostinfo.ShareeNode() {
 | |
| 			machine = "external-device"
 | |
| 		} else {
 | |
| 			machine = strings.TrimSuffix(whois.Node.Name, ".")
 | |
| 		}
 | |
| 	}
 | |
| 	if whois.UserProfile != nil {
 | |
| 		user = whois.UserProfile.LoginName
 | |
| 		if user == "tagged-devices" && whois.Node != nil {
 | |
| 			user = strings.Join(whois.Node.Tags, ",")
 | |
| 		}
 | |
| 	}
 | |
| 	if user == "" || machine == "" {
 | |
| 		p.errors.Add("no-ts-identity", 1)
 | |
| 		return fmt.Errorf("couldn't identify source user and machine (user %q, machine %q)", user, machine)
 | |
| 	}
 | |
| 	log.Printf("%d: session start, from %s (machine %s, user %s)", sessionID, c.RemoteAddr(), machine, user)
 | |
| 	start := time.Now()
 | |
| 	defer func() {
 | |
| 		elapsed := time.Since(start)
 | |
| 		log.Printf("%d: session end, from %s (machine %s, user %s), lasted %s", sessionID, c.RemoteAddr(), machine, user, elapsed.Round(time.Millisecond))
 | |
| 	}()
 | |
| 
 | |
| 	// Read the client's opening message, to figure out if it's trying
 | |
| 	// to TLS or not.
 | |
| 	var buf [8]byte
 | |
| 	if _, err := io.ReadFull(c, buf[:len(sslStart)]); err != nil {
 | |
| 		p.errors.Add("network-error", 1)
 | |
| 		return fmt.Errorf("initial magic read: %v", err)
 | |
| 	}
 | |
| 	var clientIsTLS bool
 | |
| 	switch {
 | |
| 	case buf == sslStart:
 | |
| 		clientIsTLS = true
 | |
| 	case buf == plaintextStart:
 | |
| 		clientIsTLS = false
 | |
| 	default:
 | |
| 		p.errors.Add("client-bad-protocol", 1)
 | |
| 		return fmt.Errorf("unrecognized initial packet = % 02x", buf)
 | |
| 	}
 | |
| 
 | |
| 	// Dial & verify upstream connection.
 | |
| 	var d net.Dialer
 | |
| 	d.Timeout = 10 * time.Second
 | |
| 	upc, err := d.Dial("tcp", p.upstreamAddr)
 | |
| 	if err != nil {
 | |
| 		p.errors.Add("network-error", 1)
 | |
| 		return fmt.Errorf("upstream dial: %v", err)
 | |
| 	}
 | |
| 	defer upc.Close()
 | |
| 	if _, err := upc.Write(sslStart[:]); err != nil {
 | |
| 		p.errors.Add("network-error", 1)
 | |
| 		return fmt.Errorf("upstream write of start-ssl magic: %v", err)
 | |
| 	}
 | |
| 	if _, err := io.ReadFull(upc, buf[:1]); err != nil {
 | |
| 		p.errors.Add("network-error", 1)
 | |
| 		return fmt.Errorf("reading upstream start-ssl response: %v", err)
 | |
| 	}
 | |
| 	if buf[0] != 'S' {
 | |
| 		p.errors.Add("upstream-bad-protocol", 1)
 | |
| 		return fmt.Errorf("upstream didn't acknowledge start-ssl, said %q", buf[0])
 | |
| 	}
 | |
| 	tlsConf := &tls.Config{
 | |
| 		ServerName: p.upstreamHost,
 | |
| 		RootCAs:    p.upstreamCertPool,
 | |
| 		MinVersion: tls.VersionTLS12,
 | |
| 	}
 | |
| 	uptc := tls.Client(upc, tlsConf)
 | |
| 	if err = uptc.HandshakeContext(ctx); err != nil {
 | |
| 		p.errors.Add("upstream-tls", 1)
 | |
| 		return fmt.Errorf("upstream TLS handshake: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	// Accept the client conn and set it up the way the client wants.
 | |
| 	var clientConn net.Conn
 | |
| 	if clientIsTLS {
 | |
| 		io.WriteString(c, "S") // yeah, we're good to speak TLS
 | |
| 		s := tls.Server(c, &tls.Config{
 | |
| 			ServerName:   p.upstreamHost,
 | |
| 			Certificates: p.downstreamCert,
 | |
| 			MinVersion:   tls.VersionTLS12,
 | |
| 		})
 | |
| 		if err = uptc.HandshakeContext(ctx); err != nil {
 | |
| 			p.errors.Add("client-tls", 1)
 | |
| 			return fmt.Errorf("client TLS handshake: %v", err)
 | |
| 		}
 | |
| 		clientConn = s
 | |
| 	} else {
 | |
| 		// Repeat the header we read earlier up to the server.
 | |
| 		if _, err := uptc.Write(plaintextStart[:]); err != nil {
 | |
| 			p.errors.Add("network-error", 1)
 | |
| 			return fmt.Errorf("sending initial client bytes to upstream: %v", err)
 | |
| 		}
 | |
| 		clientConn = c
 | |
| 	}
 | |
| 
 | |
| 	// Finally, proxy the client to the upstream.
 | |
| 	errc := make(chan error, 1)
 | |
| 	go func() {
 | |
| 		_, err := io.Copy(uptc, clientConn)
 | |
| 		errc <- err
 | |
| 	}()
 | |
| 	go func() {
 | |
| 		_, err := io.Copy(clientConn, uptc)
 | |
| 		errc <- err
 | |
| 	}()
 | |
| 	if err := <-errc; err != nil {
 | |
| 		// Don't increment error counts here, because the most common
 | |
| 		// cause of termination is client or server closing the
 | |
| 		// connection normally, and it'll obscure "interesting"
 | |
| 		// handshake errors.
 | |
| 		return fmt.Errorf("session terminated with error: %v", err)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // mkSelfSigned creates and returns a self-signed TLS certificate for
 | |
| // hostname.
 | |
| func mkSelfSigned(hostname string) (tls.Certificate, error) {
 | |
| 	priv, err := ecdsa.GenerateKey(elliptic.P256(), crand.Reader)
 | |
| 	if err != nil {
 | |
| 		return tls.Certificate{}, err
 | |
| 	}
 | |
| 	pub := priv.Public()
 | |
| 	template := x509.Certificate{
 | |
| 		SerialNumber: big.NewInt(1),
 | |
| 		Subject: pkix.Name{
 | |
| 			Organization: []string{"pgproxy"},
 | |
| 		},
 | |
| 		DNSNames:              []string{hostname},
 | |
| 		NotBefore:             time.Now(),
 | |
| 		NotAfter:              time.Now().Add(10 * 365 * 24 * time.Hour),
 | |
| 		KeyUsage:              x509.KeyUsageDigitalSignature,
 | |
| 		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
 | |
| 		BasicConstraintsValid: true,
 | |
| 	}
 | |
| 	derBytes, err := x509.CreateCertificate(crand.Reader, &template, &template, pub, priv)
 | |
| 	if err != nil {
 | |
| 		return tls.Certificate{}, err
 | |
| 	}
 | |
| 	cert, err := x509.ParseCertificate(derBytes)
 | |
| 	if err != nil {
 | |
| 		return tls.Certificate{}, err
 | |
| 	}
 | |
| 
 | |
| 	return tls.Certificate{
 | |
| 		Certificate: [][]byte{derBytes},
 | |
| 		PrivateKey:  priv,
 | |
| 		Leaf:        cert,
 | |
| 	}, nil
 | |
| }
 |