From fd6686d81acbe24cc665103e08f8569245de42e2 Mon Sep 17 00:00:00 2001 From: Anton Tolchanov Date: Mon, 19 Aug 2024 19:32:14 +0100 Subject: [PATCH] tka: truncate long rotation signature chains When a rotation signature chain reaches a certain size, remove the oldest rotation signature from the chain before wrapping it in a new rotation signature. Since all previous rotation signatures are signed by the same wrapping pubkey (node's own tailnet lock key), the node can re-construct the chain, re-signing previous rotation signatures. This will satisfy the existing certificate validation logic. Updates #13185 Signed-off-by: Anton Tolchanov --- ipn/ipnlocal/network-lock.go | 21 ++--- ipn/ipnlocal/network-lock_test.go | 25 ++++++ tka/sig.go | 52 +++++++++++- tka/sig_test.go | 134 ++++++++++++++++++++++++++++++ 4 files changed, 221 insertions(+), 11 deletions(-) diff --git a/ipn/ipnlocal/network-lock.go b/ipn/ipnlocal/network-lock.go index b27b6427a..d20bf94eb 100644 --- a/ipn/ipnlocal/network-lock.go +++ b/ipn/ipnlocal/network-lock.go @@ -175,23 +175,24 @@ func (r *rotationTracker) addRotationDetails(np key.NodePublic, d *tka.RotationD // obsoleteKeys returns the set of node keys that are obsolete due to key rotation. func (r *rotationTracker) obsoleteKeys() set.Set[key.NodePublic] { for _, v := range r.byWrappingKey { + // Do not consider signatures for keys that have been marked as obsolete + // by another signature. + v = slices.DeleteFunc(v, func(rd sigRotationDetails) bool { + return r.obsolete.Contains(rd.np) + }) + if len(v) == 0 { + continue + } + // If there are multiple rotation signatures with the same wrapping // pubkey, we need to decide which one is the "latest", and keep it. // The signature with the largest number of previous keys is likely to - // be the latest, unless it has been marked as obsolete (rotated out) by - // another signature (which might happen in the future if we start - // compacting long rotated signature chains). + // be the latest. slices.SortStableFunc(v, func(a, b sigRotationDetails) int { - // Group all obsolete keys after non-obsolete keys. - if ao, bo := r.obsolete.Contains(a.np), r.obsolete.Contains(b.np); ao != bo { - if ao { - return 1 - } - return -1 - } // Sort by decreasing number of previous keys. return b.numPrevKeys - a.numPrevKeys }) + // If there are several signatures with the same number of previous // keys, we cannot determine which one is the latest, so all of them are // rejected for safety. diff --git a/ipn/ipnlocal/network-lock_test.go b/ipn/ipnlocal/network-lock_test.go index c3576dfb0..4b79136c8 100644 --- a/ipn/ipnlocal/network-lock_test.go +++ b/ipn/ipnlocal/network-lock_test.go @@ -667,6 +667,31 @@ func TestTKAFilterNetmap(t *testing.T) { if diff := cmp.Diff(want, nm.Peers, nodePubComparer); diff != "" { t.Errorf("filtered netmap differs (-want, +got):\n%s", diff) } + + // Confirm that repeated rotation works correctly. + for range 100 { + n5Rotated, n5RotatedSig = resign(n5nl, n5RotatedSig) + } + + n51, n51Sig := resign(n5nl, n5RotatedSig) + + nm = &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + {ID: 1, Key: n1.Public(), KeySignature: n1GoodSig.Serialize()}, + {ID: 5, Key: n5Rotated.Public(), KeySignature: n5RotatedSig}, // rotated + {ID: 51, Key: n51.Public(), KeySignature: n51Sig}, + }), + } + + b.tkaFilterNetmapLocked(nm) + + want = nodeViews([]*tailcfg.Node{ + {ID: 1, Key: n1.Public(), KeySignature: n1GoodSig.Serialize()}, + {ID: 51, Key: n51.Public(), KeySignature: n51Sig}, + }) + if diff := cmp.Diff(want, nm.Peers, nodePubComparer); diff != "" { + t.Errorf("filtered netmap differs (-want, +got):\n%s", diff) + } } func TestTKADisable(t *testing.T) { diff --git a/tka/sig.go b/tka/sig.go index d3fe0ff6c..c82f9715c 100644 --- a/tka/sig.go +++ b/tka/sig.go @@ -372,10 +372,15 @@ func ResignNKS(priv key.NLPrivate, nodeKey key.NodePublic, oldNKS tkatype.Marsha return oldNKS, nil } + nested, err := maybeTrimRotationSignatureChain(oldSig, priv) + if err != nil { + return nil, fmt.Errorf("trimming rotation signature chain: %w", err) + } + newSig := NodeKeySignature{ SigKind: SigRotation, Pubkey: nk, - Nested: &oldSig, + Nested: &nested, } if newSig.Signature, err = priv.SignNKS(newSig.SigHash()); err != nil { return nil, fmt.Errorf("signing NKS: %w", err) @@ -384,6 +389,51 @@ func ResignNKS(priv key.NLPrivate, nodeKey key.NodePublic, oldNKS tkatype.Marsha return newSig.Serialize(), nil } +// maybeTrimRotationSignatureChain truncates rotation signature chain to ensure +// it contains no more than 15 node keys. +func maybeTrimRotationSignatureChain(sig NodeKeySignature, priv key.NLPrivate) (NodeKeySignature, error) { + if sig.SigKind != SigRotation { + return sig, nil + } + + // Collect all the previous node keys, ordered from newest to oldest. + prevPubkeys := [][]byte{sig.Pubkey} + nested := sig.Nested + for nested != nil { + if len(nested.Pubkey) > 0 { + prevPubkeys = append(prevPubkeys, nested.Pubkey) + } + if nested.SigKind != SigRotation { + break + } + nested = nested.Nested + } + + // Existing rotation signature with 15 keys is the maximum we can wrap in a + // new signature without hitting the CBOR nesting limit of 16 (see + // MaxNestedLevels in tka.go). + const maxPrevKeys = 15 + if len(prevPubkeys) <= maxPrevKeys { + return sig, nil + } + + // Create a new rotation signature chain, starting with the original + // direct signature. + var err error + result := nested // original direct signature + for i := maxPrevKeys - 2; i >= 0; i-- { + result = &NodeKeySignature{ + SigKind: SigRotation, + Pubkey: prevPubkeys[i], + Nested: result, + } + if result.Signature, err = priv.SignNKS(result.SigHash()); err != nil { + return sig, fmt.Errorf("signing NKS: %w", err) + } + } + return *result, nil +} + // SignByCredential signs a node public key by a private key which has its // signing authority delegated by a SigCredential signature. This is used by // wrapped auth keys. diff --git a/tka/sig_test.go b/tka/sig_test.go index d857eaf55..d64575e7c 100644 --- a/tka/sig_test.go +++ b/tka/sig_test.go @@ -9,7 +9,9 @@ "testing" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "tailscale.com/types/key" + "tailscale.com/types/tkatype" ) func TestSigDirect(t *testing.T) { @@ -74,6 +76,9 @@ func TestSigNested(t *testing.T) { if err := nestedSig.verifySignature(oldNode.Public(), k); err != nil { t.Fatalf("verifySignature(oldNode) failed: %v", err) } + if l := sigChainLength(nestedSig); l != 1 { + t.Errorf("nestedSig chain length = %v, want 1", l) + } // The signature authorizing the rotation, signed by the // rotation key & embedding the original signature. @@ -88,6 +93,9 @@ func TestSigNested(t *testing.T) { if err := sig.verifySignature(node.Public(), k); err != nil { t.Fatalf("verifySignature(node) failed: %v", err) } + if l := sigChainLength(sig); l != 2 { + t.Errorf("sig chain length = %v, want 2", l) + } // Test verification fails if the wrong verification key is provided kBad := Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}, Votes: 2} @@ -497,3 +505,129 @@ func TestDecodeWrappedAuthkey(t *testing.T) { } } + +func TestResignNKS(t *testing.T) { + // Tailnet lock keypair of a signing node. + authPub, authPriv := testingKey25519(t, 1) + authKey := Key{Kind: Key25519, Public: authPub, Votes: 2} + + // Node's own tailnet lock key used to sign rotation signatures. + tlPriv := key.NewNLPrivate() + + // The original (oldest) node key, signed by a signing node. + origNode := key.NewNode() + origPub, _ := origNode.Public().MarshalBinary() + + // The original signature for the old node key, signed by + // the network-lock key. + directSig := NodeKeySignature{ + SigKind: SigDirect, + KeyID: authKey.MustID(), + Pubkey: origPub, + WrappingPubkey: tlPriv.Public().Verifier(), + } + sigHash := directSig.SigHash() + directSig.Signature = ed25519.Sign(authPriv, sigHash[:]) + if err := directSig.verifySignature(origNode.Public(), authKey); err != nil { + t.Fatalf("verifySignature(origNode) failed: %v", err) + } + + // Generate a bunch of node keys to be used by tests. + var nodeKeys []key.NodePublic + for range 20 { + n := key.NewNode() + nodeKeys = append(nodeKeys, n.Public()) + } + + // mkSig creates a signature chain starting with a direct signature + // with rotation signatures matching provided keys (from the nodeKeys slice). + mkSig := func(prevKeyIDs ...int) tkatype.MarshaledSignature { + sig := &directSig + for _, i := range prevKeyIDs { + pk, _ := nodeKeys[i].MarshalBinary() + sig = &NodeKeySignature{ + SigKind: SigRotation, + Pubkey: pk, + Nested: sig, + } + var err error + sig.Signature, err = tlPriv.SignNKS(sig.SigHash()) + if err != nil { + t.Error(err) + } + } + return sig.Serialize() + } + + tests := []struct { + name string + oldSig tkatype.MarshaledSignature + wantPrevNodeKeys []key.NodePublic + }{ + { + name: "first-rotation", + oldSig: directSig.Serialize(), + wantPrevNodeKeys: []key.NodePublic{origNode.Public()}, + }, + { + name: "second-rotation", + oldSig: mkSig(0), + wantPrevNodeKeys: []key.NodePublic{nodeKeys[0], origNode.Public()}, + }, + { + name: "truncate-chain", + oldSig: mkSig(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), + wantPrevNodeKeys: []key.NodePublic{ + nodeKeys[14], + nodeKeys[13], + nodeKeys[12], + nodeKeys[11], + nodeKeys[10], + nodeKeys[9], + nodeKeys[8], + nodeKeys[7], + nodeKeys[6], + nodeKeys[5], + nodeKeys[4], + nodeKeys[3], + nodeKeys[2], + nodeKeys[1], + origNode.Public(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + newNode := key.NewNode() + got, err := ResignNKS(tlPriv, newNode.Public(), tt.oldSig) + if err != nil { + t.Fatalf("ResignNKS() error = %v", err) + } + var gotSig NodeKeySignature + if err := gotSig.Unserialize(got); err != nil { + t.Fatalf("Unserialize() failed: %v", err) + } + if err := gotSig.verifySignature(newNode.Public(), authKey); err != nil { + t.Errorf("verifySignature(newNode) error: %v", err) + } + + rd, err := gotSig.rotationDetails() + if err != nil { + t.Fatalf("rotationDetails() error = %v", err) + } + if sigChainLength(gotSig) != len(tt.wantPrevNodeKeys)+1 { + t.Errorf("sigChainLength() = %v, want %v", sigChainLength(gotSig), len(tt.wantPrevNodeKeys)+1) + } + if diff := cmp.Diff(tt.wantPrevNodeKeys, rd.PrevNodeKeys, cmpopts.EquateComparable(key.NodePublic{})); diff != "" { + t.Errorf("PrevNodeKeys mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func sigChainLength(s NodeKeySignature) int { + if s.Nested != nil { + return 1 + sigChainLength(*s.Nested) + } + return 1 +}