tailscale/util/sha256x/sha256.go
Joe Tsai 1c3c6b5382
util/sha256x: make Hash.Sum non-escaping (#5338)
Since Hash is a concrete type, we can make it such that
Sum never escapes the input.

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
2022-08-10 17:31:44 -07:00

159 lines
3.8 KiB
Go

// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package sha256x is like crypto/sha256 with extra methods.
// It exports a concrete Hash type
// rather than only returning an interface implementation.
package sha256x
import (
"crypto/sha256"
"encoding/binary"
"hash"
)
var _ hash.Hash = (*Hash)(nil)
// Hash is a hash.Hash for SHA-256,
// but has efficient methods for hashing fixed-width integers.
type Hash struct {
// The optimization is to maintain our own block and
// only call h.Write with entire blocks.
// This avoids double-copying of buffers within sha256.digest itself.
// However, it does mean that sha256.digest.x goes unused,
// which is a waste of 64B.
h hash.Hash // always *sha256.digest
x [sha256.BlockSize]byte // equivalent to sha256.digest.x
nx int // equivalent to sha256.digest.nx
}
func New() *Hash {
return &Hash{h: sha256.New()}
}
func (h *Hash) Write(b []byte) (int, error) {
h.HashBytes(b)
return len(b), nil
}
func (h *Hash) Sum(b []byte) []byte {
if h.nx > 0 {
// This causes block mis-alignment. Future operations will be correct,
// but are less efficient until Reset is called.
h.h.Write(h.x[:h.nx])
h.nx = 0
}
// Unfortunately hash.Hash.Sum always causes the input to escape since
// escape analysis cannot prove anything past an interface method call.
// Assuming h already escapes, we call Sum with h.x first,
// and then the copy the result to b.
sum := h.h.Sum(h.x[:0])
return append(b, sum...)
}
func (h *Hash) Reset() {
if h.h == nil {
h.h = sha256.New()
}
h.h.Reset()
h.nx = 0
}
func (h *Hash) Size() int {
return h.h.Size()
}
func (h *Hash) BlockSize() int {
return h.h.BlockSize()
}
func (h *Hash) HashUint8(n uint8) {
// NOTE: This method is carefully written to be inlineable.
if h.nx <= len(h.x)-1 {
h.x[h.nx] = n
h.nx += 1
} else {
h.hashUint8Slow(n) // mark "noinline" to keep this within inline budget
}
}
//go:noinline
func (h *Hash) hashUint8Slow(n uint8) { h.hashUint(uint64(n), 1) }
func (h *Hash) HashUint16(n uint16) {
// NOTE: This method is carefully written to be inlineable.
if h.nx <= len(h.x)-2 {
binary.LittleEndian.PutUint16(h.x[h.nx:], n)
h.nx += 2
} else {
h.hashUint16Slow(n) // mark "noinline" to keep this within inline budget
}
}
//go:noinline
func (h *Hash) hashUint16Slow(n uint16) { h.hashUint(uint64(n), 2) }
func (h *Hash) HashUint32(n uint32) {
// NOTE: This method is carefully written to be inlineable.
if h.nx <= len(h.x)-4 {
binary.LittleEndian.PutUint32(h.x[h.nx:], n)
h.nx += 4
} else {
h.hashUint32Slow(n) // mark "noinline" to keep this within inline budget
}
}
//go:noinline
func (h *Hash) hashUint32Slow(n uint32) { h.hashUint(uint64(n), 4) }
func (h *Hash) HashUint64(n uint64) {
// NOTE: This method is carefully written to be inlineable.
if h.nx <= len(h.x)-8 {
binary.LittleEndian.PutUint64(h.x[h.nx:], n)
h.nx += 8
} else {
h.hashUint64Slow(n) // mark "noinline" to keep this within inline budget
}
}
//go:noinline
func (h *Hash) hashUint64Slow(n uint64) { h.hashUint(uint64(n), 8) }
func (h *Hash) hashUint(n uint64, i int) {
for ; i > 0; i-- {
if h.nx == len(h.x) {
h.h.Write(h.x[:])
h.nx = 0
}
h.x[h.nx] = byte(n)
h.nx += 1
n >>= 8
}
}
func (h *Hash) HashBytes(b []byte) {
// Nearly identical to sha256.digest.Write.
if h.nx > 0 {
n := copy(h.x[h.nx:], b)
h.nx += n
if h.nx == len(h.x) {
h.h.Write(h.x[:])
h.nx = 0
}
b = b[n:]
}
if len(b) >= len(h.x) {
n := len(b) &^ (len(h.x) - 1) // n is a multiple of len(h.x)
h.h.Write(b[:n])
b = b[n:]
}
if len(b) > 0 {
h.nx = copy(h.x[:], b)
}
}
// TODO: Add Hash.MarshalBinary and Hash.UnmarshalBinary?