From ab86732c89d07a0bbee0466e70a78c0b8e9aa550 Mon Sep 17 00:00:00 2001 From: topjohnwu Date: Sat, 1 Feb 2025 01:29:08 +0800 Subject: [PATCH] Implement simple serialization over IPC --- native/src/Cargo.lock | 10 ++ native/src/Cargo.toml | 5 +- native/src/core/Cargo.toml | 1 + native/src/core/db.rs | 6 +- native/src/core/derive/Cargo.toml | 13 +++ native/src/core/derive/decodable.rs | 124 ++++++++++++++++++++++ native/src/core/derive/lib.rs | 8 ++ native/src/core/include/socket.hpp | 37 +++++-- native/src/core/socket.cpp | 18 +--- native/src/core/socket.rs | 156 +++++++++++++++++++++------- native/src/core/zygisk/daemon.rs | 16 +-- native/src/core/zygisk/module.cpp | 8 +- 12 files changed, 323 insertions(+), 79 deletions(-) create mode 100644 native/src/core/derive/Cargo.toml create mode 100644 native/src/core/derive/decodable.rs create mode 100644 native/src/core/derive/lib.rs diff --git a/native/src/Cargo.lock b/native/src/Cargo.lock index 89b2a408f..ba95d75dd 100644 --- a/native/src/Cargo.lock +++ b/native/src/Cargo.lock @@ -300,6 +300,15 @@ dependencies = [ "syn", ] +[[package]] +name = "derive" +version = "0.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "digest" version = "0.11.0-pre.9" @@ -461,6 +470,7 @@ dependencies = [ "bytemuck", "cxx", "cxx-gen", + "derive", "num-derive", "num-traits", "pb-rs", diff --git a/native/src/Cargo.toml b/native/src/Cargo.toml index 85c4853fe..8e3132bff 100644 --- a/native/src/Cargo.toml +++ b/native/src/Cargo.toml @@ -1,6 +1,6 @@ [workspace] exclude = ["external"] -members = ["base", "boot", "core", "init", "sepolicy"] +members = ["base", "boot", "core", "core/derive", "init", "sepolicy"] resolver = "2" [workspace.dependencies] @@ -26,6 +26,9 @@ bytemuck = "1.16" fdt = "0.1" const_format = "0.2" bit-set = "0.8" +syn = "2" +quote = "1" +proc-macro2 = "1" [workspace.dependencies.argh] git = "https://github.com/google/argh.git" diff --git a/native/src/core/Cargo.toml b/native/src/core/Cargo.toml index 25862b66b..08061a019 100644 --- a/native/src/core/Cargo.toml +++ b/native/src/core/Cargo.toml @@ -17,6 +17,7 @@ pb-rs = { workspace = true } [dependencies] base = { path = "../base", features = ["selinux"] } +derive = { path = "derive" } cxx = { workspace = true } num-traits = { workspace = true } num-derive = { workspace = true } diff --git a/native/src/core/db.rs b/native/src/core/db.rs index 97f73336f..05245401a 100644 --- a/native/src/core/db.rs +++ b/native/src/core/db.rs @@ -282,7 +282,7 @@ impl MagiskD { fn db_exec_for_client(&self, fd: OwnedFd) -> LoggedResult<()> { let mut file = File::from(fd); let mut reader = BufReader::new(&mut file); - let sql = reader.ipc_read_string()?; + let sql: String = reader.read_decodable()?; let mut writer = BufWriter::new(&mut file); let mut output_fn = |columns: &[String], values: &DbValues| { let mut out = "".to_string(); @@ -294,10 +294,10 @@ impl MagiskD { out.push('='); out.push_str(values.get_text(i as i32)); } - writer.ipc_write_string(&out).log().ok(); + writer.write_encodable(&out).log().ok(); }; self.db_exec_with_rows(&sql, &[], &mut output_fn); - writer.ipc_write_string("").log() + writer.write_encodable("").log() } } diff --git a/native/src/core/derive/Cargo.toml b/native/src/core/derive/Cargo.toml new file mode 100644 index 000000000..5e881c220 --- /dev/null +++ b/native/src/core/derive/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "derive" +version = "0.0.0" +edition = "2021" + +[lib] +path = "lib.rs" +proc-macro = true + +[dependencies] +syn = { workspace = true } +quote = { workspace = true } +proc-macro2 = { workspace = true } diff --git a/native/src/core/derive/decodable.rs b/native/src/core/derive/decodable.rs new file mode 100644 index 000000000..fd2d70398 --- /dev/null +++ b/native/src/core/derive/decodable.rs @@ -0,0 +1,124 @@ +use proc_macro2::TokenStream; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; +use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields, GenericParam}; + +pub(crate) fn derive_decodable(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as DeriveInput); + + let name = input.ident; + + // Add a bound `T: Decodable` to every type parameter T. + let mut generics = input.generics; + for param in &mut generics.params { + if let GenericParam::Type(ref mut type_param) = *param { + type_param + .bounds + .push(parse_quote!(crate::socket::Decodable)); + } + } + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let sum = gen_size_sum(&input.data); + let encode = gen_encode(&input.data); + let decode = gen_decode(&input.data); + + let expanded = quote! { + // The generated impl. + impl #impl_generics crate::socket::Encodable for #name #ty_generics #where_clause { + fn encoded_len(&self) -> usize { + #sum + } + fn encode(&self, w: &mut impl std::io::Write) -> std::io::Result<()> { + #encode + Ok(()) + } + } + impl #impl_generics crate::socket::Decodable for #name #ty_generics #where_clause { + fn decode(r: &mut impl std::io::Read) -> std::io::Result { + let val = #decode; + Ok(val) + } + } + }; + proc_macro::TokenStream::from(expanded) +} + +// Generate an expression to sum up the size of each field. +fn gen_size_sum(data: &Data) -> TokenStream { + match *data { + Data::Struct(ref data) => { + match data.fields { + Fields::Named(ref fields) => { + // Expands to an expression like + // + // 0 + self.x.encoded_len() + self.y.encoded_len() + self.z.encoded_len() + let recurse = fields.named.iter().map(|f| { + let name = &f.ident; + quote_spanned! { f.span() => + crate::socket::Encodable::encoded_len(&self.#name) + } + }); + quote! { + 0 #(+ #recurse)* + } + } + _ => unimplemented!(), + } + } + Data::Enum(_) | Data::Union(_) => unimplemented!(), + } +} + +// Generate an expression to encode each field. +fn gen_encode(data: &Data) -> TokenStream { + match *data { + Data::Struct(ref data) => { + match data.fields { + Fields::Named(ref fields) => { + // Expands to an expression like + // + // self.x.encode(w)?; self.y.encode(w)?; self.z.encode(w)?; + let recurse = fields.named.iter().map(|f| { + let name = &f.ident; + quote_spanned! { f.span() => + crate::socket::Encodable::encode(&self.#name, w)?; + } + }); + quote! { + #(#recurse)* + } + } + _ => unimplemented!(), + } + } + Data::Enum(_) | Data::Union(_) => unimplemented!(), + } +} + +// Generate an expression to decode each field. +fn gen_decode(data: &Data) -> TokenStream { + match *data { + Data::Struct(ref data) => { + match data.fields { + Fields::Named(ref fields) => { + // Expands to an expression like + // + // Self { x: Decodable::decode(r)?, y: Decodable::decode(r)?, } + let recurse = fields.named.iter().map(|f| { + let name = &f.ident; + quote_spanned! { f.span() => + #name: crate::socket::Decodable::decode(r)?, + } + }); + quote! { + Self { #(#recurse)* } + } + } + _ => unimplemented!(), + } + } + Data::Enum(_) | Data::Union(_) => unimplemented!(), + } +} diff --git a/native/src/core/derive/lib.rs b/native/src/core/derive/lib.rs new file mode 100644 index 000000000..6e12156e2 --- /dev/null +++ b/native/src/core/derive/lib.rs @@ -0,0 +1,8 @@ +use proc_macro::TokenStream; + +mod decodable; + +#[proc_macro_derive(Decodable)] +pub fn derive_decodable(input: TokenStream) -> TokenStream { + decodable::derive_decodable(input) +} diff --git a/native/src/core/include/socket.hpp b/native/src/core/include/socket.hpp index c51c6650d..1f4c2832b 100644 --- a/native/src/core/include/socket.hpp +++ b/native/src/core/include/socket.hpp @@ -6,29 +6,44 @@ #include #include +#include + struct sock_cred : public ucred { std::string context; }; -bool get_client_cred(int fd, sock_cred *cred); -int read_int(int fd); -int read_int_be(int fd); -void write_int(int fd, int val); -void write_int_be(int fd, int val); -std::string read_string(int fd); -bool read_string(int fd, std::string &str); -void write_string(int fd, std::string_view str); +template requires(std::is_trivially_copyable_v) +T read_any(int fd) { + T val; + if (xxread(fd, &val, sizeof(val)) != sizeof(val)) + return -1; + return val; +} + +template requires(std::is_trivially_copyable_v) +void write_any(int fd, T val) { + if (fd < 0) return; + xwrite(fd, &val, sizeof(val)); +} template requires(std::is_trivially_copyable_v) void write_vector(int fd, const std::vector &vec) { - write_int(fd, static_cast(vec.size())); + write_any(fd, vec.size()); xwrite(fd, vec.data(), vec.size() * sizeof(T)); } template requires(std::is_trivially_copyable_v) bool read_vector(int fd, std::vector &vec) { - int size = read_int(fd); - if (size == -1) return false; + auto size = read_any(fd); vec.resize(size); return xread(fd, vec.data(), size * sizeof(T)) == size * sizeof(T); } + +bool get_client_cred(int fd, sock_cred *cred); +static inline int read_int(int fd) { return read_any(fd); } +int read_int_be(int fd); +static inline void write_int(int fd, int val) { write_any(fd, val); } +void write_int_be(int fd, int val); +std::string read_string(int fd); +bool read_string(int fd, std::string &str); +void write_string(int fd, std::string_view str); diff --git a/native/src/core/socket.cpp b/native/src/core/socket.cpp index 30640d43c..cd1c22cde 100644 --- a/native/src/core/socket.cpp +++ b/native/src/core/socket.cpp @@ -19,31 +19,17 @@ bool get_client_cred(int fd, sock_cred *cred) { return true; } -int read_int(int fd) { - int val; - if (xxread(fd, &val, sizeof(val)) != sizeof(val)) - return -1; - return val; -} - int read_int_be(int fd) { return ntohl(read_int(fd)); } -void write_int(int fd, int val) { - if (fd < 0) return; - xwrite(fd, &val, sizeof(val)); -} - void write_int_be(int fd, int val) { write_int(fd, htonl(val)); } bool read_string(int fd, std::string &str) { - int len = read_int(fd); str.clear(); - if (len < 0) - return false; + auto len = read_any(fd); str.resize(len); return xxread(fd, str.data(), len) == len; } @@ -56,6 +42,6 @@ string read_string(int fd) { void write_string(int fd, string_view str) { if (fd < 0) return; - write_int(fd, str.size()); + write_any(fd, str.size()); xwrite(fd, str.data(), str.size()); } diff --git a/native/src/core/socket.rs b/native/src/core/socket.rs index fb24d3217..2cdc51f0e 100644 --- a/native/src/core/socket.rs +++ b/native/src/core/socket.rs @@ -1,56 +1,140 @@ use base::{libc, warn, ReadExt, ResultExt, WriteExt}; -use bytemuck::{bytes_of, bytes_of_mut, Pod, Zeroable}; +use bytemuck::{bytes_of, bytes_of_mut, Zeroable}; use std::io; use std::io::{ErrorKind, IoSlice, IoSliceMut, Read, Write}; use std::mem::ManuallyDrop; use std::os::fd::{FromRawFd, IntoRawFd, OwnedFd, RawFd}; use std::os::unix::net::{AncillaryData, SocketAncillary, UnixStream}; +pub trait Encodable { + fn encoded_len(&self) -> usize; + fn encode(&self, w: &mut impl Write) -> io::Result<()>; +} + +pub trait Decodable: Sized + Encodable { + fn decode(r: &mut impl Read) -> io::Result; +} + +macro_rules! impl_pod_encodable { + ($($t:ty)*) => ($( + impl Encodable for $t { + #[inline(always)] + fn encoded_len(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn encode(&self, w: &mut impl Write) -> io::Result<()> { + w.write_pod(self) + } + } + impl Decodable for $t { + #[inline(always)] + fn decode(r: &mut impl Read) -> io::Result { + let mut val = Self::zeroed(); + r.read_pod(&mut val)?; + Ok(val) + } + } + )*) +} + +impl_pod_encodable! { u8 i32 usize } + +impl Encodable for bool { + #[inline(always)] + fn encoded_len(&self) -> usize { + size_of::() + } + + #[inline(always)] + fn encode(&self, w: &mut impl Write) -> io::Result<()> { + match *self { + true => 1u8.encode(w), + false => 0u8.encode(w), + } + } +} + +impl Decodable for bool { + #[inline(always)] + fn decode(r: &mut impl Read) -> io::Result { + Ok(u8::decode(r)? != 0) + } +} + +impl Encodable for Vec { + fn encoded_len(&self) -> usize { + size_of::() + size_of::() * self.len() + } + + fn encode(&self, w: &mut impl Write) -> io::Result<()> { + self.len().encode(w)?; + self.iter().try_for_each(|e| e.encode(w)) + } +} + +impl Decodable for Vec { + fn decode(r: &mut impl Read) -> io::Result { + let len = usize::decode(r)?; + let mut val = Vec::with_capacity(len); + for _ in 0..len { + val.push(T::decode(r)?); + } + Ok(val) + } +} + +impl Encodable for str { + fn encoded_len(&self) -> usize { + size_of::() + self.as_bytes().len() + } + + fn encode(&self, w: &mut impl Write) -> io::Result<()> { + self.as_bytes().len().encode(w)?; + w.write_all(self.as_bytes()) + } +} + +impl Encodable for String { + fn encoded_len(&self) -> usize { + self.as_str().encoded_len() + } + + fn encode(&self, w: &mut impl Write) -> io::Result<()> { + self.as_str().encode(w) + } +} + +impl Decodable for String { + fn decode(r: &mut impl Read) -> io::Result { + let len = usize::decode(r)?; + let mut val = String::with_capacity(len); + let mut r = r.take(len as u64); + r.read_to_string(&mut val)?; + Ok(val) + } +} + pub trait IpcRead { - fn ipc_read_int(&mut self) -> io::Result; - fn ipc_read_string(&mut self) -> io::Result; - fn ipc_read_vec(&mut self) -> io::Result>; + fn read_decodable(&mut self) -> io::Result; } impl IpcRead for T { - fn ipc_read_int(&mut self) -> io::Result { - let mut val: i32 = 0; - self.read_pod(&mut val)?; - Ok(val) - } - - fn ipc_read_string(&mut self) -> io::Result { - let len = self.ipc_read_int()?; - let mut val = "".to_string(); - self.take(len as u64).read_to_string(&mut val)?; - Ok(val) - } - - fn ipc_read_vec(&mut self) -> io::Result> { - let len = self.ipc_read_int()? as usize; - let mut vec = Vec::new(); - let mut val: E = Zeroable::zeroed(); - for _ in 0..len { - self.read_pod(&mut val)?; - vec.push(val); - } - Ok(vec) + #[inline(always)] + fn read_decodable(&mut self) -> io::Result { + E::decode(self) } } pub trait IpcWrite { - fn ipc_write_int(&mut self, val: i32) -> io::Result<()>; - fn ipc_write_string(&mut self, val: &str) -> io::Result<()>; + fn write_encodable(&mut self, val: &E) -> io::Result<()>; } impl IpcWrite for T { - fn ipc_write_int(&mut self, val: i32) -> io::Result<()> { - self.write_pod(&val) - } - - fn ipc_write_string(&mut self, val: &str) -> io::Result<()> { - self.ipc_write_int(val.len() as i32)?; - self.write_all(val.as_bytes()) + #[inline(always)] + fn write_encodable(&mut self, val: &E) -> io::Result<()> { + val.encode(self) } } @@ -63,7 +147,7 @@ pub trait UnixSocketExt { impl UnixSocketExt for UnixStream { fn send_fds(&mut self, fds: &[RawFd]) -> io::Result<()> { match fds.len() { - 0 => self.ipc_write_int(-1)?, + 0 => self.write_encodable(&-1)?, len => { // 4k buffer is reasonable enough let mut buf = [0u8; 4096]; diff --git a/native/src/core/zygisk/daemon.rs b/native/src/core/zygisk/daemon.rs index 77c649619..36de82772 100644 --- a/native/src/core/zygisk/daemon.rs +++ b/native/src/core/zygisk/daemon.rs @@ -59,7 +59,7 @@ impl MagiskD { let mut client = unsafe { UnixStream::from_raw_fd(client) }; let _: LoggedResult<()> = try { let code = ZygiskRequest { - repr: client.ipc_read_int()?, + repr: client.read_decodable()?, }; match code { ZygiskRequest::GetInfo => self.get_process_info(client)?, @@ -102,7 +102,7 @@ impl MagiskD { fn connect_zygiskd(&self, mut client: UnixStream) { let mut zygiskd_sockets = self.zygiskd_sockets.lock().unwrap(); let result: LoggedResult<()> = try { - let is_64_bit = client.ipc_read_int()? != 0; + let is_64_bit: bool = client.read_decodable()?; let socket = if is_64_bit { &mut zygiskd_sockets.1 } else { @@ -135,7 +135,7 @@ impl MagiskD { if let Some(module_fds) = self.get_module_fds(is_64_bit) { local.send_fds(&module_fds)?; } - if local.ipc_read_int()? != 0 { + if local.read_decodable::()? != 0 { Err(LoggedError::default())?; } local @@ -148,9 +148,9 @@ impl MagiskD { } fn get_process_info(&self, mut client: UnixStream) -> LoggedResult<()> { - let uid = client.ipc_read_int()?; - let process = client.ipc_read_string()?; - let is_64_bit = client.ipc_read_int()? != 0; + let uid: i32 = client.read_decodable()?; + let process: String = client.read_decodable()?; + let is_64_bit: bool = client.read_decodable()?; let mut flags: u32 = 0; update_deny_flags(uid, &process, &mut flags); if self.get_manager_uid(to_user_id(uid)) == uid { @@ -180,7 +180,7 @@ impl MagiskD { } // Read all failed modules - let failed_ids: Vec = client.ipc_read_vec()?; + let failed_ids: Vec = client.read_decodable()?; if let Some(module_list) = self.module_list.get() { for id in failed_ids { let mut buf = Utf8CStrBufArr::default(); @@ -197,7 +197,7 @@ impl MagiskD { } fn get_mod_dir(&self, mut client: UnixStream) -> LoggedResult<()> { - let id = client.ipc_read_int()?; + let id: i32 = client.read_decodable()?; let module = &self.module_list.get().unwrap()[id as usize]; let mut buf = Utf8CStrBufArr::default(); let dir = FsPathBuf::new(&mut buf).join(MODULEROOT).join(&module.name); diff --git a/native/src/core/zygisk/module.cpp b/native/src/core/zygisk/module.cpp index a9ce48593..bde9c8828 100644 --- a/native/src/core/zygisk/module.cpp +++ b/native/src/core/zygisk/module.cpp @@ -79,9 +79,9 @@ bool ZygiskModule::valid() const { int ZygiskModule::connectCompanion() const { if (int fd = zygisk_request(+ZygiskRequest::ConnectCompanion); fd >= 0) { #ifdef __LP64__ - write_int(fd, 1); + write_any(fd, true); #else - write_int(fd, 0); + write_any(fd, false); #endif write_int(fd, id); return fd; @@ -213,9 +213,9 @@ int ZygiskContext::get_module_info(int uid, rust::Vec &fds) { write_int(fd, uid); write_string(fd, process); #ifdef __LP64__ - write_int(fd, 1); + write_any(fd, true); #else - write_int(fd, 0); + write_any(fd, false); #endif xxread(fd, &info_flags, sizeof(info_flags)); if (zygisk_should_load_module(info_flags)) {