diff --git a/native/jni/magiskboot/compress.cpp b/native/jni/magiskboot/compress.cpp index edc3df1fc..dba7d22dc 100644 --- a/native/jni/magiskboot/compress.cpp +++ b/native/jni/magiskboot/compress.cpp @@ -476,142 +476,90 @@ private: } }; -class LZ4_decoder : public cpr_stream { +class buf_cpr_stream : public chunk_out_stream { +public: + using chunk_out_stream::chunk_out_stream; + ssize_t writeFully(void *buf, size_t len) override { + return write(buf, len); + } +}; + +class LZ4_decoder : public buf_cpr_stream { public: explicit LZ4_decoder(stream_ptr &&base) : - cpr_stream(std::move(base)), out_buf(new char[LZ4_UNCOMPRESSED]), - buf(new char[LZ4_COMPRESSED]), init(false), block_sz(0), buf_off(0) {} + buf_cpr_stream(std::move(base), LZ4_COMPRESSED, sizeof(block_sz) + 4), + out_buf(new char[LZ4_UNCOMPRESSED]), block_sz(0) {} ~LZ4_decoder() override { + close(); delete[] out_buf; - delete[] buf; } - ssize_t write(const void *in, size_t size) override { - size_t ret = 0; - auto inbuf = static_cast(in); - if (!init) { - // Skip magic - inbuf += 4; - size -= 4; - init = true; - } - for (size_t consumed; size != 0;) { - if (block_sz == 0) { - if (buf_off + size >= sizeof(block_sz)) { - consumed = sizeof(block_sz) - buf_off; - memcpy(buf + buf_off, inbuf, consumed); - memcpy(&block_sz, buf, sizeof(block_sz)); - buf_off = 0; - } else { - consumed = size; - memcpy(buf + buf_off, inbuf, size); - } - inbuf += consumed; - size -= consumed; - } else if (buf_off + size >= block_sz) { - consumed = block_sz - buf_off; - memcpy(buf + buf_off, inbuf, consumed); - inbuf += consumed; - size -= consumed; +protected: + ssize_t write_chunk(const void *buf, size_t len) override { + // This is an error + if (len != chunk_sz) + return -1; - int write = LZ4_decompress_safe(buf, out_buf, block_sz, LZ4_UNCOMPRESSED); - if (write < 0) { - LOGW("LZ4HC decompression failure (%d)\n", write); - return -1; - } - ret += bwrite(out_buf, write); + auto in = reinterpret_cast(buf); - // Reset - buf_off = 0; - block_sz = 0; + if (block_sz == 0) { + if (chunk_sz == sizeof(block_sz) + 4) { + // Skip the first 4 bytes, which is magic + memcpy(&block_sz, in + 4, sizeof(block_sz)); } else { - // Copy to internal buffer - memcpy(buf + buf_off, inbuf, size); - buf_off += size; - break; + memcpy(&block_sz, in, sizeof(block_sz)); } + chunk_sz = block_sz; + return 0; + } else { + int r = LZ4_decompress_safe(in, out_buf, block_sz, LZ4_UNCOMPRESSED); + chunk_sz = sizeof(block_sz); + block_sz = 0; + if (r < 0) { + LOGW("LZ4HC decompression failure (%d)\n", r); + return -1; + } + return bwrite(out_buf, r); } - return ret; } private: char *out_buf; - char *buf; - bool init; - unsigned block_sz; - int buf_off; + uint32_t block_sz; }; -class LZ4_encoder : public cpr_stream { +class LZ4_encoder : public buf_cpr_stream { public: explicit LZ4_encoder(stream_ptr &&base, bool lg) : - cpr_stream(std::move(base)), outbuf(new char[LZ4_COMPRESSED]), - buf(new char[LZ4_UNCOMPRESSED]), init(false), lg(lg), buf_off(0), in_total(0) {} - - ssize_t write(const void *in, size_t size) override { - size_t ret = 0; - if (!init) { - ret += bwrite("\x02\x21\x4c\x18", 4); - init = true; - } - if (size == 0) - return 0; - in_total += size; - const char *inbuf = (const char *) in; - size_t consumed; - do { - if (buf_off + size >= LZ4_UNCOMPRESSED) { - consumed = LZ4_UNCOMPRESSED - buf_off; - memcpy(buf + buf_off, inbuf, consumed); - inbuf += consumed; - size -= consumed; - buf_off = LZ4_UNCOMPRESSED; - - if (int written = write_block(); written < 0) - return -1; - else - ret += written; - - // Reset buffer - buf_off = 0; - } else { - // Copy to internal buffer - memcpy(buf + buf_off, inbuf, size); - buf_off += size; - size = 0; - } - } while (size != 0); - return ret; + buf_cpr_stream(std::move(base), LZ4_UNCOMPRESSED), + out_buf(new char[LZ4_COMPRESSED]), lg(lg), in_total(0) { + bwrite("\x02\x21\x4c\x18", 4); } ~LZ4_encoder() override { - if (buf_off) - write_block(); + close(); if (lg) bwrite(&in_total, sizeof(in_total)); - delete[] outbuf; - delete[] buf; + delete[] out_buf; } -private: - char *outbuf; - char *buf; - bool init; - bool lg; - int buf_off; - unsigned in_total; - - int write_block() { - int written = LZ4_compress_HC(buf, outbuf, buf_off, LZ4_COMPRESSED, LZ4HC_CLEVEL_MAX); - if (written == 0) { +protected: + ssize_t write_chunk(const void *buf, size_t len) override { + int r = LZ4_compress_HC((const char *) buf, out_buf, len, LZ4_COMPRESSED, LZ4HC_CLEVEL_MAX); + if (r == 0) { LOGW("LZ4HC compression failure\n"); return -1; } - bwrite(&written, sizeof(written)); - bwrite(outbuf, written); - return written + sizeof(written); + bwrite(&r, sizeof(r)); + bwrite(out_buf, r); + return r + sizeof(r); } + +private: + char *out_buf; + bool lg; + unsigned in_total; }; stream_ptr get_encoder(format_t type, stream_ptr &&base) { diff --git a/native/jni/utils/include/stream.hpp b/native/jni/utils/include/stream.hpp index ce2230cb4..7873a055b 100644 --- a/native/jni/utils/include/stream.hpp +++ b/native/jni/utils/include/stream.hpp @@ -35,6 +35,34 @@ protected: stream_ptr base; }; +// Buffered output stream, writing in chunks +class chunk_out_stream : public filter_stream { +public: + chunk_out_stream(stream_ptr &&base, size_t buf_sz, size_t chunk_sz) + : filter_stream(std::move(base)), chunk_sz(chunk_sz), buf_sz(buf_sz) {} + + chunk_out_stream(stream_ptr &&base, size_t buf_sz = 4096) + : chunk_out_stream(std::move(base), buf_sz, buf_sz) {} + + ~chunk_out_stream() { delete[] _buf; } + + // Reading does not make sense + ssize_t read(void *buf, size_t len) final { return stream::read(buf, len); } + ssize_t write(const void *buf, size_t len) final; + +protected: + // Classes inheriting this class has to call close() in the destructor + void close(); + virtual ssize_t write_chunk(const void *buf, size_t len) = 0; + + size_t chunk_sz; + +private: + size_t buf_sz; + size_t buf_off = 0; + uint8_t *_buf = nullptr; +}; + // Byte stream that dynamically allocates memory class byte_stream : public stream { public: diff --git a/native/jni/utils/stream.cpp b/native/jni/utils/stream.cpp index e870c61ad..bd63c9bb1 100644 --- a/native/jni/utils/stream.cpp +++ b/native/jni/utils/stream.cpp @@ -121,6 +121,51 @@ ssize_t filter_stream::write(const void *buf, size_t len) { return base->write(buf, len); } +ssize_t chunk_out_stream::write(const void *_in, size_t len) { + ssize_t ret = 0; + auto in = static_cast(_in); + while (len) { + if (buf_off + len >= chunk_sz) { + const uint8_t *src; + if (buf_off) { + // Copy the rest of the chunk to internal buffer + src = _buf; + auto copy = chunk_sz - buf_off; + memcpy(_buf + buf_off, in, copy); + in += copy; + len -= copy; + buf_off = 0; + } else { + src = in; + in += chunk_sz; + len -= chunk_sz; + } + auto r = write_chunk(src, chunk_sz); + if (r < 0) + return ret; + ret += r; + } else { + // Buffer internally + if (!_buf) { + _buf = new uint8_t[buf_sz]; + } + memcpy(_buf + buf_off, in, len); + buf_off += len; + break; + } + } + return ret; +} + +void chunk_out_stream::close() { + if (buf_off) { + write_chunk(_buf, buf_off); + delete[] _buf; + _buf = nullptr; + buf_off = 0; + } +} + byte_stream::byte_stream(uint8_t *&buf, size_t &len) : _buf(buf), _len(len) { buf = nullptr; len = 0;