From 5787aa1078db1ea618a50c6a229e6925ea7f31a9 Mon Sep 17 00:00:00 2001 From: topjohnwu Date: Sun, 21 Nov 2021 04:38:22 -0800 Subject: [PATCH] Stream should always write all bytes --- native/jni/magiskboot/compress.cpp | 348 +++++++++++++--------------- native/jni/utils/include/stream.hpp | 35 +-- native/jni/utils/stream.cpp | 66 +++--- 3 files changed, 218 insertions(+), 231 deletions(-) diff --git a/native/jni/magiskboot/compress.cpp b/native/jni/magiskboot/compress.cpp index dba7d22dc..da46dde26 100644 --- a/native/jni/magiskboot/compress.cpp +++ b/native/jni/magiskboot/compress.cpp @@ -29,30 +29,26 @@ constexpr size_t ZOPFLI_CHUNK = ZOPFLI_MASTER_BLOCK_SIZE; constexpr size_t ZOPFLI_CHUNK = CHUNK; #endif -class cpr_stream : public filter_stream { -public: +class out_stream : public filter_stream { using filter_stream::filter_stream; using stream::read; - ssize_t writeFully(void *buf, size_t len) override { - return write(buf, len); - } }; -class gz_strm : public cpr_stream { +class gz_strm : public out_stream { public: - ssize_t write(const void *buf, size_t len) override { - return len ? write(buf, len, Z_NO_FLUSH) : 0; + bool write(const void *buf, size_t len) override { + return len == 0 || write(buf, len, Z_NO_FLUSH); } ~gz_strm() override { write(nullptr, 0, Z_FINISH); switch(mode) { - case DECODE: - inflateEnd(&strm); - break; - case ENCODE: - deflateEnd(&strm); - break; + case DECODE: + inflateEnd(&strm); + break; + case ENCODE: + deflateEnd(&strm); + break; } } @@ -63,14 +59,14 @@ protected: } mode; gz_strm(mode_t mode, stream_ptr &&base) : - cpr_stream(std::move(base)), mode(mode), strm{}, outbuf{0} { + out_stream(std::move(base)), mode(mode), strm{}, outbuf{0} { switch(mode) { - case DECODE: - inflateInit2(&strm, 15 | 16); - break; - case ENCODE: - deflateInit2(&strm, 9, Z_DEFLATED, 15 | 16, 8, Z_DEFAULT_STRATEGY); - break; + case DECODE: + inflateInit2(&strm, 15 | 16); + break; + case ENCODE: + deflateInit2(&strm, 9, Z_DEFLATED, 15 | 16, 8, Z_DEFAULT_STRATEGY); + break; } } @@ -78,8 +74,7 @@ private: z_stream strm; uint8_t outbuf[CHUNK]; - ssize_t write(const void *buf, size_t len, int flush) { - size_t ret = 0; + bool write(const void *buf, size_t len, int flush) { strm.next_in = (Bytef *) buf; strm.avail_in = len; do { @@ -96,11 +91,12 @@ private: } if (code == Z_STREAM_ERROR) { LOGW("gzip %s failed (%d)\n", mode ? "encode" : "decode", code); - return -1; + return false; } - ret += bwrite(outbuf, sizeof(outbuf) - strm.avail_out); + if (!bwrite(outbuf, sizeof(outbuf) - strm.avail_out)) + return false; } while (strm.avail_out == 0); - return ret; + return true; } }; @@ -114,14 +110,40 @@ public: explicit gz_encoder(stream_ptr &&base) : gz_strm(ENCODE, std::move(base)) {}; }; -class zopfli_encoder : public cpr_stream { +class zopfli_encoder : public out_stream { public: - ssize_t write(const void *buf, size_t len) override { - return len ? write(static_cast(buf), len) : 0; + bool write(const void *buf, size_t len) override { + if (len == 0) + return true; + + auto in = static_cast(buf); + in_size += len; + crcvalue = crc32_z(crcvalue, in, len); + + for (size_t offset = 0; offset < len; offset += ZOPFLI_CHUNK) { + size_t end_offset = std::min(len, offset + ZOPFLI_CHUNK); + ZopfliDeflatePart(&zo, 2, 0, in, offset, end_offset, &bp, &out, &outsize); + + if (bp) { + // The last byte is not complete + if (!bwrite(out, outsize - 1)) + return false; + uint8_t b = out[outsize - 1]; + free_out(); + ZOPFLI_APPEND_DATA(b, &out, &outsize); + } else { + if (!bwrite(out, outsize)) + return false; + free_out(); + } + } + + return true; } - explicit zopfli_encoder(stream_ptr &&base) : cpr_stream(std::move(base)), - zo({}), out(nullptr), outsize(0), bp(0), crcvalue(crc32_z(0L, Z_NULL, 0)), in_size(0) { + explicit zopfli_encoder(stream_ptr &&base) : + out_stream(std::move(base)), zo({}), out(nullptr), outsize(0), bp(0), + crcvalue(crc32_z(0L, Z_NULL, 0)), in_size(0) { ZopfliInitOptions(&zo); // Speed things up a bit, this still leads to better compression than zlib @@ -174,36 +196,12 @@ private: out = nullptr; outsize = 0; } - - ssize_t write(const unsigned char *buf, size_t len) { - ssize_t ret = 0; - in_size += len; - crcvalue = crc32_z(crcvalue, buf, len); - - for (size_t offset = 0; offset < len; offset += ZOPFLI_CHUNK) { - size_t end_offset = std::min(len, offset + ZOPFLI_CHUNK); - ZopfliDeflatePart(&zo, 2, 0, buf, offset, end_offset, &bp, &out, &outsize); - - if (bp) { - // The last byte is not complete - ret += bwrite(out, outsize - 1); - uint8_t b = out[outsize - 1]; - free_out(); - ZOPFLI_APPEND_DATA(b, &out, &outsize); - } else { - ret += bwrite(out, outsize); - free_out(); - } - } - - return ret; - } }; -class bz_strm : public cpr_stream { +class bz_strm : public out_stream { public: - ssize_t write(const void *buf, size_t len) override { - return len ? write(buf, len, BZ_RUN) : 0; + bool write(const void *buf, size_t len) override { + return len == 0 || write(buf, len, BZ_RUN); } ~bz_strm() override { @@ -225,14 +223,14 @@ protected: } mode; bz_strm(mode_t mode, stream_ptr &&base) : - cpr_stream(std::move(base)), mode(mode), strm{}, outbuf{0} { + out_stream(std::move(base)), mode(mode), strm{}, outbuf{0} { switch(mode) { - case DECODE: - BZ2_bzDecompressInit(&strm, 0, 0); - break; - case ENCODE: - BZ2_bzCompressInit(&strm, 9, 0, 0); - break; + case DECODE: + BZ2_bzDecompressInit(&strm, 0, 0); + break; + case ENCODE: + BZ2_bzCompressInit(&strm, 9, 0, 0); + break; } } @@ -240,8 +238,7 @@ private: bz_stream strm; char outbuf[CHUNK]; - ssize_t write(const void *buf, size_t len, int flush) { - size_t ret = 0; + bool write(const void *buf, size_t len, int flush) { strm.next_in = (char *) buf; strm.avail_in = len; do { @@ -249,20 +246,21 @@ private: strm.avail_out = sizeof(outbuf); strm.next_out = outbuf; switch(mode) { - case DECODE: - code = BZ2_bzDecompress(&strm); - break; - case ENCODE: - code = BZ2_bzCompress(&strm, flush); - break; + case DECODE: + code = BZ2_bzDecompress(&strm); + break; + case ENCODE: + code = BZ2_bzCompress(&strm, flush); + break; } if (code < 0) { LOGW("bzip2 %s failed (%d)\n", mode ? "encode" : "decode", code); - return -1; + return false; } - ret += bwrite(outbuf, sizeof(outbuf) - strm.avail_out); + if (!bwrite(outbuf, sizeof(outbuf) - strm.avail_out)) + return false; } while (strm.avail_out == 0); - return ret; + return true; } }; @@ -276,10 +274,10 @@ public: explicit bz_encoder(stream_ptr &&base) : bz_strm(ENCODE, std::move(base)) {}; }; -class lzma_strm : public cpr_stream { +class lzma_strm : public out_stream { public: - ssize_t write(const void *buf, size_t len) override { - return len ? write(buf, len, LZMA_RUN) : 0; + bool write(const void *buf, size_t len) override { + return len == 0 || write(buf, len, LZMA_RUN); } ~lzma_strm() override { @@ -295,7 +293,7 @@ protected: } mode; lzma_strm(mode_t mode, stream_ptr &&base) : - cpr_stream(std::move(base)), mode(mode), strm(LZMA_STREAM_INIT), outbuf{0} { + out_stream(std::move(base)), mode(mode), strm(LZMA_STREAM_INIT), outbuf{0} { lzma_options_lzma opt; // Initialize preset @@ -307,15 +305,15 @@ protected: lzma_ret code; switch(mode) { - case DECODE: - code = lzma_auto_decoder(&strm, UINT64_MAX, 0); - break; - case ENCODE_XZ: - code = lzma_stream_encoder(&strm, filters, LZMA_CHECK_CRC32); - break; - case ENCODE_LZMA: - code = lzma_alone_encoder(&strm, &opt); - break; + case DECODE: + code = lzma_auto_decoder(&strm, UINT64_MAX, 0); + break; + case ENCODE_XZ: + code = lzma_stream_encoder(&strm, filters, LZMA_CHECK_CRC32); + break; + case ENCODE_LZMA: + code = lzma_alone_encoder(&strm, &opt); + break; } if (code != LZMA_OK) { LOGE("LZMA initialization failed (%d)\n", code); @@ -326,8 +324,7 @@ private: lzma_stream strm; uint8_t outbuf[CHUNK]; - ssize_t write(const void *buf, size_t len, lzma_action flush) { - size_t ret = 0; + bool write(const void *buf, size_t len, lzma_action flush) { strm.next_in = (uint8_t *) buf; strm.avail_in = len; do { @@ -336,11 +333,12 @@ private: int code = lzma_code(&strm, flush); if (code != LZMA_OK && code != LZMA_STREAM_END) { LOGW("LZMA %s failed (%d)\n", mode ? "encode" : "decode", code); - return -1; + return false; } - ret += bwrite(outbuf, sizeof(outbuf) - strm.avail_out); + if (!bwrite(outbuf, sizeof(outbuf) - strm.avail_out)) + return false; } while (strm.avail_out == 0); - return ret; + return true; } }; @@ -359,10 +357,10 @@ public: explicit lzma_encoder(stream_ptr &&base) : lzma_strm(ENCODE_LZMA, std::move(base)) {} }; -class LZ4F_decoder : public cpr_stream { +class LZ4F_decoder : public out_stream { public: explicit LZ4F_decoder(stream_ptr &&base) : - cpr_stream(std::move(base)), ctx(nullptr), outbuf(nullptr), outCapacity(0) { + out_stream(std::move(base)), ctx(nullptr), outbuf(nullptr), outCapacity(0) { LZ4F_createDecompressionContext(&ctx, LZ4F_VERSION); } @@ -371,124 +369,112 @@ public: delete[] outbuf; } - ssize_t write(const void *buf, size_t len) override { - size_t ret = 0; - auto inbuf = reinterpret_cast(buf); - if (!outbuf) - read_header(inbuf, len); + bool write(const void *buf, size_t len) override { + auto in = reinterpret_cast(buf); + if (!outbuf) { + size_t read = len; + LZ4F_frameInfo_t info; + LZ4F_getFrameInfo(ctx, &info, in, &read); + switch (info.blockSizeID) { + case LZ4F_default: + case LZ4F_max64KB: outCapacity = 1 << 16; break; + case LZ4F_max256KB: outCapacity = 1 << 18; break; + case LZ4F_max1MB: outCapacity = 1 << 20; break; + case LZ4F_max4MB: outCapacity = 1 << 22; break; + } + outbuf = new uint8_t[outCapacity]; + in += read; + len -= read; + } size_t read, write; LZ4F_errorCode_t code; do { read = len; write = outCapacity; - code = LZ4F_decompress(ctx, outbuf, &write, inbuf, &read, nullptr); + code = LZ4F_decompress(ctx, outbuf, &write, in, &read, nullptr); if (LZ4F_isError(code)) { LOGW("LZ4F decode error: %s\n", LZ4F_getErrorName(code)); - return -1; + return false; } len -= read; - inbuf += read; - ret += bwrite(outbuf, write); + in += read; + if (!bwrite(outbuf, write)) + return false; } while (len != 0 || write != 0); - return ret; + return true; } private: LZ4F_decompressionContext_t ctx; uint8_t *outbuf; size_t outCapacity; - - void read_header(const uint8_t *&in, size_t &size) { - size_t read = size; - LZ4F_frameInfo_t info; - LZ4F_getFrameInfo(ctx, &info, in, &read); - switch (info.blockSizeID) { - case LZ4F_default: - case LZ4F_max64KB: outCapacity = 1 << 16; break; - case LZ4F_max256KB: outCapacity = 1 << 18; break; - case LZ4F_max1MB: outCapacity = 1 << 20; break; - case LZ4F_max4MB: outCapacity = 1 << 22; break; - } - outbuf = new uint8_t[outCapacity]; - in += read; - size -= read; - } }; -class LZ4F_encoder : public cpr_stream { +class LZ4F_encoder : public out_stream { public: explicit LZ4F_encoder(stream_ptr &&base) : - cpr_stream(std::move(base)), ctx(nullptr), outbuf(nullptr), outCapacity(0) { + out_stream(std::move(base)), ctx(nullptr), out_buf(nullptr), outCapacity(0) { LZ4F_createCompressionContext(&ctx, LZ4F_VERSION); } - ssize_t write(const void *buf, size_t len) override { - size_t ret = 0; - if (!outbuf) - ret += write_header(); + bool write(const void *buf, size_t len) override { + if (!out_buf) { + LZ4F_preferences_t prefs { + .frameInfo = { + .blockSizeID = LZ4F_max4MB, + .blockMode = LZ4F_blockIndependent, + .contentChecksumFlag = LZ4F_contentChecksumEnabled, + .blockChecksumFlag = LZ4F_noBlockChecksum, + }, + .compressionLevel = 9, + .autoFlush = 1, + }; + outCapacity = LZ4F_compressBound(BLOCK_SZ, &prefs); + out_buf = new uint8_t[outCapacity]; + size_t write = LZ4F_compressBegin(ctx, out_buf, outCapacity, &prefs); + if (!bwrite(out_buf, write)) + return false; + } if (len == 0) - return 0; - auto inbuf = reinterpret_cast(buf); + return true; + + auto in = reinterpret_cast(buf); size_t read, write; do { read = len > BLOCK_SZ ? BLOCK_SZ : len; - write = LZ4F_compressUpdate(ctx, outbuf, outCapacity, inbuf, read, nullptr); + write = LZ4F_compressUpdate(ctx, out_buf, outCapacity, in, read, nullptr); if (LZ4F_isError(write)) { LOGW("LZ4F encode error: %s\n", LZ4F_getErrorName(write)); - return -1; + return false; } len -= read; - inbuf += read; - ret += bwrite(outbuf, write); + in += read; + if (!bwrite(out_buf, write)) + return false; } while (len != 0); - return ret; + return true; } ~LZ4F_encoder() override { - size_t len = LZ4F_compressEnd(ctx, outbuf, outCapacity, nullptr); - bwrite(outbuf, len); + size_t len = LZ4F_compressEnd(ctx, out_buf, outCapacity, nullptr); + bwrite(out_buf, len); LZ4F_freeCompressionContext(ctx); - delete[] outbuf; + delete[] out_buf; } private: LZ4F_compressionContext_t ctx; - uint8_t *outbuf; + uint8_t *out_buf; size_t outCapacity; static constexpr size_t BLOCK_SZ = 1 << 22; - - int write_header() { - LZ4F_preferences_t prefs { - .frameInfo = { - .blockSizeID = LZ4F_max4MB, - .blockMode = LZ4F_blockIndependent, - .contentChecksumFlag = LZ4F_contentChecksumEnabled, - .blockChecksumFlag = LZ4F_noBlockChecksum, - }, - .compressionLevel = 9, - .autoFlush = 1, - }; - outCapacity = LZ4F_compressBound(BLOCK_SZ, &prefs); - outbuf = new uint8_t[outCapacity]; - size_t write = LZ4F_compressBegin(ctx, outbuf, outCapacity, &prefs); - return bwrite(outbuf, write); - } }; -class buf_cpr_stream : public chunk_out_stream { -public: - using chunk_out_stream::chunk_out_stream; - ssize_t writeFully(void *buf, size_t len) override { - return write(buf, len); - } -}; - -class LZ4_decoder : public buf_cpr_stream { +class LZ4_decoder : public chunk_out_stream { public: explicit LZ4_decoder(stream_ptr &&base) : - buf_cpr_stream(std::move(base), LZ4_COMPRESSED, sizeof(block_sz) + 4), - out_buf(new char[LZ4_UNCOMPRESSED]), block_sz(0) {} + chunk_out_stream(std::move(base), LZ4_COMPRESSED, sizeof(block_sz) + 4), + out_buf(new char[LZ4_UNCOMPRESSED]), block_sz(0) {} ~LZ4_decoder() override { close(); @@ -496,10 +482,10 @@ public: } protected: - ssize_t write_chunk(const void *buf, size_t len) override { + bool write_chunk(const void *buf, size_t len) override { // This is an error if (len != chunk_sz) - return -1; + return false; auto in = reinterpret_cast(buf); @@ -511,14 +497,14 @@ protected: memcpy(&block_sz, in, sizeof(block_sz)); } chunk_sz = block_sz; - return 0; + return true; } else { int r = LZ4_decompress_safe(in, out_buf, block_sz, LZ4_UNCOMPRESSED); chunk_sz = sizeof(block_sz); block_sz = 0; if (r < 0) { LOGW("LZ4HC decompression failure (%d)\n", r); - return -1; + return false; } return bwrite(out_buf, r); } @@ -529,11 +515,11 @@ private: uint32_t block_sz; }; -class LZ4_encoder : public buf_cpr_stream { +class LZ4_encoder : public chunk_out_stream { public: explicit LZ4_encoder(stream_ptr &&base, bool lg) : - buf_cpr_stream(std::move(base), LZ4_UNCOMPRESSED), - out_buf(new char[LZ4_COMPRESSED]), lg(lg), in_total(0) { + chunk_out_stream(std::move(base), LZ4_UNCOMPRESSED), + out_buf(new char[LZ4_COMPRESSED]), lg(lg), in_total(0) { bwrite("\x02\x21\x4c\x18", 4); } @@ -545,15 +531,13 @@ public: } protected: - ssize_t write_chunk(const void *buf, size_t len) override { + bool write_chunk(const void *buf, size_t len) override { int r = LZ4_compress_HC((const char *) buf, out_buf, len, LZ4_COMPRESSED, LZ4HC_CLEVEL_MAX); if (r == 0) { LOGW("LZ4HC compression failure\n"); - return -1; + return false; } - bwrite(&r, sizeof(r)); - bwrite(out_buf, r); - return r + sizeof(r); + return bwrite(&r, sizeof(r)) && bwrite(out_buf, r); } private: @@ -644,7 +628,7 @@ void decompress(char *infile, const char *outfile) { strm = get_decoder(type, make_unique(out_fp)); if (ext) *ext = '.'; } - if (strm->write(buf, len) < 0) + if (!strm->write(buf, len)) LOGE("Decompression error!\n"); } @@ -687,9 +671,9 @@ void compress(const char *method, const char *infile, const char *outfile) { char buf[4096]; size_t len; while ((len = fread(buf, 1, sizeof(buf), in_fp))) { - if (strm->write(buf, len) < 0) + if (!strm->write(buf, len)) LOGE("Compression error!\n"); - }; + } strm.reset(nullptr); fclose(in_fp); diff --git a/native/jni/utils/include/stream.hpp b/native/jni/utils/include/stream.hpp index 7873a055b..0afdbb48d 100644 --- a/native/jni/utils/include/stream.hpp +++ b/native/jni/utils/include/stream.hpp @@ -11,8 +11,7 @@ public: virtual ssize_t read(void *buf, size_t len); virtual ssize_t readFully(void *buf, size_t len); virtual ssize_t readv(const iovec *iov, int iovcnt); - virtual ssize_t write(const void *buf, size_t len); - virtual ssize_t writeFully(void *buf, size_t len); + virtual bool write(const void *buf, size_t len); virtual ssize_t writev(const iovec *iov, int iovcnt); virtual off_t seek(off_t off, int whence); virtual ~stream() = default; @@ -26,7 +25,7 @@ public: filter_stream(stream_ptr &&base) : base(std::move(base)) {} ssize_t read(void *buf, size_t len) override; - ssize_t write(const void *buf, size_t len) override; + bool write(const void *buf, size_t len) override; // Seeking while filtering does not make sense off_t seek(off_t off, int whence) final { return stream::seek(off, whence); } @@ -44,16 +43,16 @@ public: chunk_out_stream(stream_ptr &&base, size_t buf_sz = 4096) : chunk_out_stream(std::move(base), buf_sz, buf_sz) {} - ~chunk_out_stream() { delete[] _buf; } + ~chunk_out_stream() override { delete[] _buf; } // Reading does not make sense ssize_t read(void *buf, size_t len) final { return stream::read(buf, len); } - ssize_t write(const void *buf, size_t len) final; + bool write(const void *buf, size_t len) final; protected: - // Classes inheriting this class has to call close() in the destructor + // Classes inheriting this class has to call close() in its destructor void close(); - virtual ssize_t write_chunk(const void *buf, size_t len) = 0; + virtual bool write_chunk(const void *buf, size_t len) = 0; size_t chunk_sz; @@ -71,7 +70,7 @@ public: byte_stream(Byte *&buf, size_t &len) : byte_stream(reinterpret_cast(buf), len) {} ssize_t read(void *buf, size_t len) override; - ssize_t write(const void *buf, size_t len) override; + bool write(const void *buf, size_t len) override; off_t seek(off_t off, int whence) override; private: @@ -83,16 +82,23 @@ private: void resize(size_t new_pos, bool zero = false); }; +class file_stream : public stream { +public: + bool write(const void *buf, size_t len) final; +protected: + virtual ssize_t do_write(const void *buf, size_t len) = 0; +}; + // File stream but does not close the file descriptor at any time -class fd_stream : public stream { +class fd_stream : public file_stream { public: fd_stream(int fd) : fd(fd) {} ssize_t read(void *buf, size_t len) override; ssize_t readv(const iovec *iov, int iovcnt) override; - ssize_t write(const void *buf, size_t len) override; ssize_t writev(const iovec *iov, int iovcnt) override; off_t seek(off_t off, int whence) override; - +protected: + ssize_t do_write(const void *buf, size_t len) override; private: int fd; }; @@ -102,15 +108,14 @@ private: * ****************************************/ // sFILE -> stream_ptr -class fp_stream final : public stream { +class fp_stream final : public file_stream { public: fp_stream(FILE *fp = nullptr) : fp(fp, fclose) {} fp_stream(sFILE &&fp) : fp(std::move(fp)) {} - ssize_t read(void *buf, size_t len) override; - ssize_t write(const void *buf, size_t len) override; off_t seek(off_t off, int whence) override; - +protected: + ssize_t do_write(const void *buf, size_t len) override; private: sFILE fp; }; diff --git a/native/jni/utils/stream.cpp b/native/jni/utils/stream.cpp index bd63c9bb1..f9d2114bd 100644 --- a/native/jni/utils/stream.cpp +++ b/native/jni/utils/stream.cpp @@ -13,7 +13,9 @@ static int strm_read(void *v, char *buf, int len) { static int strm_write(void *v, const char *buf, int len) { auto strm = static_cast(v); - return strm->write(buf, len); + if (!strm->write(buf, len)) + return -1; + return len; } static fpos_t strm_seek(void *v, fpos_t off, int whence) { @@ -64,33 +66,17 @@ ssize_t stream::readv(const iovec *iov, int iovcnt) { return read_sz; } -ssize_t stream::write(const void *buf, size_t len) { +bool stream::write(const void *buf, size_t len) { LOGE("This stream does not implement write\n"); - return -1; -} - -ssize_t stream::writeFully(void *buf, size_t len) { - size_t write_sz = 0; - ssize_t ret; - do { - ret = write((byte *) buf + write_sz, len - write_sz); - if (ret < 0) { - if (errno == EINTR) - continue; - return ret; - } - write_sz += ret; - } while (write_sz != len && ret != 0); - return write_sz; + return false; } ssize_t stream::writev(const iovec *iov, int iovcnt) { size_t write_sz = 0; for (int i = 0; i < iovcnt; ++i) { - auto ret = writeFully(iov[i].iov_base, iov[i].iov_len); - if (ret < 0) - return ret; - write_sz += ret; + if (!write(iov[i].iov_base, iov[i].iov_len)) + return write_sz; + write_sz += iov[i].iov_len; } return write_sz; } @@ -105,7 +91,7 @@ ssize_t fp_stream::read(void *buf, size_t len) { return ret ? ret : (ferror(fp.get()) ? -1 : 0); } -ssize_t fp_stream::write(const void *buf, size_t len) { +ssize_t fp_stream::do_write(const void *buf, size_t len) { return fwrite(buf, 1, len, fp.get()); } @@ -117,12 +103,11 @@ ssize_t filter_stream::read(void *buf, size_t len) { return base->read(buf, len); } -ssize_t filter_stream::write(const void *buf, size_t len) { +bool filter_stream::write(const void *buf, size_t len) { return base->write(buf, len); } -ssize_t chunk_out_stream::write(const void *_in, size_t len) { - ssize_t ret = 0; +bool chunk_out_stream::write(const void *_in, size_t len) { auto in = static_cast(_in); while (len) { if (buf_off + len >= chunk_sz) { @@ -140,10 +125,8 @@ ssize_t chunk_out_stream::write(const void *_in, size_t len) { in += chunk_sz; len -= chunk_sz; } - auto r = write_chunk(src, chunk_sz); - if (r < 0) - return ret; - ret += r; + if (!write_chunk(src, chunk_sz)) + return false; } else { // Buffer internally if (!_buf) { @@ -154,7 +137,7 @@ ssize_t chunk_out_stream::write(const void *_in, size_t len) { break; } } - return ret; + return true; } void chunk_out_stream::close() { @@ -177,12 +160,12 @@ ssize_t byte_stream::read(void *buf, size_t len) { return len; } -ssize_t byte_stream::write(const void *buf, size_t len) { +bool byte_stream::write(const void *buf, size_t len) { resize(_pos + len); memcpy(_buf + _pos, buf, len); _pos += len; _len = std::max(_len, _pos); - return len; + return true; } off_t byte_stream::seek(off_t off, int whence) { @@ -227,7 +210,7 @@ ssize_t fd_stream::readv(const iovec *iov, int iovcnt) { return ::readv(fd, iov, iovcnt); } -ssize_t fd_stream::write(const void *buf, size_t len) { +ssize_t fd_stream::do_write(const void *buf, size_t len) { return ::write(fd, buf, len); } @@ -238,3 +221,18 @@ ssize_t fd_stream::writev(const iovec *iov, int iovcnt) { off_t fd_stream::seek(off_t off, int whence) { return lseek(fd, off, whence); } + +bool file_stream::write(const void *buf, size_t len) { + size_t write_sz = 0; + ssize_t ret; + do { + ret = do_write((byte *) buf + write_sz, len - write_sz); + if (ret < 0) { + if (errno == EINTR) + continue; + return false; + } + write_sz += ret; + } while (write_sz != len && ret != 0); + return true; +}