mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 04:55:31 +00:00
control/controlclient: select newer certificate
If multiple certificates match when selecting a certificate, use the one issued the most recently (as determined by the NotBefore timestamp). This also adds some tests for the function that performs that comparison. Updates tailscale/coral#6 Signed-off-by: Adrian Dewhurst <adrian@tailscale.com>
This commit is contained in:
parent
a80cef0c13
commit
adda2d2a51
@ -18,6 +18,7 @@
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/tailscale/certstore"
|
"github.com/tailscale/certstore"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
@ -73,23 +74,46 @@ func isSubjectInChain(subject string, chain []*x509.Certificate) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func selectIdentityFromSlice(subject string, ids []certstore.Identity) (certstore.Identity, []*x509.Certificate) {
|
func selectIdentityFromSlice(subject string, ids []certstore.Identity, now time.Time) (certstore.Identity, []*x509.Certificate) {
|
||||||
|
var bestCandidate struct {
|
||||||
|
id certstore.Identity
|
||||||
|
chain []*x509.Certificate
|
||||||
|
}
|
||||||
|
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
chain, err := id.CertificateChain()
|
chain, err := id.CertificateChain()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(chain) < 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if !isSupportedCertificate(chain[0]) {
|
if !isSupportedCertificate(chain[0]) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if isSubjectInChain(subject, chain) {
|
if now.Before(chain[0].NotBefore) || now.After(chain[0].NotAfter) {
|
||||||
return id, chain
|
// Certificate is not valid at this time
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !isSubjectInChain(subject, chain) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select the most recently issued certificate. If there is a tie, pick
|
||||||
|
// one arbitrarily.
|
||||||
|
if len(bestCandidate.chain) > 0 && bestCandidate.chain[0].NotBefore.After(chain[0].NotBefore) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
bestCandidate.id = id
|
||||||
|
bestCandidate.chain = chain
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil
|
return bestCandidate.id, bestCandidate.chain
|
||||||
}
|
}
|
||||||
|
|
||||||
// findIdentity locates an identity from the Windows or Darwin certificate
|
// findIdentity locates an identity from the Windows or Darwin certificate
|
||||||
@ -105,7 +129,7 @@ func findIdentity(subject string, st certstore.Store) (certstore.Identity, []*x5
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
selected, chain := selectIdentityFromSlice(subject, ids)
|
selected, chain := selectIdentityFromSlice(subject, ids, time.Now())
|
||||||
|
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
if id != selected {
|
if id != selected {
|
||||||
|
238
control/controlclient/sign_supported_test.go
Normal file
238
control/controlclient/sign_supported_test.go
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
//go:build windows && cgo
|
||||||
|
// +build windows,cgo
|
||||||
|
|
||||||
|
package controlclient
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tailscale/certstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testRootCommonName = "testroot"
|
||||||
|
testRootSubject = "CN=testroot"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testIdentity struct {
|
||||||
|
chain []*x509.Certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate {
|
||||||
|
return []*x509.Certificate{
|
||||||
|
{
|
||||||
|
NotBefore: notBefore,
|
||||||
|
NotAfter: notAfter,
|
||||||
|
PublicKeyAlgorithm: x509.RSA,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: rootCommonName,
|
||||||
|
},
|
||||||
|
PublicKeyAlgorithm: x509.RSA,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testIdentity) Certificate() (*x509.Certificate, error) {
|
||||||
|
return t.chain[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) {
|
||||||
|
return t.chain, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testIdentity) Signer() (crypto.Signer, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testIdentity) Delete() error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testIdentity) Close() {}
|
||||||
|
|
||||||
|
func TestSelectIdentityFromSlice(t *testing.T) {
|
||||||
|
var times []time.Time
|
||||||
|
for _, ts := range []string{
|
||||||
|
"2000-01-01T00:00:00Z",
|
||||||
|
"2001-01-01T00:00:00Z",
|
||||||
|
"2002-01-01T00:00:00Z",
|
||||||
|
"2003-01-01T00:00:00Z",
|
||||||
|
} {
|
||||||
|
tm, err := time.Parse(time.RFC3339, ts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
times = append(times, tm)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
subject string
|
||||||
|
ids []certstore.Identity
|
||||||
|
now time.Time
|
||||||
|
// wantIndex is an index into ids, or -1 for nil.
|
||||||
|
wantIndex int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single unexpired identity",
|
||||||
|
subject: testRootSubject,
|
||||||
|
ids: []certstore.Identity{
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain(testRootCommonName, times[0], times[2]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
now: times[1],
|
||||||
|
wantIndex: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single expired identity",
|
||||||
|
subject: testRootSubject,
|
||||||
|
ids: []certstore.Identity{
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
now: times[2],
|
||||||
|
wantIndex: -1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unrelated ids",
|
||||||
|
subject: testRootSubject,
|
||||||
|
ids: []certstore.Identity{
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain("something", times[0], times[2]),
|
||||||
|
},
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain(testRootCommonName, times[0], times[2]),
|
||||||
|
},
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain("else", times[0], times[2]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
now: times[1],
|
||||||
|
wantIndex: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired with unrelated ids",
|
||||||
|
subject: testRootSubject,
|
||||||
|
ids: []certstore.Identity{
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain("something", times[0], times[3]),
|
||||||
|
},
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||||
|
},
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain("else", times[0], times[3]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
now: times[2],
|
||||||
|
wantIndex: -1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one expired",
|
||||||
|
subject: testRootSubject,
|
||||||
|
ids: []certstore.Identity{
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||||
|
},
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain(testRootCommonName, times[1], times[3]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
now: times[2],
|
||||||
|
wantIndex: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two certs both unexpired",
|
||||||
|
subject: testRootSubject,
|
||||||
|
ids: []certstore.Identity{
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain(testRootCommonName, times[0], times[3]),
|
||||||
|
},
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain(testRootCommonName, times[1], times[3]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
now: times[2],
|
||||||
|
wantIndex: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "two unexpired one expired",
|
||||||
|
subject: testRootSubject,
|
||||||
|
ids: []certstore.Identity{
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain(testRootCommonName, times[0], times[3]),
|
||||||
|
},
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain(testRootCommonName, times[1], times[3]),
|
||||||
|
},
|
||||||
|
&testIdentity{
|
||||||
|
chain: makeChain(testRootCommonName, times[0], times[1]),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
now: times[2],
|
||||||
|
wantIndex: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now)
|
||||||
|
|
||||||
|
if gotId == nil && gotChain != nil {
|
||||||
|
t.Error("id is nil: got non-nil chain, want nil chain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gotId != nil && gotChain == nil {
|
||||||
|
t.Error("id is not nil: got nil chain, want non-nil chain")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantIndex == -1 {
|
||||||
|
if gotId != nil {
|
||||||
|
t.Error("got non-nil id, want nil id")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gotId == nil {
|
||||||
|
t.Error("got nil id, want non-nil id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gotId != tt.ids[tt.wantIndex] {
|
||||||
|
found := -1
|
||||||
|
for i := range tt.ids {
|
||||||
|
if tt.ids[i] == gotId {
|
||||||
|
found = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if found == -1 {
|
||||||
|
t.Errorf("got unknown id, want id at index %v", tt.wantIndex)
|
||||||
|
} else {
|
||||||
|
t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tid, ok := tt.ids[tt.wantIndex].(*testIdentity)
|
||||||
|
if !ok {
|
||||||
|
t.Error("got non-testIdentity, want testIdentity")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tid.chain, gotChain) {
|
||||||
|
t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user