From a5768e02eaf4805e2766adba440a4c593f815a57 Mon Sep 17 00:00:00 2001 From: topjohnwu Date: Sat, 20 May 2023 14:19:40 -0700 Subject: [PATCH] Cleanup byte_channel implementation --- native/src/base/files.hpp | 45 +++++++++++++++++++----------- native/src/base/include/stream.hpp | 18 ++++-------- native/src/base/stream.cpp | 41 ++++++++++++--------------- native/src/sepolicy/policydb.cpp | 10 ++----- 4 files changed, 55 insertions(+), 59 deletions(-) diff --git a/native/src/base/files.hpp b/native/src/base/files.hpp index 97dd12328..2aa81a055 100644 --- a/native/src/base/files.hpp +++ b/native/src/base/files.hpp @@ -26,19 +26,6 @@ struct file_attr { char con[128]; }; -struct byte_data { - using str_pairs = std::initializer_list>; - - uint8_t *buf = nullptr; - size_t sz = 0; - - int patch(str_pairs list) { return patch(true, list); } - int patch(bool log, str_pairs list); - bool contains(std::string_view pattern, bool log = true) const; -protected: - void swap(byte_data &o); -}; - struct mount_info { unsigned int id; unsigned int parent; @@ -56,13 +43,37 @@ struct mount_info { std::string fs_option; }; +struct byte_data { + using str_pairs = std::initializer_list>; + + uint8_t *buf = nullptr; + size_t sz = 0; + + int patch(str_pairs list) { return patch(true, list); } + int patch(bool log, str_pairs list); + bool contains(std::string_view pattern, bool log = true) const; +protected: + void swap(byte_data &o); +}; + +#define MOVE_ONLY(clazz) \ +clazz() = default; \ +clazz(const clazz&) = delete; \ +clazz(clazz &&o) { swap(o); } \ +clazz& operator=(clazz &&o) { swap(o); return *this; } + +struct heap_data : public byte_data { + MOVE_ONLY(heap_data) + + explicit heap_data(size_t sz) { this->sz = sz; buf = new uint8_t[sz]; } + ~heap_data() { free(buf); } +}; + struct mmap_data : public byte_data { - mmap_data() = default; - mmap_data(const mmap_data&) = delete; - mmap_data(mmap_data &&o) { swap(o); } + MOVE_ONLY(mmap_data) + mmap_data(const char *name, bool rw = false); ~mmap_data() { if (buf) munmap(buf, sz); } - mmap_data& operator=(mmap_data &&other) { swap(other); return *this; } }; extern "C" { diff --git a/native/src/base/include/stream.hpp b/native/src/base/include/stream.hpp index e1127001d..474208510 100644 --- a/native/src/base/include/stream.hpp +++ b/native/src/base/include/stream.hpp @@ -27,26 +27,23 @@ protected: class chunk_out_stream : public filter_out_stream { public: chunk_out_stream(out_strm_ptr &&base, size_t buf_sz, size_t chunk_sz) - : filter_out_stream(std::move(base)), chunk_sz(chunk_sz), buf_sz(buf_sz) {} + : filter_out_stream(std::move(base)), chunk_sz(chunk_sz), data(buf_sz) {} chunk_out_stream(out_strm_ptr &&base, size_t buf_sz = 4096) : chunk_out_stream(std::move(base), buf_sz, buf_sz) {} - ~chunk_out_stream() override { delete[] _buf; } - bool write(const void *buf, size_t len) final; protected: // Classes inheriting this class has to call finalize() in its destructor void finalize(); - virtual bool write_chunk(const void *buf, size_t len, bool final) = 0; + virtual bool write_chunk(const void *buf, size_t len, bool final); size_t chunk_sz; private: - size_t buf_sz; size_t buf_off = 0; - uint8_t *_buf = nullptr; + heap_data data; }; struct in_stream { @@ -67,21 +64,18 @@ using channel_ptr = std::unique_ptr; // Byte channel that dynamically allocates memory class byte_channel : public channel { public: - byte_channel(uint8_t *&buf, size_t &len); - template - byte_channel(Byte *&buf, size_t &len) : byte_channel(reinterpret_cast(buf), len) {} + byte_channel(heap_data &data) : _data(data) {} ssize_t read(void *buf, size_t len) override; bool write(const void *buf, size_t len) override; off_t seek(off_t off, int whence) override; private: - uint8_t *&_buf; - size_t &_len; + heap_data &_data; size_t _pos = 0; size_t _cap = 0; - void resize(size_t new_pos, bool zero = false); + void resize(size_t new_sz, bool zero = false); }; class file_channel : public channel { diff --git a/native/src/base/stream.cpp b/native/src/base/stream.cpp index ba7c2d00f..b66109a25 100644 --- a/native/src/base/stream.cpp +++ b/native/src/base/stream.cpp @@ -95,9 +95,9 @@ bool chunk_out_stream::write(const void *_in, size_t len) { // Enough input for a chunk const uint8_t *src; if (buf_off) { - src = _buf; + src = data.buf; auto copy = chunk_sz - buf_off; - memcpy(_buf + buf_off, in, copy); + memcpy(data.buf + buf_off, in, copy); in += copy; len -= copy; buf_off = 0; @@ -110,10 +110,7 @@ bool chunk_out_stream::write(const void *_in, size_t len) { return false; } else { // Buffer internally - if (!_buf) { - _buf = new uint8_t[buf_sz]; - } - memcpy(_buf + buf_off, in, len); + memcpy(data.buf + buf_off, in, len); buf_off += len; break; } @@ -121,33 +118,31 @@ bool chunk_out_stream::write(const void *_in, size_t len) { return true; } +bool chunk_out_stream::write_chunk(const void *buf, size_t len, bool) { + return base->write(buf, len); +} + void chunk_out_stream::finalize() { if (buf_off) { - if (!write_chunk(_buf, buf_off, true)) { + if (!write_chunk(data.buf, buf_off, true)) { LOGE("Error in finalize, file truncated\n"); } - delete[] _buf; - _buf = nullptr; buf_off = 0; } } -byte_channel::byte_channel(uint8_t *&buf, size_t &len) : _buf(buf), _len(len) { - buf = nullptr; - len = 0; -} - ssize_t byte_channel::read(void *buf, size_t len) { - len = std::min((size_t) len, _len - _pos); - memcpy(buf, _buf + _pos, len); + len = std::min((size_t) len, _data.sz - _pos); + memcpy(buf, _data.buf + _pos, len); + _pos += len; return len; } bool byte_channel::write(const void *buf, size_t len) { resize(_pos + len); - memcpy(_buf + _pos, buf, len); + memcpy(_data.buf + _pos, buf, len); _pos += len; - _len = std::max(_len, _pos); + _data.sz = std::max(_data.sz, _pos); return true; } @@ -158,7 +153,7 @@ off_t byte_channel::seek(off_t off, int whence) { np = _pos + off; break; case SEEK_END: - np = _len + off; + np = _data.sz + off; break; case SEEK_SET: np = off; @@ -171,17 +166,17 @@ off_t byte_channel::seek(off_t off, int whence) { return np; } -void byte_channel::resize(size_t new_pos, bool zero) { +void byte_channel::resize(size_t new_sz, bool zero) { bool resize = false; size_t old_cap = _cap; - while (new_pos > _cap) { + while (new_sz > _cap) { _cap = _cap ? (_cap << 1) - (_cap >> 1) : 1 << 12; resize = true; } if (resize) { - _buf = (uint8_t *) realloc(_buf, _cap); + _data.buf = (uint8_t *) realloc(_data.buf, _cap); if (zero) - memset(_buf + old_cap, 0, _cap - old_cap); + memset(_data.buf + old_cap, 0, _cap - old_cap); } } diff --git a/native/src/sepolicy/policydb.cpp b/native/src/sepolicy/policydb.cpp index be8862976..540990616 100644 --- a/native/src/sepolicy/policydb.cpp +++ b/native/src/sepolicy/policydb.cpp @@ -234,14 +234,10 @@ sepol_impl::~sepol_impl() { } bool sepolicy::to_file(const char *file) { - uint8_t *data; - size_t len; - // No partial writes are allowed to /sys/fs/selinux/load, thus the reason why we // first dump everything into memory, then directly call write system call - - auto fp = make_channel_fp(data, len); - run_finally fin([=]{ free(data); }); + heap_data data; + auto fp = make_channel_fp(data); policy_file_t pf; policy_file_init(&pf); @@ -255,7 +251,7 @@ bool sepolicy::to_file(const char *file) { int fd = xopen(file, O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC, 0644); if (fd < 0) return false; - xwrite(fd, data, len); + xwrite(fd, data.buf, data.sz); close(fd); return true;