Stream should always write all bytes

This commit is contained in:
topjohnwu 2021-11-21 04:38:22 -08:00
parent d8b9265484
commit 5787aa1078
3 changed files with 218 additions and 231 deletions

View File

@ -29,19 +29,15 @@ constexpr size_t ZOPFLI_CHUNK = ZOPFLI_MASTER_BLOCK_SIZE;
constexpr size_t ZOPFLI_CHUNK = CHUNK;
#endif
class cpr_stream : public filter_stream {
public:
class out_stream : public filter_stream {
using filter_stream::filter_stream;
using stream::read;
ssize_t writeFully(void *buf, size_t len) override {
return write(buf, len);
}
};
class gz_strm : public cpr_stream {
class gz_strm : public out_stream {
public:
ssize_t write(const void *buf, size_t len) override {
return len ? write(buf, len, Z_NO_FLUSH) : 0;
bool write(const void *buf, size_t len) override {
return len == 0 || write(buf, len, Z_NO_FLUSH);
}
~gz_strm() override {
@ -63,7 +59,7 @@ protected:
} mode;
gz_strm(mode_t mode, stream_ptr &&base) :
cpr_stream(std::move(base)), mode(mode), strm{}, outbuf{0} {
out_stream(std::move(base)), mode(mode), strm{}, outbuf{0} {
switch(mode) {
case DECODE:
inflateInit2(&strm, 15 | 16);
@ -78,8 +74,7 @@ private:
z_stream strm;
uint8_t outbuf[CHUNK];
ssize_t write(const void *buf, size_t len, int flush) {
size_t ret = 0;
bool write(const void *buf, size_t len, int flush) {
strm.next_in = (Bytef *) buf;
strm.avail_in = len;
do {
@ -96,11 +91,12 @@ private:
}
if (code == Z_STREAM_ERROR) {
LOGW("gzip %s failed (%d)\n", mode ? "encode" : "decode", code);
return -1;
return false;
}
ret += bwrite(outbuf, sizeof(outbuf) - strm.avail_out);
if (!bwrite(outbuf, sizeof(outbuf) - strm.avail_out))
return false;
} while (strm.avail_out == 0);
return ret;
return true;
}
};
@ -114,14 +110,40 @@ public:
explicit gz_encoder(stream_ptr &&base) : gz_strm(ENCODE, std::move(base)) {};
};
class zopfli_encoder : public cpr_stream {
class zopfli_encoder : public out_stream {
public:
ssize_t write(const void *buf, size_t len) override {
return len ? write(static_cast<const unsigned char *>(buf), len) : 0;
bool write(const void *buf, size_t len) override {
if (len == 0)
return true;
auto in = static_cast<const unsigned char *>(buf);
in_size += len;
crcvalue = crc32_z(crcvalue, in, len);
for (size_t offset = 0; offset < len; offset += ZOPFLI_CHUNK) {
size_t end_offset = std::min(len, offset + ZOPFLI_CHUNK);
ZopfliDeflatePart(&zo, 2, 0, in, offset, end_offset, &bp, &out, &outsize);
if (bp) {
// The last byte is not complete
if (!bwrite(out, outsize - 1))
return false;
uint8_t b = out[outsize - 1];
free_out();
ZOPFLI_APPEND_DATA(b, &out, &outsize);
} else {
if (!bwrite(out, outsize))
return false;
free_out();
}
}
explicit zopfli_encoder(stream_ptr &&base) : cpr_stream(std::move(base)),
zo({}), out(nullptr), outsize(0), bp(0), crcvalue(crc32_z(0L, Z_NULL, 0)), in_size(0) {
return true;
}
explicit zopfli_encoder(stream_ptr &&base) :
out_stream(std::move(base)), zo({}), out(nullptr), outsize(0), bp(0),
crcvalue(crc32_z(0L, Z_NULL, 0)), in_size(0) {
ZopfliInitOptions(&zo);
// Speed things up a bit, this still leads to better compression than zlib
@ -174,36 +196,12 @@ private:
out = nullptr;
outsize = 0;
}
ssize_t write(const unsigned char *buf, size_t len) {
ssize_t ret = 0;
in_size += len;
crcvalue = crc32_z(crcvalue, buf, len);
for (size_t offset = 0; offset < len; offset += ZOPFLI_CHUNK) {
size_t end_offset = std::min(len, offset + ZOPFLI_CHUNK);
ZopfliDeflatePart(&zo, 2, 0, buf, offset, end_offset, &bp, &out, &outsize);
if (bp) {
// The last byte is not complete
ret += bwrite(out, outsize - 1);
uint8_t b = out[outsize - 1];
free_out();
ZOPFLI_APPEND_DATA(b, &out, &outsize);
} else {
ret += bwrite(out, outsize);
free_out();
}
}
return ret;
}
};
class bz_strm : public cpr_stream {
class bz_strm : public out_stream {
public:
ssize_t write(const void *buf, size_t len) override {
return len ? write(buf, len, BZ_RUN) : 0;
bool write(const void *buf, size_t len) override {
return len == 0 || write(buf, len, BZ_RUN);
}
~bz_strm() override {
@ -225,7 +223,7 @@ protected:
} mode;
bz_strm(mode_t mode, stream_ptr &&base) :
cpr_stream(std::move(base)), mode(mode), strm{}, outbuf{0} {
out_stream(std::move(base)), mode(mode), strm{}, outbuf{0} {
switch(mode) {
case DECODE:
BZ2_bzDecompressInit(&strm, 0, 0);
@ -240,8 +238,7 @@ private:
bz_stream strm;
char outbuf[CHUNK];
ssize_t write(const void *buf, size_t len, int flush) {
size_t ret = 0;
bool write(const void *buf, size_t len, int flush) {
strm.next_in = (char *) buf;
strm.avail_in = len;
do {
@ -258,11 +255,12 @@ private:
}
if (code < 0) {
LOGW("bzip2 %s failed (%d)\n", mode ? "encode" : "decode", code);
return -1;
return false;
}
ret += bwrite(outbuf, sizeof(outbuf) - strm.avail_out);
if (!bwrite(outbuf, sizeof(outbuf) - strm.avail_out))
return false;
} while (strm.avail_out == 0);
return ret;
return true;
}
};
@ -276,10 +274,10 @@ public:
explicit bz_encoder(stream_ptr &&base) : bz_strm(ENCODE, std::move(base)) {};
};
class lzma_strm : public cpr_stream {
class lzma_strm : public out_stream {
public:
ssize_t write(const void *buf, size_t len) override {
return len ? write(buf, len, LZMA_RUN) : 0;
bool write(const void *buf, size_t len) override {
return len == 0 || write(buf, len, LZMA_RUN);
}
~lzma_strm() override {
@ -295,7 +293,7 @@ protected:
} mode;
lzma_strm(mode_t mode, stream_ptr &&base) :
cpr_stream(std::move(base)), mode(mode), strm(LZMA_STREAM_INIT), outbuf{0} {
out_stream(std::move(base)), mode(mode), strm(LZMA_STREAM_INIT), outbuf{0} {
lzma_options_lzma opt;
// Initialize preset
@ -326,8 +324,7 @@ private:
lzma_stream strm;
uint8_t outbuf[CHUNK];
ssize_t write(const void *buf, size_t len, lzma_action flush) {
size_t ret = 0;
bool write(const void *buf, size_t len, lzma_action flush) {
strm.next_in = (uint8_t *) buf;
strm.avail_in = len;
do {
@ -336,11 +333,12 @@ private:
int code = lzma_code(&strm, flush);
if (code != LZMA_OK && code != LZMA_STREAM_END) {
LOGW("LZMA %s failed (%d)\n", mode ? "encode" : "decode", code);
return -1;
return false;
}
ret += bwrite(outbuf, sizeof(outbuf) - strm.avail_out);
if (!bwrite(outbuf, sizeof(outbuf) - strm.avail_out))
return false;
} while (strm.avail_out == 0);
return ret;
return true;
}
};
@ -359,10 +357,10 @@ public:
explicit lzma_encoder(stream_ptr &&base) : lzma_strm(ENCODE_LZMA, std::move(base)) {}
};
class LZ4F_decoder : public cpr_stream {
class LZ4F_decoder : public out_stream {
public:
explicit LZ4F_decoder(stream_ptr &&base) :
cpr_stream(std::move(base)), ctx(nullptr), outbuf(nullptr), outCapacity(0) {
out_stream(std::move(base)), ctx(nullptr), outbuf(nullptr), outCapacity(0) {
LZ4F_createDecompressionContext(&ctx, LZ4F_VERSION);
}
@ -371,35 +369,10 @@ public:
delete[] outbuf;
}
ssize_t write(const void *buf, size_t len) override {
size_t ret = 0;
auto inbuf = reinterpret_cast<const uint8_t *>(buf);
if (!outbuf)
read_header(inbuf, len);
size_t read, write;
LZ4F_errorCode_t code;
do {
read = len;
write = outCapacity;
code = LZ4F_decompress(ctx, outbuf, &write, inbuf, &read, nullptr);
if (LZ4F_isError(code)) {
LOGW("LZ4F decode error: %s\n", LZ4F_getErrorName(code));
return -1;
}
len -= read;
inbuf += read;
ret += bwrite(outbuf, write);
} while (len != 0 || write != 0);
return ret;
}
private:
LZ4F_decompressionContext_t ctx;
uint8_t *outbuf;
size_t outCapacity;
void read_header(const uint8_t *&in, size_t &size) {
size_t read = size;
bool write(const void *buf, size_t len) override {
auto in = reinterpret_cast<const uint8_t *>(buf);
if (!outbuf) {
size_t read = len;
LZ4F_frameInfo_t info;
LZ4F_getFrameInfo(ctx, &info, in, &read);
switch (info.blockSizeID) {
@ -411,54 +384,41 @@ private:
}
outbuf = new uint8_t[outCapacity];
in += read;
size -= read;
len -= read;
}
};
class LZ4F_encoder : public cpr_stream {
public:
explicit LZ4F_encoder(stream_ptr &&base) :
cpr_stream(std::move(base)), ctx(nullptr), outbuf(nullptr), outCapacity(0) {
LZ4F_createCompressionContext(&ctx, LZ4F_VERSION);
}
ssize_t write(const void *buf, size_t len) override {
size_t ret = 0;
if (!outbuf)
ret += write_header();
if (len == 0)
return 0;
auto inbuf = reinterpret_cast<const uint8_t *>(buf);
size_t read, write;
LZ4F_errorCode_t code;
do {
read = len > BLOCK_SZ ? BLOCK_SZ : len;
write = LZ4F_compressUpdate(ctx, outbuf, outCapacity, inbuf, read, nullptr);
if (LZ4F_isError(write)) {
LOGW("LZ4F encode error: %s\n", LZ4F_getErrorName(write));
return -1;
read = len;
write = outCapacity;
code = LZ4F_decompress(ctx, outbuf, &write, in, &read, nullptr);
if (LZ4F_isError(code)) {
LOGW("LZ4F decode error: %s\n", LZ4F_getErrorName(code));
return false;
}
len -= read;
inbuf += read;
ret += bwrite(outbuf, write);
} while (len != 0);
return ret;
}
~LZ4F_encoder() override {
size_t len = LZ4F_compressEnd(ctx, outbuf, outCapacity, nullptr);
bwrite(outbuf, len);
LZ4F_freeCompressionContext(ctx);
delete[] outbuf;
in += read;
if (!bwrite(outbuf, write))
return false;
} while (len != 0 || write != 0);
return true;
}
private:
LZ4F_compressionContext_t ctx;
LZ4F_decompressionContext_t ctx;
uint8_t *outbuf;
size_t outCapacity;
};
static constexpr size_t BLOCK_SZ = 1 << 22;
class LZ4F_encoder : public out_stream {
public:
explicit LZ4F_encoder(stream_ptr &&base) :
out_stream(std::move(base)), ctx(nullptr), out_buf(nullptr), outCapacity(0) {
LZ4F_createCompressionContext(&ctx, LZ4F_VERSION);
}
int write_header() {
bool write(const void *buf, size_t len) override {
if (!out_buf) {
LZ4F_preferences_t prefs {
.frameInfo = {
.blockSizeID = LZ4F_max4MB,
@ -470,24 +430,50 @@ private:
.autoFlush = 1,
};
outCapacity = LZ4F_compressBound(BLOCK_SZ, &prefs);
outbuf = new uint8_t[outCapacity];
size_t write = LZ4F_compressBegin(ctx, outbuf, outCapacity, &prefs);
return bwrite(outbuf, write);
out_buf = new uint8_t[outCapacity];
size_t write = LZ4F_compressBegin(ctx, out_buf, outCapacity, &prefs);
if (!bwrite(out_buf, write))
return false;
}
if (len == 0)
return true;
auto in = reinterpret_cast<const uint8_t *>(buf);
size_t read, write;
do {
read = len > BLOCK_SZ ? BLOCK_SZ : len;
write = LZ4F_compressUpdate(ctx, out_buf, outCapacity, in, read, nullptr);
if (LZ4F_isError(write)) {
LOGW("LZ4F encode error: %s\n", LZ4F_getErrorName(write));
return false;
}
len -= read;
in += read;
if (!bwrite(out_buf, write))
return false;
} while (len != 0);
return true;
}
~LZ4F_encoder() override {
size_t len = LZ4F_compressEnd(ctx, out_buf, outCapacity, nullptr);
bwrite(out_buf, len);
LZ4F_freeCompressionContext(ctx);
delete[] out_buf;
}
private:
LZ4F_compressionContext_t ctx;
uint8_t *out_buf;
size_t outCapacity;
static constexpr size_t BLOCK_SZ = 1 << 22;
};
class buf_cpr_stream : public chunk_out_stream {
public:
using chunk_out_stream::chunk_out_stream;
ssize_t writeFully(void *buf, size_t len) override {
return write(buf, len);
}
};
class LZ4_decoder : public buf_cpr_stream {
class LZ4_decoder : public chunk_out_stream {
public:
explicit LZ4_decoder(stream_ptr &&base) :
buf_cpr_stream(std::move(base), LZ4_COMPRESSED, sizeof(block_sz) + 4),
chunk_out_stream(std::move(base), LZ4_COMPRESSED, sizeof(block_sz) + 4),
out_buf(new char[LZ4_UNCOMPRESSED]), block_sz(0) {}
~LZ4_decoder() override {
@ -496,10 +482,10 @@ public:
}
protected:
ssize_t write_chunk(const void *buf, size_t len) override {
bool write_chunk(const void *buf, size_t len) override {
// This is an error
if (len != chunk_sz)
return -1;
return false;
auto in = reinterpret_cast<const char *>(buf);
@ -511,14 +497,14 @@ protected:
memcpy(&block_sz, in, sizeof(block_sz));
}
chunk_sz = block_sz;
return 0;
return true;
} else {
int r = LZ4_decompress_safe(in, out_buf, block_sz, LZ4_UNCOMPRESSED);
chunk_sz = sizeof(block_sz);
block_sz = 0;
if (r < 0) {
LOGW("LZ4HC decompression failure (%d)\n", r);
return -1;
return false;
}
return bwrite(out_buf, r);
}
@ -529,10 +515,10 @@ private:
uint32_t block_sz;
};
class LZ4_encoder : public buf_cpr_stream {
class LZ4_encoder : public chunk_out_stream {
public:
explicit LZ4_encoder(stream_ptr &&base, bool lg) :
buf_cpr_stream(std::move(base), LZ4_UNCOMPRESSED),
chunk_out_stream(std::move(base), LZ4_UNCOMPRESSED),
out_buf(new char[LZ4_COMPRESSED]), lg(lg), in_total(0) {
bwrite("\x02\x21\x4c\x18", 4);
}
@ -545,15 +531,13 @@ public:
}
protected:
ssize_t write_chunk(const void *buf, size_t len) override {
bool write_chunk(const void *buf, size_t len) override {
int r = LZ4_compress_HC((const char *) buf, out_buf, len, LZ4_COMPRESSED, LZ4HC_CLEVEL_MAX);
if (r == 0) {
LOGW("LZ4HC compression failure\n");
return -1;
return false;
}
bwrite(&r, sizeof(r));
bwrite(out_buf, r);
return r + sizeof(r);
return bwrite(&r, sizeof(r)) && bwrite(out_buf, r);
}
private:
@ -644,7 +628,7 @@ void decompress(char *infile, const char *outfile) {
strm = get_decoder(type, make_unique<fp_stream>(out_fp));
if (ext) *ext = '.';
}
if (strm->write(buf, len) < 0)
if (!strm->write(buf, len))
LOGE("Decompression error!\n");
}
@ -687,9 +671,9 @@ void compress(const char *method, const char *infile, const char *outfile) {
char buf[4096];
size_t len;
while ((len = fread(buf, 1, sizeof(buf), in_fp))) {
if (strm->write(buf, len) < 0)
if (!strm->write(buf, len))
LOGE("Compression error!\n");
};
}
strm.reset(nullptr);
fclose(in_fp);

View File

@ -11,8 +11,7 @@ public:
virtual ssize_t read(void *buf, size_t len);
virtual ssize_t readFully(void *buf, size_t len);
virtual ssize_t readv(const iovec *iov, int iovcnt);
virtual ssize_t write(const void *buf, size_t len);
virtual ssize_t writeFully(void *buf, size_t len);
virtual bool write(const void *buf, size_t len);
virtual ssize_t writev(const iovec *iov, int iovcnt);
virtual off_t seek(off_t off, int whence);
virtual ~stream() = default;
@ -26,7 +25,7 @@ public:
filter_stream(stream_ptr &&base) : base(std::move(base)) {}
ssize_t read(void *buf, size_t len) override;
ssize_t write(const void *buf, size_t len) override;
bool write(const void *buf, size_t len) override;
// Seeking while filtering does not make sense
off_t seek(off_t off, int whence) final { return stream::seek(off, whence); }
@ -44,16 +43,16 @@ public:
chunk_out_stream(stream_ptr &&base, size_t buf_sz = 4096)
: chunk_out_stream(std::move(base), buf_sz, buf_sz) {}
~chunk_out_stream() { delete[] _buf; }
~chunk_out_stream() override { delete[] _buf; }
// Reading does not make sense
ssize_t read(void *buf, size_t len) final { return stream::read(buf, len); }
ssize_t write(const void *buf, size_t len) final;
bool write(const void *buf, size_t len) final;
protected:
// Classes inheriting this class has to call close() in the destructor
// Classes inheriting this class has to call close() in its destructor
void close();
virtual ssize_t write_chunk(const void *buf, size_t len) = 0;
virtual bool write_chunk(const void *buf, size_t len) = 0;
size_t chunk_sz;
@ -71,7 +70,7 @@ public:
byte_stream(Byte *&buf, size_t &len) : byte_stream(reinterpret_cast<uint8_t *&>(buf), len) {}
ssize_t read(void *buf, size_t len) override;
ssize_t write(const void *buf, size_t len) override;
bool write(const void *buf, size_t len) override;
off_t seek(off_t off, int whence) override;
private:
@ -83,16 +82,23 @@ private:
void resize(size_t new_pos, bool zero = false);
};
class file_stream : public stream {
public:
bool write(const void *buf, size_t len) final;
protected:
virtual ssize_t do_write(const void *buf, size_t len) = 0;
};
// File stream but does not close the file descriptor at any time
class fd_stream : public stream {
class fd_stream : public file_stream {
public:
fd_stream(int fd) : fd(fd) {}
ssize_t read(void *buf, size_t len) override;
ssize_t readv(const iovec *iov, int iovcnt) override;
ssize_t write(const void *buf, size_t len) override;
ssize_t writev(const iovec *iov, int iovcnt) override;
off_t seek(off_t off, int whence) override;
protected:
ssize_t do_write(const void *buf, size_t len) override;
private:
int fd;
};
@ -102,15 +108,14 @@ private:
* ****************************************/
// sFILE -> stream_ptr
class fp_stream final : public stream {
class fp_stream final : public file_stream {
public:
fp_stream(FILE *fp = nullptr) : fp(fp, fclose) {}
fp_stream(sFILE &&fp) : fp(std::move(fp)) {}
ssize_t read(void *buf, size_t len) override;
ssize_t write(const void *buf, size_t len) override;
off_t seek(off_t off, int whence) override;
protected:
ssize_t do_write(const void *buf, size_t len) override;
private:
sFILE fp;
};

View File

@ -13,7 +13,9 @@ static int strm_read(void *v, char *buf, int len) {
static int strm_write(void *v, const char *buf, int len) {
auto strm = static_cast<stream *>(v);
return strm->write(buf, len);
if (!strm->write(buf, len))
return -1;
return len;
}
static fpos_t strm_seek(void *v, fpos_t off, int whence) {
@ -64,33 +66,17 @@ ssize_t stream::readv(const iovec *iov, int iovcnt) {
return read_sz;
}
ssize_t stream::write(const void *buf, size_t len) {
bool stream::write(const void *buf, size_t len) {
LOGE("This stream does not implement write\n");
return -1;
}
ssize_t stream::writeFully(void *buf, size_t len) {
size_t write_sz = 0;
ssize_t ret;
do {
ret = write((byte *) buf + write_sz, len - write_sz);
if (ret < 0) {
if (errno == EINTR)
continue;
return ret;
}
write_sz += ret;
} while (write_sz != len && ret != 0);
return write_sz;
return false;
}
ssize_t stream::writev(const iovec *iov, int iovcnt) {
size_t write_sz = 0;
for (int i = 0; i < iovcnt; ++i) {
auto ret = writeFully(iov[i].iov_base, iov[i].iov_len);
if (ret < 0)
return ret;
write_sz += ret;
if (!write(iov[i].iov_base, iov[i].iov_len))
return write_sz;
write_sz += iov[i].iov_len;
}
return write_sz;
}
@ -105,7 +91,7 @@ ssize_t fp_stream::read(void *buf, size_t len) {
return ret ? ret : (ferror(fp.get()) ? -1 : 0);
}
ssize_t fp_stream::write(const void *buf, size_t len) {
ssize_t fp_stream::do_write(const void *buf, size_t len) {
return fwrite(buf, 1, len, fp.get());
}
@ -117,12 +103,11 @@ ssize_t filter_stream::read(void *buf, size_t len) {
return base->read(buf, len);
}
ssize_t filter_stream::write(const void *buf, size_t len) {
bool filter_stream::write(const void *buf, size_t len) {
return base->write(buf, len);
}
ssize_t chunk_out_stream::write(const void *_in, size_t len) {
ssize_t ret = 0;
bool chunk_out_stream::write(const void *_in, size_t len) {
auto in = static_cast<const uint8_t *>(_in);
while (len) {
if (buf_off + len >= chunk_sz) {
@ -140,10 +125,8 @@ ssize_t chunk_out_stream::write(const void *_in, size_t len) {
in += chunk_sz;
len -= chunk_sz;
}
auto r = write_chunk(src, chunk_sz);
if (r < 0)
return ret;
ret += r;
if (!write_chunk(src, chunk_sz))
return false;
} else {
// Buffer internally
if (!_buf) {
@ -154,7 +137,7 @@ ssize_t chunk_out_stream::write(const void *_in, size_t len) {
break;
}
}
return ret;
return true;
}
void chunk_out_stream::close() {
@ -177,12 +160,12 @@ ssize_t byte_stream::read(void *buf, size_t len) {
return len;
}
ssize_t byte_stream::write(const void *buf, size_t len) {
bool byte_stream::write(const void *buf, size_t len) {
resize(_pos + len);
memcpy(_buf + _pos, buf, len);
_pos += len;
_len = std::max(_len, _pos);
return len;
return true;
}
off_t byte_stream::seek(off_t off, int whence) {
@ -227,7 +210,7 @@ ssize_t fd_stream::readv(const iovec *iov, int iovcnt) {
return ::readv(fd, iov, iovcnt);
}
ssize_t fd_stream::write(const void *buf, size_t len) {
ssize_t fd_stream::do_write(const void *buf, size_t len) {
return ::write(fd, buf, len);
}
@ -238,3 +221,18 @@ ssize_t fd_stream::writev(const iovec *iov, int iovcnt) {
off_t fd_stream::seek(off_t off, int whence) {
return lseek(fd, off, whence);
}
bool file_stream::write(const void *buf, size_t len) {
size_t write_sz = 0;
ssize_t ret;
do {
ret = do_write((byte *) buf + write_sz, len - write_sz);
if (ret < 0) {
if (errno == EINTR)
continue;
return false;
}
write_sz += ret;
} while (write_sz != len && ret != 0);
return true;
}