diff --git a/native/jni/magiskboot/compress.cpp b/native/jni/magiskboot/compress.cpp index b46fe8c87..e78d99e9a 100644 --- a/native/jni/magiskboot/compress.cpp +++ b/native/jni/magiskboot/compress.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -14,13 +15,11 @@ using namespace std; int64_t decompress(format_t type, int fd, const void *from, size_t size) { - unique_ptr cmp(get_decoder(type)); - return cmp->one_step(fd, from, size); + return unique_ptr(get_decoder(type))->one_step(fd, from, size); } int64_t compress(format_t type, int fd, const void *from, size_t size) { - unique_ptr cmp(get_encoder(type)); - return cmp->one_step(fd, from, size); + return unique_ptr(get_encoder(type))->one_step(fd, from, size); } static bool read_file(FILE *fp, const function &fn) { @@ -68,10 +67,10 @@ void decompress(char *infile, const char *outfile) { } out_fd = strcmp(outfile, "-") == 0 ? STDOUT_FILENO : creat(outfile, 0644); - cmp->set_outfd(out_fd); + cmp->set_out(make_unique(out_fd)); if (ext) *ext = '.'; } - if (!cmp->update(buf, len)) + if (!cmp->write(buf, len)) LOGE("Decompression error!\n"); }); @@ -113,10 +112,10 @@ void compress(const char *method, const char *infile, const char *outfile) { out_fd = strcmp(infile, "-") == 0 ? STDOUT_FILENO : creat(infile, 0644); } - cmp->set_outfd(out_fd); + cmp->set_out(make_unique(out_fd)); read_file(in_file, [&](void *buf, size_t len) -> void { - if (!cmp->update(buf, len)) + if (!cmp->write(buf, len)) LOGE("Compression error!\n"); }); @@ -166,21 +165,9 @@ Compression *get_decoder(format_t type) { } } -Compression::Compression() : fn([](auto, auto) -> void {}) {} - -void Compression::set_outfn(std::function &&fn) { - this->fn = std::move(fn); -} - -void Compression::set_outfd(int fd) { - fn = [=](const void *out, size_t len) -> void { - xwrite(fd, out, len); - }; -} - int64_t Compression::one_step(int outfd, const void *in, size_t size) { - set_outfd(outfd); - if (!update(in, size)) + set_out(make_unique(outfd)); + if (!write(in, size)) return -1; return finalize(); } @@ -196,12 +183,12 @@ GZStream::GZStream(int mode) : mode(mode), strm({}) { } } -bool GZStream::update(const void *in, size_t size) { - return update(in, size, Z_NO_FLUSH); +bool GZStream::write(const void *in, size_t size) { + return write(in, size, Z_NO_FLUSH); } uint64_t GZStream::finalize() { - update(nullptr, 0, Z_FINISH); + write(nullptr, 0, Z_FINISH); uint64_t total = strm.total_out; switch(mode) { case 0: @@ -214,7 +201,7 @@ uint64_t GZStream::finalize() { return total; } -bool GZStream::update(const void *in, size_t size, int flush) { +bool GZStream::write(const void *in, size_t size, int flush) { int ret; strm.next_in = (Bytef *) in; strm.avail_in = size; @@ -233,7 +220,7 @@ bool GZStream::update(const void *in, size_t size, int flush) { LOGW("Gzip %s failed (%d)\n", mode ? "encode" : "decode", ret); return false; } - fn(outbuf, sizeof(outbuf) - strm.avail_out); + FilterOutStream::write(outbuf, sizeof(outbuf) - strm.avail_out); } while (strm.avail_out == 0); return true; } @@ -249,13 +236,13 @@ BZStream::BZStream(int mode) : mode(mode), strm({}) { } } -bool BZStream::update(const void *in, size_t size) { - return update(in, size, BZ_RUN); +bool BZStream::write(const void *in, size_t size) { + return write(in, size, BZ_RUN); } uint64_t BZStream::finalize() { if (mode) - update(nullptr, 0, BZ_FINISH); + write(nullptr, 0, BZ_FINISH); uint64_t total = ((uint64_t) strm.total_out_hi32 << 32) + strm.total_out_lo32; switch(mode) { case 0: @@ -268,7 +255,7 @@ uint64_t BZStream::finalize() { return total; } -bool BZStream::update(const void *in, size_t size, int flush) { +bool BZStream::write(const void *in, size_t size, int flush) { int ret; strm.next_in = (char *) in; strm.avail_in = size; @@ -287,7 +274,7 @@ bool BZStream::update(const void *in, size_t size, int flush) { LOGW("Bzip2 %s failed (%d)\n", mode ? "encode" : "decode", ret); return false; } - fn(outbuf, sizeof(outbuf) - strm.avail_out); + FilterOutStream::write(outbuf, sizeof(outbuf) - strm.avail_out); } while (strm.avail_out == 0); return true; } @@ -316,18 +303,18 @@ LZMAStream::LZMAStream(int mode) : mode(mode), strm(LZMA_STREAM_INIT) { } } -bool LZMAStream::update(const void *in, size_t size) { - return update(in, size, LZMA_RUN); +bool LZMAStream::write(const void *in, size_t size) { + return write(in, size, LZMA_RUN); } uint64_t LZMAStream::finalize() { - update(nullptr, 0, LZMA_FINISH); + write(nullptr, 0, LZMA_FINISH); uint64_t total = strm.total_out; lzma_end(&strm); return total; } -bool LZMAStream::update(const void *in, size_t size, lzma_action flush) { +bool LZMAStream::write(const void *in, size_t size, lzma_action flush) { int ret; strm.next_in = (uint8_t *) in; strm.avail_in = size; @@ -339,7 +326,7 @@ bool LZMAStream::update(const void *in, size_t size, lzma_action flush) { LOGW("LZMA %s failed (%d)\n", mode ? "encode" : "decode", ret); return false; } - fn(outbuf, sizeof(outbuf) - strm.avail_out); + FilterOutStream::write(outbuf, sizeof(outbuf) - strm.avail_out); } while (strm.avail_out == 0); return true; } @@ -353,7 +340,7 @@ LZ4FDecoder::~LZ4FDecoder() { delete[] outbuf; } -bool LZ4FDecoder::update(const void *in, size_t size) { +bool LZ4FDecoder::write(const void *in, size_t size) { auto inbuf = (const uint8_t *) in; if (!outbuf) read_header(inbuf, size); @@ -370,7 +357,7 @@ bool LZ4FDecoder::update(const void *in, size_t size) { size -= read; inbuf += read; total += write; - fn(outbuf, write); + FilterOutStream::write(outbuf, write); } while (size != 0 || write != 0); return true; } @@ -404,7 +391,7 @@ LZ4FEncoder::~LZ4FEncoder() { delete[] outbuf; } -bool LZ4FEncoder::update(const void *in, size_t size) { +bool LZ4FEncoder::write(const void *in, size_t size) { if (!outbuf) write_header(); auto inbuf = (const uint8_t *) in; @@ -419,7 +406,7 @@ bool LZ4FEncoder::update(const void *in, size_t size) { size -= read; inbuf += read; total += write; - fn(outbuf, write); + FilterOutStream::write(outbuf, write); } while (size != 0); return true; } @@ -427,7 +414,7 @@ bool LZ4FEncoder::update(const void *in, size_t size) { uint64_t LZ4FEncoder::finalize() { size_t write = LZ4F_compressEnd(ctx, outbuf, outCapacity, nullptr); total += write; - fn(outbuf, write); + FilterOutStream::write(outbuf, write); return total; } @@ -446,7 +433,7 @@ void LZ4FEncoder::write_header() { outbuf = new uint8_t[outCapacity]; size_t write = LZ4F_compressBegin(ctx, outbuf, outCapacity, &prefs); total += write; - fn(outbuf, write); + FilterOutStream::write(outbuf, write); } LZ4Decoder::LZ4Decoder() : init(false), buf_off(0), total(0), block_sz(0) { @@ -459,7 +446,7 @@ LZ4Decoder::~LZ4Decoder() { delete[] buf; } -bool LZ4Decoder::update(const void *in, size_t size) { +bool LZ4Decoder::write(const void *in, size_t size) { const char *inbuf = (const char *) in; if (!init) { // Skip magic @@ -485,7 +472,7 @@ bool LZ4Decoder::update(const void *in, size_t size) { LOGW("LZ4HC decompression failure (%d)\n", write); return false; } - fn(outbuf, write); + FilterOutStream::write(outbuf, write); total += write; // Reset @@ -515,9 +502,9 @@ LZ4Encoder::~LZ4Encoder() { delete[] buf; } -bool LZ4Encoder::update(const void *in, size_t size) { +bool LZ4Encoder::write(const void *in, size_t size) { if (!init) { - fn("\x02\x21\x4c\x18", 4); + FilterOutStream::write("\x02\x21\x4c\x18", 4); init = true; } in_total += size; @@ -536,8 +523,8 @@ bool LZ4Encoder::update(const void *in, size_t size) { LOGW("LZ4HC compression failure\n"); return false; } - fn(&write, sizeof(write)); - fn(outbuf, write); + FilterOutStream::write(&write, sizeof(write)); + FilterOutStream::write(outbuf, write); out_total += write + sizeof(write); // Reset buffer @@ -555,10 +542,10 @@ bool LZ4Encoder::update(const void *in, size_t size) { uint64_t LZ4Encoder::finalize() { if (buf_off) { int write = LZ4_compress_HC(buf, outbuf, buf_off, LZ4_COMPRESSED, 9); - fn(&write, sizeof(write)); - fn(outbuf, write); + FilterOutStream::write(&write, sizeof(write)); + FilterOutStream::write(outbuf, write); out_total += write + sizeof(write); } - fn(&in_total, sizeof(in_total)); + FilterOutStream::write(&in_total, sizeof(in_total)); return out_total + sizeof(in_total); } diff --git a/native/jni/magiskboot/compress.h b/native/jni/magiskboot/compress.h index e83e78768..f23460ac2 100644 --- a/native/jni/magiskboot/compress.h +++ b/native/jni/magiskboot/compress.h @@ -1,41 +1,26 @@ #pragma once -#include - #include #include #include #include #include #include +#include #include "format.h" #define CHUNK 0x40000 -class Compression { +class Compression : public FilterOutStream { public: - virtual ~Compression() = default; - void set_outfn(std::function &&fn); - void set_outfd(int fd); int64_t one_step(int outfd, const void *in, size_t size); - virtual bool update(const void *in, size_t size) = 0; virtual uint64_t finalize() = 0; - - template - static int64_t one_step(int outfd, const void *in, size_t size) { - T cmp; - return cmp.one_step(outfd, in, size); - } - -protected: - Compression(); - std::function fn; }; class GZStream : public Compression { public: - bool update(const void *in, size_t size) override; + bool write(const void *in, size_t size) override; uint64_t finalize() override; protected: @@ -46,7 +31,7 @@ private: z_stream strm; uint8_t outbuf[CHUNK]; - bool update(const void *in, size_t size, int flush); + bool write(const void *in, size_t size, int flush); }; class GZDecoder : public GZStream { @@ -61,7 +46,7 @@ public: class BZStream : public Compression { public: - bool update(const void *in, size_t size) override; + bool write(const void *in, size_t size) override; uint64_t finalize() override; protected: @@ -72,7 +57,7 @@ private: bz_stream strm; char outbuf[CHUNK]; - bool update(const void *in, size_t size, int flush); + bool write(const void *in, size_t size, int flush); }; class BZDecoder : public BZStream { @@ -87,7 +72,7 @@ public: class LZMAStream : public Compression { public: - bool update(const void *in, size_t size) override; + bool write(const void *in, size_t size) override; uint64_t finalize() override; protected: @@ -98,7 +83,7 @@ private: lzma_stream strm; uint8_t outbuf[CHUNK]; - bool update(const void *in, size_t size, lzma_action flush); + bool write(const void *in, size_t size, lzma_action flush); }; class LZMADecoder : public LZMAStream { @@ -120,7 +105,7 @@ class LZ4FDecoder : public Compression { public: LZ4FDecoder(); ~LZ4FDecoder() override; - bool update(const void *in, size_t size) override; + bool write(const void *in, size_t size) override; uint64_t finalize() override; private: @@ -136,7 +121,7 @@ class LZ4FEncoder : public Compression { public: LZ4FEncoder(); ~LZ4FEncoder() override; - bool update(const void *in, size_t size) override; + bool write(const void *in, size_t size) override; uint64_t finalize() override; private: @@ -156,7 +141,7 @@ class LZ4Decoder : public Compression { public: LZ4Decoder(); ~LZ4Decoder() override; - bool update(const void *in, size_t size) override; + bool write(const void *in, size_t size) override; uint64_t finalize() override; private: @@ -172,7 +157,7 @@ class LZ4Encoder : public Compression { public: LZ4Encoder(); ~LZ4Encoder() override; - bool update(const void *in, size_t size) override; + bool write(const void *in, size_t size) override; uint64_t finalize() override; private: diff --git a/native/jni/utils/include/OutStream.h b/native/jni/utils/include/OutStream.h new file mode 100644 index 000000000..057f51c13 --- /dev/null +++ b/native/jni/utils/include/OutStream.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +class OutStream { +public: + virtual bool write(const void *buf, size_t len) = 0; + virtual ~OutStream() = default; +}; + +typedef std::unique_ptr strm_ptr; + +class FilterOutStream : public OutStream { +public: + FilterOutStream() = default; + + FilterOutStream(strm_ptr &&ptr) : out(std::move(ptr)) {} + + void set_out(strm_ptr &&ptr) { out = std::move(ptr); } + + bool write(const void *buf, size_t len) override { + return out ? out->write(buf, len) : false; + } + +protected: + strm_ptr out; +}; + +class FDOutStream : public OutStream { +public: + FDOutStream(int fd, bool close = false) : fd(fd), close(close) {} + + bool write(const void *buf, size_t len) override { + return ::write(fd, buf, len) == len; + } + + ~FDOutStream() override { + if (close) + ::close(fd); + } + +protected: + int fd; + bool close; +};