tailscale/control/controlclient/sign_supported_test.go

238 lines
5.2 KiB
Go
Raw Normal View History

// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build windows && cgo
package controlclient
import (
"crypto"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"reflect"
"testing"
"time"
"github.com/tailscale/certstore"
)
const (
testRootCommonName = "testroot"
testRootSubject = "CN=testroot"
)
type testIdentity struct {
chain []*x509.Certificate
}
func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate {
return []*x509.Certificate{
{
NotBefore: notBefore,
NotAfter: notAfter,
PublicKeyAlgorithm: x509.RSA,
},
{
Subject: pkix.Name{
CommonName: rootCommonName,
},
PublicKeyAlgorithm: x509.RSA,
},
}
}
func (t *testIdentity) Certificate() (*x509.Certificate, error) {
return t.chain[0], nil
}
func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) {
return t.chain, nil
}
func (t *testIdentity) Signer() (crypto.Signer, error) {
return nil, errors.New("not implemented")
}
func (t *testIdentity) Delete() error {
return errors.New("not implemented")
}
func (t *testIdentity) Close() {}
func TestSelectIdentityFromSlice(t *testing.T) {
var times []time.Time
for _, ts := range []string{
"2000-01-01T00:00:00Z",
"2001-01-01T00:00:00Z",
"2002-01-01T00:00:00Z",
"2003-01-01T00:00:00Z",
} {
tm, err := time.Parse(time.RFC3339, ts)
if err != nil {
t.Fatal(err)
}
times = append(times, tm)
}
tests := []struct {
name string
subject string
ids []certstore.Identity
now time.Time
// wantIndex is an index into ids, or -1 for nil.
wantIndex int
}{
{
name: "single unexpired identity",
subject: testRootSubject,
ids: []certstore.Identity{
&testIdentity{
chain: makeChain(testRootCommonName, times[0], times[2]),
},
},
now: times[1],
wantIndex: 0,
},
{
name: "single expired identity",
subject: testRootSubject,
ids: []certstore.Identity{
&testIdentity{
chain: makeChain(testRootCommonName, times[0], times[1]),
},
},
now: times[2],
wantIndex: -1,
},
{
name: "unrelated ids",
subject: testRootSubject,
ids: []certstore.Identity{
&testIdentity{
chain: makeChain("something", times[0], times[2]),
},
&testIdentity{
chain: makeChain(testRootCommonName, times[0], times[2]),
},
&testIdentity{
chain: makeChain("else", times[0], times[2]),
},
},
now: times[1],
wantIndex: 1,
},
{
name: "expired with unrelated ids",
subject: testRootSubject,
ids: []certstore.Identity{
&testIdentity{
chain: makeChain("something", times[0], times[3]),
},
&testIdentity{
chain: makeChain(testRootCommonName, times[0], times[1]),
},
&testIdentity{
chain: makeChain("else", times[0], times[3]),
},
},
now: times[2],
wantIndex: -1,
},
{
name: "one expired",
subject: testRootSubject,
ids: []certstore.Identity{
&testIdentity{
chain: makeChain(testRootCommonName, times[0], times[1]),
},
&testIdentity{
chain: makeChain(testRootCommonName, times[1], times[3]),
},
},
now: times[2],
wantIndex: 1,
},
{
name: "two certs both unexpired",
subject: testRootSubject,
ids: []certstore.Identity{
&testIdentity{
chain: makeChain(testRootCommonName, times[0], times[3]),
},
&testIdentity{
chain: makeChain(testRootCommonName, times[1], times[3]),
},
},
now: times[2],
wantIndex: 1,
},
{
name: "two unexpired one expired",
subject: testRootSubject,
ids: []certstore.Identity{
&testIdentity{
chain: makeChain(testRootCommonName, times[0], times[3]),
},
&testIdentity{
chain: makeChain(testRootCommonName, times[1], times[3]),
},
&testIdentity{
chain: makeChain(testRootCommonName, times[0], times[1]),
},
},
now: times[2],
wantIndex: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now)
if gotId == nil && gotChain != nil {
t.Error("id is nil: got non-nil chain, want nil chain")
return
}
if gotId != nil && gotChain == nil {
t.Error("id is not nil: got nil chain, want non-nil chain")
return
}
if tt.wantIndex == -1 {
if gotId != nil {
t.Error("got non-nil id, want nil id")
}
return
}
if gotId == nil {
t.Error("got nil id, want non-nil id")
return
}
if gotId != tt.ids[tt.wantIndex] {
found := -1
for i := range tt.ids {
if tt.ids[i] == gotId {
found = i
break
}
}
if found == -1 {
t.Errorf("got unknown id, want id at index %v", tt.wantIndex)
} else {
t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex)
}
}
tid, ok := tt.ids[tt.wantIndex].(*testIdentity)
if !ok {
t.Error("got non-testIdentity, want testIdentity")
return
}
if !reflect.DeepEqual(tid.chain, gotChain) {
t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex)
}
})
}
}