mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-25 19:15:34 +00:00
tka: implement consensus & state computation internals
Signed-off-by: Tom DNetto <tom@tailscale.com>
This commit is contained in:
parent
af412e8874
commit
4f1374ec9e
365
tka/chaintest_test.go
Normal file
365
tka/chaintest_test.go
Normal file
@ -0,0 +1,365 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tka
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"text/scanner"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
)
|
||||
|
||||
// chaintest_test.go implements test helpers for concisely describing
|
||||
// chains of possibly signed AUMs, to assist in making tests shorter and
|
||||
// easier to read.
|
||||
|
||||
// parsed representation of a named AUM in a test chain.
|
||||
type testchainNode struct {
|
||||
Name string
|
||||
Parent string
|
||||
Uses []scanner.Position
|
||||
|
||||
HashSeed int
|
||||
Template string
|
||||
SignedWith string
|
||||
}
|
||||
|
||||
// testChain represents a constructed web of AUMs for testing purposes.
|
||||
type testChain struct {
|
||||
Nodes map[string]*testchainNode
|
||||
AUMs map[string]AUM
|
||||
AUMHashes map[string]AUMHash
|
||||
|
||||
// Configured by options to NewTestchain()
|
||||
Template map[string]AUM
|
||||
Key map[string]*Key
|
||||
KeyPrivs map[string]ed25519.PrivateKey
|
||||
SignAllKeys []string
|
||||
}
|
||||
|
||||
// newTestchain constructs a web of AUMs based on the provided input and
|
||||
// options.
|
||||
//
|
||||
// Input is expected to be a graph & tweaks, looking like this:
|
||||
//
|
||||
// G1 -> A -> B
|
||||
// | -> C
|
||||
//
|
||||
// which defines AUMs G1, A, B, and C; with G1 having no parent, A having
|
||||
// G1 as a parent, and both B & C having A as a parent.
|
||||
//
|
||||
// Tweaks are specified like this:
|
||||
//
|
||||
// <AUM>.<tweak> = <value>
|
||||
//
|
||||
// for example: G1.hashSeed = 2
|
||||
//
|
||||
// There are 3 available tweaks:
|
||||
// - hashSeed: Set to an integer to tweak the AUM hash of that AUM.
|
||||
// - template: Set to the name of a template provided via optTemplate().
|
||||
// The template is copied and use as the content for that AUM.
|
||||
// - signedWith: Set to the name of a key provided via optKey(). This
|
||||
// key is used to sign that AUM.
|
||||
func newTestchain(t *testing.T, input string, options ...testchainOpt) *testChain {
|
||||
t.Helper()
|
||||
|
||||
var (
|
||||
s scanner.Scanner
|
||||
out = testChain{
|
||||
Nodes: map[string]*testchainNode{},
|
||||
Template: map[string]AUM{},
|
||||
Key: map[string]*Key{},
|
||||
KeyPrivs: map[string]ed25519.PrivateKey{},
|
||||
}
|
||||
)
|
||||
|
||||
// Process any options
|
||||
for _, o := range options {
|
||||
if o.Template != nil {
|
||||
out.Template[o.Name] = *o.Template
|
||||
}
|
||||
if o.Key != nil {
|
||||
out.Key[o.Name] = o.Key
|
||||
out.KeyPrivs[o.Name] = o.Private
|
||||
}
|
||||
if o.SignAllWith {
|
||||
out.SignAllKeys = append(out.SignAllKeys, o.Name)
|
||||
}
|
||||
}
|
||||
|
||||
s.Init(strings.NewReader(input))
|
||||
s.Mode = scanner.ScanIdents | scanner.SkipComments | scanner.ScanComments | scanner.ScanChars | scanner.ScanInts
|
||||
s.Whitespace ^= 1 << '\t' // clear tabs
|
||||
var (
|
||||
lastIdent string
|
||||
lastWasChain bool // if the last token was '->'
|
||||
)
|
||||
for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() {
|
||||
switch tok {
|
||||
case '\t':
|
||||
t.Fatalf("tabs disallowed, use spaces (seen at %v)", s.Pos())
|
||||
|
||||
case '.': // tweaks, like <ident>.hashSeed = <val>
|
||||
s.Scan()
|
||||
tweak := s.TokenText()
|
||||
if tok := s.Scan(); tok == '=' {
|
||||
s.Scan()
|
||||
switch tweak {
|
||||
case "hashSeed":
|
||||
out.Nodes[lastIdent].HashSeed, _ = strconv.Atoi(s.TokenText())
|
||||
case "template":
|
||||
out.Nodes[lastIdent].Template = s.TokenText()
|
||||
case "signedWith":
|
||||
out.Nodes[lastIdent].SignedWith = s.TokenText()
|
||||
}
|
||||
}
|
||||
|
||||
case scanner.Ident:
|
||||
out.recordPos(s.TokenText(), s.Pos())
|
||||
// If the last token was '->', that means
|
||||
// that the next identifier has a child relationship
|
||||
// with the identifier preceeding '->'.
|
||||
if lastWasChain {
|
||||
out.recordParent(t, s.TokenText(), lastIdent)
|
||||
}
|
||||
lastIdent = s.TokenText()
|
||||
|
||||
case '-': // handle '->'
|
||||
switch s.Peek() {
|
||||
case '>':
|
||||
s.Scan()
|
||||
lastWasChain = true
|
||||
continue
|
||||
}
|
||||
|
||||
case '|': // handle '|'
|
||||
line, col := s.Pos().Line, s.Pos().Column
|
||||
nodeLoop:
|
||||
for _, n := range out.Nodes {
|
||||
for _, p := range n.Uses {
|
||||
// Find the identifier used right here on the line above.
|
||||
if p.Line == line-1 && col <= p.Column && col > p.Column-len(n.Name) {
|
||||
lastIdent = n.Name
|
||||
out.recordPos(n.Name, s.Pos())
|
||||
break nodeLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
lastWasChain = false
|
||||
// t.Logf("tok = %v, %q", tok, s.TokenText())
|
||||
}
|
||||
|
||||
out.buildChain()
|
||||
return &out
|
||||
}
|
||||
|
||||
// called from the parser to record the location of an
|
||||
// identifier (a named AUM).
|
||||
func (c *testChain) recordPos(ident string, pos scanner.Position) {
|
||||
n := c.Nodes[ident]
|
||||
if n == nil {
|
||||
n = &testchainNode{Name: ident}
|
||||
}
|
||||
|
||||
n.Uses = append(n.Uses, pos)
|
||||
c.Nodes[ident] = n
|
||||
}
|
||||
|
||||
// called from the parser to record a parent relationship between
|
||||
// two AUMs.
|
||||
func (c *testChain) recordParent(t *testing.T, child, parent string) {
|
||||
if p := c.Nodes[child].Parent; p != "" && p != parent {
|
||||
t.Fatalf("differing parent specified for %s: %q != %q", child, p, parent)
|
||||
}
|
||||
c.Nodes[child].Parent = parent
|
||||
}
|
||||
|
||||
// called after parsing to build the web of AUM structures.
|
||||
// This method populates c.AUMs and c.AUMHashes.
|
||||
func (c *testChain) buildChain() {
|
||||
pending := make(map[string]*testchainNode, len(c.Nodes))
|
||||
for k, v := range c.Nodes {
|
||||
pending[k] = v
|
||||
}
|
||||
|
||||
// AUMs with a parent need to know their hash, so we
|
||||
// only compute AUMs who's parents have been computed
|
||||
// each iteration. Since at least the genesis AUM
|
||||
// had no parent, theres always a path to completion
|
||||
// in O(n+1) where n is the number of AUMs.
|
||||
c.AUMs = make(map[string]AUM, len(c.Nodes))
|
||||
c.AUMHashes = make(map[string]AUMHash, len(c.Nodes))
|
||||
for i := 0; i < len(c.Nodes)+1; i++ {
|
||||
if len(pending) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
next := make([]*testchainNode, 0, 10)
|
||||
for _, v := range pending {
|
||||
if _, parentPending := pending[v.Parent]; !parentPending {
|
||||
next = append(next, v)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range next {
|
||||
aum := c.makeAUM(v)
|
||||
h := aum.Hash()
|
||||
|
||||
c.AUMHashes[v.Name] = h
|
||||
c.AUMs[v.Name] = aum
|
||||
delete(pending, v.Name)
|
||||
}
|
||||
}
|
||||
panic("unexpected: incomplete despite len(Nodes)+1 iterations")
|
||||
}
|
||||
|
||||
func (c *testChain) makeAUM(v *testchainNode) AUM {
|
||||
// By default, the AUM used is just a no-op AUM
|
||||
// with a parent hash set (if any).
|
||||
//
|
||||
// If <AUM>.template is set to the same name as in
|
||||
// a provided optTemplate(), the AUM is built
|
||||
// from a copy of that instead.
|
||||
//
|
||||
// If <AUM>.hashSeed = <int> is set, the KeyID is
|
||||
// tweaked to effect tweaking the hash. This is useful
|
||||
// if you want one AUM to have a lower hash than another.
|
||||
aum := AUM{MessageKind: AUMNoOp}
|
||||
if template := v.Template; template != "" {
|
||||
aum = c.Template[template]
|
||||
}
|
||||
if v.Parent != "" {
|
||||
parentHash := c.AUMHashes[v.Parent]
|
||||
aum.PrevAUMHash = parentHash[:]
|
||||
}
|
||||
if seed := v.HashSeed; seed != 0 {
|
||||
aum.KeyID = []byte{byte(seed)}
|
||||
}
|
||||
if err := aum.StaticValidate(); err != nil {
|
||||
// Usually caused by a test writer specifying a template
|
||||
// AUM which is ultimately invalid.
|
||||
panic(fmt.Sprintf("aum %+v failed static validation: %v", aum, err))
|
||||
}
|
||||
|
||||
sigHash := aum.SigHash()
|
||||
for _, key := range c.SignAllKeys {
|
||||
aum.Signatures = append(aum.Signatures, Signature{
|
||||
KeyID: c.Key[key].ID(),
|
||||
Signature: ed25519.Sign(c.KeyPrivs[key], sigHash[:]),
|
||||
})
|
||||
}
|
||||
|
||||
// If the aum was specified as being signed by some key, then
|
||||
// sign it using that key.
|
||||
if key := v.SignedWith; key != "" {
|
||||
aum.Signatures = append(aum.Signatures, Signature{
|
||||
KeyID: c.Key[key].ID(),
|
||||
Signature: ed25519.Sign(c.KeyPrivs[key], sigHash[:]),
|
||||
})
|
||||
}
|
||||
|
||||
return aum
|
||||
}
|
||||
|
||||
// Chonk returns a tailchonk containing all AUMs.
|
||||
func (c *testChain) Chonk() Chonk {
|
||||
var out Mem
|
||||
for _, update := range c.AUMs {
|
||||
if err := out.CommitVerifiedAUMs([]AUM{update}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
// ChonkWith returns a tailchonk containing the named AUMs.
|
||||
func (c *testChain) ChonkWith(names ...string) Chonk {
|
||||
var out Mem
|
||||
for _, name := range names {
|
||||
update := c.AUMs[name]
|
||||
if err := out.CommitVerifiedAUMs([]AUM{update}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return &out
|
||||
}
|
||||
|
||||
type testchainOpt struct {
|
||||
Name string
|
||||
Template *AUM
|
||||
Key *Key
|
||||
Private ed25519.PrivateKey
|
||||
SignAllWith bool
|
||||
}
|
||||
|
||||
func optTemplate(name string, template AUM) testchainOpt {
|
||||
return testchainOpt{
|
||||
Name: name,
|
||||
Template: &template,
|
||||
}
|
||||
}
|
||||
|
||||
func optKey(name string, key Key, priv ed25519.PrivateKey) testchainOpt {
|
||||
return testchainOpt{
|
||||
Name: name,
|
||||
Key: &key,
|
||||
Private: priv,
|
||||
}
|
||||
}
|
||||
|
||||
func optSignAllUsing(keyName string) testchainOpt {
|
||||
return testchainOpt{
|
||||
Name: keyName,
|
||||
SignAllWith: true,
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTestchain(t *testing.T) {
|
||||
c := newTestchain(t, `
|
||||
genesis -> B -> C
|
||||
| -> D
|
||||
| -> E -> F
|
||||
|
||||
E.hashSeed = 12 // tweak E to have the lowest hash so its chosen
|
||||
F.template = test
|
||||
`, optTemplate("test", AUM{MessageKind: AUMNoOp, KeyID: []byte{10}}))
|
||||
|
||||
want := map[string]*testchainNode{
|
||||
"genesis": &testchainNode{Name: "genesis", Uses: []scanner.Position{{Line: 2, Column: 16}}},
|
||||
"B": &testchainNode{
|
||||
Name: "B",
|
||||
Parent: "genesis",
|
||||
Uses: []scanner.Position{{Line: 2, Column: 21}, {Line: 3, Column: 21}, {Line: 4, Column: 21}},
|
||||
},
|
||||
"C": &testchainNode{Name: "C", Parent: "B", Uses: []scanner.Position{{Line: 2, Column: 26}}},
|
||||
"D": &testchainNode{Name: "D", Parent: "B", Uses: []scanner.Position{{Line: 3, Column: 26}}},
|
||||
"E": &testchainNode{Name: "E", Parent: "B", HashSeed: 12, Uses: []scanner.Position{{Line: 4, Column: 26}, {Line: 6, Column: 10}}},
|
||||
"F": &testchainNode{Name: "F", Parent: "E", Template: "test", Uses: []scanner.Position{{Line: 4, Column: 31}, {Line: 7, Column: 10}}},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, c.Nodes, cmpopts.IgnoreFields(scanner.Position{}, "Offset")); diff != "" {
|
||||
t.Errorf("decoded state differs (-want, +got):\n%s", diff)
|
||||
}
|
||||
if !bytes.Equal(c.AUMs["F"].KeyID, []byte{10}) {
|
||||
t.Errorf("AUM 'F' missing KeyID from template: %v", c.AUMs["F"])
|
||||
}
|
||||
|
||||
// chonk := c.Chonk()
|
||||
// authority, err := Open(chonk)
|
||||
// if err != nil {
|
||||
// t.Errorf("failed to initialize from chonk: %v", err)
|
||||
// }
|
||||
|
||||
// if authority.Head() != c.AUMHashes["F"] {
|
||||
// t.Errorf("head = %X, want %X", authority.Head(), c.AUMHashes["F"])
|
||||
// }
|
||||
}
|
332
tka/tka.go
332
tka/tka.go
@ -4,3 +4,335 @@
|
||||
|
||||
// Package tka (WIP) implements the Tailnet Key Authority.
|
||||
package tka
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// A chain describes a linear sequence of updates from Oldest to Head,
|
||||
// resulting in some State at Head.
|
||||
type chain struct {
|
||||
Oldest AUM
|
||||
Head AUM
|
||||
|
||||
state State
|
||||
|
||||
// Set to true if the AUM chain intersects with the active
|
||||
// chain from a previous run.
|
||||
chainsThroughActive bool
|
||||
}
|
||||
|
||||
// computeChainCandidates returns all possible chains based on AUMs stored
|
||||
// in the given tailchonk. A chain is defined as a unique (oldest, newest)
|
||||
// AUM tuple. chain.state is not yet populated in returned chains.
|
||||
//
|
||||
// If lastKnownOldest is provided, any chain that includes the given AUM
|
||||
// has the chainsThroughActive field set to true. This bit is leveraged
|
||||
// in computeActiveAncestor() to filter out irrelevant chains when determining
|
||||
// the active ancestor from a list of distinct chains.
|
||||
func computeChainCandidates(storage Chonk, lastKnownOldest *AUMHash, maxIter int) ([]chain, error) {
|
||||
heads, err := storage.Heads()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading heads: %v", err)
|
||||
}
|
||||
candidates := make([]chain, len(heads))
|
||||
for i := range heads {
|
||||
// Oldest is iteratively computed below.
|
||||
candidates[i] = chain{Oldest: heads[i], Head: heads[i]}
|
||||
}
|
||||
// Not strictly necessary, but simplifies checks in tests.
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
ih, jh := candidates[i].Oldest.Hash(), candidates[j].Oldest.Hash()
|
||||
return bytes.Compare(ih[:], jh[:]) < 0
|
||||
})
|
||||
|
||||
// candidates.Oldest needs to be computed by working backwards from
|
||||
// head as far as we can.
|
||||
iterAgain := true // if theres still work to be done.
|
||||
for i := 0; iterAgain; i++ {
|
||||
if i >= maxIter {
|
||||
return nil, fmt.Errorf("iteration limit exceeded (%d)", maxIter)
|
||||
}
|
||||
|
||||
iterAgain = false
|
||||
for j := range candidates {
|
||||
parent, hasParent := candidates[j].Oldest.Parent()
|
||||
if hasParent {
|
||||
parent, err := storage.AUM(parent)
|
||||
if err != nil {
|
||||
if err == os.ErrNotExist {
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("reading parent: %v", err)
|
||||
}
|
||||
candidates[j].Oldest = parent
|
||||
if lastKnownOldest != nil && *lastKnownOldest == parent.Hash() {
|
||||
candidates[j].chainsThroughActive = true
|
||||
}
|
||||
iterAgain = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return candidates, nil
|
||||
}
|
||||
|
||||
// pickNextAUM returns the AUM which should be used as the next
|
||||
// AUM in the chain, possibly applying fork resolution logic.
|
||||
//
|
||||
// In other words: given an AUM with 3 children like this:
|
||||
// / - 1
|
||||
// P - 2
|
||||
// \ - 3
|
||||
//
|
||||
// pickNextAUM will determine and return the correct branch.
|
||||
//
|
||||
// This method takes ownership of the provided slice.
|
||||
func pickNextAUM(state State, candidates []AUM) AUM {
|
||||
switch len(candidates) {
|
||||
case 0:
|
||||
panic("pickNextAUM called with empty candidate set")
|
||||
case 1:
|
||||
return candidates[0]
|
||||
}
|
||||
|
||||
// Oooof, we have some forks in the chain. We need to pick which
|
||||
// one to use by applying the Fork Resolution Algorithm ✨
|
||||
//
|
||||
// The rules are this:
|
||||
// 1. The child with the highest signature weight is chosen.
|
||||
// 2. If equal, the child which is a RemoveKey AUM is chosen.
|
||||
// 3. If equal, the child with the lowest AUM hash is chosen.
|
||||
sort.Slice(candidates, func(j, i int) bool {
|
||||
// Rule 1.
|
||||
iSigWeight, jSigWeight := candidates[i].Weight(state), candidates[j].Weight(state)
|
||||
if iSigWeight != jSigWeight {
|
||||
return iSigWeight < jSigWeight
|
||||
}
|
||||
|
||||
// Rule 2.
|
||||
if iKind, jKind := candidates[i].MessageKind, candidates[j].MessageKind; iKind != jKind &&
|
||||
(iKind == AUMRemoveKey || jKind == AUMRemoveKey) {
|
||||
return jKind == AUMRemoveKey
|
||||
}
|
||||
|
||||
// Rule 3.
|
||||
iHash, jHash := candidates[i].Hash(), candidates[j].Hash()
|
||||
return bytes.Compare(iHash[:], jHash[:]) > 0
|
||||
})
|
||||
|
||||
return candidates[0]
|
||||
}
|
||||
|
||||
// advanceChain computes the next AUM to advance with based on all child
|
||||
// AUMs, returning the chosen AUM & the state obtained by applying that
|
||||
// AUM.
|
||||
//
|
||||
// The return value for next is nil if there are no children AUMs, hence
|
||||
// the provided state is at head (up to date).
|
||||
func advanceChain(state State, candidates []AUM) (next *AUM, out State, err error) {
|
||||
if len(candidates) == 0 {
|
||||
return nil, state, nil
|
||||
}
|
||||
|
||||
aum := pickNextAUM(state, candidates)
|
||||
if state, err = state.applyVerifiedAUM(aum); err != nil {
|
||||
return nil, State{}, fmt.Errorf("advancing state: %v", err)
|
||||
}
|
||||
return &aum, state, nil
|
||||
}
|
||||
|
||||
// fastForward iteratively advances the current state based on known AUMs until
|
||||
// the given termination function returns true or there is no more progress possible.
|
||||
//
|
||||
// The last-processed AUM, and the state computed after applying the last AUM,
|
||||
// are returned.
|
||||
func fastForward(storage Chonk, maxIter int, startState State, done func(curAUM AUM, curState State) bool) (AUM, State, error) {
|
||||
if startState.LastAUMHash == nil {
|
||||
return AUM{}, State{}, errors.New("invalid initial state")
|
||||
}
|
||||
nextAUM, err := storage.AUM(*startState.LastAUMHash)
|
||||
if err != nil {
|
||||
return AUM{}, State{}, fmt.Errorf("reading next: %v", err)
|
||||
}
|
||||
|
||||
curs := nextAUM
|
||||
state := startState
|
||||
for i := 0; i < maxIter; i++ {
|
||||
if done != nil && done(curs, state) {
|
||||
return curs, state, nil
|
||||
}
|
||||
|
||||
children, err := storage.ChildAUMs(curs.Hash())
|
||||
if err != nil {
|
||||
return AUM{}, State{}, fmt.Errorf("getting children of %X: %v", curs.Hash(), err)
|
||||
}
|
||||
next, nextState, err := advanceChain(state, children)
|
||||
if err != nil {
|
||||
return AUM{}, State{}, fmt.Errorf("advance %X: %v", curs.Hash(), err)
|
||||
}
|
||||
if next == nil {
|
||||
// There were no more children, we are at 'head'.
|
||||
return curs, state, nil
|
||||
}
|
||||
curs = *next
|
||||
state = nextState
|
||||
}
|
||||
|
||||
return AUM{}, State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter)
|
||||
}
|
||||
|
||||
// computeStateAt returns the State at wantHash.
|
||||
func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) {
|
||||
// TODO(tom): This is going to get expensive for really long
|
||||
// chains. We should make nodes emit a checkpoint every
|
||||
// X updates or something.
|
||||
|
||||
topAUM, err := storage.AUM(wantHash)
|
||||
if err != nil {
|
||||
return State{}, err
|
||||
}
|
||||
|
||||
// Iterate backwards till we find a starting point to compute
|
||||
// the state from.
|
||||
//
|
||||
// Valid starting points are either a checkpoint AUM, or a
|
||||
// genesis AUM.
|
||||
curs := topAUM
|
||||
var state State
|
||||
for i := 0; true; i++ {
|
||||
if i > maxIter {
|
||||
return State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter)
|
||||
}
|
||||
|
||||
// Checkpoints encapsulate the state at that point, dope.
|
||||
if curs.MessageKind == AUMCheckpoint {
|
||||
state = curs.State.cloneForUpdate(&curs)
|
||||
break
|
||||
}
|
||||
parent, hasParent := curs.Parent()
|
||||
if !hasParent {
|
||||
// This is a 'genesis' update: there are none before it, so
|
||||
// this AUM can be applied to the empty state to determine
|
||||
// the state at this AUM.
|
||||
//
|
||||
// It is only valid for NoOp, AddKey, and Checkpoint AUMs
|
||||
// to be a genesis update. Checkpoint was handled earlier.
|
||||
if mk := curs.MessageKind; mk == AUMNoOp || mk == AUMAddKey {
|
||||
var err error
|
||||
if state, err = (State{}).applyVerifiedAUM(curs); err != nil {
|
||||
return State{}, fmt.Errorf("applying genesis (%+v): %v", curs, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
return State{}, fmt.Errorf("invalid genesis update: %+v", curs)
|
||||
}
|
||||
|
||||
// If we got here, the current state is dependent on the previous.
|
||||
// Keep iterating backwards till thats not the case.
|
||||
if curs, err = storage.AUM(parent); err != nil {
|
||||
return State{}, fmt.Errorf("reading parent: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// We now know some starting point state. Iterate forward till we
|
||||
// are at the AUM we want state for.
|
||||
_, state, err = fastForward(storage, maxIter, state, func(curs AUM, _ State) bool {
|
||||
return curs.Hash() == wantHash
|
||||
})
|
||||
// fastForward only terminates before the done condition if it
|
||||
// doesnt have any later AUMs to process. This cant be the case
|
||||
// as we've already iterated through them above so they must exist,
|
||||
// but we check anyway to be super duper sure.
|
||||
if err == nil && *state.LastAUMHash != wantHash {
|
||||
panic("unexpected fastForward outcome")
|
||||
}
|
||||
return state, err
|
||||
}
|
||||
|
||||
// computeActiveAncestor determines which ancestor AUM to use as the
|
||||
// ancestor of the valid chain.
|
||||
//
|
||||
// If all the chains end up having the same ancestor, then thats the
|
||||
// only possible ancestor, ezpz. However if there are multiple distinct
|
||||
// ancestors, that means there are distinct chains, and we need some
|
||||
// hint to choose what to use. For that, we rely on the chainsThroughActive
|
||||
// bit, which signals to us that that ancestor was part of the
|
||||
// chain in a previous run.
|
||||
func computeActiveAncestor(storage Chonk, chains []chain) (AUMHash, error) {
|
||||
// Dedupe possible ancestors, tracking if they were part of
|
||||
// the active chain on a previous run.
|
||||
ancestors := make(map[AUMHash]bool, len(chains))
|
||||
for _, c := range chains {
|
||||
ancestors[c.Oldest.Hash()] = c.chainsThroughActive
|
||||
}
|
||||
|
||||
if len(ancestors) == 1 {
|
||||
// There's only one. DOPE.
|
||||
for k, _ := range ancestors {
|
||||
return k, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Theres more than one, so we need to use the ancestor that was
|
||||
// part of the active chain in a previous iteration.
|
||||
// Note that there can only be one distinct ancestor that was
|
||||
// formerly part of the active chain, because AUMs can only have
|
||||
// one parent and would have converged to a common ancestor.
|
||||
for k, chainsThroughActive := range ancestors {
|
||||
if chainsThroughActive {
|
||||
return k, nil
|
||||
}
|
||||
}
|
||||
|
||||
return AUMHash{}, errors.New("multiple distinct chains")
|
||||
}
|
||||
|
||||
// computeActiveChain bootstraps the runtime state of the Authority when
|
||||
// starting entirely off stored state.
|
||||
//
|
||||
// TODO(tom): Don't look at head states, just iterate forward from
|
||||
// the ancestor.
|
||||
//
|
||||
// The algorithm is as follows:
|
||||
// 1. Determine all possible 'head' (like in git) states.
|
||||
// 2. Filter these possible chains based on whether the ancestor was
|
||||
// formerly (in a previous run) part of the chain.
|
||||
// 3. Compute the state of the state machine at this ancestor. This is
|
||||
// needed for fast-forward, as each update operates on the state of
|
||||
// the update preceeding it.
|
||||
// 4. Iteratively apply updates till we reach head ('fast forward').
|
||||
func computeActiveChain(storage Chonk, lastKnownOldest *AUMHash, maxIter int) (chain, error) {
|
||||
chains, err := computeChainCandidates(storage, lastKnownOldest, maxIter)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("computing candidates: %v", err)
|
||||
}
|
||||
|
||||
// Find the right ancestor.
|
||||
oldestHash, err := computeActiveAncestor(storage, chains)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("computing ancestor: %v", err)
|
||||
}
|
||||
ancestor, err := storage.AUM(oldestHash)
|
||||
if err != nil {
|
||||
return chain{}, err
|
||||
}
|
||||
|
||||
// At this stage we know the ancestor AUM, so we have excluded distinct
|
||||
// chains but we might still have forks (so we don't know the head AUM).
|
||||
//
|
||||
// We iterate forward from the ancestor AUM, handling any forks as we go
|
||||
// till we arrive at a head.
|
||||
out := chain{Oldest: ancestor, Head: ancestor}
|
||||
if out.state, err = computeStateAt(storage, maxIter, oldestHash); err != nil {
|
||||
return chain{}, fmt.Errorf("bootstrapping state: %v", err)
|
||||
}
|
||||
out.Head, out.state, err = fastForward(storage, maxIter, out.state, nil)
|
||||
if err != nil {
|
||||
return chain{}, fmt.Errorf("fast forward: %v", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
187
tka/tka_test.go
Normal file
187
tka/tka_test.go
Normal file
@ -0,0 +1,187 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tka
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestComputeChainCandidates(t *testing.T) {
|
||||
c := newTestchain(t, `
|
||||
G1 -> I1 -> I2 -> I3 -> L2
|
||||
| -> L1 | -> L3
|
||||
|
||||
G2 -> L4
|
||||
|
||||
// We tweak these AUMs so they are different hashes.
|
||||
G2.hashSeed = 2
|
||||
L1.hashSeed = 2
|
||||
L3.hashSeed = 2
|
||||
L4.hashSeed = 3
|
||||
`)
|
||||
// Should result in 4 chains:
|
||||
// G1->L1, G1->L2, G1->L3, G2->L4
|
||||
|
||||
i1H := c.AUMHashes["I1"]
|
||||
got, err := computeChainCandidates(c.Chonk(), &i1H, 50)
|
||||
if err != nil {
|
||||
t.Fatalf("computeChainCandidates() failed: %v", err)
|
||||
}
|
||||
|
||||
want := []chain{
|
||||
{Oldest: c.AUMs["G1"], Head: c.AUMs["L1"], chainsThroughActive: true},
|
||||
{Oldest: c.AUMs["G1"], Head: c.AUMs["L3"], chainsThroughActive: true},
|
||||
{Oldest: c.AUMs["G1"], Head: c.AUMs["L2"], chainsThroughActive: true},
|
||||
{Oldest: c.AUMs["G2"], Head: c.AUMs["L4"]},
|
||||
}
|
||||
if diff := cmp.Diff(want, got, cmp.AllowUnexported(chain{})); diff != "" {
|
||||
t.Errorf("chains differ (-want, +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForkResolutionHash(t *testing.T) {
|
||||
c := newTestchain(t, `
|
||||
G1 -> L1
|
||||
| -> L2
|
||||
|
||||
// tweak hashes so L1 & L2 are not identical
|
||||
L1.hashSeed = 2
|
||||
L2.hashSeed = 3
|
||||
`)
|
||||
|
||||
got, err := computeActiveChain(c.Chonk(), nil, 50)
|
||||
if err != nil {
|
||||
t.Fatalf("computeActiveChain() failed: %v", err)
|
||||
}
|
||||
|
||||
// The fork with the lowest AUM hash should have been chosen.
|
||||
l1H := c.AUMHashes["L1"]
|
||||
l2H := c.AUMHashes["L2"]
|
||||
want := l1H
|
||||
if bytes.Compare(l2H[:], l1H[:]) < 0 {
|
||||
want = l2H
|
||||
}
|
||||
|
||||
if got := got.Head.Hash(); got != want {
|
||||
t.Errorf("head was %x, want %x", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForkResolutionSigWeight(t *testing.T) {
|
||||
pub, priv := testingKey25519(t, 1)
|
||||
key := Key{Kind: Key25519, Public: pub, Votes: 2}
|
||||
|
||||
c := newTestchain(t, `
|
||||
G1 -> L1
|
||||
| -> L2
|
||||
|
||||
G1.template = addKey
|
||||
L1.hashSeed = 2
|
||||
L2.signedWith = key
|
||||
`,
|
||||
optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}),
|
||||
optKey("key", key, priv))
|
||||
|
||||
l1H := c.AUMHashes["L1"]
|
||||
l2H := c.AUMHashes["L2"]
|
||||
if bytes.Compare(l2H[:], l1H[:]) < 0 {
|
||||
t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes")
|
||||
}
|
||||
|
||||
got, err := computeActiveChain(c.Chonk(), nil, 50)
|
||||
if err != nil {
|
||||
t.Fatalf("computeActiveChain() failed: %v", err)
|
||||
}
|
||||
|
||||
// Based on the hash, l1H should be chosen.
|
||||
// But based on the signature weight (which has higher
|
||||
// precedence), it should be l2H
|
||||
want := l2H
|
||||
if got := got.Head.Hash(); got != want {
|
||||
t.Errorf("head was %x, want %x", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForkResolutionMessageType(t *testing.T) {
|
||||
pub, _ := testingKey25519(t, 1)
|
||||
key := Key{Kind: Key25519, Public: pub, Votes: 2}
|
||||
|
||||
c := newTestchain(t, `
|
||||
G1 -> L1
|
||||
| -> L2
|
||||
| -> L3
|
||||
|
||||
G1.template = addKey
|
||||
L1.hashSeed = 11
|
||||
L2.template = removeKey
|
||||
L3.hashSeed = 18
|
||||
`,
|
||||
optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}),
|
||||
optTemplate("removeKey", AUM{MessageKind: AUMRemoveKey, KeyID: key.ID()}))
|
||||
|
||||
l1H := c.AUMHashes["L1"]
|
||||
l2H := c.AUMHashes["L2"]
|
||||
l3H := c.AUMHashes["L3"]
|
||||
if bytes.Compare(l2H[:], l1H[:]) < 0 {
|
||||
t.Fatal("failed assert: h(l1) > h(l2)\nTweak hashSeed till this passes")
|
||||
}
|
||||
if bytes.Compare(l2H[:], l3H[:]) < 0 {
|
||||
t.Fatal("failed assert: h(l3) > h(l2)\nTweak hashSeed till this passes")
|
||||
}
|
||||
|
||||
got, err := computeActiveChain(c.Chonk(), nil, 50)
|
||||
if err != nil {
|
||||
t.Fatalf("computeActiveChain() failed: %v", err)
|
||||
}
|
||||
|
||||
// Based on the hash, L1 or L3 should be chosen.
|
||||
// But based on the preference for AUMRemoveKey messages,
|
||||
// it should be L2.
|
||||
want := l2H
|
||||
if got := got.Head.Hash(); got != want {
|
||||
t.Errorf("head was %x, want %x", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeStateAt(t *testing.T) {
|
||||
pub, _ := testingKey25519(t, 1)
|
||||
key := Key{Kind: Key25519, Public: pub, Votes: 2}
|
||||
|
||||
c := newTestchain(t, `
|
||||
G1 -> I1 -> I2
|
||||
I1.template = addKey
|
||||
`,
|
||||
optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}))
|
||||
|
||||
// G1 is before the key, so there shouldn't be a key there.
|
||||
state, err := computeStateAt(c.Chonk(), 500, c.AUMHashes["G1"])
|
||||
if err != nil {
|
||||
t.Fatalf("computeStateAt(G1) failed: %v", err)
|
||||
}
|
||||
if _, err := state.GetKey(key.ID()); err != ErrNoSuchKey {
|
||||
t.Errorf("expected key to be missing: err = %v", err)
|
||||
}
|
||||
if *state.LastAUMHash != c.AUMHashes["G1"] {
|
||||
t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, c.AUMHashes["G1"])
|
||||
}
|
||||
|
||||
// I1 & I2 are after the key, so the computed state should contain
|
||||
// the key.
|
||||
for _, wantHash := range []AUMHash{c.AUMHashes["I1"], c.AUMHashes["I2"]} {
|
||||
state, err = computeStateAt(c.Chonk(), 500, wantHash)
|
||||
if err != nil {
|
||||
t.Fatalf("computeStateAt(%X) failed: %v", wantHash, err)
|
||||
}
|
||||
if *state.LastAUMHash != wantHash {
|
||||
t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, wantHash)
|
||||
}
|
||||
if _, err := state.GetKey(key.ID()); err != nil {
|
||||
t.Errorf("expected key to be present at state: err = %v", err)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user