// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package prober

import (
	"bytes"
	"context"
	"crypto"
	"crypto/rand"
	"crypto/rsa"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"math/big"
	"net"
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"
	"time"

	"golang.org/x/crypto/ocsp"
)

var leafCert = x509.Certificate{
	SerialNumber:       big.NewInt(10001),
	Subject:            pkix.Name{CommonName: "tlsprobe.test"},
	SignatureAlgorithm: x509.SHA256WithRSA,
	PublicKeyAlgorithm: x509.RSA,
	Version:            3,
	IPAddresses:        []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
	NotBefore:          time.Now().Add(-5 * time.Minute),
	NotAfter:           time.Now().Add(60 * 24 * time.Hour),
	SubjectKeyId:       []byte{1, 2, 3},
	AuthorityKeyId:     []byte{1, 2, 3, 4, 5}, // issuerCert below
	ExtKeyUsage:        []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
	KeyUsage:           x509.KeyUsageDigitalSignature,
}

var issuerCertTpl = x509.Certificate{
	SerialNumber:       big.NewInt(10002),
	Subject:            pkix.Name{CommonName: "tlsprobe.ca.test"},
	SignatureAlgorithm: x509.SHA256WithRSA,
	PublicKeyAlgorithm: x509.RSA,
	Version:            3,
	IPAddresses:        []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
	NotBefore:          time.Now().Add(-5 * time.Minute),
	NotAfter:           time.Now().Add(60 * 24 * time.Hour),
	SubjectKeyId:       []byte{1, 2, 3, 4, 5},
	ExtKeyUsage:        []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
	KeyUsage:           x509.KeyUsageDigitalSignature,
}

func simpleCert() (tls.Certificate, error) {
	certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
	if err != nil {
		return tls.Certificate{}, err
	}
	certPrivKeyPEM := new(bytes.Buffer)
	pem.Encode(certPrivKeyPEM, &pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
	})
	certBytes, err := x509.CreateCertificate(rand.Reader, &leafCert, &leafCert, &certPrivKey.PublicKey, certPrivKey)
	if err != nil {
		return tls.Certificate{}, err
	}
	certPEM := new(bytes.Buffer)
	pem.Encode(certPEM, &pem.Block{
		Type:  "CERTIFICATE",
		Bytes: certBytes,
	})
	return tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
}

func TestTLSConnection(t *testing.T) {
	crt, err := simpleCert()
	if err != nil {
		t.Fatal(err)
	}
	srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
	srv.TLS = &tls.Config{Certificates: []tls.Certificate{crt}}
	srv.StartTLS()
	defer srv.Close()

	err = probeTLS(context.Background(), "fail.example.com", srv.Listener.Addr().String())
	// The specific error message here is platform-specific ("certificate is not trusted"
	// on macOS and "certificate signed by unknown authority" on Linux), so only check
	// that it contains the word 'certificate'.
	if err == nil || !strings.Contains(err.Error(), "certificate") {
		t.Errorf("unexpected error: %q", err)
	}
}

func TestCertExpiration(t *testing.T) {
	for _, tt := range []struct {
		name    string
		cert    func() *x509.Certificate
		wantErr string
	}{
		{
			"cert not valid yet",
			func() *x509.Certificate {
				c := leafCert
				c.NotBefore = time.Now().Add(time.Hour)
				return &c
			},
			"one of the certs has NotBefore in the future",
		},
		{
			"cert expiring soon",
			func() *x509.Certificate {
				c := leafCert
				c.NotAfter = time.Now().Add(time.Hour)
				return &c
			},
			"one of the certs expires in",
		},
		{
			"valid duration but no OCSP",
			func() *x509.Certificate { return &leafCert },
			"no OCSP server presented in leaf cert for CN=tlsprobe.test",
		},
	} {
		t.Run(tt.name, func(t *testing.T) {
			cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{tt.cert()}}
			err := validateConnState(context.Background(), cs)
			if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
				t.Errorf("unexpected error %q; want %q", err, tt.wantErr)
			}
		})
	}
}

type ocspServer struct {
	issuer        *x509.Certificate
	responderCert *x509.Certificate
	template      *ocsp.Response
	priv          crypto.Signer
}

func (s *ocspServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	if s.template == nil {
		w.WriteHeader(http.StatusInternalServerError)
		return
	}
	resp, err := ocsp.CreateResponse(s.issuer, s.responderCert, *s.template, s.priv)
	if err != nil {
		panic(err)
	}
	w.Write(resp)
}

func TestOCSP(t *testing.T) {
	issuerKey, err := rsa.GenerateKey(rand.Reader, 4096)
	if err != nil {
		t.Fatal(err)
	}
	issuerBytes, err := x509.CreateCertificate(rand.Reader, &issuerCertTpl, &issuerCertTpl, &issuerKey.PublicKey, issuerKey)
	if err != nil {
		t.Fatal(err)
	}
	issuerCert, err := x509.ParseCertificate(issuerBytes)
	if err != nil {
		t.Fatal(err)
	}

	responderKey, err := rsa.GenerateKey(rand.Reader, 4096)
	if err != nil {
		t.Fatal(err)
	}
	// issuer cert template re-used here, but with a different key
	responderBytes, err := x509.CreateCertificate(rand.Reader, &issuerCertTpl, &issuerCertTpl, &responderKey.PublicKey, responderKey)
	if err != nil {
		t.Fatal(err)
	}
	responderCert, err := x509.ParseCertificate(responderBytes)
	if err != nil {
		t.Fatal(err)
	}

	handler := &ocspServer{
		issuer:        issuerCert,
		responderCert: responderCert,
		priv:          issuerKey,
	}
	srv := httptest.NewUnstartedServer(handler)
	srv.Start()
	defer srv.Close()

	cert := leafCert
	cert.OCSPServer = append(cert.OCSPServer, srv.URL)
	key, err := rsa.GenerateKey(rand.Reader, 4096)
	if err != nil {
		t.Fatal(err)
	}
	certBytes, err := x509.CreateCertificate(rand.Reader, &cert, issuerCert, &key.PublicKey, issuerKey)
	if err != nil {
		t.Fatal(err)
	}
	parsed, err := x509.ParseCertificate(certBytes)
	if err != nil {
		t.Fatal(err)
	}

	for _, tt := range []struct {
		name    string
		resp    *ocsp.Response
		wantErr string
	}{
		{"good response", &ocsp.Response{Status: ocsp.Good}, ""},
		{"unknown response", &ocsp.Response{Status: ocsp.Unknown}, "unknown OCSP verification status for CN=tlsprobe.test"},
		{"revoked response", &ocsp.Response{Status: ocsp.Revoked}, "cert for CN=tlsprobe.test has been revoked"},
		{"error 500 from ocsp", nil, "non-200 status code from OCSP"},
	} {
		t.Run(tt.name, func(t *testing.T) {
			handler.template = tt.resp
			if handler.template != nil {
				handler.template.SerialNumber = big.NewInt(1337)
			}
			cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{parsed, issuerCert}}
			err := validateConnState(context.Background(), cs)

			if err == nil && tt.wantErr == "" {
				return
			}

			if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
				t.Errorf("unexpected error %q; want %q", err, tt.wantErr)
			}
		})
	}
}