Improve Rust implementation

- Move mmap_file implementation into Rust
- Introduce Utf8CStr as the better c-string type to use
This commit is contained in:
topjohnwu
2023-06-12 01:07:43 -07:00
committed by John Wu
parent 866386e21f
commit 23c1f0111b
9 changed files with 352 additions and 118 deletions

View File

@@ -1,14 +1,18 @@
// Functions listed here are just to export to C++
use crate::{fd_path, mkdirs, realpath, rm_rf, slice_from_ptr_mut, Directory, ResultExt};
use std::ffi::CStr;
use std::io;
use std::os::fd::{BorrowedFd, OwnedFd, RawFd};
use anyhow::Context;
use cxx::private::c_char;
use libc::mode_t;
use std::ffi::CStr;
use std::io;
use std::os::fd::{OwnedFd, RawFd};
pub fn fd_path_for_cxx(fd: RawFd, buf: &mut [u8]) -> isize {
use crate::{
fd_path, map_fd, map_file, mkdirs, realpath, rm_rf, slice_from_ptr_mut, Directory, ResultExt,
};
pub(crate) fn fd_path_for_cxx(fd: RawFd, buf: &mut [u8]) -> isize {
fd_path(fd, buf)
.context("fd_path failed")
.log()
@@ -37,3 +41,19 @@ unsafe extern "C" fn frm_rf(fd: OwnedFd) -> bool {
}
inner(fd).map_or(false, |_| true)
}
pub(crate) fn map_file_for_cxx(path: &[u8], rw: bool) -> &'static mut [u8] {
unsafe {
map_file(CStr::from_bytes_with_nul_unchecked(path), rw)
.log()
.unwrap_or(&mut [])
}
}
pub(crate) fn map_fd_for_cxx(fd: RawFd, sz: usize, rw: bool) -> &'static mut [u8] {
unsafe {
map_fd(BorrowedFd::borrow_raw(fd), sz, rw)
.log()
.unwrap_or(&mut [])
}
}

View File

@@ -375,29 +375,19 @@ sFILE make_file(FILE *fp) {
}
mmap_data::mmap_data(const char *name, bool rw) {
int fd = xopen(name, (rw ? O_RDWR : O_RDONLY) | O_CLOEXEC);
if (fd < 0)
return;
run_finally g([=] { close(fd); });
struct stat st{};
if (fstat(fd, &st))
return;
if (S_ISBLK(st.st_mode)) {
uint64_t size;
ioctl(fd, BLKGETSIZE64, &size);
init(fd, size, rw);
} else {
init(fd, st.st_size, rw);
auto slice = rust::map_file(byte_view(name), rw);
if (!slice.empty()) {
_buf = slice.data();
_sz = slice.size();
}
}
void mmap_data::init(int fd, size_t sz, bool rw) {
_sz = sz;
void *b = sz > 0
? xmmap(nullptr, sz, PROT_READ | PROT_WRITE, rw ? MAP_SHARED : MAP_PRIVATE, fd, 0)
: nullptr;
_buf = static_cast<uint8_t *>(b);
mmap_data::mmap_data(int fd, size_t sz, bool rw) {
auto slice = rust::map_fd(fd, sz, rw);
if (!slice.empty()) {
_buf = slice.data();
_sz = slice.size();
}
}
mmap_data::~mmap_data() {

View File

@@ -45,10 +45,8 @@ struct mmap_data : public byte_data {
ALLOW_MOVE_ONLY(mmap_data)
explicit mmap_data(const char *name, bool rw = false);
mmap_data(int fd, size_t sz, bool rw = false) { init(fd, sz, rw); }
mmap_data(int fd, size_t sz, bool rw = false);
~mmap_data();
private:
void init(int fd, size_t sz, bool rw);
};
extern "C" {

View File

@@ -7,20 +7,16 @@ use std::marker::PhantomData;
use std::ops::Deref;
use std::os::fd::{AsFd, BorrowedFd, IntoRawFd};
use std::os::unix::io::{AsRawFd, FromRawFd, OwnedFd, RawFd};
use std::{io, mem, slice};
use std::{io, mem, ptr, slice};
use libc::{c_char, c_uint, dirent, mode_t, EEXIST, ENOENT, O_CLOEXEC, O_PATH, O_RDONLY};
use libc::{c_char, c_uint, dirent, mode_t, EEXIST, ENOENT, O_CLOEXEC, O_PATH, O_RDONLY, O_RDWR};
use crate::{bfmt_cstr, copy_cstr, cstr, errno, error};
use crate::{bfmt_cstr, copy_cstr, cstr, errno, error, LibcReturn};
pub fn __open_fd_impl(path: &CStr, flags: i32, mode: mode_t) -> io::Result<OwnedFd> {
unsafe {
let fd = libc::open(path.as_ptr(), flags, mode as c_uint);
if fd >= 0 {
Ok(OwnedFd::from_raw_fd(fd))
} else {
Err(io::Error::last_os_error())
}
let fd = libc::open(path.as_ptr(), flags, mode as c_uint).check_os_err()?;
Ok(OwnedFd::from_raw_fd(fd))
}
}
@@ -43,10 +39,8 @@ pub unsafe fn readlink_unsafe(path: *const c_char, buf: *mut u8, bufsz: usize) -
}
pub fn readlink(path: &CStr, data: &mut [u8]) -> io::Result<usize> {
let r = unsafe { readlink_unsafe(path.as_ptr(), data.as_mut_ptr(), data.len()) };
if r < 0 {
return Err(io::Error::last_os_error());
}
let r =
unsafe { readlink_unsafe(path.as_ptr(), data.as_mut_ptr(), data.len()) }.check_os_err()?;
Ok(r as usize)
}
@@ -229,9 +223,7 @@ impl DirEntry<'_> {
pub fn unlink(&self) -> io::Result<()> {
let flag = if self.is_dir() { libc::AT_REMOVEDIR } else { 0 };
unsafe {
if libc::unlinkat(self.dir.as_raw_fd(), self.d_name.as_ptr(), flag) < 0 {
return Err(io::Error::last_os_error());
}
libc::unlinkat(self.dir.as_raw_fd(), self.d_name.as_ptr(), flag).check_os_err()?;
}
Ok(())
}
@@ -245,10 +237,8 @@ impl DirEntry<'_> {
self.dir.as_raw_fd(),
self.d_name.as_ptr(),
O_RDONLY | O_CLOEXEC,
);
if fd < 0 {
return Err(io::Error::last_os_error());
}
)
.check_os_err()?;
Directory::try_from(OwnedFd::from_raw_fd(fd))
}
}
@@ -262,10 +252,8 @@ impl DirEntry<'_> {
self.dir.as_raw_fd(),
self.d_name.as_ptr(),
flags | O_CLOEXEC,
);
if fd < 0 {
return Err(io::Error::last_os_error());
}
)
.check_os_err()?;
Ok(File::from_raw_fd(fd))
}
}
@@ -292,10 +280,7 @@ pub enum WalkResult {
impl<'a> Directory<'a> {
pub fn open(path: &CStr) -> io::Result<Directory> {
let dirp = unsafe { libc::opendir(path.as_ptr()) };
if dirp.is_null() {
return Err(io::Error::last_os_error());
}
let dirp = unsafe { libc::opendir(path.as_ptr()) }.check_os_err()?;
Ok(Directory {
dirp,
_phantom: PhantomData,
@@ -415,10 +400,7 @@ impl TryFrom<OwnedFd> for Directory<'_> {
type Error = io::Error;
fn try_from(fd: OwnedFd) -> io::Result<Self> {
let dirp = unsafe { libc::fdopendir(fd.into_raw_fd()) };
if dirp.is_null() {
return Err(io::Error::last_os_error());
}
let dirp = unsafe { libc::fdopendir(fd.into_raw_fd()) }.check_os_err()?;
Ok(Directory {
dirp,
_phantom: PhantomData,
@@ -449,16 +431,131 @@ impl Drop for Directory<'_> {
pub fn rm_rf(path: &CStr) -> io::Result<()> {
unsafe {
let mut stat: libc::stat = mem::zeroed();
if libc::lstat(path.as_ptr(), &mut stat) < 0 {
return Err(io::Error::last_os_error());
}
if (stat.st_mode & libc::S_IFMT as u32) == libc::S_IFDIR as u32 {
libc::lstat(path.as_ptr(), &mut stat).check_os_err()?;
if stat.is_dir() {
let mut dir = Directory::open(path)?;
dir.remove_all()?;
}
if libc::remove(path.as_ptr()) < 0 {
return Err(io::Error::last_os_error());
}
libc::remove(path.as_ptr()).check_os_err()?;
}
Ok(())
}
pub trait StatExt {
fn get_mode(&self) -> u32;
fn get_type(&self) -> u32;
fn is_dir(&self) -> bool;
fn is_blk(&self) -> bool;
}
impl StatExt for libc::stat {
fn get_mode(&self) -> u32 {
self.st_mode & libc::S_IFMT as u32
}
fn get_type(&self) -> u32 {
self.st_mode & !(libc::S_IFMT as u32)
}
fn is_dir(&self) -> bool {
self.get_type() == libc::S_IFDIR as u32
}
fn is_blk(&self) -> bool {
self.get_type() == libc::S_IFBLK as u32
}
}
pub trait FdExt {
fn size(&self) -> io::Result<usize>;
}
const BLKGETSIZE64: u32 = 0x80081272;
impl<T: AsRawFd> FdExt for T {
fn size(&self) -> io::Result<usize> {
unsafe fn inner(fd: RawFd) -> io::Result<usize> {
extern "C" {
// Don't use the declaration from the libc crate as request should be u32 not i32
fn ioctl(fd: RawFd, request: u32, ...) -> i32;
}
let mut stat: libc::stat = mem::zeroed();
libc::fstat(fd, &mut stat).check_os_err()?;
if stat.is_blk() {
let mut sz = 0_u64;
ioctl(fd, BLKGETSIZE64, &mut sz).check_os_err()?;
Ok(sz as usize)
} else {
Ok(stat.st_size as usize)
}
}
unsafe { inner(self.as_raw_fd()) }
}
}
pub struct MappedFile(&'static mut [u8]);
impl MappedFile {
pub fn open(path: &CStr) -> io::Result<MappedFile> {
Ok(MappedFile(map_file(path, false)?))
}
pub fn open_rw(path: &CStr) -> io::Result<MappedFile> {
Ok(MappedFile(map_file(path, true)?))
}
pub fn create(fd: BorrowedFd, sz: usize, rw: bool) -> io::Result<MappedFile> {
Ok(MappedFile(map_fd(fd, sz, rw)?))
}
}
impl AsRef<[u8]> for MappedFile {
fn as_ref(&self) -> &[u8] {
self.0
}
}
impl AsMut<[u8]> for MappedFile {
fn as_mut(&mut self) -> &mut [u8] {
self.0
}
}
impl Drop for MappedFile {
fn drop(&mut self) {
unsafe {
libc::munmap(self.0.as_mut_ptr().cast(), self.0.len());
}
}
}
// We mark the returned slice static because it is valid until explicitly unmapped
pub(crate) fn map_file(path: &CStr, rw: bool) -> io::Result<&'static mut [u8]> {
let flag = if rw { O_RDONLY } else { O_RDWR };
let fd = open_fd!(path, flag | O_CLOEXEC)?;
map_fd(fd.as_fd(), fd.size()?, rw)
}
pub(crate) fn map_fd(fd: BorrowedFd, sz: usize, rw: bool) -> io::Result<&'static mut [u8]> {
let flag = if rw {
libc::MAP_SHARED
} else {
libc::MAP_PRIVATE
};
unsafe {
let ptr = libc::mmap(
ptr::null_mut(),
sz,
libc::PROT_READ | libc::PROT_WRITE,
flag,
fd.as_raw_fd(),
0,
);
if ptr == libc::MAP_FAILED {
return Err(io::Error::last_os_error());
}
Ok(slice::from_raw_parts_mut(ptr.cast(), sz))
}
}

View File

@@ -38,5 +38,9 @@ pub mod ffi {
fn xpipe2(fds: &mut [i32; 2], flags: i32) -> i32;
#[rust_name = "fd_path_for_cxx"]
fn fd_path(fd: i32, buf: &mut [u8]) -> isize;
#[rust_name = "map_file_for_cxx"]
fn map_file(path: &[u8], rw: bool) -> &'static mut [u8];
#[rust_name = "map_fd_for_cxx"]
fn map_fd(fd: i32, sz: usize, rw: bool) -> &'static mut [u8];
}
}

View File

@@ -1,9 +1,12 @@
use std::cmp::min;
use std::ffi::CStr;
use std::fmt::{Arguments, Debug};
use std::ffi::{CStr, FromBytesWithNulError, OsStr};
use std::fmt::{Arguments, Debug, Display, Formatter};
use std::ops::Deref;
use std::path::Path;
use std::str::Utf8Error;
use std::{fmt, slice};
use std::{fmt, io, slice, str};
use libc::c_char;
use thiserror::Error;
pub fn copy_str<T: AsRef<[u8]>>(dest: &mut [u8], src: T) -> usize {
@@ -96,14 +99,120 @@ macro_rules! raw_cstr {
#[derive(Debug, Error)]
pub enum StrErr {
#[error(transparent)]
Invalid(#[from] Utf8Error),
Utf8Error(#[from] Utf8Error),
#[error(transparent)]
CStrError(#[from] FromBytesWithNulError),
#[error("argument is null")]
NullPointer,
NullPointerError,
}
// The better CStr: UTF-8 validated + null terminated buffer
pub struct Utf8CStr {
inner: [u8],
}
impl Utf8CStr {
pub fn from_cstr(cstr: &CStr) -> Result<&Utf8CStr, StrErr> {
// Validate the buffer during construction
str::from_utf8(cstr.to_bytes())?;
Ok(unsafe { Self::from_bytes_unchecked(cstr.to_bytes_with_nul()) })
}
pub fn from_bytes(buf: &[u8]) -> Result<&Utf8CStr, StrErr> {
Self::from_cstr(CStr::from_bytes_with_nul(buf)?)
}
pub fn from_string(s: &mut String) -> &Utf8CStr {
if s.capacity() == s.len() {
s.reserve(1);
}
// SAFETY: the string is reserved to have enough capacity to fit in the null byte
// SAFETY: the null byte is explicitly added outside of the string's length
unsafe {
let buf = slice::from_raw_parts_mut(s.as_mut_ptr(), s.len() + 1);
*buf.get_unchecked_mut(s.len()) = b'\0';
Self::from_bytes_unchecked(buf)
}
}
pub unsafe fn from_bytes_unchecked(buf: &[u8]) -> &Utf8CStr {
&*(buf as *const [u8] as *const Utf8CStr)
}
pub unsafe fn from_ptr<'a>(ptr: *const c_char) -> Result<&'a Utf8CStr, StrErr> {
if ptr.is_null() {
return Err(StrErr::NullPointerError);
}
Self::from_cstr(unsafe { CStr::from_ptr(ptr) })
}
pub fn as_bytes(&self) -> &[u8] {
// The length of the slice is at least 1 due to null termination check
unsafe { self.inner.get_unchecked(..self.inner.len() - 1) }
}
pub fn as_bytes_with_nul(&self) -> &[u8] {
&self.inner
}
pub fn as_ptr(&self) -> *const c_char {
self.inner.as_ptr().cast()
}
}
impl Deref for Utf8CStr {
type Target = str;
fn deref(&self) -> &str {
self.as_ref()
}
}
impl Display for Utf8CStr {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(self.deref(), f)
}
}
impl Debug for Utf8CStr {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Debug::fmt(self.deref(), f)
}
}
impl AsRef<CStr> for Utf8CStr {
#[inline]
fn as_ref(&self) -> &CStr {
// SAFETY: Already validated as null terminated during construction
unsafe { CStr::from_bytes_with_nul_unchecked(&self.inner) }
}
}
impl AsRef<str> for Utf8CStr {
#[inline]
fn as_ref(&self) -> &str {
// SAFETY: Already UTF-8 validated during construction
unsafe { str::from_utf8_unchecked(self.as_bytes()) }
}
}
impl AsRef<OsStr> for Utf8CStr {
#[inline]
fn as_ref(&self) -> &OsStr {
OsStr::new(self)
}
}
impl AsRef<Path> for Utf8CStr {
#[inline]
fn as_ref(&self) -> &Path {
Path::new(self)
}
}
pub fn ptr_to_str_result<'a, T>(ptr: *const T) -> Result<&'a str, StrErr> {
if ptr.is_null() {
Err(StrErr::NullPointer)
Err(StrErr::NullPointerError)
} else {
unsafe { CStr::from_ptr(ptr.cast()) }
.to_str()
@@ -167,3 +276,38 @@ pub trait FlatData {
}
}
}
// Check libc return value and map errors to Result
pub trait LibcReturn: Copy {
fn is_error(&self) -> bool;
fn check_os_err(self) -> io::Result<Self> {
if self.is_error() {
return Err(io::Error::last_os_error());
}
Ok(self)
}
}
impl LibcReturn for i32 {
fn is_error(&self) -> bool {
*self < 0
}
}
impl LibcReturn for isize {
fn is_error(&self) -> bool {
*self < 0
}
}
impl<T> LibcReturn for *const T {
fn is_error(&self) -> bool {
self.is_null()
}
}
impl<T> LibcReturn for *mut T {
fn is_error(&self) -> bool {
self.is_null()
}
}