diff --git a/logtail/logtail_test.go b/logtail/logtail_test.go index 18c3408ee..d531eb983 100644 --- a/logtail/logtail_test.go +++ b/logtail/logtail_test.go @@ -15,7 +15,6 @@ "time" "tailscale.com/tstest" - "tailscale.com/types/logid" ) func TestFastShutdown(t *testing.T) { @@ -296,28 +295,6 @@ func TestParseAndRemoveLogLevel(t *testing.T) { } } -func TestPublicIDUnmarshalText(t *testing.T) { - const hexStr = "6c60a9e0e7af57170bb1347b2d477e4cbc27d4571a4923b21651456f931e3d55" - x := []byte(hexStr) - - var id logid.PublicID - if err := id.UnmarshalText(x); err != nil { - t.Fatal(err) - } - if id.String() != hexStr { - t.Errorf("String = %q; want %q", id.String(), hexStr) - } - err := tstest.MinAllocsPerRun(t, 0, func() { - var id logid.PublicID - if err := id.UnmarshalText(x); err != nil { - t.Fatal(err) - } - }) - if err != nil { - t.Fatal(err) - } -} - func unmarshalOne(t *testing.T, body []byte) map[string]any { t.Helper() var entries []map[string]any diff --git a/types/logid/id.go b/types/logid/id.go index 5046150e2..f3d705f18 100644 --- a/types/logid/id.go +++ b/types/logid/id.go @@ -1,27 +1,30 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +// Package logid contains ID types for interacting with the log service. package logid import ( + "bytes" "crypto/rand" "crypto/sha256" "encoding/binary" "encoding/hex" - "errors" "fmt" + "unicode/utf8" + + "golang.org/x/exp/slices" ) -// PrivateID represents an instance that write logs. +// PrivateID represents a log steam for writing. // Private IDs are only shared with the server when writing logs. type PrivateID [32]byte -// Safely generate a new PrivateId for use in Config objects. -// You should persist this across runs of an instance of your app, so that -// it can append to the same log file on each run. +// NewPrivateID generates a new random PrivateID. +// This should persist across runs of an instance of the application, +// so that it can append to the same log stream for each invocation. func NewPrivateID() (id PrivateID, err error) { - _, err = rand.Read(id[:]) - if err != nil { + if _, err := rand.Read(id[:]); err != nil { return PrivateID{}, err } // Clamping, for future use. @@ -30,137 +33,58 @@ func NewPrivateID() (id PrivateID, err error) { return id, nil } +// ParsePrivateID returns a PrivateID from its hex representation. +func ParsePrivateID(in string) (out PrivateID, err error) { + err = parseID("logid.ParsePublicID", (*[32]byte)(&out), in) + return out, err +} + func (id PrivateID) MarshalText() ([]byte, error) { - b := make([]byte, hex.EncodedLen(len(id))) - if i := hex.Encode(b, id[:]); i != len(b) { - return nil, fmt.Errorf("logid.PrivateID.MarshalText: i=%d", i) - } - return b, nil + return formatID(id), nil } -// ParsePrivateID returns a PrivateID from its hex (String) representation. -func ParsePrivateID(s string) (PrivateID, error) { - if len(s) != 64 { - return PrivateID{}, errors.New("invalid length") - } - var p PrivateID - for i := range p { - a, ok1 := fromHexChar(s[i*2+0]) - b, ok2 := fromHexChar(s[i*2+1]) - if !ok1 || !ok2 { - return PrivateID{}, errors.New("invalid hex character") - } - p[i] = (a << 4) | b - } - return p, nil -} - -// IsZero reports whether id is the zero value. -func (id PrivateID) IsZero() bool { return id == PrivateID{} } - -func (id *PrivateID) UnmarshalText(s []byte) error { - b, err := hex.DecodeString(string(s)) - if err != nil { - return fmt.Errorf("logid.PrivateID.UnmarshalText: %v", err) - } - if len(b) != len(id) { - return fmt.Errorf("logid.PrivateID.UnmarshalText: invalid hex length: %d", len(b)) - } - copy(id[:], b) - return nil +func (id *PrivateID) UnmarshalText(in []byte) error { + return parseID("logid.PrivateID", (*[32]byte)(id), in) } func (id PrivateID) String() string { - b, err := id.MarshalText() - if err != nil { - panic(err) - } - return string(b) + return string(formatID(id)) } +func (id PrivateID) IsZero() bool { + return id == PrivateID{} +} + +// Public returns the public ID of the private ID, +// which is the SHA-256 hash of the private ID. func (id PrivateID) Public() (pub PublicID) { - h := sha256.New() - h.Write(id[:]) - if n := copy(pub[:], h.Sum(pub[:0])); n != len(pub) { - panic(fmt.Sprintf("public id short copy: %d", n)) - } - return pub + return PublicID(sha256.Sum256(id[:])) } -// PublicID represents an instance in the logs service for reading and adoption. -// The public ID value is a SHA-256 hash of a private ID. +// PublicID represents a log stream for reading. +// The PrivateID cannot be feasibly reversed from the PublicID. type PublicID [sha256.Size]byte -// ParsePublicID returns a PublicID from its hex (String) representation. -func ParsePublicID(s string) (PublicID, error) { - if len(s) != sha256.Size*2 { - return PublicID{}, errors.New("invalid length") - } - var p PublicID - for i := range p { - a, ok1 := fromHexChar(s[i*2+0]) - b, ok2 := fromHexChar(s[i*2+1]) - if !ok1 || !ok2 { - return PublicID{}, errors.New("invalid hex character") - } - p[i] = (a << 4) | b - } - return p, nil +// ParsePublicID returns a PublicID from its hex representation. +func ParsePublicID(in string) (out PublicID, err error) { + err = parseID("logid.ParsePublicID", (*[32]byte)(&out), in) + return out, err } func (id PublicID) MarshalText() ([]byte, error) { - b := make([]byte, hex.EncodedLen(len(id))) - if i := hex.Encode(b, id[:]); i != len(b) { - return nil, fmt.Errorf("logid.PublicID.MarshalText: i=%d", i) - } - return b, nil + return formatID(id), nil } -func (id *PublicID) UnmarshalText(s []byte) error { - if len(s) != len(id)*2 { - return fmt.Errorf("logid.PublicID.UnmarshalText: invalid hex length: %d", len(s)) - } - for i := range id { - a, ok1 := fromHexChar(s[i*2+0]) - b, ok2 := fromHexChar(s[i*2+1]) - if !ok1 || !ok2 { - return errors.New("invalid hex character") - } - id[i] = (a << 4) | b - } - return nil +func (id *PublicID) UnmarshalText(in []byte) error { + return parseID("logid.ParsePublicID", (*[32]byte)(id), in) } func (id PublicID) String() string { - b, err := id.MarshalText() - if err != nil { - panic(err) - } - return string(b) -} - -// fromHexChar converts a hex character into its value and a success flag. -func fromHexChar(c byte) (byte, bool) { - switch { - case '0' <= c && c <= '9': - return c - '0', true - case 'a' <= c && c <= 'f': - return c - 'a' + 10, true - case 'A' <= c && c <= 'F': - return c - 'A' + 10, true - } - - return 0, false + return string(formatID(id)) } func (id1 PublicID) Less(id2 PublicID) bool { - for i, c1 := range id1[:] { - c2 := id2[i] - if c1 != c2 { - return c1 < c2 - } - } - return false // equal + return slices.Compare(id1[:], id2[:]) < 0 } func (id PublicID) IsZero() bool { @@ -170,3 +94,22 @@ func (id PublicID) IsZero() bool { func (id PublicID) Prefix64() uint64 { return binary.BigEndian.Uint64(id[:8]) } + +func formatID(in [32]byte) []byte { + var hexArr [2 * len(in)]byte + hex.Encode(hexArr[:], in[:]) + return hexArr[:] +} + +func parseID[Bytes []byte | string](funcName string, out *[32]byte, in Bytes) (err error) { + if len(in) != 2*len(out) { + return fmt.Errorf("%s: invalid hex length: %d", funcName, len(in)) + } + var hexArr [2 * len(out)]byte + copy(hexArr[:], in) + if _, err := hex.Decode(out[:], hexArr[:]); err != nil { + r, _ := utf8.DecodeRune(bytes.TrimLeft([]byte(in), "0123456789abcdefABCDEF")) + return fmt.Errorf("%s: invalid hex character: %c", funcName, r) + } + return nil +} diff --git a/types/logid/id_test.go b/types/logid/id_test.go index 9390545ea..fb41de860 100644 --- a/types/logid/id_test.go +++ b/types/logid/id_test.go @@ -5,6 +5,8 @@ import ( "testing" + + "tailscale.com/tstest" ) func TestIDs(t *testing.T) { @@ -63,4 +65,15 @@ func TestIDs(t *testing.T) { if id1 != id4 { t.Fatalf("ParsePrivateID returned different id") } + + hexString := "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + hexBytes := []byte(hexString) + if err := tstest.MinAllocsPerRun(t, 0, func() { + ParsePrivateID(hexString) + new(PrivateID).UnmarshalText(hexBytes) + ParsePublicID(hexString) + new(PublicID).UnmarshalText(hexBytes) + }); err != nil { + t.Fatal(err) + } }