Implement simple serialization over IPC

This commit is contained in:
topjohnwu 2025-02-01 01:29:08 +08:00 committed by John Wu
parent 59622d1688
commit ab86732c89
12 changed files with 323 additions and 79 deletions

10
native/src/Cargo.lock generated
View File

@ -300,6 +300,15 @@ dependencies = [
"syn",
]
[[package]]
name = "derive"
version = "0.0.0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "digest"
version = "0.11.0-pre.9"
@ -461,6 +470,7 @@ dependencies = [
"bytemuck",
"cxx",
"cxx-gen",
"derive",
"num-derive",
"num-traits",
"pb-rs",

View File

@ -1,6 +1,6 @@
[workspace]
exclude = ["external"]
members = ["base", "boot", "core", "init", "sepolicy"]
members = ["base", "boot", "core", "core/derive", "init", "sepolicy"]
resolver = "2"
[workspace.dependencies]
@ -26,6 +26,9 @@ bytemuck = "1.16"
fdt = "0.1"
const_format = "0.2"
bit-set = "0.8"
syn = "2"
quote = "1"
proc-macro2 = "1"
[workspace.dependencies.argh]
git = "https://github.com/google/argh.git"

View File

@ -17,6 +17,7 @@ pb-rs = { workspace = true }
[dependencies]
base = { path = "../base", features = ["selinux"] }
derive = { path = "derive" }
cxx = { workspace = true }
num-traits = { workspace = true }
num-derive = { workspace = true }

View File

@ -282,7 +282,7 @@ impl MagiskD {
fn db_exec_for_client(&self, fd: OwnedFd) -> LoggedResult<()> {
let mut file = File::from(fd);
let mut reader = BufReader::new(&mut file);
let sql = reader.ipc_read_string()?;
let sql: String = reader.read_decodable()?;
let mut writer = BufWriter::new(&mut file);
let mut output_fn = |columns: &[String], values: &DbValues| {
let mut out = "".to_string();
@ -294,10 +294,10 @@ impl MagiskD {
out.push('=');
out.push_str(values.get_text(i as i32));
}
writer.ipc_write_string(&out).log().ok();
writer.write_encodable(&out).log().ok();
};
self.db_exec_with_rows(&sql, &[], &mut output_fn);
writer.ipc_write_string("").log()
writer.write_encodable("").log()
}
}

View File

@ -0,0 +1,13 @@
[package]
name = "derive"
version = "0.0.0"
edition = "2021"
[lib]
path = "lib.rs"
proc-macro = true
[dependencies]
syn = { workspace = true }
quote = { workspace = true }
proc-macro2 = { workspace = true }

View File

@ -0,0 +1,124 @@
use proc_macro2::TokenStream;
use quote::{quote, quote_spanned};
use syn::spanned::Spanned;
use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, GenericParam};
pub(crate) fn derive_decodable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = input.ident;
// Add a bound `T: Decodable` to every type parameter T.
let mut generics = input.generics;
for param in &mut generics.params {
if let GenericParam::Type(ref mut type_param) = *param {
type_param
.bounds
.push(parse_quote!(crate::socket::Decodable));
}
}
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let sum = gen_size_sum(&input.data);
let encode = gen_encode(&input.data);
let decode = gen_decode(&input.data);
let expanded = quote! {
// The generated impl.
impl #impl_generics crate::socket::Encodable for #name #ty_generics #where_clause {
fn encoded_len(&self) -> usize {
#sum
}
fn encode(&self, w: &mut impl std::io::Write) -> std::io::Result<()> {
#encode
Ok(())
}
}
impl #impl_generics crate::socket::Decodable for #name #ty_generics #where_clause {
fn decode(r: &mut impl std::io::Read) -> std::io::Result<Self> {
let val = #decode;
Ok(val)
}
}
};
proc_macro::TokenStream::from(expanded)
}
// Generate an expression to sum up the size of each field.
fn gen_size_sum(data: &Data) -> TokenStream {
match *data {
Data::Struct(ref data) => {
match data.fields {
Fields::Named(ref fields) => {
// Expands to an expression like
//
// 0 + self.x.encoded_len() + self.y.encoded_len() + self.z.encoded_len()
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
quote_spanned! { f.span() =>
crate::socket::Encodable::encoded_len(&self.#name)
}
});
quote! {
0 #(+ #recurse)*
}
}
_ => unimplemented!(),
}
}
Data::Enum(_) | Data::Union(_) => unimplemented!(),
}
}
// Generate an expression to encode each field.
fn gen_encode(data: &Data) -> TokenStream {
match *data {
Data::Struct(ref data) => {
match data.fields {
Fields::Named(ref fields) => {
// Expands to an expression like
//
// self.x.encode(w)?; self.y.encode(w)?; self.z.encode(w)?;
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
quote_spanned! { f.span() =>
crate::socket::Encodable::encode(&self.#name, w)?;
}
});
quote! {
#(#recurse)*
}
}
_ => unimplemented!(),
}
}
Data::Enum(_) | Data::Union(_) => unimplemented!(),
}
}
// Generate an expression to decode each field.
fn gen_decode(data: &Data) -> TokenStream {
match *data {
Data::Struct(ref data) => {
match data.fields {
Fields::Named(ref fields) => {
// Expands to an expression like
//
// Self { x: Decodable::decode(r)?, y: Decodable::decode(r)?, }
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
quote_spanned! { f.span() =>
#name: crate::socket::Decodable::decode(r)?,
}
});
quote! {
Self { #(#recurse)* }
}
}
_ => unimplemented!(),
}
}
Data::Enum(_) | Data::Union(_) => unimplemented!(),
}
}

View File

@ -0,0 +1,8 @@
use proc_macro::TokenStream;
mod decodable;
#[proc_macro_derive(Decodable)]
pub fn derive_decodable(input: TokenStream) -> TokenStream {
decodable::derive_decodable(input)
}

View File

@ -6,29 +6,44 @@
#include <string>
#include <vector>
#include <base.hpp>
struct sock_cred : public ucred {
std::string context;
};
bool get_client_cred(int fd, sock_cred *cred);
int read_int(int fd);
int read_int_be(int fd);
void write_int(int fd, int val);
void write_int_be(int fd, int val);
std::string read_string(int fd);
bool read_string(int fd, std::string &str);
void write_string(int fd, std::string_view str);
template<typename T> requires(std::is_trivially_copyable_v<T>)
T read_any(int fd) {
T val;
if (xxread(fd, &val, sizeof(val)) != sizeof(val))
return -1;
return val;
}
template<typename T> requires(std::is_trivially_copyable_v<T>)
void write_any(int fd, T val) {
if (fd < 0) return;
xwrite(fd, &val, sizeof(val));
}
template<typename T> requires(std::is_trivially_copyable_v<T>)
void write_vector(int fd, const std::vector<T> &vec) {
write_int(fd, static_cast<int>(vec.size()));
write_any(fd, vec.size());
xwrite(fd, vec.data(), vec.size() * sizeof(T));
}
template<typename T> requires(std::is_trivially_copyable_v<T>)
bool read_vector(int fd, std::vector<T> &vec) {
int size = read_int(fd);
if (size == -1) return false;
auto size = read_any<size_t>(fd);
vec.resize(size);
return xread(fd, vec.data(), size * sizeof(T)) == size * sizeof(T);
}
bool get_client_cred(int fd, sock_cred *cred);
static inline int read_int(int fd) { return read_any<int>(fd); }
int read_int_be(int fd);
static inline void write_int(int fd, int val) { write_any(fd, val); }
void write_int_be(int fd, int val);
std::string read_string(int fd);
bool read_string(int fd, std::string &str);
void write_string(int fd, std::string_view str);

View File

@ -19,31 +19,17 @@ bool get_client_cred(int fd, sock_cred *cred) {
return true;
}
int read_int(int fd) {
int val;
if (xxread(fd, &val, sizeof(val)) != sizeof(val))
return -1;
return val;
}
int read_int_be(int fd) {
return ntohl(read_int(fd));
}
void write_int(int fd, int val) {
if (fd < 0) return;
xwrite(fd, &val, sizeof(val));
}
void write_int_be(int fd, int val) {
write_int(fd, htonl(val));
}
bool read_string(int fd, std::string &str) {
int len = read_int(fd);
str.clear();
if (len < 0)
return false;
auto len = read_any<size_t>(fd);
str.resize(len);
return xxread(fd, str.data(), len) == len;
}
@ -56,6 +42,6 @@ string read_string(int fd) {
void write_string(int fd, string_view str) {
if (fd < 0) return;
write_int(fd, str.size());
write_any(fd, str.size());
xwrite(fd, str.data(), str.size());
}

View File

@ -1,56 +1,140 @@
use base::{libc, warn, ReadExt, ResultExt, WriteExt};
use bytemuck::{bytes_of, bytes_of_mut, Pod, Zeroable};
use bytemuck::{bytes_of, bytes_of_mut, Zeroable};
use std::io;
use std::io::{ErrorKind, IoSlice, IoSliceMut, Read, Write};
use std::mem::ManuallyDrop;
use std::os::fd::{FromRawFd, IntoRawFd, OwnedFd, RawFd};
use std::os::unix::net::{AncillaryData, SocketAncillary, UnixStream};
pub trait Encodable {
fn encoded_len(&self) -> usize;
fn encode(&self, w: &mut impl Write) -> io::Result<()>;
}
pub trait Decodable: Sized + Encodable {
fn decode(r: &mut impl Read) -> io::Result<Self>;
}
macro_rules! impl_pod_encodable {
($($t:ty)*) => ($(
impl Encodable for $t {
#[inline(always)]
fn encoded_len(&self) -> usize {
size_of::<Self>()
}
#[inline(always)]
fn encode(&self, w: &mut impl Write) -> io::Result<()> {
w.write_pod(self)
}
}
impl Decodable for $t {
#[inline(always)]
fn decode(r: &mut impl Read) -> io::Result<Self> {
let mut val = Self::zeroed();
r.read_pod(&mut val)?;
Ok(val)
}
}
)*)
}
impl_pod_encodable! { u8 i32 usize }
impl Encodable for bool {
#[inline(always)]
fn encoded_len(&self) -> usize {
size_of::<u8>()
}
#[inline(always)]
fn encode(&self, w: &mut impl Write) -> io::Result<()> {
match *self {
true => 1u8.encode(w),
false => 0u8.encode(w),
}
}
}
impl Decodable for bool {
#[inline(always)]
fn decode(r: &mut impl Read) -> io::Result<Self> {
Ok(u8::decode(r)? != 0)
}
}
impl<T: Decodable> Encodable for Vec<T> {
fn encoded_len(&self) -> usize {
size_of::<usize>() + size_of::<T>() * self.len()
}
fn encode(&self, w: &mut impl Write) -> io::Result<()> {
self.len().encode(w)?;
self.iter().try_for_each(|e| e.encode(w))
}
}
impl<T: Decodable> Decodable for Vec<T> {
fn decode(r: &mut impl Read) -> io::Result<Self> {
let len = usize::decode(r)?;
let mut val = Vec::with_capacity(len);
for _ in 0..len {
val.push(T::decode(r)?);
}
Ok(val)
}
}
impl Encodable for str {
fn encoded_len(&self) -> usize {
size_of::<usize>() + self.as_bytes().len()
}
fn encode(&self, w: &mut impl Write) -> io::Result<()> {
self.as_bytes().len().encode(w)?;
w.write_all(self.as_bytes())
}
}
impl Encodable for String {
fn encoded_len(&self) -> usize {
self.as_str().encoded_len()
}
fn encode(&self, w: &mut impl Write) -> io::Result<()> {
self.as_str().encode(w)
}
}
impl Decodable for String {
fn decode(r: &mut impl Read) -> io::Result<String> {
let len = usize::decode(r)?;
let mut val = String::with_capacity(len);
let mut r = r.take(len as u64);
r.read_to_string(&mut val)?;
Ok(val)
}
}
pub trait IpcRead {
fn ipc_read_int(&mut self) -> io::Result<i32>;
fn ipc_read_string(&mut self) -> io::Result<String>;
fn ipc_read_vec<E: Pod>(&mut self) -> io::Result<Vec<E>>;
fn read_decodable<E: Decodable>(&mut self) -> io::Result<E>;
}
impl<T: Read> IpcRead for T {
fn ipc_read_int(&mut self) -> io::Result<i32> {
let mut val: i32 = 0;
self.read_pod(&mut val)?;
Ok(val)
}
fn ipc_read_string(&mut self) -> io::Result<String> {
let len = self.ipc_read_int()?;
let mut val = "".to_string();
self.take(len as u64).read_to_string(&mut val)?;
Ok(val)
}
fn ipc_read_vec<E: Pod>(&mut self) -> io::Result<Vec<E>> {
let len = self.ipc_read_int()? as usize;
let mut vec = Vec::new();
let mut val: E = Zeroable::zeroed();
for _ in 0..len {
self.read_pod(&mut val)?;
vec.push(val);
}
Ok(vec)
#[inline(always)]
fn read_decodable<E: Decodable>(&mut self) -> io::Result<E> {
E::decode(self)
}
}
pub trait IpcWrite {
fn ipc_write_int(&mut self, val: i32) -> io::Result<()>;
fn ipc_write_string(&mut self, val: &str) -> io::Result<()>;
fn write_encodable<E: Encodable + ?Sized>(&mut self, val: &E) -> io::Result<()>;
}
impl<T: Write> IpcWrite for T {
fn ipc_write_int(&mut self, val: i32) -> io::Result<()> {
self.write_pod(&val)
}
fn ipc_write_string(&mut self, val: &str) -> io::Result<()> {
self.ipc_write_int(val.len() as i32)?;
self.write_all(val.as_bytes())
#[inline(always)]
fn write_encodable<E: Encodable + ?Sized>(&mut self, val: &E) -> io::Result<()> {
val.encode(self)
}
}
@ -63,7 +147,7 @@ pub trait UnixSocketExt {
impl UnixSocketExt for UnixStream {
fn send_fds(&mut self, fds: &[RawFd]) -> io::Result<()> {
match fds.len() {
0 => self.ipc_write_int(-1)?,
0 => self.write_encodable(&-1)?,
len => {
// 4k buffer is reasonable enough
let mut buf = [0u8; 4096];

View File

@ -59,7 +59,7 @@ impl MagiskD {
let mut client = unsafe { UnixStream::from_raw_fd(client) };
let _: LoggedResult<()> = try {
let code = ZygiskRequest {
repr: client.ipc_read_int()?,
repr: client.read_decodable()?,
};
match code {
ZygiskRequest::GetInfo => self.get_process_info(client)?,
@ -102,7 +102,7 @@ impl MagiskD {
fn connect_zygiskd(&self, mut client: UnixStream) {
let mut zygiskd_sockets = self.zygiskd_sockets.lock().unwrap();
let result: LoggedResult<()> = try {
let is_64_bit = client.ipc_read_int()? != 0;
let is_64_bit: bool = client.read_decodable()?;
let socket = if is_64_bit {
&mut zygiskd_sockets.1
} else {
@ -135,7 +135,7 @@ impl MagiskD {
if let Some(module_fds) = self.get_module_fds(is_64_bit) {
local.send_fds(&module_fds)?;
}
if local.ipc_read_int()? != 0 {
if local.read_decodable::<i32>()? != 0 {
Err(LoggedError::default())?;
}
local
@ -148,9 +148,9 @@ impl MagiskD {
}
fn get_process_info(&self, mut client: UnixStream) -> LoggedResult<()> {
let uid = client.ipc_read_int()?;
let process = client.ipc_read_string()?;
let is_64_bit = client.ipc_read_int()? != 0;
let uid: i32 = client.read_decodable()?;
let process: String = client.read_decodable()?;
let is_64_bit: bool = client.read_decodable()?;
let mut flags: u32 = 0;
update_deny_flags(uid, &process, &mut flags);
if self.get_manager_uid(to_user_id(uid)) == uid {
@ -180,7 +180,7 @@ impl MagiskD {
}
// Read all failed modules
let failed_ids: Vec<i32> = client.ipc_read_vec()?;
let failed_ids: Vec<i32> = client.read_decodable()?;
if let Some(module_list) = self.module_list.get() {
for id in failed_ids {
let mut buf = Utf8CStrBufArr::default();
@ -197,7 +197,7 @@ impl MagiskD {
}
fn get_mod_dir(&self, mut client: UnixStream) -> LoggedResult<()> {
let id = client.ipc_read_int()?;
let id: i32 = client.read_decodable()?;
let module = &self.module_list.get().unwrap()[id as usize];
let mut buf = Utf8CStrBufArr::default();
let dir = FsPathBuf::new(&mut buf).join(MODULEROOT).join(&module.name);

View File

@ -79,9 +79,9 @@ bool ZygiskModule::valid() const {
int ZygiskModule::connectCompanion() const {
if (int fd = zygisk_request(+ZygiskRequest::ConnectCompanion); fd >= 0) {
#ifdef __LP64__
write_int(fd, 1);
write_any<bool>(fd, true);
#else
write_int(fd, 0);
write_any<bool>(fd, false);
#endif
write_int(fd, id);
return fd;
@ -213,9 +213,9 @@ int ZygiskContext::get_module_info(int uid, rust::Vec<int> &fds) {
write_int(fd, uid);
write_string(fd, process);
#ifdef __LP64__
write_int(fd, 1);
write_any<bool>(fd, true);
#else
write_int(fd, 0);
write_any<bool>(fd, false);
#endif
xxread(fd, &info_flags, sizeof(info_flags));
if (zygisk_should_load_module(info_flags)) {