Migrate away from unsafe set_len of Utf8CStr

This commit is contained in:
topjohnwu
2025-09-09 22:19:05 -07:00
parent c8caaa98f5
commit 111136733a
5 changed files with 108 additions and 64 deletions

View File

@@ -3,7 +3,7 @@ use libc::c_char;
use nix::NixPath;
use std::borrow::Borrow;
use std::cmp::{Ordering, min};
use std::ffi::{CStr, FromBytesWithNulError, OsStr};
use std::ffi::{CStr, FromBytesUntilNulError, FromBytesWithNulError, OsStr};
use std::fmt::{Debug, Display, Formatter, Write};
use std::ops::Deref;
use std::os::unix::ffi::OsStrExt;
@@ -73,13 +73,6 @@ pub trait Utf8CStrBuf: Display + Write + AsRef<Utf8CStr> + Deref<Target = Utf8CS
// The length of the string without the terminating null character.
// assert_true(len <= capacity - 1)
fn len(&self) -> usize;
// Set the length of the string
//
// It is your responsibility to:
// 1. Null terminate the string by setting the next byte after len to null
// 2. Ensure len <= capacity - 1
// 3. All bytes from 0 to len is valid UTF-8 and does not contain null
unsafe fn set_len(&mut self, len: usize);
fn push_str(&mut self, s: &str) -> usize;
// The capacity of the internal buffer. The maximum string length this buffer can contain
// is capacity - 1, because the last byte is reserved for the terminating null character.
@@ -87,6 +80,10 @@ pub trait Utf8CStrBuf: Display + Write + AsRef<Utf8CStr> + Deref<Target = Utf8CS
fn clear(&mut self);
fn as_mut_ptr(&mut self) -> *mut c_char;
fn truncate(&mut self, new_len: usize);
// Rebuild the Utf8CStr based on the contents of the internal buffer. Required after any
// unsafe modifications directly though the pointer obtained from self.as_mut_ptr().
// If an error is returned, the internal buffer will be reset, resulting in an empty string.
fn rebuild(&mut self) -> Result<(), StrErr>;
#[inline(always)]
fn is_empty(&self) -> bool {
@@ -161,12 +158,6 @@ impl Utf8CStrBuf for Utf8CString {
self.0.len()
}
unsafe fn set_len(&mut self, len: usize) {
unsafe {
self.0.as_mut_vec().set_len(len);
}
}
fn push_str(&mut self, s: &str) -> usize {
self.0.push_str(s);
self.0.nul_terminate();
@@ -190,6 +181,32 @@ impl Utf8CStrBuf for Utf8CString {
self.0.truncate(new_len);
self.0.nul_terminate();
}
fn rebuild(&mut self) -> Result<(), StrErr> {
// Temporarily move the internal String out
let mut tmp = String::new();
mem::swap(&mut tmp, &mut self.0);
let (ptr, _, capacity) = tmp.into_raw_parts();
unsafe {
// Validate the entire buffer, including the unused part
let bytes = slice::from_raw_parts(ptr, capacity);
match Utf8CStr::from_bytes_until_nul(bytes) {
Ok(s) => {
// Move the String with the new length back
self.0 = String::from_raw_parts(ptr, s.len(), capacity);
}
Err(e) => {
// Move the String with 0 length back
self.0 = String::from_raw_parts(ptr, 0, capacity);
self.0.nul_terminate();
return Err(e);
}
}
}
Ok(())
}
}
impl From<String> for Utf8CString {
@@ -267,7 +284,9 @@ pub enum StrErr {
#[error(transparent)]
Utf8Error(#[from] Utf8Error),
#[error(transparent)]
CStrError(#[from] FromBytesWithNulError),
CStrWithNullError(#[from] FromBytesWithNulError),
#[error(transparent)]
CStrUntilNullError(#[from] FromBytesUntilNulError),
#[error("argument is null")]
NullPointerError,
}
@@ -283,8 +302,12 @@ impl Utf8CStr {
Ok(unsafe { Self::from_bytes_unchecked(cstr.to_bytes_with_nul()) })
}
pub fn from_bytes(buf: &[u8]) -> Result<&Utf8CStr, StrErr> {
Self::from_cstr(CStr::from_bytes_with_nul(buf)?)
fn from_bytes_until_nul(bytes: &[u8]) -> Result<&Utf8CStr, StrErr> {
Self::from_cstr(CStr::from_bytes_until_nul(bytes)?)
}
pub fn from_bytes(bytes: &[u8]) -> Result<&Utf8CStr, StrErr> {
Self::from_cstr(CStr::from_bytes_with_nul(bytes)?)
}
pub fn from_string(s: &mut String) -> &Utf8CStr {
@@ -294,8 +317,8 @@ impl Utf8CStr {
}
#[inline(always)]
pub const unsafe fn from_bytes_unchecked(buf: &[u8]) -> &Utf8CStr {
unsafe { mem::transmute(buf) }
pub const unsafe fn from_bytes_unchecked(bytes: &[u8]) -> &Utf8CStr {
unsafe { mem::transmute(bytes) }
}
pub unsafe fn from_ptr<'a>(ptr: *const c_char) -> Result<&'a Utf8CStr, StrErr> {
@@ -564,10 +587,6 @@ macro_rules! impl_cstr_buf {
self.used
}
#[inline(always)]
unsafe fn set_len(&mut self, len: usize) {
self.used = len;
}
#[inline(always)]
fn push_str(&mut self, s: &str) -> usize {
// SAFETY: self.used is guaranteed to always <= SIZE - 1
let dest = unsafe { self.buf.get_unchecked_mut(self.used..) };
@@ -595,6 +614,18 @@ macro_rules! impl_cstr_buf {
self.buf[new_len] = b'\0';
self.used = new_len;
}
fn rebuild(&mut self) -> Result<(), StrErr> {
// Validate the entire buffer, including the unused part
match Utf8CStr::from_bytes_until_nul(&self.buf) {
Ok(s) => self.used = s.len(),
Err(e) => {
self.used = 0;
self.buf[0] = b'\0';
return Err(e);
}
}
Ok(())
}
}
)*}
}

View File

@@ -212,15 +212,15 @@ impl Directory {
) -> OsResult<'a, ()> {
buf.clear();
unsafe {
let r = readlinkat(
readlinkat(
self.as_raw_fd(),
name.as_ptr(),
buf.as_mut_ptr().cast(),
buf.capacity(),
)
.into_os_result("readlinkat", Some(name), None)? as usize;
buf.set_len(r);
.check_os_err("readlinkat", Some(name), None)?;
}
buf.rebuild().ok();
Ok(())
}

View File

@@ -257,8 +257,8 @@ impl Utf8CStr {
let r = libc::readlink(self.as_ptr(), buf.as_mut_ptr(), buf.capacity() - 1)
.into_os_result("readlink", Some(self), None)? as isize;
*(buf.as_mut_ptr().offset(r) as *mut u8) = b'\0';
buf.set_len(r as usize);
}
buf.rebuild().ok();
Ok(())
}
@@ -336,23 +336,25 @@ impl Utf8CStr {
}
pub fn get_secontext(&self, con: &mut dyn Utf8CStrBuf) -> OsResult<'_, ()> {
unsafe {
let sz = libc::lgetxattr(
con.clear();
let result = unsafe {
libc::lgetxattr(
self.as_ptr(),
XATTR_NAME_SELINUX.as_ptr(),
con.as_mut_ptr().cast(),
con.capacity(),
);
if sz < 1 {
con.clear();
if *errno() != libc::ENODATA {
return Err(OsError::last_os_error("lgetxattr", Some(self), None));
}
} else {
con.set_len((sz - 1) as usize);
)
.check_err()
};
match result {
Ok(_) => {
con.rebuild().ok();
Ok(())
}
Err(Errno::ENODATA) => Ok(()),
Err(e) => Err(OsError::new(e, "lgetxattr", Some(self), None)),
}
Ok(())
}
pub fn set_secontext<'a>(&'a self, con: &'a Utf8CStr) -> OsResult<'a, ()> {
@@ -549,23 +551,25 @@ impl FsPathFollow {
}
pub fn get_secontext(&self, con: &mut dyn Utf8CStrBuf) -> OsResult<'_, ()> {
unsafe {
let sz = libc::getxattr(
con.clear();
let result = unsafe {
libc::getxattr(
self.as_ptr(),
XATTR_NAME_SELINUX.as_ptr(),
con.as_mut_ptr().cast(),
con.capacity(),
);
if sz < 1 {
con.clear();
if *errno() != libc::ENODATA {
return Err(OsError::last_os_error("getxattr", Some(self), None));
}
} else {
con.set_len((sz - 1) as usize);
)
.check_err()
};
match result {
Ok(_) => {
con.rebuild().ok();
Ok(())
}
Err(Errno::ENODATA) => Ok(()),
Err(e) => Err(OsError::new(e, "getxattr", Some(self), None)),
}
Ok(())
}
pub fn set_secontext<'a>(&'a self, con: &'a Utf8CStr) -> OsResult<'a, ()> {
@@ -660,22 +664,25 @@ pub fn fd_set_attr(fd: RawFd, attr: &FileAttr) -> OsResult<'_, ()> {
}
pub fn fd_get_secontext(fd: RawFd, con: &mut dyn Utf8CStrBuf) -> OsResult<'static, ()> {
unsafe {
let sz = libc::fgetxattr(
con.clear();
let result = unsafe {
libc::fgetxattr(
fd,
XATTR_NAME_SELINUX.as_ptr(),
con.as_mut_ptr().cast(),
con.capacity(),
);
if sz < 1 {
if *errno() != libc::ENODATA {
return Err(OsError::last_os_error("fgetxattr", None, None));
}
} else {
con.set_len((sz - 1) as usize);
)
.check_err()
};
match result {
Ok(_) => {
con.rebuild().ok();
Ok(())
}
Err(Errno::ENODATA) => Ok(()),
Err(e) => Err(OsError::new(e, "fgetxattr", None, None)),
}
Ok(())
}
pub fn fd_set_secontext(fd: RawFd, con: &Utf8CStr) -> OsResult<'_, ()> {

View File

@@ -1,3 +1,4 @@
#![feature(vec_into_raw_parts)]
#![allow(clippy::missing_safety_doc)]
pub use const_format;

View File

@@ -313,21 +313,26 @@ extern "C" fn logfile_writer(arg: *mut c_void) -> *mut c_void {
// the crate cannot fetch the proper local timezone without pulling in a bunch of
// timezone handling code. To reduce binary size, fallback to use localtime_r in libc.
unsafe {
let secs: time_t = now.as_secs() as time_t;
let secs = now.as_secs() as time_t;
let mut tm: tm = std::mem::zeroed();
if localtime_r(&secs, &mut tm).is_null() {
continue;
}
let len = strftime(aux.as_mut_ptr(), aux.capacity(), raw_cstr!("%m-%d %T"), &tm);
aux.set_len(len);
aux.write_fmt(format_args!(
strftime(aux.as_mut_ptr(), aux.capacity(), raw_cstr!("%m-%d %T"), &tm);
}
if aux.rebuild().is_ok() {
write!(
aux,
".{:03} {:5} {:5} {} : ",
now.subsec_millis(),
meta.pid,
meta.tid,
prio
))
)
.ok();
} else {
continue;
}
let io1 = IoSlice::new(aux.as_bytes());