diff --git a/native/jni/magiskpolicy/policydb.cpp b/native/jni/magiskpolicy/policydb.cpp index 120c209ff..717ddf5c3 100644 --- a/native/jni/magiskpolicy/policydb.cpp +++ b/native/jni/magiskpolicy/policydb.cpp @@ -7,6 +7,7 @@ #include #include +#include #include "magiskpolicy.h" #include "sepolicy.h" @@ -180,7 +181,7 @@ int dump_policydb(const char *file) { size_t len; pf.type = PF_USE_STDIO; - pf.fp = open_memfile(data, len); + pf.fp = open_stream(data, len); if (policydb_write(magisk_policydb, &pf)) { LOGE("Fail to create policy image\n"); return 1; diff --git a/native/jni/utils/Android.mk b/native/jni/utils/Android.mk index 67aebd666..a746c427b 100644 --- a/native/jni/utils/Android.mk +++ b/native/jni/utils/Android.mk @@ -11,6 +11,7 @@ LOCAL_SRC_FILES := \ selinux.cpp \ logging.cpp \ cpio.cpp \ - xwrap.cpp + xwrap.cpp \ + stream.cpp include $(BUILD_STATIC_LIBRARY) diff --git a/native/jni/utils/file.cpp b/native/jni/utils/file.cpp index c85523359..42f8463b5 100644 --- a/native/jni/utils/file.cpp +++ b/native/jni/utils/file.cpp @@ -407,87 +407,3 @@ void parse_mnt(const char *file, const function &fn) { } } } - -struct io_buf { - uint8_t *&buf; - size_t &len; - size_t cap = 0; - size_t pos = 0; - - io_buf(uint8_t *&buf, size_t &len) : buf(buf), len(len) { - buf = nullptr; - len = 0; - } - uint8_t *cur() { - return buf + pos; - } - int max_read() { - return len - pos; - } - void resize(int new_pos, bool zero = false) { - bool resize = false; - size_t old_cap = cap; - while (new_pos > cap) { - cap = cap ? (cap << 1) - (cap >> 1) : 1 << 12; - resize = true; - } - if (resize) { - buf = (uint8_t *) xrealloc(buf, cap); - if (zero) - memset(buf + old_cap, 0, cap - old_cap); - } - } -}; - -static int mem_read(void *v, char *buf, int len) { - auto io = reinterpret_cast(v); - len = std::min(len, io->max_read()); - memcpy(buf, io->cur(), len); - return len; -} - -static int mem_write(void *v, const char *buf, int len) { - auto io = reinterpret_cast(v); - io->resize(io->pos + len); - memcpy(io->cur(), buf, len); - io->pos += len; - io->len = std::max(io->len, io->pos); - return len; -} - -static fpos_t mem_seek(void *v, fpos_t off, int whence) { - auto io = reinterpret_cast(v); - off_t new_pos; - switch (whence) { - case SEEK_CUR: - new_pos = io->pos + off; - break; - case SEEK_END: - new_pos = io->len + off; - break; - case SEEK_SET: - new_pos = off; - break; - default: - return -1; - } - if (new_pos < 0) - return -1; - - io->resize(new_pos, true); - io->pos = new_pos; - return new_pos; -} - -static int mem_close(void *v) { - auto io = reinterpret_cast(v); - delete io; - return 0; -} - -FILE *open_memfile(uint8_t *&buf, size_t &len) { - auto io = new io_buf(buf, len); - FILE *fp = funopen(io, mem_read, mem_write, mem_seek, mem_close); - setbuf(fp, nullptr); - return fp; -} diff --git a/native/jni/utils/files.h b/native/jni/utils/files.h index a0c3a12fc..e23fba6b1 100644 --- a/native/jni/utils/files.h +++ b/native/jni/utils/files.h @@ -38,7 +38,6 @@ void *__mmap(const char *filename, size_t *size, bool rw); void frm_rf(int dirfd, std::initializer_list excl = std::initializer_list()); void clone_dir(int src, int dest, bool overwrite = true); void parse_mnt(const char *file, const std::function &fn); -FILE *open_memfile(uint8_t *&buf, size_t &len); template void full_read(const char *filename, T &buf, size_t &size) { diff --git a/native/jni/utils/include/stream.h b/native/jni/utils/include/stream.h index 5216e75e2..43d0562d0 100644 --- a/native/jni/utils/include/stream.h +++ b/native/jni/utils/include/stream.h @@ -1,9 +1,105 @@ #pragma once #include +#include #include -#include "utils.h" +#include + +class stream; + +FILE *open_stream(stream *strm); + +template +FILE *open_stream(Args &&... args) { + return open_stream(new T(args...)); +} + +/* Base classes */ + +class stream { +public: + virtual int read(void *buf, size_t len); + virtual int write(const void *buf, size_t len); + virtual off_t seek(off_t off, int whence); + virtual int close(); + virtual ~stream() = default; +}; + +class filter_stream : public stream { +public: + filter_stream(FILE *fp) : fp(fp) {} + int close() override { return fclose(fp); } + virtual ~filter_stream() { close(); } + + void set_base(FILE *f) { + if (fp) fclose(fp); + fp = f; + } + + template + void set_base(Args&&... args) { + set_base(open_stream(args...)); + } + +protected: + FILE *fp; +}; + +class filter_in_stream : public filter_stream { +public: + filter_in_stream(FILE *fp = nullptr) : filter_stream(fp) {} + int read(void *buf, size_t len) override { return fread(buf, len, 1, fp); } +}; + +class filter_out_stream : public filter_stream { +public: + filter_out_stream(FILE *fp = nullptr) : filter_stream(fp) {} + int write(const void *buf, size_t len) override { return fwrite(buf, len, 1, fp); } +}; + +class seekable_stream : public stream { +protected: + size_t _pos = 0; + + off_t new_pos(off_t off, int whence); + virtual size_t end_pos() = 0; +}; + +/* Concrete classes */ + +class byte_stream : public seekable_stream { +public: + byte_stream(uint8_t *&buf, size_t &len); + template + byte_stream(byte *&buf, size_t &len) : byte_stream(reinterpret_cast(buf), len) {} + int read(void *buf, size_t len) override; + int write(const void *buf, size_t len) override; + off_t seek(off_t off, int whence) override; + virtual ~byte_stream() = default; + +private: + uint8_t *&_buf; + size_t &_len; + size_t _cap = 0; + + void resize(size_t new_pos, bool zero = false); + size_t end_pos() override { return _len; } +}; + +class fd_stream : stream { +public: + fd_stream(int fd) : fd(fd) {} + int read(void *buf, size_t len) override; + int write(const void *buf, size_t len) override; + off_t seek(off_t off, int whence) override; + virtual ~fd_stream() = default; + +private: + int fd; +}; + +/* TODO: Replace classes below to new implementation */ class OutStream { public: diff --git a/native/jni/utils/stream.cpp b/native/jni/utils/stream.cpp new file mode 100644 index 000000000..442373cf6 --- /dev/null +++ b/native/jni/utils/stream.cpp @@ -0,0 +1,122 @@ +#include +#include + +static int strm_read(void *v, char *buf, int len) { + auto strm = reinterpret_cast(v); + return strm->read(buf, len); +} + +static int strm_write(void *v, const char *buf, int len) { + auto strm = reinterpret_cast(v); + return strm->write(buf, len); +} + +static fpos_t strm_seek(void *v, fpos_t off, int whence) { + auto strm = reinterpret_cast(v); + return strm->seek(off, whence); +} + +static int strm_close(void *v) { + auto strm = reinterpret_cast(v); + int ret = strm->close(); + delete strm; + return ret; +} + +FILE *open_stream(stream *strm) { + FILE *fp = funopen(strm, strm_read, strm_write, strm_seek, strm_close); + // Disable buffering + setbuf(fp, nullptr); + return fp; +} + +int stream::read(void *buf, size_t len) { + LOGE("This stream does not support read"); + return -1; +} + +int stream::write(const void *buf, size_t len) { + LOGE("This stream does not support write"); + return -1; +} + +off_t stream::seek(off_t off, int whence) { + LOGE("This stream does not support seek"); + return -1; +} + +int stream::close() { + return 0; +} + +off_t seekable_stream::new_pos(off_t off, int whence) { + off_t new_pos; + switch (whence) { + case SEEK_CUR: + new_pos = _pos + off; + break; + case SEEK_END: + new_pos = end_pos() + off; + break; + case SEEK_SET: + new_pos = off; + break; + default: + return -1; + } + return new_pos; +} + +byte_stream::byte_stream(uint8_t *&buf, size_t &len) : _buf(buf), _len(len) { + buf = nullptr; + len = 0; +} + +int byte_stream::read(void *buf, size_t len) { + len = std::min(len, _len - _pos); + memcpy(buf, _buf + _pos, len); + return len; +} + +int byte_stream::write(const void *buf, size_t len) { + resize(_pos + len); + memcpy(_buf + _pos, buf, len); + _pos += len; + _len = std::max(_len, _pos); + return len; +} + +off_t byte_stream::seek(off_t off, int whence) { + off_t np = new_pos(off, whence); + if (np < 0) + return -1; + resize(np, true); + _pos = np; + return np; +} + +void byte_stream::resize(size_t new_pos, bool zero) { + bool resize = false; + size_t old_cap = _cap; + while (new_pos > _cap) { + _cap = _cap ? (_cap << 1) - (_cap >> 1) : 1 << 12; + resize = true; + } + if (resize) { + _buf = (uint8_t *) xrealloc(_buf, _cap); + if (zero) + memset(_buf + old_cap, 0, _cap - old_cap); + } +} + +int fd_stream::read(void *buf, size_t len) { + return ::read(fd, buf, len); +} + +int fd_stream::write(const void *buf, size_t len) { + return ::write(fd, buf, len); +} + +off_t fd_stream::seek(off_t off, int whence) { + return lseek(fd, off, whence); +}