mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-16 11:41:39 +00:00
crypto/x509: add support for CertPool to load certs lazily
(from patchset 1, 7cdc3c3e7427c9ef69e19224d6036c09c5ea1723, of https://go-review.googlesource.com/c/go/+/229917/1) This will allow building CertPools that consume less memory. (Most certs are never accessed. Different users/programs access different ones, but not many.) This CL only adds the new internal mechanism (and uses it for the old AddCert) but does not modify any existing root pool behavior. (That is, the default Unix roots are still all slurped into memory as of this CL) Change-Id: Ib3a42e4050627b5e34413c595d8ced839c7bfa14
This commit is contained in:
parent
6b232b5a79
commit
f5993f2440
@ -12,9 +12,15 @@ import (
|
|||||||
|
|
||||||
// CertPool is a set of certificates.
|
// CertPool is a set of certificates.
|
||||||
type CertPool struct {
|
type CertPool struct {
|
||||||
bySubjectKeyId map[string][]int
|
bySubjectKeyId map[string][]int // cert.SubjectKeyId => getCert index
|
||||||
byName map[string][]int
|
byName map[string][]int // cert.RawSubject => getCert index
|
||||||
certs []*Certificate
|
|
||||||
|
// getCert contains funcs that return the certificates.
|
||||||
|
getCert []func() (*Certificate, error)
|
||||||
|
|
||||||
|
// rawSubjects is each cert's RawSubject field.
|
||||||
|
// Its indexes correspond to the getCert indexes.
|
||||||
|
rawSubjects [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCertPool returns a new, empty CertPool.
|
// NewCertPool returns a new, empty CertPool.
|
||||||
@ -25,11 +31,26 @@ func NewCertPool() *CertPool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// len returns the number of certs in the set.
|
||||||
|
// A nil set is a valid empty set.
|
||||||
|
func (s *CertPool) len() int {
|
||||||
|
if s == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return len(s.getCert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cert returns cert index n in s.
|
||||||
|
func (s *CertPool) cert(n int) (*Certificate, error) {
|
||||||
|
return s.getCert[n]()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *CertPool) copy() *CertPool {
|
func (s *CertPool) copy() *CertPool {
|
||||||
p := &CertPool{
|
p := &CertPool{
|
||||||
bySubjectKeyId: make(map[string][]int, len(s.bySubjectKeyId)),
|
bySubjectKeyId: make(map[string][]int, len(s.bySubjectKeyId)),
|
||||||
byName: make(map[string][]int, len(s.byName)),
|
byName: make(map[string][]int, len(s.byName)),
|
||||||
certs: make([]*Certificate, len(s.certs)),
|
getCert: make([]func() (*Certificate, error), len(s.getCert)),
|
||||||
|
rawSubjects: make([][]byte, len(s.rawSubjects)),
|
||||||
}
|
}
|
||||||
for k, v := range s.bySubjectKeyId {
|
for k, v := range s.bySubjectKeyId {
|
||||||
indexes := make([]int, len(v))
|
indexes := make([]int, len(v))
|
||||||
@ -41,7 +62,8 @@ func (s *CertPool) copy() *CertPool {
|
|||||||
copy(indexes, v)
|
copy(indexes, v)
|
||||||
p.byName[k] = indexes
|
p.byName[k] = indexes
|
||||||
}
|
}
|
||||||
copy(p.certs, s.certs)
|
copy(p.getCert, s.getCert)
|
||||||
|
copy(p.rawSubjects, s.rawSubjects)
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,19 +104,22 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []int {
|
|||||||
return candidates
|
return candidates
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *CertPool) contains(cert *Certificate) bool {
|
func (s *CertPool) contains(cert *Certificate) (bool, error) {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return false
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
candidates := s.byName[string(cert.RawSubject)]
|
candidates := s.byName[string(cert.RawSubject)]
|
||||||
for _, c := range candidates {
|
for _, i := range candidates {
|
||||||
if s.certs[c].Equal(cert) {
|
c, err := s.cert(i)
|
||||||
return true
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if c.Equal(cert) {
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddCert adds a certificate to a pool.
|
// AddCert adds a certificate to a pool.
|
||||||
@ -102,21 +127,47 @@ func (s *CertPool) AddCert(cert *Certificate) {
|
|||||||
if cert == nil {
|
if cert == nil {
|
||||||
panic("adding nil Certificate to CertPool")
|
panic("adding nil Certificate to CertPool")
|
||||||
}
|
}
|
||||||
|
err := s.AddCertFunc(string(cert.RawSubject), string(cert.SubjectKeyId), func() (*Certificate, error) {
|
||||||
|
return cert, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCertFunc adds metadata about a certificate to a pool, along with
|
||||||
|
// a func to fetch that certificate later when needed.
|
||||||
|
//
|
||||||
|
// The rawSubject is Certificate.RawSubject and must be non-empty.
|
||||||
|
// The subjectKeyID is Certificate.SubjectKeyId and may be empty.
|
||||||
|
// The getCert func may be called 0 or more times.
|
||||||
|
func (s *CertPool) AddCertFunc(rawSubject, subjectKeyID string, getCert func() (*Certificate, error)) error {
|
||||||
|
if getCert == nil {
|
||||||
|
panic("getCert can't be nil")
|
||||||
|
}
|
||||||
|
|
||||||
// Check that the certificate isn't being added twice.
|
// Check that the certificate isn't being added twice.
|
||||||
if s.contains(cert) {
|
if len(s.byName[rawSubject]) > 0 {
|
||||||
return
|
c, err := getCert()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if dup, err := s.contains(c); dup {
|
||||||
|
return nil
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n := len(s.certs)
|
n := len(s.getCert)
|
||||||
s.certs = append(s.certs, cert)
|
s.getCert = append(s.getCert, getCert)
|
||||||
|
|
||||||
if len(cert.SubjectKeyId) > 0 {
|
if subjectKeyID != "" {
|
||||||
keyId := string(cert.SubjectKeyId)
|
s.bySubjectKeyId[subjectKeyID] = append(s.bySubjectKeyId[subjectKeyID], n)
|
||||||
s.bySubjectKeyId[keyId] = append(s.bySubjectKeyId[keyId], n)
|
|
||||||
}
|
}
|
||||||
name := string(cert.RawSubject)
|
s.byName[rawSubject] = append(s.byName[rawSubject], n)
|
||||||
s.byName[name] = append(s.byName[name], n)
|
s.rawSubjects = append(s.rawSubjects, []byte(rawSubject))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AppendCertsFromPEM attempts to parse a series of PEM encoded certificates.
|
// AppendCertsFromPEM attempts to parse a series of PEM encoded certificates.
|
||||||
@ -151,9 +202,9 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
|
|||||||
// Subjects returns a list of the DER-encoded subjects of
|
// Subjects returns a list of the DER-encoded subjects of
|
||||||
// all of the certificates in the pool.
|
// all of the certificates in the pool.
|
||||||
func (s *CertPool) Subjects() [][]byte {
|
func (s *CertPool) Subjects() [][]byte {
|
||||||
res := make([][]byte, len(s.certs))
|
res := make([][]byte, s.len())
|
||||||
for i, c := range s.certs {
|
for i, s := range s.rawSubjects {
|
||||||
res[i] = c.RawSubject
|
res[i] = s
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
@ -1993,7 +1993,7 @@ func TestConstraintCases(t *testing.T) {
|
|||||||
pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
|
pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
|
||||||
return buf.String()
|
return buf.String()
|
||||||
}
|
}
|
||||||
t.Errorf("#%d: root:\n%s", i, certAsPEM(rootPool.certs[0]))
|
t.Errorf("#%d: root:\n%s", i, certAsPEM(rootPool.mustCert(0)))
|
||||||
t.Errorf("#%d: leaf:\n%s", i, certAsPEM(leafCert))
|
t.Errorf("#%d: leaf:\n%s", i, certAsPEM(leafCert))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2019,10 +2019,18 @@ func writePEMsToTempFile(certs []*Certificate) *os.File {
|
|||||||
return file
|
return file
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func allCerts(p *CertPool) []*Certificate {
|
||||||
|
all := make([]*Certificate, p.len())
|
||||||
|
for i := range all {
|
||||||
|
all[i] = p.mustCert(i)
|
||||||
|
}
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
|
||||||
func testChainAgainstOpenSSL(leaf *Certificate, intermediates, roots *CertPool) (string, error) {
|
func testChainAgainstOpenSSL(leaf *Certificate, intermediates, roots *CertPool) (string, error) {
|
||||||
args := []string{"verify", "-no_check_time"}
|
args := []string{"verify", "-no_check_time"}
|
||||||
|
|
||||||
rootsFile := writePEMsToTempFile(roots.certs)
|
rootsFile := writePEMsToTempFile(allCerts(roots))
|
||||||
if debugOpenSSLFailure {
|
if debugOpenSSLFailure {
|
||||||
println("roots file:", rootsFile.Name())
|
println("roots file:", rootsFile.Name())
|
||||||
} else {
|
} else {
|
||||||
@ -2030,8 +2038,8 @@ func testChainAgainstOpenSSL(leaf *Certificate, intermediates, roots *CertPool)
|
|||||||
}
|
}
|
||||||
args = append(args, "-CAfile", rootsFile.Name())
|
args = append(args, "-CAfile", rootsFile.Name())
|
||||||
|
|
||||||
if len(intermediates.certs) > 0 {
|
if intermediates.len() > 0 {
|
||||||
intermediatesFile := writePEMsToTempFile(intermediates.certs)
|
intermediatesFile := writePEMsToTempFile(allCerts(intermediates))
|
||||||
if debugOpenSSLFailure {
|
if debugOpenSSLFailure {
|
||||||
println("intermediates file:", intermediatesFile.Name())
|
println("intermediates file:", intermediatesFile.Name())
|
||||||
} else {
|
} else {
|
||||||
|
@ -84,7 +84,7 @@ func loadSystemRoots() (*CertPool, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(roots.certs) > 0 || firstErr == nil {
|
if roots.len() > 0 || firstErr == nil {
|
||||||
return roots, nil
|
return roots, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,15 +113,15 @@ func TestEnvVars(t *testing.T) {
|
|||||||
|
|
||||||
// Verify that the returned certs match, otherwise report where the mismatch is.
|
// Verify that the returned certs match, otherwise report where the mismatch is.
|
||||||
for i, cn := range tc.cns {
|
for i, cn := range tc.cns {
|
||||||
if i >= len(r.certs) {
|
if i >= r.len() {
|
||||||
t.Errorf("missing cert %v @ %v", cn, i)
|
t.Errorf("missing cert %v @ %v", cn, i)
|
||||||
} else if r.certs[i].Subject.CommonName != cn {
|
} else if r.mustCert(i).Subject.CommonName != cn {
|
||||||
fmt.Printf("%#v\n", r.certs[0].Subject)
|
fmt.Printf("%#v\n", r.mustCert(0).Subject)
|
||||||
t.Errorf("unexpected cert common name %q, want %q", r.certs[i].Subject.CommonName, cn)
|
t.Errorf("unexpected cert common name %q, want %q", r.mustCert(i).Subject.CommonName, cn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(r.certs) > len(tc.cns) {
|
if r.len() > len(tc.cns) {
|
||||||
t.Errorf("got %v certs, which is more than %v wanted", len(r.certs), len(tc.cns))
|
t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -197,6 +197,10 @@ func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) {
|
|||||||
strCertPool := func(p *CertPool) string {
|
strCertPool := func(p *CertPool) string {
|
||||||
return string(bytes.Join(p.Subjects(), []byte("\n")))
|
return string(bytes.Join(p.Subjects(), []byte("\n")))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
zeroPoolFuncs(gotPool)
|
||||||
|
zeroPoolFuncs(wantPool)
|
||||||
|
|
||||||
if !reflect.DeepEqual(gotPool, wantPool) {
|
if !reflect.DeepEqual(gotPool, wantPool) {
|
||||||
g, w := strCertPool(gotPool), strCertPool(wantPool)
|
g, w := strCertPool(gotPool), strCertPool(wantPool)
|
||||||
t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w)
|
t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w)
|
||||||
|
@ -737,11 +737,13 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
|
|||||||
if len(c.Raw) == 0 {
|
if len(c.Raw) == 0 {
|
||||||
return nil, errNotParsed
|
return nil, errNotParsed
|
||||||
}
|
}
|
||||||
if opts.Intermediates != nil {
|
for i := 0; i < opts.Intermediates.len(); i++ {
|
||||||
for _, intermediate := range opts.Intermediates.certs {
|
c, err := opts.Intermediates.cert(i)
|
||||||
if len(intermediate.Raw) == 0 {
|
if err != nil {
|
||||||
return nil, errNotParsed
|
return nil, fmt.Errorf("crypto/x509: error fetching cert: %w", err)
|
||||||
}
|
}
|
||||||
|
if len(c.Raw) == 0 {
|
||||||
|
return nil, errNotParsed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -770,8 +772,10 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
|
|||||||
}
|
}
|
||||||
|
|
||||||
var candidateChains [][]*Certificate
|
var candidateChains [][]*Certificate
|
||||||
if opts.Roots.contains(c) {
|
if inRoots, err := opts.Roots.contains(c); inRoots {
|
||||||
candidateChains = append(candidateChains, []*Certificate{c})
|
candidateChains = append(candidateChains, []*Certificate{c})
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
if candidateChains, err = c.buildChains(nil, []*Certificate{c}, nil, &opts); err != nil {
|
if candidateChains, err = c.buildChains(nil, []*Certificate{c}, nil, &opts); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -868,10 +872,18 @@ func (c *Certificate) buildChains(cache map[*Certificate][][]*Certificate, curre
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, rootNum := range opts.Roots.findPotentialParents(c) {
|
for _, rootNum := range opts.Roots.findPotentialParents(c) {
|
||||||
considerCandidate(rootCertificate, opts.Roots.certs[rootNum])
|
c, err := opts.Roots.cert(rootNum)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("crypto/x509: error fetching cert: %w", err)
|
||||||
|
}
|
||||||
|
considerCandidate(rootCertificate, c)
|
||||||
}
|
}
|
||||||
for _, intermediateNum := range opts.Intermediates.findPotentialParents(c) {
|
for _, intermediateNum := range opts.Intermediates.findPotentialParents(c) {
|
||||||
considerCandidate(intermediateCertificate, opts.Intermediates.certs[intermediateNum])
|
c, err := opts.Intermediates.cert(intermediateNum)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("crypto/x509: error fetching cert: %w", err)
|
||||||
|
}
|
||||||
|
considerCandidate(intermediateCertificate, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(chains) > 0 {
|
if len(chains) > 0 {
|
||||||
|
@ -1983,6 +1983,8 @@ func TestSystemCertPool(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
zeroPoolFuncs(a)
|
||||||
|
zeroPoolFuncs(b)
|
||||||
if !reflect.DeepEqual(a, b) {
|
if !reflect.DeepEqual(a, b) {
|
||||||
t.Fatal("two calls to SystemCertPool had different results")
|
t.Fatal("two calls to SystemCertPool had different results")
|
||||||
}
|
}
|
||||||
@ -2644,3 +2646,19 @@ func TestCreateRevocationList(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *CertPool) mustCert(n int) *Certificate {
|
||||||
|
c, err := s.getCert[n]()
|
||||||
|
if err != nil {
|
||||||
|
panic(err.Error())
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// zeroPoolFuncs zeros out funcs in p so two pools can be compared
|
||||||
|
// with reflect.DeepEqual.
|
||||||
|
func zeroPoolFuncs(p *CertPool) {
|
||||||
|
for i := range p.getCert {
|
||||||
|
p.getCert[i] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user