cmd/tailscale,tka: make KeyID return an error instead of panicking

Signed-off-by: Tom DNetto <tom@tailscale.com>
This commit is contained in:
Tom DNetto 2023-01-03 09:39:55 -08:00 committed by Tom
parent 8724aa254f
commit 907f85cd67
14 changed files with 125 additions and 59 deletions

View File

@ -103,7 +103,11 @@ func runNetworkLockInit(ctx context.Context, args []string) error {
// Common mistake: Not specifying the current node's key as one of the trusted keys. // Common mistake: Not specifying the current node's key as one of the trusted keys.
foundSelfKey := false foundSelfKey := false
for _, k := range keys { for _, k := range keys {
if bytes.Equal(k.ID(), st.PublicKey.KeyID()) { keyID, err := k.ID()
if err != nil {
return err
}
if bytes.Equal(keyID, st.PublicKey.KeyID()) {
foundSelfKey = true foundSelfKey = true
break break
} }
@ -457,8 +461,13 @@ func nlDescribeUpdate(update ipnstate.NetworkLockUpdate, color bool) (string, er
var stanza strings.Builder var stanza strings.Builder
printKey := func(key *tka.Key, prefix string) { printKey := func(key *tka.Key, prefix string) {
fmt.Fprintf(&stanza, "%sType: %s\n", prefix, key.Kind.String()) fmt.Fprintf(&stanza, "%sType: %s\n", prefix, key.Kind.String())
fmt.Fprintf(&stanza, "%sKeyID: %x\n", prefix, key.ID()) if keyID, err := key.ID(); err == nil {
fmt.Fprintf(&stanza, "%sVotes: %d\n", prefix, key.Votes) fmt.Fprintf(&stanza, "%sKeyID: %x\n", prefix, keyID)
} else {
// Older versions of the client shouldn't explode when they encounter an
// unknown key type.
fmt.Fprintf(&stanza, "%sKeyID: <Error: %v>\n", prefix, err)
}
if key.Meta != nil { if key.Meta != nil {
fmt.Fprintf(&stanza, "%sMetadata: %+v\n", prefix, key.Meta) fmt.Fprintf(&stanza, "%sMetadata: %+v\n", prefix, key.Meta)
} }

View File

@ -666,7 +666,11 @@ func (b *LocalBackend) NetworkLockModify(addKeys, removeKeys []tka.Key) (err err
} }
} }
for _, removeKey := range removeKeys { for _, removeKey := range removeKeys {
if err := updater.RemoveKey(removeKey.ID()); err != nil { keyID, err := removeKey.ID()
if err != nil {
return err
}
if err := updater.RemoveKey(keyID); err != nil {
return err return err
} }
} }

View File

@ -321,7 +321,7 @@ type tkaSyncScenario struct {
name: "control has an update", name: "control has an update",
controlAUMs: func(t *testing.T, a *tka.Authority, storage tka.Chonk, signer tka.Signer) []tka.AUM { controlAUMs: func(t *testing.T, a *tka.Authority, storage tka.Chonk, signer tka.Signer) []tka.AUM {
b := a.NewUpdater(signer) b := a.NewUpdater(signer)
if err := b.RemoveKey(someKey.ID()); err != nil { if err := b.RemoveKey(someKey.MustID()); err != nil {
t.Fatal(err) t.Fatal(err)
} }
aums, err := b.Finalize(storage) aums, err := b.Finalize(storage)
@ -336,7 +336,7 @@ type tkaSyncScenario struct {
name: "node has an update", name: "node has an update",
nodeAUMs: func(t *testing.T, a *tka.Authority, storage tka.Chonk, signer tka.Signer) []tka.AUM { nodeAUMs: func(t *testing.T, a *tka.Authority, storage tka.Chonk, signer tka.Signer) []tka.AUM {
b := a.NewUpdater(signer) b := a.NewUpdater(signer)
if err := b.RemoveKey(someKey.ID()); err != nil { if err := b.RemoveKey(someKey.MustID()); err != nil {
t.Fatal(err) t.Fatal(err)
} }
aums, err := b.Finalize(storage) aums, err := b.Finalize(storage)
@ -351,7 +351,7 @@ type tkaSyncScenario struct {
name: "node and control diverge", name: "node and control diverge",
controlAUMs: func(t *testing.T, a *tka.Authority, storage tka.Chonk, signer tka.Signer) []tka.AUM { controlAUMs: func(t *testing.T, a *tka.Authority, storage tka.Chonk, signer tka.Signer) []tka.AUM {
b := a.NewUpdater(signer) b := a.NewUpdater(signer)
if err := b.SetKeyMeta(someKey.ID(), map[string]string{"ye": "swiggity"}); err != nil { if err := b.SetKeyMeta(someKey.MustID(), map[string]string{"ye": "swiggity"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
aums, err := b.Finalize(storage) aums, err := b.Finalize(storage)
@ -362,7 +362,7 @@ type tkaSyncScenario struct {
}, },
nodeAUMs: func(t *testing.T, a *tka.Authority, storage tka.Chonk, signer tka.Signer) []tka.AUM { nodeAUMs: func(t *testing.T, a *tka.Authority, storage tka.Chonk, signer tka.Signer) []tka.AUM {
b := a.NewUpdater(signer) b := a.NewUpdater(signer)
if err := b.SetKeyMeta(someKey.ID(), map[string]string{"ye": "swooty"}); err != nil { if err := b.SetKeyMeta(someKey.MustID(), map[string]string{"ye": "swooty"}); err != nil {
t.Fatal(err) t.Fatal(err)
} }
aums, err := b.Finalize(storage) aums, err := b.Finalize(storage)

View File

@ -281,14 +281,20 @@ func (a *AUM) Parent() (h AUMHash, ok bool) {
return h, false return h, false
} }
func (a *AUM) sign25519(priv ed25519.PrivateKey) { func (a *AUM) sign25519(priv ed25519.PrivateKey) error {
key := Key{Kind: Key25519, Public: priv.Public().(ed25519.PublicKey)} key := Key{Kind: Key25519, Public: priv.Public().(ed25519.PublicKey)}
sigHash := a.SigHash() sigHash := a.SigHash()
keyID, err := key.ID()
if err != nil {
return err
}
a.Signatures = append(a.Signatures, tkatype.Signature{ a.Signatures = append(a.Signatures, tkatype.Signature{
KeyID: key.ID(), KeyID: keyID,
Signature: ed25519.Sign(priv, sigHash[:]), Signature: ed25519.Sign(priv, sigHash[:]),
}) })
return nil
} }
// Weight computes the 'signature weight' of the AUM // Weight computes the 'signature weight' of the AUM

View File

@ -189,7 +189,7 @@ func TestAUMWeight(t *testing.T) {
{ {
"Unary key", "Unary key",
AUM{ AUM{
Signatures: []tkatype.Signature{{KeyID: key.ID()}}, Signatures: []tkatype.Signature{{KeyID: key.MustID()}},
}, },
State{ State{
Keys: []Key{key}, Keys: []Key{key},
@ -199,7 +199,7 @@ func TestAUMWeight(t *testing.T) {
{ {
"Multiple keys", "Multiple keys",
AUM{ AUM{
Signatures: []tkatype.Signature{{KeyID: key.ID()}, {KeyID: key2.ID()}}, Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key2.MustID()}},
}, },
State{ State{
Keys: []Key{key, key2}, Keys: []Key{key, key2},
@ -209,7 +209,7 @@ func TestAUMWeight(t *testing.T) {
{ {
"Double use", "Double use",
AUM{ AUM{
Signatures: []tkatype.Signature{{KeyID: key.ID()}, {KeyID: key.ID()}}, Signatures: []tkatype.Signature{{KeyID: key.MustID()}, {KeyID: key.MustID()}},
}, },
State{ State{
Keys: []Key{key}, Keys: []Key{key},

View File

@ -60,7 +60,12 @@ func (b *UpdateBuilder) mkUpdate(update AUM) error {
// AddKey adds a new key to the authority. // AddKey adds a new key to the authority.
func (b *UpdateBuilder) AddKey(key Key) error { func (b *UpdateBuilder) AddKey(key Key) error {
if _, err := b.state.GetKey(key.ID()); err == nil { keyID, err := key.ID()
if err != nil {
return err
}
if _, err := b.state.GetKey(keyID); err == nil {
return fmt.Errorf("cannot add key %v: already exists", key) return fmt.Errorf("cannot add key %v: already exists", key)
} }
return b.mkUpdate(AUM{MessageKind: AUMAddKey, Key: &key}) return b.mkUpdate(AUM{MessageKind: AUMAddKey, Key: &key})

View File

@ -19,7 +19,7 @@ func (s signer25519) SignAUM(sigHash tkatype.AUMSigHash) ([]tkatype.Signature, e
key := Key{Kind: Key25519, Public: priv.Public().(ed25519.PublicKey)} key := Key{Kind: Key25519, Public: priv.Public().(ed25519.PublicKey)}
return []tkatype.Signature{{ return []tkatype.Signature{{
KeyID: key.ID(), KeyID: key.MustID(),
Signature: ed25519.Sign(priv, sigHash[:]), Signature: ed25519.Sign(priv, sigHash[:]),
}}, nil }}, nil
} }
@ -54,7 +54,7 @@ func TestAuthorityBuilderAddKey(t *testing.T) {
if err := a.Inform(storage, updates); err != nil { if err := a.Inform(storage, updates); err != nil {
t.Fatalf("could not apply generated updates: %v", err) t.Fatalf("could not apply generated updates: %v", err)
} }
if _, err := a.state.GetKey(key2.ID()); err != nil { if _, err := a.state.GetKey(key2.MustID()); err != nil {
t.Errorf("could not read new key: %v", err) t.Errorf("could not read new key: %v", err)
} }
} }
@ -75,7 +75,7 @@ func TestAuthorityBuilderRemoveKey(t *testing.T) {
} }
b := a.NewUpdater(signer25519(priv)) b := a.NewUpdater(signer25519(priv))
if err := b.RemoveKey(key2.ID()); err != nil { if err := b.RemoveKey(key2.MustID()); err != nil {
t.Fatalf("RemoveKey(%v) failed: %v", key2, err) t.Fatalf("RemoveKey(%v) failed: %v", key2, err)
} }
updates, err := b.Finalize(storage) updates, err := b.Finalize(storage)
@ -88,7 +88,7 @@ func TestAuthorityBuilderRemoveKey(t *testing.T) {
if err := a.Inform(storage, updates); err != nil { if err := a.Inform(storage, updates); err != nil {
t.Fatalf("could not apply generated updates: %v", err) t.Fatalf("could not apply generated updates: %v", err)
} }
if _, err := a.state.GetKey(key2.ID()); err != ErrNoSuchKey { if _, err := a.state.GetKey(key2.MustID()); err != ErrNoSuchKey {
t.Errorf("GetKey(key2).err = %v, want %v", err, ErrNoSuchKey) t.Errorf("GetKey(key2).err = %v, want %v", err, ErrNoSuchKey)
} }
} }
@ -107,8 +107,8 @@ func TestAuthorityBuilderSetKeyVote(t *testing.T) {
} }
b := a.NewUpdater(signer25519(priv)) b := a.NewUpdater(signer25519(priv))
if err := b.SetKeyVote(key.ID(), 5); err != nil { if err := b.SetKeyVote(key.MustID(), 5); err != nil {
t.Fatalf("SetKeyVote(%v) failed: %v", key.ID(), err) t.Fatalf("SetKeyVote(%v) failed: %v", key.MustID(), err)
} }
updates, err := b.Finalize(storage) updates, err := b.Finalize(storage)
if err != nil { if err != nil {
@ -120,7 +120,7 @@ func TestAuthorityBuilderSetKeyVote(t *testing.T) {
if err := a.Inform(storage, updates); err != nil { if err := a.Inform(storage, updates); err != nil {
t.Fatalf("could not apply generated updates: %v", err) t.Fatalf("could not apply generated updates: %v", err)
} }
k, err := a.state.GetKey(key.ID()) k, err := a.state.GetKey(key.MustID())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -143,7 +143,7 @@ func TestAuthorityBuilderSetKeyMeta(t *testing.T) {
} }
b := a.NewUpdater(signer25519(priv)) b := a.NewUpdater(signer25519(priv))
if err := b.SetKeyMeta(key.ID(), map[string]string{"b": "c"}); err != nil { if err := b.SetKeyMeta(key.MustID(), map[string]string{"b": "c"}); err != nil {
t.Fatalf("SetKeyMeta(%v) failed: %v", key, err) t.Fatalf("SetKeyMeta(%v) failed: %v", key, err)
} }
updates, err := b.Finalize(storage) updates, err := b.Finalize(storage)
@ -156,7 +156,7 @@ func TestAuthorityBuilderSetKeyMeta(t *testing.T) {
if err := a.Inform(storage, updates); err != nil { if err := a.Inform(storage, updates); err != nil {
t.Fatalf("could not apply generated updates: %v", err) t.Fatalf("could not apply generated updates: %v", err)
} }
k, err := a.state.GetKey(key.ID()) k, err := a.state.GetKey(key.MustID())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -185,10 +185,10 @@ func TestAuthorityBuilderMultiple(t *testing.T) {
if err := b.AddKey(key2); err != nil { if err := b.AddKey(key2); err != nil {
t.Fatalf("AddKey(%v) failed: %v", key2, err) t.Fatalf("AddKey(%v) failed: %v", key2, err)
} }
if err := b.SetKeyVote(key2.ID(), 42); err != nil { if err := b.SetKeyVote(key2.MustID(), 42); err != nil {
t.Fatalf("SetKeyVote(%v) failed: %v", key2, err) t.Fatalf("SetKeyVote(%v) failed: %v", key2, err)
} }
if err := b.RemoveKey(key.ID()); err != nil { if err := b.RemoveKey(key.MustID()); err != nil {
t.Fatalf("RemoveKey(%v) failed: %v", key, err) t.Fatalf("RemoveKey(%v) failed: %v", key, err)
} }
updates, err := b.Finalize(storage) updates, err := b.Finalize(storage)
@ -201,14 +201,14 @@ func TestAuthorityBuilderMultiple(t *testing.T) {
if err := a.Inform(storage, updates); err != nil { if err := a.Inform(storage, updates); err != nil {
t.Fatalf("could not apply generated updates: %v", err) t.Fatalf("could not apply generated updates: %v", err)
} }
k, err := a.state.GetKey(key2.ID()) k, err := a.state.GetKey(key2.MustID())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if got, want := k.Votes, uint(42); got != want { if got, want := k.Votes, uint(42); got != want {
t.Errorf("key.Votes = %d, want %d", got, want) t.Errorf("key.Votes = %d, want %d", got, want)
} }
if _, err := a.state.GetKey(key.ID()); err != ErrNoSuchKey { if _, err := a.state.GetKey(key.MustID()); err != ErrNoSuchKey {
t.Errorf("GetKey(key).err = %v, want %v", err, ErrNoSuchKey) t.Errorf("GetKey(key).err = %v, want %v", err, ErrNoSuchKey)
} }
} }
@ -243,7 +243,7 @@ func TestAuthorityBuilderCheckpointsAfterXUpdates(t *testing.T) {
if err := a.Inform(storage, updates); err != nil { if err := a.Inform(storage, updates); err != nil {
t.Fatalf("could not apply generated updates: %v", err) t.Fatalf("could not apply generated updates: %v", err)
} }
if _, err := a.state.GetKey(key2.ID()); err != nil { if _, err := a.state.GetKey(key2.MustID()); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -267,7 +267,7 @@ func (c *testChain) makeAUM(v *testchainNode) AUM {
sigHash := aum.SigHash() sigHash := aum.SigHash()
for _, key := range c.SignAllKeys { for _, key := range c.SignAllKeys {
aum.Signatures = append(aum.Signatures, tkatype.Signature{ aum.Signatures = append(aum.Signatures, tkatype.Signature{
KeyID: c.Key[key].ID(), KeyID: c.Key[key].MustID(),
Signature: ed25519.Sign(c.KeyPrivs[key], sigHash[:]), Signature: ed25519.Sign(c.KeyPrivs[key], sigHash[:]),
}) })
} }
@ -276,7 +276,7 @@ func (c *testChain) makeAUM(v *testchainNode) AUM {
// sign it using that key. // sign it using that key.
if key := v.SignedWith; key != "" { if key := v.SignedWith; key != "" {
aum.Signatures = append(aum.Signatures, tkatype.Signature{ aum.Signatures = append(aum.Signatures, tkatype.Signature{
KeyID: c.Key[key].ID(), KeyID: c.Key[key].MustID(),
Signature: ed25519.Sign(c.KeyPrivs[key], sigHash[:]), Signature: ed25519.Sign(c.KeyPrivs[key], sigHash[:]),
}) })
} }

View File

@ -74,14 +74,25 @@ func (k Key) Clone() Key {
return out return out
} }
func (k Key) ID() tkatype.KeyID { // MustID returns the KeyID of the key, panicking if an error is
// encountered. This must only be used for tests.
func (k Key) MustID() tkatype.KeyID {
id, err := k.ID()
if err != nil {
panic(err)
}
return id
}
// ID returns the KeyID of the key.
func (k Key) ID() (tkatype.KeyID, error) {
switch k.Kind { switch k.Kind {
// Because 25519 public keys are so short, we just use the 32-byte // Because 25519 public keys are so short, we just use the 32-byte
// public as their 'key ID'. // public as their 'key ID'.
case Key25519: case Key25519:
return tkatype.KeyID(k.Public) return tkatype.KeyID(k.Public), nil
default: default:
panic("unsupported key kind") return nil, fmt.Errorf("unknown key kind: %v", k.Kind)
} }
} }

View File

@ -49,7 +49,7 @@ func TestVerify25519(t *testing.T) {
sigHash := aum.SigHash() sigHash := aum.SigHash()
aum.Signatures = []tkatype.Signature{ aum.Signatures = []tkatype.Signature{
{ {
KeyID: key.ID(), KeyID: key.MustID(),
Signature: ed25519.Sign(priv, sigHash[:]), Signature: ed25519.Sign(priv, sigHash[:]),
}, },
} }
@ -92,7 +92,7 @@ func TestNLPrivate(t *testing.T) {
// We manually compute the keyID, so make sure its consistent with // We manually compute the keyID, so make sure its consistent with
// tka.Key.ID(). // tka.Key.ID().
if !bytes.Equal(k.ID(), p.KeyID()) { if !bytes.Equal(k.MustID(), p.KeyID()) {
t.Errorf("private.KeyID() & tka KeyID differ: %x != %x", k.ID(), p.KeyID()) t.Errorf("private.KeyID() & tka KeyID differ: %x != %x", k.MustID(), p.KeyID())
} }
} }

View File

@ -273,7 +273,7 @@ func TestForkingPropagation(t *testing.T) {
F1.template = removeKey1`, F1.template = removeKey1`,
optSignAllUsing("key2"), optSignAllUsing("key2"),
optKey("key2", key, priv), optKey("key2", key, priv),
optTemplate("removeKey1", AUM{MessageKind: AUMRemoveKey, KeyID: s.defaultKey.ID()})), optTemplate("removeKey1", AUM{MessageKind: AUMRemoveKey, KeyID: s.defaultKey.MustID()})),
}) })
s.testSyncsBetween(control, n2) s.testSyncsBetween(control, n2)
s.checkHaveConsensus(control, n2) s.checkHaveConsensus(control, n2)
@ -282,10 +282,10 @@ func TestForkingPropagation(t *testing.T) {
s.testSyncsBetween(control, n1) s.testSyncsBetween(control, n1)
s.checkHaveConsensus(n1, n2) s.checkHaveConsensus(n1, n2)
if _, err := n1.A.state.GetKey(s.defaultKey.ID()); err != ErrNoSuchKey { if _, err := n1.A.state.GetKey(s.defaultKey.MustID()); err != ErrNoSuchKey {
t.Error("default key was still present") t.Error("default key was still present")
} }
if _, err := n1.A.state.GetKey(key.ID()); err != nil { if _, err := n1.A.state.GetKey(key.MustID()); err != nil {
t.Errorf("key2 was not trusted: %v", err) t.Errorf("key2 was not trusted: %v", err)
} }
} }
@ -305,7 +305,9 @@ func TestInvalidAUMPropagationRejected(t *testing.T) {
l3 := n1.AUMs["L3"] l3 := n1.AUMs["L3"]
l3H := l3.Hash() l3H := l3.Hash()
l4 := AUM{MessageKind: AUMAddKey, PrevAUMHash: l3H[:]} l4 := AUM{MessageKind: AUMAddKey, PrevAUMHash: l3H[:]}
l4.sign25519(s.defaultPriv) if err := l4.sign25519(s.defaultPriv); err != nil {
t.Fatal(err)
}
l4H := l4.Hash() l4H := l4.Hash()
n1.storage.CommitVerifiedAUMs([]AUM{l4}) n1.storage.CommitVerifiedAUMs([]AUM{l4})
n1.A.state.LastAUMHash = &l4H n1.A.state.LastAUMHash = &l4H
@ -371,7 +373,9 @@ func TestBadSigAUMPropagationRejected(t *testing.T) {
l3 := n1.AUMs["L3"] l3 := n1.AUMs["L3"]
l3H := l3.Hash() l3H := l3.Hash()
l4 := AUM{MessageKind: AUMNoOp, PrevAUMHash: l3H[:]} l4 := AUM{MessageKind: AUMNoOp, PrevAUMHash: l3H[:]}
l4.sign25519(s.defaultPriv) if err := l4.sign25519(s.defaultPriv); err != nil {
t.Fatal(err)
}
l4.Signatures[0].Signature[3] = 42 l4.Signatures[0].Signature[3] = 42
l4H := l4.Hash() l4H := l4.Hash()
n1.storage.CommitVerifiedAUMs([]AUM{l4}) n1.storage.CommitVerifiedAUMs([]AUM{l4})

View File

@ -22,7 +22,7 @@ func TestSigDirect(t *testing.T) {
sig := NodeKeySignature{ sig := NodeKeySignature{
SigKind: SigDirect, SigKind: SigDirect,
KeyID: k.ID(), KeyID: k.MustID(),
Pubkey: nodeKeyPub, Pubkey: nodeKeyPub,
} }
sigHash := sig.SigHash() sigHash := sig.SigHash()
@ -65,7 +65,7 @@ func TestSigNested(t *testing.T) {
// the network-lock key. // the network-lock key.
nestedSig := NodeKeySignature{ nestedSig := NodeKeySignature{
SigKind: SigDirect, SigKind: SigDirect,
KeyID: k.ID(), KeyID: k.MustID(),
Pubkey: oldPub, Pubkey: oldPub,
WrappingPubkey: rPub, WrappingPubkey: rPub,
} }
@ -132,7 +132,7 @@ func TestSigNested_DeepNesting(t *testing.T) {
// the network-lock key. // the network-lock key.
nestedSig := NodeKeySignature{ nestedSig := NodeKeySignature{
SigKind: SigDirect, SigKind: SigDirect,
KeyID: k.ID(), KeyID: k.MustID(),
Pubkey: oldPub, Pubkey: oldPub,
WrappingPubkey: rPub, WrappingPubkey: rPub,
} }
@ -204,7 +204,7 @@ func TestSigCredential(t *testing.T) {
// public key. // public key.
nestedSig := NodeKeySignature{ nestedSig := NodeKeySignature{
SigKind: SigCredential, SigKind: SigCredential,
KeyID: k.ID(), KeyID: k.MustID(),
WrappingPubkey: cPub, WrappingPubkey: cPub,
} }
sigHash := nestedSig.SigHash() sigHash := nestedSig.SigHash()
@ -280,11 +280,11 @@ func TestSigSerializeUnserialize(t *testing.T) {
key := Key{Kind: Key25519, Public: pub, Votes: 2} key := Key{Kind: Key25519, Public: pub, Votes: 2}
sig := NodeKeySignature{ sig := NodeKeySignature{
SigKind: SigDirect, SigKind: SigDirect,
KeyID: key.ID(), KeyID: key.MustID(),
Pubkey: nodeKeyPub, Pubkey: nodeKeyPub,
Nested: &NodeKeySignature{ Nested: &NodeKeySignature{
SigKind: SigDirect, SigKind: SigDirect,
KeyID: key.ID(), KeyID: key.MustID(),
Pubkey: nodeKeyPub, Pubkey: nodeKeyPub,
}, },
} }

View File

@ -45,7 +45,12 @@ type State struct {
// GetKey returns the trusted key with the specified KeyID. // GetKey returns the trusted key with the specified KeyID.
func (s State) GetKey(key tkatype.KeyID) (Key, error) { func (s State) GetKey(key tkatype.KeyID) (Key, error) {
for _, k := range s.Keys { for _, k := range s.Keys {
if bytes.Equal(k.ID(), key) { keyID, err := k.ID()
if err != nil {
return Key{}, err
}
if bytes.Equal(keyID, key) {
return k, nil return k, nil
} }
} }
@ -169,7 +174,11 @@ func (s State) applyVerifiedAUM(update AUM) (State, error) {
if update.Key == nil { if update.Key == nil {
return State{}, errors.New("no key to add provided") return State{}, errors.New("no key to add provided")
} }
if _, err := s.GetKey(update.Key.ID()); err == nil { keyID, err := update.Key.ID()
if err != nil {
return State{}, err
}
if _, err := s.GetKey(keyID); err == nil {
return State{}, errors.New("key already exists") return State{}, errors.New("key already exists")
} }
out := s.cloneForUpdate(&update) out := s.cloneForUpdate(&update)
@ -192,7 +201,11 @@ func (s State) applyVerifiedAUM(update AUM) (State, error) {
} }
out := s.cloneForUpdate(&update) out := s.cloneForUpdate(&update)
for i := range out.Keys { for i := range out.Keys {
if bytes.Equal(out.Keys[i].ID(), update.KeyID) { keyID, err := out.Keys[i].ID()
if err != nil {
return State{}, err
}
if bytes.Equal(keyID, update.KeyID) {
out.Keys[i] = k out.Keys[i] = k
} }
} }
@ -201,7 +214,11 @@ func (s State) applyVerifiedAUM(update AUM) (State, error) {
case AUMRemoveKey: case AUMRemoveKey:
idx := -1 idx := -1
for i := range s.Keys { for i := range s.Keys {
if bytes.Equal(update.KeyID, s.Keys[i].ID()) { keyID, err := s.Keys[i].ID()
if err != nil {
return State{}, err
}
if bytes.Equal(update.KeyID, keyID) {
idx = i idx = i
break break
} }
@ -277,7 +294,17 @@ func (s *State) staticValidateCheckpoint() error {
if i == j { if i == j {
continue continue
} }
if bytes.Equal(k.ID(), k2.ID()) {
id1, err := k.ID()
if err != nil {
return fmt.Errorf("key[%d]: %w", i, err)
}
id2, err := k2.ID()
if err != nil {
return fmt.Errorf("key[%d]: %w", j, err)
}
if bytes.Equal(id1, id2) {
return fmt.Errorf("key[%d]: duplicates key[%d]", i, j) return fmt.Errorf("key[%d]: duplicates key[%d]", i, j)
} }
} }

View File

@ -124,7 +124,7 @@ func TestForkResolutionMessageType(t *testing.T) {
L3.hashSeed = 18 L3.hashSeed = 18
`, `,
optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}), optTemplate("addKey", AUM{MessageKind: AUMAddKey, Key: &key}),
optTemplate("removeKey", AUM{MessageKind: AUMRemoveKey, KeyID: key.ID()})) optTemplate("removeKey", AUM{MessageKind: AUMRemoveKey, KeyID: key.MustID()}))
l1H := c.AUMHashes["L1"] l1H := c.AUMHashes["L1"]
l2H := c.AUMHashes["L2"] l2H := c.AUMHashes["L2"]
@ -165,7 +165,7 @@ func TestComputeStateAt(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("computeStateAt(G1) failed: %v", err) t.Fatalf("computeStateAt(G1) failed: %v", err)
} }
if _, err := state.GetKey(key.ID()); err != ErrNoSuchKey { if _, err := state.GetKey(key.MustID()); err != ErrNoSuchKey {
t.Errorf("expected key to be missing: err = %v", err) t.Errorf("expected key to be missing: err = %v", err)
} }
if *state.LastAUMHash != c.AUMHashes["G1"] { if *state.LastAUMHash != c.AUMHashes["G1"] {
@ -182,7 +182,7 @@ func TestComputeStateAt(t *testing.T) {
if *state.LastAUMHash != wantHash { if *state.LastAUMHash != wantHash {
t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, wantHash) t.Errorf("LastAUMHash = %x, want %x", *state.LastAUMHash, wantHash)
} }
if _, err := state.GetKey(key.ID()); err != nil { if _, err := state.GetKey(key.MustID()); err != nil {
t.Errorf("expected key to be present at state: err = %v", err) t.Errorf("expected key to be present at state: err = %v", err)
} }
} }
@ -234,7 +234,7 @@ func TestOpenAuthority(t *testing.T) {
i2, i2H := fakeAUM(t, 2, &i1H) i2, i2H := fakeAUM(t, 2, &i1H)
i3, i3H := fakeAUM(t, 5, &i2H) i3, i3H := fakeAUM(t, 5, &i2H)
l2, l2H := fakeAUM(t, AUM{MessageKind: AUMNoOp, KeyID: []byte{7}, Signatures: []tkatype.Signature{{KeyID: key.ID()}}}, &i3H) l2, l2H := fakeAUM(t, AUM{MessageKind: AUMNoOp, KeyID: []byte{7}, Signatures: []tkatype.Signature{{KeyID: key.MustID()}}}, &i3H)
l3, l3H := fakeAUM(t, 4, &i3H) l3, l3H := fakeAUM(t, 4, &i3H)
g2, g2H := fakeAUM(t, 8, nil) g2, g2H := fakeAUM(t, 8, nil)
@ -266,7 +266,7 @@ func TestOpenAuthority(t *testing.T) {
t.Fatalf("New() failed: %v", err) t.Fatalf("New() failed: %v", err)
} }
// Should include the key added in G1 // Should include the key added in G1
if _, err := a.state.GetKey(key.ID()); err != nil { if _, err := a.state.GetKey(key.MustID()); err != nil {
t.Errorf("missing G1 key: %v", err) t.Errorf("missing G1 key: %v", err)
} }
// The head of the chain should be L2. // The head of the chain should be L2.
@ -338,10 +338,10 @@ func TestCreateBootstrapAuthority(t *testing.T) {
} }
// Both authorities should trust the key laid down in the genesis state. // Both authorities should trust the key laid down in the genesis state.
if !a1.KeyTrusted(key.ID()) { if !a1.KeyTrusted(key.MustID()) {
t.Error("a1 did not trust genesis key") t.Error("a1 did not trust genesis key")
} }
if !a2.KeyTrusted(key.ID()) { if !a2.KeyTrusted(key.MustID()) {
t.Error("a2 did not trust genesis key") t.Error("a2 did not trust genesis key")
} }
} }