Cleanup filter_out_stream implementation

This commit is contained in:
topjohnwu 2023-05-20 01:28:10 -07:00
parent 655f778171
commit f5aaff2b1e
6 changed files with 13 additions and 50 deletions

View File

@ -18,16 +18,11 @@ using out_strm_ptr = std::unique_ptr<out_stream>;
class filter_out_stream : public out_stream { class filter_out_stream : public out_stream {
public: public:
filter_out_stream(out_strm_ptr &&base) : base(std::move(base)) {} filter_out_stream(out_strm_ptr &&base) : base(std::move(base)) {}
bool write(const void *buf, size_t len) override; bool write(const void *buf, size_t len) override;
virtual bool write(const void *buf, size_t len, bool final);
protected: protected:
out_strm_ptr base; out_strm_ptr base;
}; };
using filter_strm_ptr = std::unique_ptr<filter_out_stream>;
// Buffered output stream, writing in chunks // Buffered output stream, writing in chunks
class chunk_out_stream : public filter_out_stream { class chunk_out_stream : public filter_out_stream {
public: public:
@ -40,7 +35,6 @@ public:
~chunk_out_stream() override { delete[] _buf; } ~chunk_out_stream() override { delete[] _buf; }
bool write(const void *buf, size_t len) final; bool write(const void *buf, size_t len) final;
bool write(const void *buf, size_t len, bool final) final;
protected: protected:
// Classes inheriting this class has to call finalize() in its destructor // Classes inheriting this class has to call finalize() in its destructor
@ -63,7 +57,7 @@ struct in_stream {
}; };
// A channel is something that is writable, readable, and seekable // A channel is something that is writable, readable, and seekable
struct channel : public in_stream, public out_stream { struct channel : public out_stream, public in_stream {
virtual off_t seek(off_t off, int whence) = 0; virtual off_t seek(off_t off, int whence) = 0;
virtual ~channel() = default; virtual ~channel() = default;
}; };

View File

@ -88,15 +88,7 @@ bool filter_out_stream::write(const void *buf, size_t len) {
return base->write(buf, len); return base->write(buf, len);
} }
bool filter_out_stream::write(const void *buf, size_t len, bool final) { bool chunk_out_stream::write(const void *_in, size_t len) {
return write(buf, len);
}
bool chunk_out_stream::write(const void *buf, size_t len) {
return write(buf, len, false);
}
bool chunk_out_stream::write(const void *_in, size_t len, bool final) {
auto in = static_cast<const uint8_t *>(_in); auto in = static_cast<const uint8_t *>(_in);
while (len) { while (len) {
if (buf_off + len >= chunk_sz) { if (buf_off + len >= chunk_sz) {
@ -114,21 +106,8 @@ bool chunk_out_stream::write(const void *_in, size_t len, bool final) {
in += chunk_sz; in += chunk_sz;
len -= chunk_sz; len -= chunk_sz;
} }
if (!write_chunk(src, chunk_sz, final && len == 0)) if (!write_chunk(src, chunk_sz, false))
return false; return false;
} else if (final) {
// Final input data, write regardless whether it is chunk sized
if (buf_off) {
memcpy(_buf + buf_off, in, len);
auto avail = buf_off + len;
buf_off = 0;
if (!write_chunk(_buf, avail, true))
return false;
} else {
if (!write_chunk(in, len, true))
return false;
}
break;
} else { } else {
// Buffer internally // Buffer internally
if (!_buf) { if (!_buf) {

View File

@ -20,14 +20,14 @@ uint64_t dyn_img_hdr::j64 = 0;
static void decompress(format_t type, int fd, const void *in, size_t size) { static void decompress(format_t type, int fd, const void *in, size_t size) {
auto ptr = get_decoder(type, make_unique<fd_channel>(fd)); auto ptr = get_decoder(type, make_unique<fd_channel>(fd));
ptr->write(in, size, true); ptr->write(in, size);
} }
static off_t compress(format_t type, int fd, const void *in, size_t size) { static off_t compress(format_t type, int fd, const void *in, size_t size) {
auto prev = lseek(fd, 0, SEEK_CUR); auto prev = lseek(fd, 0, SEEK_CUR);
{ {
auto strm = get_encoder(type, make_unique<fd_channel>(fd)); auto strm = get_encoder(type, make_unique<fd_channel>(fd));
strm->write(in, size, true); strm->write(in, size);
} }
auto now = lseek(fd, 0, SEEK_CUR); auto now = lseek(fd, 0, SEEK_CUR);
return now - prev; return now - prev;

View File

@ -195,9 +195,6 @@ public:
protected: protected:
bool write_chunk(const void *buf, size_t len, bool final) override { bool write_chunk(const void *buf, size_t len, bool final) override {
if (len == 0)
return true;
auto in = static_cast<const unsigned char *>(buf); auto in = static_cast<const unsigned char *>(buf);
in_total += len; in_total += len;
@ -514,7 +511,7 @@ public:
} }
protected: protected:
bool write_chunk(const void *buf, size_t len, bool final) override { bool write_chunk(const void *buf, size_t len, bool) override {
// This is an error // This is an error
if (len != chunk_sz) if (len != chunk_sz)
return false; return false;
@ -565,7 +562,7 @@ public:
} }
protected: protected:
bool write_chunk(const void *buf, size_t len, bool final) override { bool write_chunk(const void *buf, size_t len, bool) override {
auto in = static_cast<const char *>(buf); auto in = static_cast<const char *>(buf);
uint32_t block_sz = LZ4_compress_HC(in, out_buf, len, LZ4_COMPRESSED, LZ4HC_CLEVEL_MAX); uint32_t block_sz = LZ4_compress_HC(in, out_buf, len, LZ4_COMPRESSED, LZ4HC_CLEVEL_MAX);
if (block_sz == 0) { if (block_sz == 0) {
@ -585,7 +582,7 @@ private:
uint32_t in_total; uint32_t in_total;
}; };
filter_strm_ptr get_encoder(format_t type, out_strm_ptr &&base) { out_strm_ptr get_encoder(format_t type, out_strm_ptr &&base) {
switch (type) { switch (type) {
case XZ: case XZ:
return make_unique<xz_encoder>(std::move(base)); return make_unique<xz_encoder>(std::move(base));
@ -607,7 +604,7 @@ filter_strm_ptr get_encoder(format_t type, out_strm_ptr &&base) {
} }
} }
filter_strm_ptr get_decoder(format_t type, out_strm_ptr &&base) { out_strm_ptr get_decoder(format_t type, out_strm_ptr &&base) {
switch (type) { switch (type) {
case XZ: case XZ:
case LZMA: case LZMA:
@ -721,7 +718,6 @@ void compress(const char *method, const char *infile, const char *outfile) {
unlink(infile); unlink(infile);
} }
namespace rust {
bool decompress(const unsigned char *in, uint64_t in_size, int fd) { bool decompress(const unsigned char *in, uint64_t in_size, int fd) {
format_t type = check_fmt(in, in_size); format_t type = check_fmt(in, in_size);
@ -736,4 +732,3 @@ bool decompress(const unsigned char *in, uint64_t in_size, int fd) {
} }
return true; return true;
} }
}

View File

@ -4,14 +4,8 @@
#include "format.hpp" #include "format.hpp"
filter_strm_ptr get_encoder(format_t type, out_strm_ptr &&base); out_strm_ptr get_encoder(format_t type, out_strm_ptr &&base);
out_strm_ptr get_decoder(format_t type, out_strm_ptr &&base);
filter_strm_ptr get_decoder(format_t type, out_strm_ptr &&base);
void compress(const char *method, const char *infile, const char *outfile); void compress(const char *method, const char *infile, const char *outfile);
void decompress(char *infile, const char *outfile); void decompress(char *infile, const char *outfile);
namespace rust {
bool decompress(const unsigned char *in, uint64_t in_size, int fd); bool decompress(const unsigned char *in, uint64_t in_size, int fd);
}

View File

@ -6,13 +6,14 @@ pub use payload::*;
mod payload; mod payload;
mod update_metadata; mod update_metadata;
#[cxx::bridge(namespace = "rust")] #[cxx::bridge]
pub mod ffi { pub mod ffi {
extern "C++" { extern "C++" {
include!("compress.hpp"); include!("compress.hpp");
pub unsafe fn decompress(in_: *const u8, in_size: u64, fd: i32) -> bool; pub unsafe fn decompress(in_: *const u8, in_size: u64, fd: i32) -> bool;
} }
#[namespace = "rust"]
extern "Rust" { extern "Rust" {
unsafe fn extract_boot_from_payload( unsafe fn extract_boot_from_payload(
in_path: *const c_char, in_path: *const c_char,