util/winutil: add AllocateContiguousBuffer and SetNTString helper funcs

AllocateContiguousBuffer is for allocating structs with trailing buffers
containing additional data. It is to be used for various Windows structures
containing pointers to data located immediately after the struct.

SetNTString performs in-place setting of windows.NTString and
windows.NTUnicodeString.

Updates #12383

Signed-off-by: Aaron Klotz <aaron@tailscale.com>
This commit is contained in:
Aaron Klotz 2024-06-05 14:47:36 -06:00
parent c3e2b7347b
commit df86576989
4 changed files with 247 additions and 1 deletions

View File

@ -175,6 +175,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa
golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/box from tailscale.com/types/key
golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
W golang.org/x/exp/constraints from tailscale.com/util/winutil
L golang.org/x/net/bpf from github.com/mdlayher/netlink+ L golang.org/x/net/bpf from github.com/mdlayher/netlink+
golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/dns/dnsmessage from net+
golang.org/x/net/http/httpguts from net/http golang.org/x/net/http/httpguts from net/http

View File

@ -182,7 +182,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box
golang.org/x/crypto/pbkdf2 from software.sslmate.com/src/go-pkcs12 golang.org/x/crypto/pbkdf2 from software.sslmate.com/src/go-pkcs12
golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+
W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+
golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli
golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/bpf from github.com/mdlayher/netlink+
golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/dns/dnsmessage from net+

View File

@ -7,14 +7,17 @@
"errors" "errors"
"fmt" "fmt"
"log" "log"
"math"
"os/exec" "os/exec"
"os/user" "os/user"
"reflect"
"runtime" "runtime"
"strings" "strings"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
"golang.org/x/exp/constraints"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
) )
@ -643,3 +646,141 @@ func LogonSessionID(token windows.Token) (logonSessionID windows.LUID, err error
return origin.originatingLogonSession, nil return origin.originatingLogonSession, nil
} }
// BufUnit is a type constraint for buffers passed into AllocateContiguousBuffer.
type BufUnit interface {
byte | uint16
}
// AllocateContiguousBuffer allocates memory to satisfy the Windows idiom where
// some structs contain pointers that are expected to refer to memory within the
// same buffer containing the struct itself. T is the type that contains
// the pointers. values must contain the actual data that is to be copied
// into the buffer after T. AllocateContiguousBuffer returns a pointer to the
// struct, the total length of the buffer in bytes, and a slice containing
// each value within the buffer. The caller may use slcs to populate any
// pointers in t as needed. Each element of slcs corresponds to the element of
// values in the same position.
//
// It is the responsibility of the caller to ensure that any values expected
// to contain null-terminated strings are in fact null-terminated!
//
// AllocateContiguousBuffer panics if no values are passed in, as there are
// better alternatives for allocating a struct in that case.
func AllocateContiguousBuffer[T any, BU BufUnit](values ...[]BU) (t *T, tLenBytes uint32, slcs [][]BU) {
if len(values) == 0 {
panic("len(values) must be > 0")
}
// Get the sizes of T and BU, then compute a preferred alignment for T.
tT := reflect.TypeFor[T]()
szT := tT.Size()
szBU := int(unsafe.Sizeof(BU(0)))
alignment := max(tT.Align(), szBU)
// Our buffers for values will start at the next szBU boundary.
tLenBytes = alignUp(uint32(szT), szBU)
firstValueOffset := tLenBytes
// Accumulate the length of each value into tLenBytes
for _, v := range values {
tLenBytes += uint32(len(v) * szBU)
}
// Now that we know the final length, align up to our preferred boundary.
tLenBytes = alignUp(tLenBytes, alignment)
// Allocate the buffer. We choose a type for the slice that is appropriate
// for the desired alignment. Note that we do not have a strict requirement
// that T contain pointer fields; we could just be appending more data
// within the same buffer.
bufLen := tLenBytes / uint32(alignment)
var pt unsafe.Pointer
switch alignment {
case 1:
pt = unsafe.Pointer(unsafe.SliceData(make([]byte, bufLen)))
case 2:
pt = unsafe.Pointer(unsafe.SliceData(make([]uint16, bufLen)))
case 4:
pt = unsafe.Pointer(unsafe.SliceData(make([]uint32, bufLen)))
case 8:
pt = unsafe.Pointer(unsafe.SliceData(make([]uint64, bufLen)))
default:
panic(fmt.Sprintf("bad alignment %d", alignment))
}
t = (*T)(pt)
slcs = make([][]BU, 0, len(values))
// Use the limits of the buffer area after t to construct a slice representing the remaining buffer.
firstValuePtr := unsafe.Pointer(uintptr(pt) + uintptr(firstValueOffset))
buf := unsafe.Slice((*BU)(firstValuePtr), (tLenBytes-firstValueOffset)/uint32(szBU))
// Copy each value into the buffer and record a slice describing each value's limits into slcs.
var index int
for _, v := range values {
if len(v) == 0 {
// We allow zero-length values; we simply append a nil slice.
slcs = append(slcs, nil)
continue
}
valueSlice := buf[index : index+len(v)]
copy(valueSlice, v)
slcs = append(slcs, valueSlice)
index += len(v)
}
return t, tLenBytes, slcs
}
// alignment must be a power of 2
func alignUp[V constraints.Integer](v V, alignment int) V {
return v + ((-v) & (V(alignment) - 1))
}
// NTStr is a type constraint requiring the type to be either a
// windows.NTString or a windows.NTUnicodeString.
type NTStr interface {
windows.NTString | windows.NTUnicodeString
}
// SetNTString sets the value of nts in-place to point to the string contained
// within buf. A nul terminator is optional in buf.
func SetNTString[NTS NTStr, BU BufUnit](nts *NTS, buf []BU) {
isEmpty := len(buf) == 0
codeUnitSize := uint16(unsafe.Sizeof(BU(0)))
lenBytes := len(buf) * int(codeUnitSize)
if lenBytes > math.MaxUint16 {
panic("buffer length must fit into uint16")
}
lenBytes16 := uint16(lenBytes)
switch p := any(nts).(type) {
case *windows.NTString:
if isEmpty {
*p = windows.NTString{}
break
}
p.Buffer = unsafe.SliceData(any(buf).([]byte))
p.MaximumLength = lenBytes16
p.Length = lenBytes16
// account for nul terminator when present
if buf[len(buf)-1] == 0 {
p.Length -= codeUnitSize
}
case *windows.NTUnicodeString:
if isEmpty {
*p = windows.NTUnicodeString{}
break
}
p.Buffer = unsafe.SliceData(any(buf).([]uint16))
p.MaximumLength = lenBytes16
p.Length = lenBytes16
// account for nul terminator when present
if buf[len(buf)-1] == 0 {
p.Length -= codeUnitSize
}
default:
panic("unknown type")
}
}

View File

@ -4,9 +4,13 @@
package winutil package winutil
import ( import (
"reflect"
"testing" "testing"
"unsafe"
) )
//lint:file-ignore U1000 Fields are unused but necessary for tests.
const ( const (
localSystemSID = "S-1-5-18" localSystemSID = "S-1-5-18"
networkSID = "S-1-5-2" networkSID = "S-1-5-2"
@ -28,3 +32,103 @@ func TestLookupPseudoUser(t *testing.T) {
t.Errorf("LookupPseudoUser(%q) unexpectedly succeeded", networkSID) t.Errorf("LookupPseudoUser(%q) unexpectedly succeeded", networkSID)
} }
} }
type testType interface {
byte | uint16 | uint32 | uint64
}
type noPointers[T testType] struct {
foo byte
bar T
baz bool
}
type hasPointer struct {
foo byte
bar uint32
s1 *struct{}
baz byte
}
func checkContiguousBuffer[T any, BU BufUnit](t *testing.T, extra []BU, pt *T, ptLen uint32, slcs [][]BU) {
szBU := int(unsafe.Sizeof(BU(0)))
expectedAlign := max(reflect.TypeFor[T]().Align(), szBU)
// Check that pointer is aligned
if rem := uintptr(unsafe.Pointer(pt)) % uintptr(expectedAlign); rem != 0 {
t.Errorf("pointer alignment got %d, want 0", rem)
}
// Check that alloc length is aligned
if rem := int(ptLen) % expectedAlign; rem != 0 {
t.Errorf("allocation length alignment got %d, want 0", rem)
}
expectedLen := int(unsafe.Sizeof(*pt))
expectedLen = alignUp(expectedLen, szBU)
expectedLen += len(extra) * szBU
expectedLen = alignUp(expectedLen, expectedAlign)
if gotLen := int(ptLen); gotLen != expectedLen {
t.Errorf("allocation length got %d, want %d", gotLen, expectedLen)
}
if l := len(slcs); l != 1 {
t.Errorf("len(slcs) got %d, want 1", l)
}
if len(extra) == 0 && slcs[0] != nil {
t.Error("slcs[0] got non-nil, want nil")
}
if len(extra) != len(slcs[0]) {
t.Errorf("len(slcs[0]) got %d, want %d", len(slcs[0]), len(extra))
} else if rem := uintptr(unsafe.Pointer(unsafe.SliceData(slcs[0]))) % uintptr(szBU); rem != 0 {
t.Errorf("additional data alignment got %d, want 0", rem)
}
}
func TestAllocateContiguousBuffer(t *testing.T) {
t.Run("NoValues", testNoValues)
t.Run("NoPointers", testNoPointers)
t.Run("HasPointer", testHasPointer)
}
func testNoValues(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("expected panic but didn't get one")
}
}()
AllocateContiguousBuffer[hasPointer, byte]()
}
const maxTestBufLen = 8
func testNoPointers(t *testing.T) {
buf8 := make([]byte, maxTestBufLen)
buf16 := make([]uint16, maxTestBufLen)
for i := range maxTestBufLen {
s8, sl, slcs8 := AllocateContiguousBuffer[noPointers[byte]](buf8[:i])
checkContiguousBuffer(t, buf8[:i], s8, sl, slcs8)
s16, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint16]](buf8[:i])
checkContiguousBuffer(t, buf8[:i], s16, sl, slcs8)
s32, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint32]](buf8[:i])
checkContiguousBuffer(t, buf8[:i], s32, sl, slcs8)
s64, sl, slcs8 := AllocateContiguousBuffer[noPointers[uint64]](buf8[:i])
checkContiguousBuffer(t, buf8[:i], s64, sl, slcs8)
s8, sl, slcs16 := AllocateContiguousBuffer[noPointers[byte]](buf16[:i])
checkContiguousBuffer(t, buf16[:i], s8, sl, slcs16)
s16, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint16]](buf16[:i])
checkContiguousBuffer(t, buf16[:i], s16, sl, slcs16)
s32, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint32]](buf16[:i])
checkContiguousBuffer(t, buf16[:i], s32, sl, slcs16)
s64, sl, slcs16 = AllocateContiguousBuffer[noPointers[uint64]](buf16[:i])
checkContiguousBuffer(t, buf16[:i], s64, sl, slcs16)
}
}
func testHasPointer(t *testing.T) {
buf8 := make([]byte, maxTestBufLen)
buf16 := make([]uint16, maxTestBufLen)
for i := range maxTestBufLen {
s, sl, slcs8 := AllocateContiguousBuffer[hasPointer](buf8[:i])
checkContiguousBuffer(t, buf8[:i], s, sl, slcs8)
s, sl, slcs16 := AllocateContiguousBuffer[hasPointer](buf16[:i])
checkContiguousBuffer(t, buf16[:i], s, sl, slcs16)
}
}