tailscale/prober/tls_test.go

236 lines
6.6 KiB
Go
Raw Permalink Normal View History

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package prober
import (
"bytes"
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"golang.org/x/crypto/ocsp"
)
var leafCert = x509.Certificate{
SerialNumber: big.NewInt(10001),
Subject: pkix.Name{CommonName: "tlsprobe.test"},
SignatureAlgorithm: x509.SHA256WithRSA,
PublicKeyAlgorithm: x509.RSA,
Version: 3,
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now().Add(-5 * time.Minute),
NotAfter: time.Now().Add(60 * 24 * time.Hour),
SubjectKeyId: []byte{1, 2, 3},
AuthorityKeyId: []byte{1, 2, 3, 4, 5}, // issuerCert below
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
var issuerCertTpl = x509.Certificate{
SerialNumber: big.NewInt(10002),
Subject: pkix.Name{CommonName: "tlsprobe.ca.test"},
SignatureAlgorithm: x509.SHA256WithRSA,
PublicKeyAlgorithm: x509.RSA,
Version: 3,
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
NotBefore: time.Now().Add(-5 * time.Minute),
NotAfter: time.Now().Add(60 * 24 * time.Hour),
SubjectKeyId: []byte{1, 2, 3, 4, 5},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
func simpleCert() (tls.Certificate, error) {
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return tls.Certificate{}, err
}
certPrivKeyPEM := new(bytes.Buffer)
pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
})
certBytes, err := x509.CreateCertificate(rand.Reader, &leafCert, &leafCert, &certPrivKey.PublicKey, certPrivKey)
if err != nil {
return tls.Certificate{}, err
}
certPEM := new(bytes.Buffer)
pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})
return tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
}
func TestTLSConnection(t *testing.T) {
crt, err := simpleCert()
if err != nil {
t.Fatal(err)
}
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
srv.TLS = &tls.Config{Certificates: []tls.Certificate{crt}}
srv.StartTLS()
defer srv.Close()
err = probeTLS(context.Background(), srv.Listener.Addr().String())
// The specific error message here is platform-specific ("certificate is not trusted"
// on macOS and "certificate signed by unknown authority" on Linux), so only check
// that it contains the word 'certificate'.
if err == nil || !strings.Contains(err.Error(), "certificate") {
t.Errorf("unexpected error: %q", err)
}
}
func TestCertExpiration(t *testing.T) {
for _, tt := range []struct {
name string
cert func() *x509.Certificate
wantErr string
}{
{
"cert not valid yet",
func() *x509.Certificate {
c := leafCert
c.NotBefore = time.Now().Add(time.Hour)
return &c
},
"one of the certs has NotBefore in the future",
},
{
"cert expiring soon",
func() *x509.Certificate {
c := leafCert
c.NotAfter = time.Now().Add(time.Hour)
return &c
},
"one of the certs expires in",
},
{
"valid duration but no OCSP",
func() *x509.Certificate { return &leafCert },
"no OCSP server presented in leaf cert for CN=tlsprobe.test",
},
} {
t.Run(tt.name, func(t *testing.T) {
cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{tt.cert()}}
err := validateConnState(context.Background(), cs)
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("unexpected error %q; want %q", err, tt.wantErr)
}
})
}
}
type ocspServer struct {
issuer *x509.Certificate
responderCert *x509.Certificate
template *ocsp.Response
priv crypto.Signer
}
func (s *ocspServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.template == nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
resp, err := ocsp.CreateResponse(s.issuer, s.responderCert, *s.template, s.priv)
if err != nil {
panic(err)
}
w.Write(resp)
}
func TestOCSP(t *testing.T) {
issuerKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
t.Fatal(err)
}
issuerBytes, err := x509.CreateCertificate(rand.Reader, &issuerCertTpl, &issuerCertTpl, &issuerKey.PublicKey, issuerKey)
if err != nil {
t.Fatal(err)
}
issuerCert, err := x509.ParseCertificate(issuerBytes)
if err != nil {
t.Fatal(err)
}
responderKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
t.Fatal(err)
}
// issuer cert template re-used here, but with a different key
responderBytes, err := x509.CreateCertificate(rand.Reader, &issuerCertTpl, &issuerCertTpl, &responderKey.PublicKey, responderKey)
if err != nil {
t.Fatal(err)
}
responderCert, err := x509.ParseCertificate(responderBytes)
if err != nil {
t.Fatal(err)
}
handler := &ocspServer{
issuer: issuerCert,
responderCert: responderCert,
priv: issuerKey,
}
srv := httptest.NewUnstartedServer(handler)
srv.Start()
defer srv.Close()
cert := leafCert
cert.OCSPServer = append(cert.OCSPServer, srv.URL)
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
t.Fatal(err)
}
certBytes, err := x509.CreateCertificate(rand.Reader, &cert, issuerCert, &key.PublicKey, issuerKey)
if err != nil {
t.Fatal(err)
}
parsed, err := x509.ParseCertificate(certBytes)
if err != nil {
t.Fatal(err)
}
for _, tt := range []struct {
name string
resp *ocsp.Response
wantErr string
}{
{"good response", &ocsp.Response{Status: ocsp.Good}, ""},
{"unknown response", &ocsp.Response{Status: ocsp.Unknown}, "unknown OCSP verification status for CN=tlsprobe.test"},
{"revoked response", &ocsp.Response{Status: ocsp.Revoked}, "cert for CN=tlsprobe.test has been revoked"},
{"error 500 from ocsp", nil, "non-200 status code from OCSP"},
} {
t.Run(tt.name, func(t *testing.T) {
handler.template = tt.resp
if handler.template != nil {
handler.template.SerialNumber = big.NewInt(1337)
}
cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{parsed, issuerCert}}
err := validateConnState(context.Background(), cs)
if err == nil && tt.wantErr == "" {
return
}
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("unexpected error %q; want %q", err, tt.wantErr)
}
})
}
}