// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package controlhttp

import (
	"context"
	"crypto/tls"
	"fmt"
	"io"
	"log"
	"net"
	"net/http"
	"net/http/httputil"
	"net/url"
	"strconv"
	"sync"
	"testing"
	"time"

	"tailscale.com/control/controlbase"
	"tailscale.com/net/socks5"
	"tailscale.com/net/tsdial"
	"tailscale.com/types/key"
)

type httpTestParam struct {
	name  string
	proxy proxy

	// makeHTTPHangAfterUpgrade makes the HTTP response hang after sending a
	// 101 switching protocols.
	makeHTTPHangAfterUpgrade bool
}

func TestControlHTTP(t *testing.T) {
	tests := []httpTestParam{
		// direct connection
		{
			name:  "no_proxy",
			proxy: nil,
		},
		// direct connection but port 80 is MITM'ed and broken
		{
			name:                     "port80_broken_mitm",
			proxy:                    nil,
			makeHTTPHangAfterUpgrade: true,
		},
		// SOCKS5
		{
			name:  "socks5",
			proxy: &socksProxy{},
		},
		// HTTP->HTTP
		{
			name: "http_to_http",
			proxy: &httpProxy{
				useTLS:       false,
				allowConnect: false,
				allowHTTP:    true,
			},
		},
		// HTTP->HTTPS
		{
			name: "http_to_https",
			proxy: &httpProxy{
				useTLS:       false,
				allowConnect: true,
				allowHTTP:    false,
			},
		},
		// HTTP->any (will pick HTTP)
		{
			name: "http_to_any",
			proxy: &httpProxy{
				useTLS:       false,
				allowConnect: true,
				allowHTTP:    true,
			},
		},
		// HTTPS->HTTP
		{
			name: "https_to_http",
			proxy: &httpProxy{
				useTLS:       true,
				allowConnect: false,
				allowHTTP:    true,
			},
		},
		// HTTPS->HTTPS
		{
			name: "https_to_https",
			proxy: &httpProxy{
				useTLS:       true,
				allowConnect: true,
				allowHTTP:    false,
			},
		},
		// HTTPS->any (will pick HTTP)
		{
			name: "https_to_any",
			proxy: &httpProxy{
				useTLS:       true,
				allowConnect: true,
				allowHTTP:    true,
			},
		},
	}

	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			testControlHTTP(t, test)
		})
	}
}

func testControlHTTP(t *testing.T, param httpTestParam) {
	proxy := param.proxy
	client, server := key.NewMachine(), key.NewMachine()

	const testProtocolVersion = 1
	sch := make(chan serverResult, 1)
	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		conn, err := AcceptHTTP(context.Background(), w, r, server)
		if err != nil {
			log.Print(err)
		}
		res := serverResult{
			err: err,
		}
		if conn != nil {
			res.clientAddr = conn.RemoteAddr().String()
			res.version = conn.ProtocolVersion()
			res.peer = conn.Peer()
			res.conn = conn
		}
		sch <- res
	})

	httpLn, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		t.Fatalf("HTTP listen: %v", err)
	}
	httpsLn, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		t.Fatalf("HTTPS listen: %v", err)
	}

	var httpHandler http.Handler = handler
	if param.makeHTTPHangAfterUpgrade {
		httpHandler = http.HandlerFunc(brokenMITMHandler)
	}
	httpServer := &http.Server{Handler: httpHandler}
	go httpServer.Serve(httpLn)
	defer httpServer.Close()

	httpsServer := &http.Server{
		Handler:   handler,
		TLSConfig: tlsConfig(t),
	}
	go httpsServer.ServeTLS(httpsLn, "", "")
	defer httpsServer.Close()

	ctx := context.Background()
	const debugTimeout = false
	if debugTimeout {
		var cancel context.CancelFunc
		ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
	}

	a := dialParams{
		host:              "localhost",
		httpPort:          strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port),
		httpsPort:         strconv.Itoa(httpsLn.Addr().(*net.TCPAddr).Port),
		machineKey:        client,
		controlKey:        server.Public(),
		version:           testProtocolVersion,
		insecureTLS:       true,
		dialer:            new(tsdial.Dialer).SystemDial,
		testFallbackDelay: 50 * time.Millisecond,
	}

	if proxy != nil {
		proxyEnv := proxy.Start(t)
		defer proxy.Close()
		proxyURL, err := url.Parse(proxyEnv)
		if err != nil {
			t.Fatal(err)
		}
		a.proxyFunc = func(*http.Request) (*url.URL, error) {
			return proxyURL, nil
		}
	} else {
		a.proxyFunc = func(*http.Request) (*url.URL, error) {
			return nil, nil
		}
	}

	conn, err := a.dial(ctx)
	if err != nil {
		t.Fatalf("dialing controlhttp: %v", err)
	}
	defer conn.Close()
	si := <-sch
	if si.conn != nil {
		defer si.conn.Close()
	}
	if si.err != nil {
		t.Fatalf("controlhttp server got error: %v", err)
	}
	if clientVersion := conn.ProtocolVersion(); si.version != clientVersion {
		t.Fatalf("client and server don't agree on protocol version: %d vs %d", clientVersion, si.version)
	}
	if si.peer != client.Public() {
		t.Fatalf("server got peer pubkey %s, want %s", si.peer, client.Public())
	}
	if spub := conn.Peer(); spub != server.Public() {
		t.Fatalf("client got peer pubkey %s, want %s", spub, server.Public())
	}
	if proxy != nil && !proxy.ConnIsFromProxy(si.clientAddr) {
		t.Fatalf("client connected from %s, which isn't the proxy", si.clientAddr)
	}
}

type serverResult struct {
	err        error
	clientAddr string
	version    int
	peer       key.MachinePublic
	conn       *controlbase.Conn
}

type proxy interface {
	Start(*testing.T) string
	Close()
	ConnIsFromProxy(string) bool
}

type socksProxy struct {
	sync.Mutex
	closed          bool
	proxy           socks5.Server
	ln              net.Listener
	clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy
}

func (s *socksProxy) Start(t *testing.T) (url string) {
	t.Helper()
	s.Lock()
	defer s.Unlock()
	ln, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		t.Fatalf("listening for SOCKS server: %v", err)
	}
	s.ln = ln
	s.clientConnAddrs = map[string]bool{}
	s.proxy.Logf = func(format string, a ...any) {
		s.Lock()
		defer s.Unlock()
		if s.closed {
			return
		}
		t.Logf(format, a...)
	}
	s.proxy.Dialer = s.dialAndRecord
	go s.proxy.Serve(ln)
	return fmt.Sprintf("socks5://%s", ln.Addr().String())
}

func (s *socksProxy) Close() {
	s.Lock()
	defer s.Unlock()
	if s.closed {
		return
	}
	s.closed = true
	s.ln.Close()
}

func (s *socksProxy) dialAndRecord(ctx context.Context, network, addr string) (net.Conn, error) {
	var d net.Dialer
	conn, err := d.DialContext(ctx, network, addr)
	if err != nil {
		return nil, err
	}
	s.Lock()
	defer s.Unlock()
	s.clientConnAddrs[conn.LocalAddr().String()] = true
	return conn, nil
}

func (s *socksProxy) ConnIsFromProxy(addr string) bool {
	s.Lock()
	defer s.Unlock()
	return s.clientConnAddrs[addr]
}

type httpProxy struct {
	useTLS       bool // take incoming connections over TLS
	allowConnect bool // allow CONNECT for TLS
	allowHTTP    bool // allow plain HTTP proxying

	sync.Mutex
	ln              net.Listener
	rp              httputil.ReverseProxy
	s               http.Server
	clientConnAddrs map[string]bool // addrs of the local end of outgoing conns from proxy
}

func (h *httpProxy) Start(t *testing.T) (url string) {
	t.Helper()
	h.Lock()
	defer h.Unlock()
	ln, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		t.Fatalf("listening for HTTP proxy: %v", err)
	}
	h.ln = ln
	h.rp = httputil.ReverseProxy{
		Director: func(*http.Request) {},
		Transport: &http.Transport{
			DialContext: h.dialAndRecord,
			TLSClientConfig: &tls.Config{
				InsecureSkipVerify: true,
			},
			TLSNextProto: map[string]func(string, *tls.Conn) http.RoundTripper{},
		},
	}
	h.clientConnAddrs = map[string]bool{}
	h.s.Handler = h
	if h.useTLS {
		h.s.TLSConfig = tlsConfig(t)
		go h.s.ServeTLS(h.ln, "", "")
		return fmt.Sprintf("https://%s", ln.Addr().String())
	} else {
		go h.s.Serve(h.ln)
		return fmt.Sprintf("http://%s", ln.Addr().String())
	}
}

func (h *httpProxy) Close() {
	h.Lock()
	defer h.Unlock()
	h.s.Close()
}

func (h *httpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	if r.Method != "CONNECT" {
		if !h.allowHTTP {
			http.Error(w, "http proxy not allowed", 500)
			return
		}
		h.rp.ServeHTTP(w, r)
		return
	}

	if !h.allowConnect {
		http.Error(w, "connect not allowed", 500)
		return
	}

	dst := r.RequestURI
	c, err := h.dialAndRecord(context.Background(), "tcp", dst)
	if err != nil {
		http.Error(w, err.Error(), 500)
		return
	}
	defer c.Close()

	cc, ccbuf, err := w.(http.Hijacker).Hijack()
	if err != nil {
		http.Error(w, err.Error(), 500)
		return
	}
	defer cc.Close()

	io.WriteString(cc, "HTTP/1.1 200 OK\r\n\r\n")

	errc := make(chan error, 1)
	go func() {
		_, err := io.Copy(cc, c)
		errc <- err
	}()
	go func() {
		_, err := io.Copy(c, ccbuf)
		errc <- err
	}()
	<-errc
}

func (h *httpProxy) dialAndRecord(ctx context.Context, network, addr string) (net.Conn, error) {
	var d net.Dialer
	conn, err := d.DialContext(ctx, network, addr)
	if err != nil {
		return nil, err
	}
	h.Lock()
	defer h.Unlock()
	h.clientConnAddrs[conn.LocalAddr().String()] = true
	return conn, nil
}

func (h *httpProxy) ConnIsFromProxy(addr string) bool {
	h.Lock()
	defer h.Unlock()
	return h.clientConnAddrs[addr]
}

func tlsConfig(t *testing.T) *tls.Config {
	// Cert and key taken from the example code in the crypto/tls
	// package.
	certPem := []byte(`-----BEGIN CERTIFICATE-----
MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw
DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow
EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d
7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B
5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr
BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1
NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l
Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc
6MF9+Yw1Yy0t
-----END CERTIFICATE-----`)
	keyPem := []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)
	cert, err := tls.X509KeyPair(certPem, keyPem)
	if err != nil {
		t.Fatal(err)
	}
	return &tls.Config{
		Certificates: []tls.Certificate{cert},
	}
}

func brokenMITMHandler(w http.ResponseWriter, r *http.Request) {
	w.Header().Set("Upgrade", upgradeHeaderValue)
	w.Header().Set("Connection", "upgrade")
	w.WriteHeader(http.StatusSwitchingProtocols)
	w.(http.Flusher).Flush()
	<-r.Context().Done()
}