Introduce owned_fd

This commit is contained in:
topjohnwu 2023-11-15 13:58:51 -08:00
parent 1ff7b9055f
commit 7c2e93d266
2 changed files with 26 additions and 16 deletions

View File

@ -11,11 +11,10 @@
#include "xwrap.hpp" #include "xwrap.hpp"
#define DISALLOW_COPY_AND_MOVE(clazz) \ #define DISALLOW_COPY_AND_MOVE(clazz) \
clazz(const clazz &) = delete; \ clazz(const clazz&) = delete; \
clazz(clazz &&) = delete; clazz(clazz &&) = delete;
#define ALLOW_MOVE_ONLY(clazz) \ #define ALLOW_MOVE_ONLY(clazz) \
clazz() = default; \
clazz(const clazz&) = delete; \ clazz(const clazz&) = delete; \
clazz(clazz &&o) { swap(o); } \ clazz(clazz &&o) { swap(o); } \
clazz& operator=(clazz &&o) { swap(o); return *this; } clazz& operator=(clazz &&o) { swap(o); return *this; }
@ -211,6 +210,7 @@ class byte_channel;
struct heap_data : public byte_data { struct heap_data : public byte_data {
ALLOW_MOVE_ONLY(heap_data) ALLOW_MOVE_ONLY(heap_data)
heap_data() = default;
explicit heap_data(size_t sz) : byte_data(calloc(sz, 1), sz) {} explicit heap_data(size_t sz) : byte_data(calloc(sz, 1), sz) {}
~heap_data() { free(_buf); } ~heap_data() { free(_buf); }
@ -218,6 +218,21 @@ struct heap_data : public byte_data {
friend byte_channel; friend byte_channel;
}; };
struct owned_fd {
ALLOW_MOVE_ONLY(owned_fd)
owned_fd() : fd(-1) {}
owned_fd(int fd) : fd(fd) {}
~owned_fd() { close(fd); fd = -1; }
operator int() { return fd; }
int release() { int f = fd; fd = -1; return f; }
void swap(owned_fd &owned) { std::swap(fd, owned.fd); }
private:
int fd;
};
rust::Vec<size_t> mut_u8_patch( rust::Vec<size_t> mut_u8_patch(
rust::Slice<uint8_t> buf, rust::Slice<uint8_t> buf,
rust::Slice<const uint8_t> from, rust::Slice<const uint8_t> from,

View File

@ -211,7 +211,7 @@ static bool is_client(pid_t pid) {
} }
static void handle_request(pollfd *pfd) { static void handle_request(pollfd *pfd) {
int client = xaccept4(pfd->fd, nullptr, nullptr, SOCK_CLOEXEC); owned_fd client = xaccept4(pfd->fd, nullptr, nullptr, SOCK_CLOEXEC);
// Verify client credentials // Verify client credentials
sock_cred cred; sock_cred cred;
@ -221,7 +221,7 @@ static void handle_request(pollfd *pfd) {
if (!get_client_cred(client, &cred)) { if (!get_client_cred(client, &cred)) {
// Client died // Client died
goto done; return;
} }
is_root = cred.uid == AID_ROOT; is_root = cred.uid == AID_ROOT;
is_zygote = cred.context == "u:r:zygote:s0"; is_zygote = cred.context == "u:r:zygote:s0";
@ -229,7 +229,7 @@ static void handle_request(pollfd *pfd) {
if (!is_root && !is_zygote && !is_client(cred.pid)) { if (!is_root && !is_zygote && !is_client(cred.pid)) {
// Unsupported client state // Unsupported client state
write_int(client, MainResponse::ACCESS_DENIED); write_int(client, MainResponse::ACCESS_DENIED);
goto done; return;
} }
code = read_int(client); code = read_int(client);
@ -237,7 +237,7 @@ static void handle_request(pollfd *pfd) {
code == MainRequest::_SYNC_BARRIER_ || code == MainRequest::_SYNC_BARRIER_ ||
code == MainRequest::_STAGE_BARRIER_) { code == MainRequest::_STAGE_BARRIER_) {
// Unknown request code // Unknown request code
goto done; return;
} }
// Check client permissions // Check client permissions
@ -251,20 +251,20 @@ static void handle_request(pollfd *pfd) {
case MainRequest::STOP_DAEMON: case MainRequest::STOP_DAEMON:
if (!is_root) { if (!is_root) {
write_int(client, MainResponse::ROOT_REQUIRED); write_int(client, MainResponse::ROOT_REQUIRED);
goto done; return;
} }
break; break;
case MainRequest::REMOVE_MODULES: case MainRequest::REMOVE_MODULES:
if (!is_root && cred.uid != AID_SHELL) { if (!is_root && cred.uid != AID_SHELL) {
write_int(client, MainResponse::ACCESS_DENIED); write_int(client, MainResponse::ACCESS_DENIED);
goto done; return;
} }
break; break;
case MainRequest::ZYGISK: case MainRequest::ZYGISK:
if (!is_zygote) { if (!is_zygote) {
// Invalid client context // Invalid client context
write_int(client, MainResponse::ACCESS_DENIED); write_int(client, MainResponse::ACCESS_DENIED);
goto done; return;
} }
break; break;
default: default:
@ -275,16 +275,11 @@ static void handle_request(pollfd *pfd) {
if (code < MainRequest::_SYNC_BARRIER_) { if (code < MainRequest::_SYNC_BARRIER_) {
handle_request_sync(client, code); handle_request_sync(client, code);
goto done;
} else if (code < MainRequest::_STAGE_BARRIER_) { } else if (code < MainRequest::_STAGE_BARRIER_) {
exec_task([=] { handle_request_async(client, code, cred); }); exec_task([=, fd = client.release()] { handle_request_async(fd, code, cred); });
} else { } else {
exec_task([=] { boot_stage_handler(client, code); }); exec_task([=, fd = client.release()] { boot_stage_handler(fd, code); });
} }
return;
done:
close(client);
} }
static void switch_cgroup(const char *cgroup, int pid) { static void switch_cgroup(const char *cgroup, int pid) {