diff --git a/types/key/key.go b/types/key/key.go index 5aeaa0f7c..8b70fffa6 100644 --- a/types/key/key.go +++ b/types/key/key.go @@ -28,6 +28,8 @@ func NewPrivate() Private { if _, err := io.ReadFull(crand.Reader, p[:]); err != nil { panic(err) } + p[0] &= 248 + p[31] = (p[31] & 127) | 64 return p } diff --git a/types/key/key_test.go b/types/key/key_test.go index b2fc88618..4c5c97625 100644 --- a/types/key/key_test.go +++ b/types/key/key_test.go @@ -6,6 +6,8 @@ package key import ( "testing" + + "github.com/tailscale/wireguard-go/wgcfg" ) func TestTextUnmarshal(t *testing.T) { @@ -22,3 +24,31 @@ func TestTextUnmarshal(t *testing.T) { t.Fatalf("mismatch; got %x want %x", p2, p) } } + +func TestClamping(t *testing.T) { + t.Run("NewPrivate", func(t *testing.T) { testClamping(t, NewPrivate) }) + + // Also test the wgcfg package, as their behavior should match. + t.Run("wgcfg", func(t *testing.T) { + testClamping(t, func() Private { + k, err := wgcfg.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + return Private(k) + }) + }) +} + +func testClamping(t *testing.T, newKey func() Private) { + for i := 0; i < 100; i++ { + k := newKey() + if k[0]&0b111 != 0 { + t.Fatalf("Bogus clamping in first byte: %#08b", k[0]) + return + } + if k[31]>>6 != 1 { + t.Fatalf("Bogus clamping in last byte: %#08b", k[0]) + } + } +}