From fb0c4ea8387c084d2a8f3d8bbe2263eaeab111a7 Mon Sep 17 00:00:00 2001 From: topjohnwu Date: Wed, 3 Sep 2025 11:35:29 -0700 Subject: [PATCH] Fallback to userspace copy if splice failed Fix #9032 --- native/src/base/files.rs | 16 ++++++ native/src/core/su/pts.rs | 115 +++++++++++++++++++++++--------------- native/src/core/su/su.cpp | 37 ++++-------- 3 files changed, 99 insertions(+), 69 deletions(-) diff --git a/native/src/base/files.rs b/native/src/base/files.rs index 3af9c6cd3..d62155a6a 100644 --- a/native/src/base/files.rs +++ b/native/src/base/files.rs @@ -885,3 +885,19 @@ pub fn parse_mount_info(pid: &str) -> Vec { } res } + +pub struct PipeFd { + pub read: OwnedFd, + pub write: OwnedFd, +} + +pub fn make_pipe(flags: i32) -> OsResult<'static, PipeFd> { + let mut pipefd: [RawFd; 2] = [0; 2]; + unsafe { + libc::pipe2(pipefd.as_mut_ptr(), flags).check_os_err("pipe2", None, None)?; + Ok(PipeFd { + read: OwnedFd::from_raw_fd(pipefd[0]), + write: OwnedFd::from_raw_fd(pipefd[1]), + }) + } +} diff --git a/native/src/core/su/pts.rs b/native/src/core/su/pts.rs index 9d4cb087f..5747f4c85 100644 --- a/native/src/core/su/pts.rs +++ b/native/src/core/su/pts.rs @@ -1,13 +1,24 @@ -use base::{error, libc, warn}; -use libc::{ - POLLIN, SFD_CLOEXEC, SIG_BLOCK, SIGWINCH, STDIN_FILENO, STDOUT_FILENO, TCSADRAIN, TCSAFLUSH, - TIOCGWINSZ, TIOCSWINSZ, cfmakeraw, close, pipe, poll, pollfd, raise, read, sigaddset, - sigemptyset, signalfd, signalfd_siginfo, sigprocmask, sigset_t, splice, tcgetattr, tcsetattr, - termios, winsize, +use base::libc::ssize_t; +use base::{ + LibcReturn, LoggedResult, OsResult, PipeFd, ReadExt, ResultExt, error, libc, log_err, + make_pipe, warn, }; -use std::{ffi::c_int, mem::MaybeUninit, ptr::null_mut}; +use bytemuck::{Pod, Zeroable}; +use libc::{ + O_CLOEXEC, POLLIN, SFD_CLOEXEC, SIG_BLOCK, SIGWINCH, STDIN_FILENO, STDOUT_FILENO, TCSADRAIN, + TCSAFLUSH, TIOCGWINSZ, TIOCSWINSZ, cfmakeraw, close, poll, pollfd, raise, sigaddset, + sigemptyset, signalfd, signalfd_siginfo, sigprocmask, sigset_t, tcgetattr, tcsetattr, termios, + winsize, +}; +use std::fs::File; +use std::io::{Read, Write}; +use std::mem::{ManuallyDrop, MaybeUninit}; +use std::os::fd::{AsRawFd, FromRawFd, RawFd}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::{ffi::c_int, ptr::null_mut}; static mut OLD_STDIN: Option = None; +static SHOULD_USE_SPLICE: AtomicBool = AtomicBool::new(true); const TIOCGPTN: u32 = 0x80045430; unsafe extern "C" { @@ -71,34 +82,55 @@ fn resize_pty(outfd: i32) { } } -fn pump_via_pipe(infd: i32, outfd: i32, pipe: &[c_int; 2]) -> bool { - // usize::MAX will EINVAL in some kernels, use i32::MAX in case - let s = unsafe { splice(infd, null_mut(), pipe[1], null_mut(), i32::MAX as _, 0) }; - if s < 0 { - error!("splice error"); - return false; - } - if s == 0 { - return true; - } - let s = unsafe { splice(pipe[0], null_mut(), outfd, null_mut(), s as usize, 0) }; - if s < 0 { - error!("splice error"); - return false; - } - true +fn splice(fd_in: RawFd, fd_out: RawFd, len: usize, flags: u32) -> OsResult<'static, ssize_t> { + unsafe { libc::splice(fd_in, null_mut(), fd_out, null_mut(), len, flags) } + .as_os_result("splice", None, None) } +fn pump_via_copy(infd: RawFd, outfd: RawFd) -> LoggedResult<()> { + let mut buf = MaybeUninit::<[u8; 4096]>::uninit(); + let buf = unsafe { buf.assume_init_mut() }; + let mut infd = ManuallyDrop::new(unsafe { File::from_raw_fd(infd) }); + let mut outfd = ManuallyDrop::new(unsafe { File::from_raw_fd(outfd) }); + let len = infd.read(buf)?; + outfd.write_all(&buf[..len])?; + Ok(()) +} + +fn pump_via_splice(infd: RawFd, outfd: RawFd, pipe: &PipeFd) -> LoggedResult<()> { + if !SHOULD_USE_SPLICE.load(Ordering::Acquire) { + return pump_via_copy(infd, outfd); + } + + // The pipe capacity is by default 16 pages, let's just use 65536 + let Ok(len) = splice(infd, pipe.write.as_raw_fd(), 65536_usize, 0) else { + // If splice failed, stop using splice and fallback to userspace copy + SHOULD_USE_SPLICE.store(false, Ordering::Release); + return pump_via_copy(infd, outfd); + }; + if len == 0 { + return Ok(()); + } + splice(pipe.read.as_raw_fd(), outfd, len as usize, 0)?; + Ok(()) +} + +#[derive(Copy, Clone)] +#[repr(transparent)] +struct SignalFdInfo(signalfd_siginfo); +unsafe impl Zeroable for SignalFdInfo {} +unsafe impl Pod for SignalFdInfo {} + pub fn pump_tty(infd: i32, outfd: i32) { set_stdin_raw(); - let sfd = unsafe { + let signal_fd = unsafe { let mut mask: sigset_t = std::mem::zeroed(); sigemptyset(&mut mask); sigaddset(&mut mask, SIGWINCH); - if sigprocmask(SIG_BLOCK, &mask, null_mut()) < 0 { - error!("sigprocmask"); - } + sigprocmask(SIG_BLOCK, &mask, null_mut()) + .check_os_err("sigprocmask", None, None) + .log_ok(); signalfd(-1, &mask, SFD_CLOEXEC) }; @@ -116,17 +148,16 @@ pub fn pump_tty(infd: i32, outfd: i32) { revents: 0, }, pollfd { - fd: sfd, + fd: signal_fd, events: POLLIN, revents: 0, }, ]; - let mut p: [c_int; 2] = [0; 2]; - if unsafe { pipe(&mut p as *mut c_int) } < 0 { - error!("pipe error"); + let Ok(pipe_fd) = make_pipe(O_CLOEXEC).log() else { return; - } + }; + 'poll: loop { let ready = unsafe { poll(pfds.as_mut_ptr(), pfds.len() as _, -1) }; @@ -138,22 +169,18 @@ pub fn pump_tty(infd: i32, outfd: i32) { for pfd in &pfds { if pfd.revents & POLLIN != 0 { let res = if pfd.fd == STDIN_FILENO { - pump_via_pipe(pfd.fd, outfd, &p) + pump_via_splice(STDIN_FILENO, outfd, &pipe_fd) } else if pfd.fd == infd { - pump_via_pipe(pfd.fd, STDOUT_FILENO, &p) - } else if pfd.fd == sfd { + pump_via_splice(infd, STDOUT_FILENO, &pipe_fd) + } else if pfd.fd == signal_fd { resize_pty(outfd); - let mut buf = [MaybeUninit::::uninit(); size_of::()]; - if unsafe { read(pfd.fd, buf.as_mut_ptr() as *mut _, buf.len()) } < 0 { - error!("read error"); - false - } else { - true - } + let mut info = SignalFdInfo::zeroed(); + let mut fd = ManuallyDrop::new(unsafe { File::from_raw_fd(signal_fd) }); + fd.read_pod(&mut info).log() } else { - false + log_err!() }; - if !res { + if res.is_err() { break 'poll; } } else if pfd.revents != 0 && pfd.fd == infd { diff --git a/native/src/core/su/su.cpp b/native/src/core/su/su.cpp index 357a3d1c6..15410e45a 100644 --- a/native/src/core/su/su.cpp +++ b/native/src/core/su/su.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2017 - 2023, John Wu (@topjohnwu) + * Copyright 2017 - 2025, John Wu (@topjohnwu) * Copyright 2015, Pierre-Hugues Husson * Copyright 2010, Adam Shanks (@ChainsDD) * Copyright 2008, Zinx Verituse (@zinxv) @@ -76,9 +76,7 @@ static void sighandler(int sig) { close(STDERR_FILENO); // Put back all the default handlers - struct sigaction act; - - memset(&act, 0, sizeof(act)); + struct sigaction act{}; act.sa_handler = SIG_DFL; for (int i = 0; quit_signals[i]; ++i) { sigaction(quit_signals[i], &act, nullptr); @@ -86,8 +84,7 @@ static void sighandler(int sig) { } static void setup_sighandlers(void (*handler)(int)) { - struct sigaction act; - memset(&act, 0, sizeof(act)); + struct sigaction act{}; act.sa_handler = handler; for (int i = 0; quit_signals[i]; ++i) { sigaction(quit_signals[i], &act, nullptr); @@ -95,8 +92,7 @@ static void setup_sighandlers(void (*handler)(int)) { } int su_client_main(int argc, char *argv[]) { - int c; - struct option long_opts[] = { + option long_opts[] = { { "command", required_argument, nullptr, 'c' }, { "help", no_argument, nullptr, 'h' }, { "login", no_argument, nullptr, 'l' }, @@ -126,6 +122,7 @@ int su_client_main(int argc, char *argv[]) { bool interactive = false; + int c; while ((c = getopt_long(argc, argv, "c:hlimpds:VvuZ:Mt:g:G:", long_opts, nullptr)) != -1) { switch (c) { case 'c': { @@ -191,7 +188,7 @@ int su_client_main(int argc, char *argv[]) { fprintf(stderr, "Invalid GID: %s\n", optarg); usage(EXIT_FAILURE); } - std::copy(gids.begin(), gids.end(), std::back_inserter(req.gids)); + ranges::copy(gids, std::back_inserter(req.gids)); break; } default: @@ -207,19 +204,15 @@ int su_client_main(int argc, char *argv[]) { } /* username or uid */ if (optind < argc) { - struct passwd *pw; - pw = getpwnam(argv[optind]); - if (pw) + if (const passwd *pw = getpwnam(argv[optind])) req.target_uid = pw->pw_uid; else req.target_uid = parse_int(argv[optind]); optind++; } - int ptmx, fd; - // Connect to client - fd = connect_daemon(+RequestCode::SUPERUSER); + owned_fd fd = connect_daemon(+RequestCode::SUPERUSER); // Send request req.write_to_fd(fd); @@ -248,23 +241,17 @@ int su_client_main(int argc, char *argv[]) { if (atty) { // We need a PTY. Get one. write_int(fd, 1); - ptmx = recv_fd(fd); - } else { - write_int(fd, 0); - } - - if (atty) { + int ptmx = recv_fd(fd); setup_sighandlers(sighandler); // if stdin is not a tty, if we pump to ptmx, our process may intercept the input to ptmx and // output to stdout, which cause the target process lost input. pump_tty(ptmx, (atty & ATTY_IN) ? ptmx : -1); + } else { + write_int(fd, 0); } // Get the exit code - int code = read_int(fd); - close(fd); - - return code; + return read_int(fd); } static void drop_caps() {