// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package controlhttp

import (
	"context"
	"crypto/tls"
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"net/http/httputil"
	"net/netip"
	"net/url"
	"runtime"
	"strconv"
	"sync"
	"testing"
	"time"

	"tailscale.com/control/controlbase"
	"tailscale.com/net/dnscache"
	"tailscale.com/net/socks5"
	"tailscale.com/net/tsdial"
	"tailscale.com/tailcfg"
	"tailscale.com/tstest"
	"tailscale.com/tstime"
	"tailscale.com/types/key"
	"tailscale.com/types/logger"
)

type httpTestParam struct {
	name  string
	proxy proxy

	// makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a
	// 101 switching protocols.
	makeHTTPHangAfterUpgrade bool

	doEarlyWrite bool
}

func TestControlHTTP(t *testing.T) {
	tests := []httpTestParam{
		// direct connection
		{
			name:  "no_proxy",
			proxy: nil,
		},
		// direct connection but port 80 is MITM'ed and broken
		{
			name:                     "port80_broken_mitm",
			proxy:                    nil,
			makeHTTPHangAfterUpgrade: true,
		},
		// SOCKS5
		{
			name:  "socks5",
			proxy: &socksProxy{},
		},
		// HTTP->HTTP
		{
			name: "http_to_http",
			proxy: &httpProxy{
				useTLS:       false,
				allowConnect: false,
				allowHTTP:    true,
			},
		},
		// HTTP->HTTPS
		{
			name: "http_to_https",
			proxy: &httpProxy{
				useTLS:       false,
				allowConnect: true,
				allowHTTP:    false,
			},
		},
		// HTTP->any (will pick HTTP)
		{
			name: "http_to_any",
			proxy: &httpProxy{
				useTLS:       false,
				allowConnect: true,
				allowHTTP:    true,
			},
		},
		// HTTPS->HTTP
		{
			name: "https_to_http",
			proxy: &httpProxy{
				useTLS:       true,
				allowConnect: false,
				allowHTTP:    true,
			},
		},
		// HTTPS->HTTPS
		{
			name: "https_to_https",
			proxy: &httpProxy{
				useTLS:       true,
				allowConnect: true,
				allowHTTP:    false,
			},
		},
		// HTTPS->any (will pick HTTP)
		{
			name: "https_to_any",
			proxy: &httpProxy{
				useTLS:       true,
				allowConnect: true,
				allowHTTP:    true,
			},
		},
		// Early write
		{
			name:         "early_write",
			doEarlyWrite: true,
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			testControlHTTP(t, test)
		})
	}
}

func testControlHTTP(t *testing.T, param httpTestParam) {
	proxy := param.proxy
	client, server := key.NewMachine(), key.NewMachine()

	const testProtocolVersion = 1
	const earlyWriteMsg = "Hello, world!"
	sch := make(chan serverResult, 1)
	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		var earlyWriteFn func(protocolVersion int, w io.Writer) error
		if param.doEarlyWrite {
			earlyWriteFn = func(protocolVersion int, w io.Writer) error {
				if protocolVersion != testProtocolVersion {
					t.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
					return fmt.Errorf("unexpected protocol version %d; want %d", protocolVersion, testProtocolVersion)
				}
				_, err := io.WriteString(w, earlyWriteMsg)
				return err
			}
		}
		conn, err := AcceptHTTP(context.Background(), w, r, server, earlyWriteFn)
		if err != nil {
			log.Print(err)
		}
		res := serverResult{
			err: err,
		}
		if conn != nil {
			res.clientAddr = conn.RemoteAddr().String()
			res.version = conn.ProtocolVersion()
			res.peer = conn.Peer()
			res.conn = conn
		}
		sch <- res
	})

	httpLn, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		t.Fatalf("HTTP listen: %v", err)
	}
	httpsLn, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		t.Fatalf("HTTPS listen: %v", err)
	}

	var httpHandler http.Handler = handler
	const fallbackDelay = 50 * time.Millisecond
	clock := tstest.NewClock(tstest.ClockOpts{Step: 2 * fallbackDelay})
	// Advance once to init the clock.
	clock.Now()
	if param.makeHTTPHangAfterUpgrade {
		httpHandler = brokenMITMHandler(clock)
	}
	httpServer := &http.Server{Handler: httpHandler}
	go httpServer.Serve(httpLn)
	defer httpServer.Close()

	httpsServer := &http.Server{
		Handler:   handler,
		TLSConfig: tlsConfig(t),
	}
	go httpsServer.ServeTLS(httpsLn, "", "")
	defer httpsServer.Close()

	ctx := context.Background()
	const debugTimeout = false
	if debugTimeout {
		var cancel context.CancelFunc
		ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
	}

	a := &Dialer{
		Hostname:             "localhost",
		HTTPPort:             strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
		HTTPSPort:            strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
		MachineKey:           client,
		ControlKey:           server.Public(),
		ProtocolVersion:      testProtocolVersion,
		Dialer:               new(tsdial.Dialer).SystemDial,
		Logf:                 t.Logf,
		omitCertErrorLogging: true,
		testFallbackDelay:    fallbackDelay,
		Clock:                clock,
	}

	if proxy != nil {
		proxyEnv := proxy.Start(t)
		defer proxy.Close()
		proxyURL, err := url.Parse(proxyEnv)
		if err != nil {
			t.Fatal(err)
		}
		a.proxyFunc = func(*http.Request) (*url.URL, error) {
			return proxyURL, nil
		}
	} else {
		a.proxyFunc = func(*http.Request) (*url.URL, error) {
			return nil, nil
		}
	}

	conn, err := a.dial(ctx)
	if err != nil {
		t.Fatalf("dialing controlhttp: %v", err)
	}
	defer conn.Close()
	si := <-sch
	if si.conn != nil {
		defer si.conn.Close()
	}
	if si.err != nil {
		t.Fatalf("controlhttp server got error: %v", err)
	}
	if clientVersion := conn.ProtocolVersion(); si.version != clientVersion {
		t.Fatalf("client and server don't agree on protocol version: %d vs %d", clientVersion, si.version)
	}
	if si.peer != client.Public() {
		t.Fatalf("server got peer pubkey %s, want %s", si.peer, client.Public())
	}
	if spub := conn.Peer(); spub != server.Public() {
		t.Fatalf("client got peer pubkey %s, want %s", spub, server.Public())
	}
	if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) {
		t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr)
	}
	if param.doEarlyWrite {
		buf := make([]byte, len(earlyWriteMsg))
		if _, err := io.ReadFull(conn, buf); err != nil {
			t.Fatalf("reading early write: %v", err)
		}
		if string(buf) != earlyWriteMsg {
			t.Errorf("early write = %q; want %q", buf, earlyWriteMsg)
		}
	}
}

type serverResult struct {
	err        error
	clientAddr string
	version    int
	peer       key.MachinePublic
	conn       *controlbase.Conn
}

type proxy interface {
	Start(*testing.T) string
	Close()
	ConnIsFromProxy(string) bool
}

type socksProxy struct {
	sync.Mutex
	closed          bool
	proxy           socks5.Server
	ln              net.Listener
	clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy
}

func (s *socksProxy) Start(t *testing.T) (url string) {
	t.Helper()
	s.Lock()
	defer s.Unlock()
	ln, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		t.Fatalf("listening for SOCKS server: %v", err)
	}
	s.ln = ln
	s.clientConnAddrs = map[string]bool{}
	s.proxy.Logf = func(format string, a ...any) {
		s.Lock()
		defer s.Unlock()
		if s.closed {
			return
		}
		t.Logf(format, a...)
	}
	s.proxy.Dialer = s.dialAndRecord
	go s.proxy.Serve(ln)
	return fmt.Sprintf("socks5://%s", ln.Addr().String())
}

func (s *socksProxy) Close() {
	s.Lock()
	defer s.Unlock()
	if s.closed {
		return
	}
	s.closed = true
	s.ln.Close()
}

func (s *socksProxy) dialAndRecord(ctx context.Context, network, addr string) (net.Conn, error) {
	var d net.Dialer
	conn, err := d.DialContext(ctx, network, addr)
	if err != nil {
		return nil, err
	}
	s.Lock()
	defer s.Unlock()
	s.clientConnAddrs[conn.LocalAddr().String()] = true
	return conn, nil
}

func (s *socksProxy) ConnIsFromProxy(addr string) bool {
	s.Lock()
	defer s.Unlock()
	return s.clientConnAddrs[addr]
}

type httpProxy struct {
	useTLS       bool // take incoming connections over TLS
	allowConnect bool // allow CONNECT for TLS
	allowHTTP    bool // allow plain HTTP proxying

	sync.Mutex
	ln              net.Listener
	rp              httputil.ReverseProxy
	s               http.Server
	clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy
}

func (h *httpProxy) Start(t *testing.T) (url string) {
	t.Helper()
	h.Lock()
	defer h.Unlock()
	ln, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		t.Fatalf("listening for HTTP proxy: %v", err)
	}
	h.ln = ln
	h.rp = httputil.ReverseProxy{
		Director: func(*http.Request) {},
		Transport: &http.Transport{
			DialContext: h.dialAndRecord,
			TLSClientConfig: &tls.Config{
				InsecureSkipVerify: true,
			},
			TLSNextProto: map[string]func(string, *tls.Conn) http.RoundTripper{},
		},
	}
	h.clientConnAddrs = map[string]bool{}
	h.s.Handler = h
	if h.useTLS {
		h.s.TLSConfig = tlsConfig(t)
		go h.s.ServeTLS(h.ln, "", "")
		return fmt.Sprintf("https://%s", ln.Addr().String())
	} else {
		go h.s.Serve(h.ln)
		return fmt.Sprintf("http://%s", ln.Addr().String())
	}
}

func (h *httpProxy) Close() {
	h.Lock()
	defer h.Unlock()
	h.s.Close()
}

func (h *httpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	if r.Method != "CONNECT" {
		if !h.allowHTTP {
			http.Error(w, "http proxy not allowed", 500)
			return
		}
		h.rp.ServeHTTP(w, r)
		return
	}

	if !h.allowConnect {
		http.Error(w, "connect not allowed", 500)
		return
	}

	dst := r.RequestURI
	c, err := h.dialAndRecord(context.Background(), "tcp", dst)
	if err != nil {
		http.Error(w, err.Error(), 500)
		return
	}
	defer c.Close()

	cc, ccbuf, err := w.(http.Hijacker).Hijack()
	if err != nil {
		http.Error(w, err.Error(), 500)
		return
	}
	defer cc.Close()

	io.WriteString(cc, "HTTP/1.1 200 OK\r\n\r\n")

	errc := make(chan error, 1)
	go func() {
		_, err := io.Copy(cc, c)
		errc <- err
	}()
	go func() {
		_, err := io.Copy(c, ccbuf)
		errc <- err
	}()
	<-errc
}

func (h *httpProxy) dialAndRecord(ctx context.Context, network, addr string) (net.Conn, error) {
	var d net.Dialer
	conn, err := d.DialContext(ctx, network, addr)
	if err != nil {
		return nil, err
	}
	h.Lock()
	defer h.Unlock()
	h.clientConnAddrs[conn.LocalAddr().String()] = true
	return conn, nil
}

func (h *httpProxy) ConnIsFromProxy(addr string) bool {
	h.Lock()
	defer h.Unlock()
	return h.clientConnAddrs[addr]
}

func tlsConfig(t *testing.T) *tls.Config {
	// Cert and key taken from the example code in the crypto/tls
	// package.
	certPem := []byte(`-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`)
	keyPem := []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)
	cert, err := tls.X509KeyPair(certPem, keyPem)
	if err != nil {
		t.Fatal(err)
	}
	return &tls.Config{
		Certificates: []tls.Certificate{cert},
	}
}

func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Upgrade", upgradeHeaderValue)
		w.Header().Set("Connection", "upgrade")
		w.WriteHeader(http.StatusSwitchingProtocols)
		w.(http.Flusher).Flush()
		// Advance the clock to trigger HTTPs fallback.
		clock.Now()
		<-r.Context().Done()
	}
}

func TestDialPlan(t *testing.T) {
	if runtime.GOOS != "linux" {
		t.Skip("only works on Linux due to multiple localhost addresses")
	}

	client, server := key.NewMachine(), key.NewMachine()

	const (
		testProtocolVersion = 1
	)

	getRandomPort := func() string {
		ln, err := net.Listen("tcp", ":0")
		if err != nil {
			t.Fatalf("net.Listen: %v", err)
		}
		defer ln.Close()
		_, port, err := net.SplitHostPort(ln.Addr().String())
		if err != nil {
			t.Fatal(err)
		}
		return port
	}

	// We need consistent ports for each address; these are chosen
	// randomly and we hope that they won't conflict during this test.
	httpPort := getRandomPort()
	httpsPort := getRandomPort()

	makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
		done := make(chan struct{})
		t.Cleanup(func() {
			close(done)
		})
		var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			conn, err := AcceptHTTP(context.Background(), w, r, server, nil)
			if err != nil {
				log.Print(err)
			} else {
				defer conn.Close()
			}
			w.Header().Set("X-Handler-Name", name)
			<-done
		})
		if wrap != nil {
			handler = wrap(handler)
		}

		httpLn, err := net.Listen("tcp", host.String()+":"+httpPort)
		if err != nil {
			t.Fatalf("HTTP listen: %v", err)
		}
		httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort)
		if err != nil {
			t.Fatalf("HTTPS listen: %v", err)
		}

		httpServer := &http.Server{Handler: handler}
		go httpServer.Serve(httpLn)
		t.Cleanup(func() {
			httpServer.Close()
		})

		httpsServer := &http.Server{
			Handler:   handler,
			TLSConfig: tlsConfig(t),
			ErrorLog:  logger.StdLogger(logger.WithPrefix(t.Logf, "http.Server.ErrorLog: ")),
		}
		go httpsServer.ServeTLS(httpsLn, "", "")
		t.Cleanup(func() {
			httpsServer.Close()
		})
		return
	}

	fallbackAddr := netip.MustParseAddr("127.0.0.1")
	goodAddr := netip.MustParseAddr("127.0.0.2")
	otherAddr := netip.MustParseAddr("127.0.0.3")
	other2Addr := netip.MustParseAddr("127.0.0.4")
	brokenAddr := netip.MustParseAddr("127.0.0.10")

	testCases := []struct {
		name string
		plan *tailcfg.ControlDialPlan
		wrap func(http.Handler) http.Handler
		want netip.Addr

		allowFallback bool
	}{
		{
			name: "single",
			plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
				{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
			}},
			want: goodAddr,
		},
		{
			name: "broken-then-good",
			plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
				// Dials the broken one, which fails, and then
				// eventually dials the good one and succeeds
				{IP: brokenAddr, Priority: 2, DialTimeoutSec: 10},
				{IP: goodAddr, Priority: 1, DialTimeoutSec: 10, DialStartDelaySec: 1},
			}},
			want: goodAddr,
		},
		// TODO(#8442): fix this test
		// {
		// 	name: "multiple-priority-fast-path",
		// 	plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
		// 		// Dials some good IPs and our bad one (which
		// 		// hangs forever), which then hits the fast
		// 		// path where we bail without waiting.
		// 		{IP: brokenAddr, Priority: 1, DialTimeoutSec: 10},
		// 		{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
		// 		{IP: other2Addr, Priority: 1, DialTimeoutSec: 10},
		// 		{IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
		// 	}},
		// 	want: otherAddr,
		// },
		{
			name: "multiple-priority-slow-path",
			plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
				// Our broken address is the highest priority,
				// so we don't hit our fast path.
				{IP: brokenAddr, Priority: 10, DialTimeoutSec: 10},
				{IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
				{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
			}},
			want: otherAddr,
		},
		{
			name: "fallback",
			plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
				{IP: brokenAddr, Priority: 1, DialTimeoutSec: 1},
			}},
			want:          fallbackAddr,
			allowFallback: true,
		},
	}
	for _, tt := range testCases {
		t.Run(tt.name, func(t *testing.T) {
			// TODO(awly): replace this with tstest.NewClock and update the
			// test to advance the clock correctly.
			clock := tstime.StdClock{}
			makeHandler(t, "fallback", fallbackAddr, nil)
			makeHandler(t, "good", goodAddr, nil)
			makeHandler(t, "other", otherAddr, nil)
			makeHandler(t, "other2", other2Addr, nil)
			makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler {
				return brokenMITMHandler(clock)
			})

			dialer := closeTrackDialer{
				t:     t,
				inner: new(tsdial.Dialer).SystemDial,
				conns: make(map[*closeTrackConn]bool),
			}
			defer dialer.Done()

			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
			defer cancel()

			// By default, we intentionally point to something that
			// we know won't connect, since we want a fallback to
			// DNS to be an error.
			host := "example.com"
			if tt.allowFallback {
				host = "localhost"
			}

			drained := make(chan struct{})
			a := &Dialer{
				Hostname:             host,
				HTTPPort:             httpPort,
				HTTPSPort:            httpsPort,
				MachineKey:           client,
				ControlKey:           server.Public(),
				ProtocolVersion:      testProtocolVersion,
				Dialer:               dialer.Dial,
				Logf:                 t.Logf,
				DialPlan:             tt.plan,
				proxyFunc:            func(*http.Request) (*url.URL, error) { return nil, nil },
				drainFinished:        drained,
				omitCertErrorLogging: true,
				testFallbackDelay:    50 * time.Millisecond,
				Clock:                clock,
			}

			conn, err := a.dial(ctx)
			if err != nil {
				t.Fatalf("dialing controlhttp: %v", err)
			}
			defer conn.Close()

			raddr := conn.RemoteAddr().(*net.TCPAddr)

			got, ok := netip.AddrFromSlice(raddr.IP)
			if !ok {
				t.Errorf("invalid remote IP: %v", raddr.IP)
			} else if got != tt.want {
				t.Errorf("got connection from %q; want %q", got, tt.want)
			} else {
				t.Logf("successfully connected to %q", raddr.String())
			}

			// Wait until our dialer drains so we can verify that
			// all connections are closed.
			<-drained
		})
	}
}

type closeTrackDialer struct {
	t     testing.TB
	inner dnscache.DialContextFunc
	mu    sync.Mutex
	conns map[*closeTrackConn]bool
}

func (d *closeTrackDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
	c, err := d.inner(ctx, network, addr)
	if err != nil {
		return nil, err
	}
	ct := &closeTrackConn{Conn: c, d: d}

	d.mu.Lock()
	d.conns[ct] = true
	d.mu.Unlock()
	return ct, nil
}

func (d *closeTrackDialer) Done() {
	// Unfortunately, tsdial.Dialer.SystemDial closes connections
	// asynchronously in a goroutine, so we can't assume that everything is
	// closed by the time we get here.
	//
	// Sleep/wait a few times on the assumption that things will close
	// "eventually".
	const iters = 100
	for i := 0; i < iters; i++ {
		d.mu.Lock()
		if len(d.conns) == 0 {
			d.mu.Unlock()
			return
		}

		// Only error on last iteration
		if i != iters-1 {
			d.mu.Unlock()
			time.Sleep(100 * time.Millisecond)
			continue
		}

		for conn := range d.conns {
			d.t.Errorf("expected close of conn %p; RemoteAddr=%q", conn, conn.RemoteAddr().String())
		}
		d.mu.Unlock()
	}
}

func (d *closeTrackDialer) noteClose(c *closeTrackConn) {
	d.mu.Lock()
	delete(d.conns, c) // safe if already deleted
	d.mu.Unlock()
}

type closeTrackConn struct {
	net.Conn
	d *closeTrackDialer
}

func (c *closeTrackConn) Close() error {
	c.d.noteClose(c)
	return c.Conn.Close()
}