From df865769898fcd0e8037784db7a99e31e510a57c Mon Sep 17 00:00:00 2001 From: Aaron Klotz Date: Wed, 5 Jun 2024 14:47:36 -0600 Subject: [PATCH] util/winutil: add AllocateContiguousBuffer and SetNTString helper funcs AllocateContiguousBuffer is for allocating structs with trailing buffers containing additional data. It is to be used for various Windows structures containing pointers to data located immediately after the struct. SetNTString performs in-place setting of windows.NTString and windows.NTUnicodeString. Updates #12383 Signed-off-by: Aaron Klotz --- cmd/derper/depaware.txt | 1 + cmd/tailscale/depaware.txt | 2 +- util/winutil/winutil_windows.go | 141 +++++++++++++++++++++++++++ util/winutil/winutil_windows_test.go | 104 ++++++++++++++++++++ 4 files changed, 247 insertions(+), 1 deletion(-) diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 780e9f2ee..fd2de6e8c 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -175,6 +175,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ + W golang.org/x/exp/constraints from tailscale.com/util/winutil L golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/http/httpguts from net/http diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index eff138ac6..c0b626f13 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -182,7 +182,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/pbkdf2 from software.sslmate.com/src/go-pkcs12 golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe + W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/dns/dnsmessage from net+ diff --git a/util/winutil/winutil_windows.go b/util/winutil/winutil_windows.go index 30602d1de..f464d01d4 100644 --- a/util/winutil/winutil_windows.go +++ b/util/winutil/winutil_windows.go @@ -7,14 +7,17 @@ "errors" "fmt" "log" + "math" "os/exec" "os/user" + "reflect" "runtime" "strings" "syscall" "time" "unsafe" + "golang.org/x/exp/constraints" "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" ) @@ -643,3 +646,141 @@ func LogonSessionID(token windows.Token) (logonSessionID windows.LUID, err error return origin.originatingLogonSession, nil } + +// BufUnit is a type constraint for buffers passed into AllocateContiguousBuffer. +type BufUnit interface { + byte | uint16 +} + +// AllocateContiguousBuffer allocates memory to satisfy the Windows idiom where +// some structs contain pointers that are expected to refer to memory within the +// same buffer containing the struct itself. T is the type that contains +// the pointers. values must contain the actual data that is to be copied +// into the buffer after T. AllocateContiguousBuffer returns a pointer to the +// struct, the total length of the buffer in bytes, and a slice containing +// each value within the buffer. The caller may use slcs to populate any +// pointers in t as needed. Each element of slcs corresponds to the element of +// values in the same position. +// +// It is the responsibility of the caller to ensure that any values expected +// to contain null-terminated strings are in fact null-terminated! +// +// AllocateContiguousBuffer panics if no values are passed in, as there are +// better alternatives for allocating a struct in that case. +func AllocateContiguousBuffer[T any, BU BufUnit](values ...[]BU) (t *T, tLenBytes uint32, slcs [][]BU) { + if len(values) == 0 { + panic("len(values) must be > 0") + } + + // Get the sizes of T and BU, then compute a preferred alignment for T. + tT := reflect.TypeFor[T]() + szT := tT.Size() + szBU := int(unsafe.Sizeof(BU(0))) + alignment := max(tT.Align(), szBU) + + // Our buffers for values will start at the next szBU boundary. + tLenBytes = alignUp(uint32(szT), szBU) + firstValueOffset := tLenBytes + + // Accumulate the length of each value into tLenBytes + for _, v := range values { + tLenBytes += uint32(len(v) * szBU) + } + + // Now that we know the final length, align up to our preferred boundary. + tLenBytes = alignUp(tLenBytes, alignment) + + // Allocate the buffer. We choose a type for the slice that is appropriate + // for the desired alignment. Note that we do not have a strict requirement + // that T contain pointer fields; we could just be appending more data + // within the same buffer. + bufLen := tLenBytes / uint32(alignment) + var pt unsafe.Pointer + switch alignment { + case 1: + pt = unsafe.Pointer(unsafe.SliceData(make([]byte, bufLen))) + case 2: + pt = unsafe.Pointer(unsafe.SliceData(make([]uint16, bufLen))) + case 4: + pt = unsafe.Pointer(unsafe.SliceData(make([]uint32, bufLen))) + case 8: + pt = unsafe.Pointer(unsafe.SliceData(make([]uint64, bufLen))) + default: + panic(fmt.Sprintf("bad alignment %d", alignment)) + } + + t = (*T)(pt) + slcs = make([][]BU, 0, len(values)) + + // Use the limits of the buffer area after t to construct a slice representing the remaining buffer. + firstValuePtr := unsafe.Pointer(uintptr(pt) + uintptr(firstValueOffset)) + buf := unsafe.Slice((*BU)(firstValuePtr), (tLenBytes-firstValueOffset)/uint32(szBU)) + + // Copy each value into the buffer and record a slice describing each value's limits into slcs. + var index int + for _, v := range values { + if len(v) == 0 { + // We allow zero-length values; we simply append a nil slice. + slcs = append(slcs, nil) + continue + } + valueSlice := buf[index : index+len(v)] + copy(valueSlice, v) + slcs = append(slcs, valueSlice) + index += len(v) + } + + return t, tLenBytes, slcs +} + +// alignment must be a power of 2 +func alignUp[V constraints.Integer](v V, alignment int) V { + return v + ((-v) & (V(alignment) - 1)) +} + +// NTStr is a type constraint requiring the type to be either a +// windows.NTString or a windows.NTUnicodeString. +type NTStr interface { + windows.NTString | windows.NTUnicodeString +} + +// SetNTString sets the value of nts in-place to point to the string contained +// within buf. A nul terminator is optional in buf. +func SetNTString[NTS NTStr, BU BufUnit](nts *NTS, buf []BU) { + isEmpty := len(buf) == 0 + codeUnitSize := uint16(unsafe.Sizeof(BU(0))) + lenBytes := len(buf) * int(codeUnitSize) + if lenBytes > math.MaxUint16 { + panic("buffer length must fit into uint16") + } + lenBytes16 := uint16(lenBytes) + + switch p := any(nts).(type) { + case *windows.NTString: + if isEmpty { + *p = windows.NTString{} + break + } + p.Buffer = unsafe.SliceData(any(buf).([]byte)) + p.MaximumLength = lenBytes16 + p.Length = lenBytes16 + // account for nul terminator when present + if buf[len(buf)-1] == 0 { + p.Length -= codeUnitSize + } + case *windows.NTUnicodeString: + if isEmpty { + *p = windows.NTUnicodeString{} + break + } + p.Buffer = unsafe.SliceData(any(buf).([]uint16)) + p.MaximumLength = lenBytes16 + p.Length = lenBytes16 + // account for nul terminator when present + if buf[len(buf)-1] == 0 { + p.Length -= codeUnitSize + } + default: + panic("unknown type") + } +} diff --git a/util/winutil/winutil_windows_test.go b/util/winutil/winutil_windows_test.go index bf22d26ca..d437ffa38 100644 --- a/util/winutil/winutil_windows_test.go +++ b/util/winutil/winutil_windows_test.go @@ -4,9 +4,13 @@ package winutil import ( + "reflect" "testing" + "unsafe" ) +//lint:file-ignore U1000 Fields are unused but necessary for tests. + const ( localSystemSID = "S-1-5-18" networkSID = "S-1-5-2" @@ -28,3 +32,103 @@ func TestLookupPseudoUser(t *testing.T) { t.Errorf("LookupPseudoUser(%q) unexpectedly succeeded", networkSID) } } + +type testType interface { + byte | uint16 | uint32 | uint64 +} + +type noPointers[T testType] struct { + foo byte + bar T + baz bool +} + +type hasPointer struct { + foo byte + bar uint32 + s1 *struct{} + baz byte +} + +func checkContiguousBuffer[T any, BU BufUnit](t *testing.T, extra []BU, pt *T, ptLen uint32, slcs [][]BU) { + szBU := int(unsafe.Sizeof(BU(0))) + expectedAlign := max(reflect.TypeFor[T]().Align(), szBU) + // Check that pointer is aligned + if rem := uintptr(unsafe.Pointer(pt)) % uintptr(expectedAlign); rem != 0 { + t.Errorf("pointer alignment got %d, want 0", rem) + } + // Check that alloc length is aligned + if rem := int(ptLen) % expectedAlign; rem != 0 { + t.Errorf("allocation length alignment got %d, want 0", rem) + } + expectedLen := int(unsafe.Sizeof(*pt)) + expectedLen = alignUp(expectedLen, szBU) + expectedLen += len(extra) * szBU + expectedLen = alignUp(expectedLen, expectedAlign) + if gotLen := int(ptLen); gotLen != expectedLen { + t.Errorf("allocation length got %d, want %d", gotLen, expectedLen) + } + if l := len(slcs); l != 1 { + t.Errorf("len(slcs) got %d, want 1", l) + } + if len(extra) == 0 && slcs[0] != nil { + t.Error("slcs[0] got non-nil, want nil") + } + if len(extra) != len(slcs[0]) { + t.Errorf("len(slcs[0]) got %d, want %d", len(slcs[0]), len(extra)) + } else if rem := uintptr(unsafe.Pointer(unsafe.SliceData(slcs[0]))) % uintptr(szBU); rem != 0 { + t.Errorf("additional data alignment got %d, want 0", rem) + } +} + +func TestAllocateContiguousBuffer(t *testing.T) { + t.Run("NoValues", testNoValues) + t.Run("NoPointers", testNoPointers) + t.Run("HasPointer", testHasPointer) +} + +func testNoValues(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic but didn't get one") + } + }() + + AllocateContiguousBuffer[hasPointer, byte]() +} + +const maxTestBufLen = 8 + +func testNoPointers(t *testing.T) { + buf8 := make([]byte, maxTestBufLen) + buf16 := make([]uint16, maxTestBufLen) + for i := range maxTestBufLen { + s8, sl, slcs8 := AllocateContiguousBuffer[noPointers[byte]](buf8[:i]) + checkContiguousBuffer(t, buf8[:i], s8, sl, slcs8) + s16, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint16]](buf8[:i]) + checkContiguousBuffer(t, buf8[:i], s16, sl, slcs8) + s32, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint32]](buf8[:i]) + checkContiguousBuffer(t, buf8[:i], s32, sl, slcs8) + s64, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint64]](buf8[:i]) + checkContiguousBuffer(t, buf8[:i], s64, sl, slcs8) + s8, sl, slcs16 := AllocateContiguousBuffer[noPointers[byte]](buf16[:i]) + checkContiguousBuffer(t, buf16[:i], s8, sl, slcs16) + s16, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint16]](buf16[:i]) + checkContiguousBuffer(t, buf16[:i], s16, sl, slcs16) + s32, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint32]](buf16[:i]) + checkContiguousBuffer(t, buf16[:i], s32, sl, slcs16) + s64, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint64]](buf16[:i]) + checkContiguousBuffer(t, buf16[:i], s64, sl, slcs16) + } +} + +func testHasPointer(t *testing.T) { + buf8 := make([]byte, maxTestBufLen) + buf16 := make([]uint16, maxTestBufLen) + for i := range maxTestBufLen { + s, sl, slcs8 := AllocateContiguousBuffer[hasPointer](buf8[:i]) + checkContiguousBuffer(t, buf8[:i], s, sl, slcs8) + s, sl, slcs16 := AllocateContiguousBuffer[hasPointer](buf16[:i]) + checkContiguousBuffer(t, buf16[:i], s, sl, slcs16) + } +}