tka: truncate long rotation signature chains

When a rotation signature chain reaches a certain size, remove the
oldest rotation signature from the chain before wrapping it in a new
rotation signature.

Since all previous rotation signatures are signed by the same wrapping
pubkey (node's own tailnet lock key), the node can re-construct the
chain, re-signing previous rotation signatures. This will satisfy the
existing certificate validation logic.

Updates #13185

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
This commit is contained in:
Anton Tolchanov 2024-08-19 19:32:14 +01:00 committed by Anton Tolchanov
parent bcc47d91ca
commit fd6686d81a
4 changed files with 221 additions and 11 deletions

View File

@ -175,23 +175,24 @@ func (r *rotationTracker) addRotationDetails(np key.NodePublic, d *tka.RotationD
// obsoleteKeys returns the set of node keys that are obsolete due to key rotation. // obsoleteKeys returns the set of node keys that are obsolete due to key rotation.
func (r *rotationTracker) obsoleteKeys() set.Set[key.NodePublic] { func (r *rotationTracker) obsoleteKeys() set.Set[key.NodePublic] {
for _, v := range r.byWrappingKey { for _, v := range r.byWrappingKey {
// Do not consider signatures for keys that have been marked as obsolete
// by another signature.
v = slices.DeleteFunc(v, func(rd sigRotationDetails) bool {
return r.obsolete.Contains(rd.np)
})
if len(v) == 0 {
continue
}
// If there are multiple rotation signatures with the same wrapping // If there are multiple rotation signatures with the same wrapping
// pubkey, we need to decide which one is the "latest", and keep it. // pubkey, we need to decide which one is the "latest", and keep it.
// The signature with the largest number of previous keys is likely to // The signature with the largest number of previous keys is likely to
// be the latest, unless it has been marked as obsolete (rotated out) by // be the latest.
// another signature (which might happen in the future if we start
// compacting long rotated signature chains).
slices.SortStableFunc(v, func(a, b sigRotationDetails) int { slices.SortStableFunc(v, func(a, b sigRotationDetails) int {
// Group all obsolete keys after non-obsolete keys.
if ao, bo := r.obsolete.Contains(a.np), r.obsolete.Contains(b.np); ao != bo {
if ao {
return 1
}
return -1
}
// Sort by decreasing number of previous keys. // Sort by decreasing number of previous keys.
return b.numPrevKeys - a.numPrevKeys return b.numPrevKeys - a.numPrevKeys
}) })
// If there are several signatures with the same number of previous // If there are several signatures with the same number of previous
// keys, we cannot determine which one is the latest, so all of them are // keys, we cannot determine which one is the latest, so all of them are
// rejected for safety. // rejected for safety.

View File

@ -667,6 +667,31 @@ func TestTKAFilterNetmap(t *testing.T) {
if diff := cmp.Diff(want, nm.Peers, nodePubComparer); diff != "" { if diff := cmp.Diff(want, nm.Peers, nodePubComparer); diff != "" {
t.Errorf("filtered netmap differs (-want, +got):\n%s", diff) t.Errorf("filtered netmap differs (-want, +got):\n%s", diff)
} }
// Confirm that repeated rotation works correctly.
for range 100 {
n5Rotated, n5RotatedSig = resign(n5nl, n5RotatedSig)
}
n51, n51Sig := resign(n5nl, n5RotatedSig)
nm = &netmap.NetworkMap{
Peers: nodeViews([]*tailcfg.Node{
{ID: 1, Key: n1.Public(), KeySignature: n1GoodSig.Serialize()},
{ID: 5, Key: n5Rotated.Public(), KeySignature: n5RotatedSig}, // rotated
{ID: 51, Key: n51.Public(), KeySignature: n51Sig},
}),
}
b.tkaFilterNetmapLocked(nm)
want = nodeViews([]*tailcfg.Node{
{ID: 1, Key: n1.Public(), KeySignature: n1GoodSig.Serialize()},
{ID: 51, Key: n51.Public(), KeySignature: n51Sig},
})
if diff := cmp.Diff(want, nm.Peers, nodePubComparer); diff != "" {
t.Errorf("filtered netmap differs (-want, +got):\n%s", diff)
}
} }
func TestTKADisable(t *testing.T) { func TestTKADisable(t *testing.T) {

View File

@ -372,10 +372,15 @@ func ResignNKS(priv key.NLPrivate, nodeKey key.NodePublic, oldNKS tkatype.Marsha
return oldNKS, nil return oldNKS, nil
} }
nested, err := maybeTrimRotationSignatureChain(oldSig, priv)
if err != nil {
return nil, fmt.Errorf("trimming rotation signature chain: %w", err)
}
newSig := NodeKeySignature{ newSig := NodeKeySignature{
SigKind: SigRotation, SigKind: SigRotation,
Pubkey: nk, Pubkey: nk,
Nested: &oldSig, Nested: &nested,
} }
if newSig.Signature, err = priv.SignNKS(newSig.SigHash()); err != nil { if newSig.Signature, err = priv.SignNKS(newSig.SigHash()); err != nil {
return nil, fmt.Errorf("signing NKS: %w", err) return nil, fmt.Errorf("signing NKS: %w", err)
@ -384,6 +389,51 @@ func ResignNKS(priv key.NLPrivate, nodeKey key.NodePublic, oldNKS tkatype.Marsha
return newSig.Serialize(), nil return newSig.Serialize(), nil
} }
// maybeTrimRotationSignatureChain truncates rotation signature chain to ensure
// it contains no more than 15 node keys.
func maybeTrimRotationSignatureChain(sig NodeKeySignature, priv key.NLPrivate) (NodeKeySignature, error) {
if sig.SigKind != SigRotation {
return sig, nil
}
// Collect all the previous node keys, ordered from newest to oldest.
prevPubkeys := [][]byte{sig.Pubkey}
nested := sig.Nested
for nested != nil {
if len(nested.Pubkey) > 0 {
prevPubkeys = append(prevPubkeys, nested.Pubkey)
}
if nested.SigKind != SigRotation {
break
}
nested = nested.Nested
}
// Existing rotation signature with 15 keys is the maximum we can wrap in a
// new signature without hitting the CBOR nesting limit of 16 (see
// MaxNestedLevels in tka.go).
const maxPrevKeys = 15
if len(prevPubkeys) <= maxPrevKeys {
return sig, nil
}
// Create a new rotation signature chain, starting with the original
// direct signature.
var err error
result := nested // original direct signature
for i := maxPrevKeys - 2; i >= 0; i-- {
result = &NodeKeySignature{
SigKind: SigRotation,
Pubkey: prevPubkeys[i],
Nested: result,
}
if result.Signature, err = priv.SignNKS(result.SigHash()); err != nil {
return sig, fmt.Errorf("signing NKS: %w", err)
}
}
return *result, nil
}
// SignByCredential signs a node public key by a private key which has its // SignByCredential signs a node public key by a private key which has its
// signing authority delegated by a SigCredential signature. This is used by // signing authority delegated by a SigCredential signature. This is used by
// wrapped auth keys. // wrapped auth keys.

View File

@ -9,7 +9,9 @@
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/tkatype"
) )
func TestSigDirect(t *testing.T) { func TestSigDirect(t *testing.T) {
@ -74,6 +76,9 @@ func TestSigNested(t *testing.T) {
if err := nestedSig.verifySignature(oldNode.Public(), k); err != nil { if err := nestedSig.verifySignature(oldNode.Public(), k); err != nil {
t.Fatalf("verifySignature(oldNode) failed: %v", err) t.Fatalf("verifySignature(oldNode) failed: %v", err)
} }
if l := sigChainLength(nestedSig); l != 1 {
t.Errorf("nestedSig chain length = %v, want 1", l)
}
// The signature authorizing the rotation, signed by the // The signature authorizing the rotation, signed by the
// rotation key & embedding the original signature. // rotation key & embedding the original signature.
@ -88,6 +93,9 @@ func TestSigNested(t *testing.T) {
if err := sig.verifySignature(node.Public(), k); err != nil { if err := sig.verifySignature(node.Public(), k); err != nil {
t.Fatalf("verifySignature(node) failed: %v", err) t.Fatalf("verifySignature(node) failed: %v", err)
} }
if l := sigChainLength(sig); l != 2 {
t.Errorf("sig chain length = %v, want 2", l)
}
// Test verification fails if the wrong verification key is provided // Test verification fails if the wrong verification key is provided
kBad := Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}, Votes: 2} kBad := Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}, Votes: 2}
@ -497,3 +505,129 @@ func TestDecodeWrappedAuthkey(t *testing.T) {
} }
} }
func TestResignNKS(t *testing.T) {
// Tailnet lock keypair of a signing node.
authPub, authPriv := testingKey25519(t, 1)
authKey := Key{Kind: Key25519, Public: authPub, Votes: 2}
// Node's own tailnet lock key used to sign rotation signatures.
tlPriv := key.NewNLPrivate()
// The original (oldest) node key, signed by a signing node.
origNode := key.NewNode()
origPub, _ := origNode.Public().MarshalBinary()
// The original signature for the old node key, signed by
// the network-lock key.
directSig := NodeKeySignature{
SigKind: SigDirect,
KeyID: authKey.MustID(),
Pubkey: origPub,
WrappingPubkey: tlPriv.Public().Verifier(),
}
sigHash := directSig.SigHash()
directSig.Signature = ed25519.Sign(authPriv, sigHash[:])
if err := directSig.verifySignature(origNode.Public(), authKey); err != nil {
t.Fatalf("verifySignature(origNode) failed: %v", err)
}
// Generate a bunch of node keys to be used by tests.
var nodeKeys []key.NodePublic
for range 20 {
n := key.NewNode()
nodeKeys = append(nodeKeys, n.Public())
}
// mkSig creates a signature chain starting with a direct signature
// with rotation signatures matching provided keys (from the nodeKeys slice).
mkSig := func(prevKeyIDs ...int) tkatype.MarshaledSignature {
sig := &directSig
for _, i := range prevKeyIDs {
pk, _ := nodeKeys[i].MarshalBinary()
sig = &NodeKeySignature{
SigKind: SigRotation,
Pubkey: pk,
Nested: sig,
}
var err error
sig.Signature, err = tlPriv.SignNKS(sig.SigHash())
if err != nil {
t.Error(err)
}
}
return sig.Serialize()
}
tests := []struct {
name string
oldSig tkatype.MarshaledSignature
wantPrevNodeKeys []key.NodePublic
}{
{
name: "first-rotation",
oldSig: directSig.Serialize(),
wantPrevNodeKeys: []key.NodePublic{origNode.Public()},
},
{
name: "second-rotation",
oldSig: mkSig(0),
wantPrevNodeKeys: []key.NodePublic{nodeKeys[0], origNode.Public()},
},
{
name: "truncate-chain",
oldSig: mkSig(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14),
wantPrevNodeKeys: []key.NodePublic{
nodeKeys[14],
nodeKeys[13],
nodeKeys[12],
nodeKeys[11],
nodeKeys[10],
nodeKeys[9],
nodeKeys[8],
nodeKeys[7],
nodeKeys[6],
nodeKeys[5],
nodeKeys[4],
nodeKeys[3],
nodeKeys[2],
nodeKeys[1],
origNode.Public(),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
newNode := key.NewNode()
got, err := ResignNKS(tlPriv, newNode.Public(), tt.oldSig)
if err != nil {
t.Fatalf("ResignNKS() error = %v", err)
}
var gotSig NodeKeySignature
if err := gotSig.Unserialize(got); err != nil {
t.Fatalf("Unserialize() failed: %v", err)
}
if err := gotSig.verifySignature(newNode.Public(), authKey); err != nil {
t.Errorf("verifySignature(newNode) error: %v", err)
}
rd, err := gotSig.rotationDetails()
if err != nil {
t.Fatalf("rotationDetails() error = %v", err)
}
if sigChainLength(gotSig) != len(tt.wantPrevNodeKeys)+1 {
t.Errorf("sigChainLength() = %v, want %v", sigChainLength(gotSig), len(tt.wantPrevNodeKeys)+1)
}
if diff := cmp.Diff(tt.wantPrevNodeKeys, rd.PrevNodeKeys, cmpopts.EquateComparable(key.NodePublic{})); diff != "" {
t.Errorf("PrevNodeKeys mismatch (-want +got):\n%s", diff)
}
})
}
}
func sigChainLength(s NodeKeySignature) int {
if s.Nested != nil {
return 1 + sigChainLength(*s.Nested)
}
return 1
}