From e1b80859f27f995d48dd1e77e59286ad70514777 Mon Sep 17 00:00:00 2001
From: Alexander Neumann <alexander@bumpern.de>
Date: Sat, 28 Oct 2017 10:59:55 +0200
Subject: [PATCH] Make crypto.Key implement cipher.AEAD

---
 internal/crypto/crypto.go          | 144 ++++++++++++++++++++++++++++-
 internal/crypto/crypto_int_test.go |  41 ++++----
 2 files changed, 166 insertions(+), 19 deletions(-)

diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go
index aaa42b54f..62ec62e25 100644
--- a/internal/crypto/crypto.go
+++ b/internal/crypto/crypto.go
@@ -147,7 +147,9 @@ func NewRandomKey() *Key {
 	return k
 }
 
-func newIV() []byte {
+// NewRandomNonce returns a new random nonce. It panics on error so that the
+// program is safely terminated.
+func NewRandomNonce() []byte {
 	iv := make([]byte, ivSize)
 	n, err := rand.Read(iv)
 	if n != ivSize || err != nil {
@@ -233,6 +235,144 @@ func (k *EncryptionKey) Valid() bool {
 // holds the plaintext.
 var ErrInvalidCiphertext = errors.New("invalid ciphertext, same slice used for plaintext")
 
+// validNonce checks that nonce is not all zero.
+func validNonce(nonce []byte) bool {
+	sum := 0
+	for b := range nonce {
+		sum += b
+	}
+	return sum > 0
+}
+
+// statically ensure that *Key implements crypto/cipher.AEAD
+var _ cipher.AEAD = &Key{}
+
+// NonceSize returns the size of the nonce that must be passed to Seal
+// and Open.
+func (k *Key) NonceSize() int {
+	return ivSize
+}
+
+// Overhead returns the maximum difference between the lengths of a
+// plaintext and its ciphertext.
+func (k *Key) Overhead() int {
+	return macSize
+}
+
+// Seal encrypts and authenticates plaintext, authenticates the
+// additional data and appends the result to dst, returning the updated
+// slice. The nonce must be NonceSize() bytes long and unique for all
+// time, for a given key.
+//
+// The plaintext and dst may alias exactly or not at all. To reuse
+// plaintext's storage for the encrypted output, use plaintext[:0] as dst.
+func (k *Key) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
+	if !k.Valid() {
+		panic("key is invalid")
+	}
+
+	if len(additionalData) > 0 {
+		panic("additional data is not supported")
+	}
+
+	if len(nonce) != ivSize {
+		panic("incorrect nonce length")
+	}
+
+	if !validNonce(nonce) {
+		panic("nonce is invalid")
+	}
+
+	// extend dst so that the ciphertext fits
+	ciphertextLength := len(plaintext) + k.Overhead()
+	pos := len(dst)
+
+	capacity := cap(dst) - len(dst)
+	if capacity < ciphertextLength {
+		dst = dst[:cap(dst)]
+		dst = append(dst, make([]byte, ciphertextLength-capacity)...)
+	} else {
+		dst = dst[:pos+ciphertextLength]
+	}
+
+	c, err := aes.NewCipher(k.EncryptionKey[:])
+	if err != nil {
+		panic(fmt.Sprintf("unable to create cipher: %v", err))
+	}
+	e := cipher.NewCTR(c, nonce)
+	e.XORKeyStream(dst[pos:pos+len(plaintext)], plaintext)
+
+	// truncate to only cover the ciphertext
+	dst = dst[:pos+len(plaintext)]
+
+	mac := poly1305MAC(dst[pos:], nonce, &k.MACKey)
+	dst = append(dst, mac...)
+
+	return dst
+}
+
+// Open decrypts and authenticates ciphertext, authenticates the
+// additional data and, if successful, appends the resulting plaintext
+// to dst, returning the updated slice. The nonce must be NonceSize()
+// bytes long and both it and the additional data must match the
+// value passed to Seal.
+//
+// The ciphertext and dst may alias exactly or not at all. To reuse
+// ciphertext's storage for the decrypted output, use ciphertext[:0] as dst.
+//
+// Even if the function fails, the contents of dst, up to its capacity,
+// may be overwritten.
+func (k *Key) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
+	if !k.Valid() {
+		return nil, errors.New("invalid key")
+	}
+
+	// check parameters
+	if len(nonce) != ivSize {
+		panic("incorrect nonce length")
+	}
+
+	if !validNonce(nonce) {
+		return nil, errors.New("nonce is invalid")
+	}
+
+	// check for plausible length
+	if len(ciphertext) < k.Overhead() {
+		return nil, errors.Errorf("trying to decrypt invalid data: ciphertext too small")
+	}
+
+	// extract mac
+	l := len(ciphertext) - macSize
+	ct, mac := ciphertext[:l], ciphertext[l:]
+
+	// verify mac
+	if !poly1305Verify(ct, nonce, &k.MACKey, mac) {
+		return nil, ErrUnauthenticated
+	}
+
+	// extend dst so that the plaintext fits
+	plaintextLength := len(ct)
+	pos := len(dst)
+
+	capacity := cap(dst) - len(dst)
+	if capacity < plaintextLength {
+		dst = dst[:cap(dst)]
+		dst = append(dst, make([]byte, plaintextLength-capacity)...)
+	} else {
+		dst = dst[:pos+plaintextLength]
+	}
+
+	// decrypt data
+	c, err := aes.NewCipher(k.EncryptionKey[:])
+	if err != nil {
+		panic(fmt.Sprintf("unable to create cipher: %v", err))
+	}
+	e := cipher.NewCTR(c, nonce)
+	e.XORKeyStream(dst[pos:], ct)
+
+	return dst, nil
+}
+
 // Encrypt encrypts and authenticates data. Stored in ciphertext is IV || Ciphertext ||
 // MAC. Encrypt returns the new ciphertext slice, which is extended when
 // necessary. ciphertext and plaintext may not point to (exactly) the same
@@ -255,7 +395,7 @@ func (k *Key) Encrypt(ciphertext []byte, plaintext []byte) ([]byte, error) {
 		ciphertext = append(ciphertext, make([]byte, ext)...)
 	}
 
-	iv := newIV()
+	iv := NewRandomNonce()
 	copy(ciphertext, iv[:])
 
 	c, err := aes.NewCipher(k.EncryptionKey[:])
diff --git a/internal/crypto/crypto_int_test.go b/internal/crypto/crypto_int_test.go
index 3ace3f393..a5995a9ad 100644
--- a/internal/crypto/crypto_int_test.go
+++ b/internal/crypto/crypto_int_test.go
@@ -113,43 +113,50 @@ func TestCrypto(t *testing.T) {
 			MACKey:        tv.skey,
 		}
 
-		msg, err := k.Encrypt(msg, tv.plaintext)
-		if err != nil {
-			t.Fatal(err)
-		}
+		nonce := NewRandomNonce()
+		ciphertext := k.Seal(msg, nonce, tv.plaintext, nil)
 
 		// decrypt message
 		buf := make([]byte, len(tv.plaintext))
-		n, err := k.Decrypt(buf, msg)
+		buf, err := k.Open(buf, nonce, ciphertext, nil)
 		if err != nil {
 			t.Fatal(err)
 		}
-		buf = buf[:n]
 
-		// change mac, this must fail
-		msg[len(msg)-8] ^= 0x23
-
-		if _, err = k.Decrypt(buf, msg); err != ErrUnauthenticated {
-			t.Fatal("wrong MAC value not detected")
+		if !bytes.Equal(buf, tv.plaintext) {
+			t.Fatalf("wrong plaintext returned")
 		}
 
+		// change mac, this must fail
+		ciphertext[len(ciphertext)-8] ^= 0x23
+
+		if _, err = k.Open(buf, nonce, ciphertext, nil); err != ErrUnauthenticated {
+			t.Fatal("wrong MAC value not detected")
+		}
 		// reset mac
-		msg[len(msg)-8] ^= 0x23
+		ciphertext[len(ciphertext)-8] ^= 0x23
+
+		// tamper with nonce, this must fail
+		nonce[2] ^= 0x88
+		if _, err = k.Open(buf, nonce, ciphertext, nil); err != ErrUnauthenticated {
+			t.Fatal("tampered nonce not detected")
+		}
+		// reset nonce
+		nonce[2] ^= 0x88
 
 		// tamper with message, this must fail
-		msg[16+5] ^= 0x85
-
-		if _, err = k.Decrypt(buf, msg); err != ErrUnauthenticated {
+		ciphertext[16+5] ^= 0x85
+		if _, err = k.Open(buf, nonce, ciphertext, nil); err != ErrUnauthenticated {
 			t.Fatal("tampered message not detected")
 		}
 
 		// test decryption
 		p := make([]byte, len(tv.ciphertext))
-		n, err = k.Decrypt(p, tv.ciphertext)
+		nonce, ciphertext = tv.ciphertext[:16], tv.ciphertext[16:]
+		p, err = k.Open(p, nonce, ciphertext, nil)
 		if err != nil {
 			t.Fatal(err)
 		}
-		p = p[:n]
 
 		if !bytes.Equal(p, tv.plaintext) {
 			t.Fatalf("wrong plaintext: expected %q but got %q\n", tv.plaintext, p)