Remove seek support from streams

This commit is contained in:
topjohnwu 2024-02-28 11:07:53 -08:00
parent 987e5f5413
commit 625a1d6f44
9 changed files with 56 additions and 114 deletions

View File

@ -53,22 +53,20 @@ struct in_stream {
virtual ~in_stream() = default;
};
// A channel is something that is writable, readable, and seekable
struct channel : public out_stream, public in_stream {
virtual off_t seek(off_t off, int whence) = 0;
virtual ~channel() = default;
// A stream is something that is writable and readable
struct stream : public out_stream, public in_stream {
virtual ~stream() = default;
};
using channel_ptr = std::unique_ptr<channel>;
using stream_ptr = std::unique_ptr<stream>;
// Byte channel that dynamically allocates memory
class byte_channel : public channel {
// Byte stream that dynamically allocates memory
class byte_stream : public stream {
public:
byte_channel(heap_data &data) : _data(data) {}
byte_stream(heap_data &data) : _data(data) {}
ssize_t read(void *buf, size_t len) override;
bool write(const void *buf, size_t len) override;
off_t seek(off_t off, int whence) override;
private:
heap_data &_data;
@ -78,13 +76,12 @@ private:
void resize(size_t new_sz, bool zero = false);
};
class rust_vec_channel : public channel {
class rust_vec_stream : public stream {
public:
rust_vec_channel(rust::Vec<uint8_t> &data) : _data(data) {}
rust_vec_stream(rust::Vec<uint8_t> &data) : _data(data) {}
ssize_t read(void *buf, size_t len) override;
bool write(const void *buf, size_t len) override;
off_t seek(off_t off, int whence) override;
private:
rust::Vec<uint8_t> &_data;
@ -93,21 +90,20 @@ private:
void ensure_size(size_t sz, bool zero = false);
};
class file_channel : public channel {
class file_stream : public stream {
public:
bool write(const void *buf, size_t len) final;
protected:
virtual ssize_t do_write(const void *buf, size_t len) = 0;
};
// File channel but does not close the file descriptor at any time
class fd_channel : public file_channel {
// File stream but does not close the file descriptor at any time
class fd_stream : public file_stream {
public:
fd_channel(int fd) : fd(fd) {}
fd_stream(int fd) : fd(fd) {}
ssize_t read(void *buf, size_t len) override;
ssize_t readv(const iovec *iov, int iovcnt) override;
ssize_t writev(const iovec *iov, int iovcnt) override;
off_t seek(off_t off, int whence) override;
protected:
ssize_t do_write(const void *buf, size_t len) override;
private:
@ -115,26 +111,25 @@ private:
};
/* ****************************************
* Bridge between channel class and C stdio
* Bridge between stream class and C stdio
* ****************************************/
// sFILE -> channel_ptr
class fp_channel final : public file_channel {
// sFILE -> stream_ptr
class fp_stream final : public file_stream {
public:
fp_channel(FILE *fp = nullptr) : fp(fp, fclose) {}
fp_channel(sFILE &&fp) : fp(std::move(fp)) {}
fp_stream(FILE *fp = nullptr) : fp(fp, fclose) {}
fp_stream(sFILE &&fp) : fp(std::move(fp)) {}
ssize_t read(void *buf, size_t len) override;
off_t seek(off_t off, int whence) override;
protected:
ssize_t do_write(const void *buf, size_t len) override;
private:
sFILE fp;
};
// channel_ptr -> sFILE
sFILE make_channel_fp(channel_ptr &&strm);
// stream_ptr -> sFILE
sFILE make_stream_fp(stream_ptr &&strm);
template <class T, class... Args>
sFILE make_channel_fp(Args &&... args) {
return make_channel_fp(channel_ptr(new T(std::forward<Args>(args)...)));
sFILE make_stream_fp(Args &&... args) {
return make_stream_fp(stream_ptr(new T(std::forward<Args>(args)...)));
}

View File

@ -205,7 +205,7 @@ private:
uint8_t arr[N];
};
class byte_channel;
class byte_stream;
struct heap_data : public byte_data {
ALLOW_MOVE_ONLY(heap_data)
@ -214,8 +214,8 @@ struct heap_data : public byte_data {
explicit heap_data(size_t sz) : byte_data(calloc(sz, 1), sz) {}
~heap_data() { free(_buf); }
// byte_channel needs to reallocate the internal buffer
friend byte_channel;
// byte_stream needs to reallocate the internal buffer
friend byte_stream;
};
struct owned_fd {

View File

@ -7,30 +7,25 @@
using namespace std;
static int strm_read(void *v, char *buf, int len) {
auto strm = static_cast<channel *>(v);
auto strm = static_cast<stream *>(v);
return strm->read(buf, len);
}
static int strm_write(void *v, const char *buf, int len) {
auto strm = static_cast<channel *>(v);
auto strm = static_cast<stream *>(v);
if (!strm->write(buf, len))
return -1;
return len;
}
static fpos_t strm_seek(void *v, fpos_t off, int whence) {
auto strm = static_cast<channel *>(v);
return strm->seek(off, whence);
}
static int strm_close(void *v) {
auto strm = static_cast<channel *>(v);
auto strm = static_cast<stream *>(v);
delete strm;
return 0;
}
sFILE make_channel_fp(channel_ptr &&strm) {
auto fp = make_file(funopen(strm.release(), strm_read, strm_write, strm_seek, strm_close));
sFILE make_stream_fp(stream_ptr &&strm) {
auto fp = make_file(funopen(strm.release(), strm_read, strm_write, nullptr, strm_close));
setbuf(fp.get(), nullptr);
return fp;
}
@ -71,19 +66,15 @@ ssize_t out_stream::writev(const iovec *iov, int iovcnt) {
return write_sz;
}
ssize_t fp_channel::read(void *buf, size_t len) {
ssize_t fp_stream::read(void *buf, size_t len) {
auto ret = fread(buf, 1, len, fp.get());
return ret ? ret : (ferror(fp.get()) ? -1 : 0);
}
ssize_t fp_channel::do_write(const void *buf, size_t len) {
ssize_t fp_stream::do_write(const void *buf, size_t len) {
return fwrite(buf, 1, len, fp.get());
}
off_t fp_channel::seek(off_t off, int whence) {
return fseek(fp.get(), off, whence);
}
bool filter_out_stream::write(const void *buf, size_t len) {
return base->write(buf, len);
}
@ -131,14 +122,14 @@ void chunk_out_stream::finalize() {
}
}
ssize_t byte_channel::read(void *buf, size_t len) {
ssize_t byte_stream::read(void *buf, size_t len) {
len = std::min((size_t) len, _data._sz- _pos);
memcpy(buf, _data.buf() + _pos, len);
_pos += len;
return len;
}
bool byte_channel::write(const void *buf, size_t len) {
bool byte_stream::write(const void *buf, size_t len) {
resize(_pos + len);
memcpy(_data.buf() + _pos, buf, len);
_pos += len;
@ -146,27 +137,7 @@ bool byte_channel::write(const void *buf, size_t len) {
return true;
}
off_t byte_channel::seek(off_t off, int whence) {
off_t np;
switch (whence) {
case SEEK_CUR:
np = _pos + off;
break;
case SEEK_END:
np = _data._sz+ off;
break;
case SEEK_SET:
np = off;
break;
default:
return -1;
}
resize(np, true);
_pos = np;
return np;
}
void byte_channel::resize(size_t new_sz, bool zero) {
void byte_stream::resize(size_t new_sz, bool zero) {
bool resize = false;
size_t old_cap = _cap;
while (new_sz > _cap) {
@ -180,41 +151,21 @@ void byte_channel::resize(size_t new_sz, bool zero) {
}
}
ssize_t rust_vec_channel::read(void *buf, size_t len) {
ssize_t rust_vec_stream::read(void *buf, size_t len) {
len = std::min<size_t>(len, _data.size() - _pos);
memcpy(buf, _data.data() + _pos, len);
_pos += len;
return len;
}
bool rust_vec_channel::write(const void *buf, size_t len) {
bool rust_vec_stream::write(const void *buf, size_t len) {
ensure_size(_pos + len);
memcpy(_data.data() + _pos, buf, len);
_pos += len;
return true;
}
off_t rust_vec_channel::seek(off_t off, int whence) {
off_t np;
switch (whence) {
case SEEK_CUR:
np = _pos + off;
break;
case SEEK_END:
np = _data.size() + off;
break;
case SEEK_SET:
np = off;
break;
default:
return -1;
}
ensure_size(np, true);
_pos = np;
return np;
}
void rust_vec_channel::ensure_size(size_t sz, bool zero) {
void rust_vec_stream::ensure_size(size_t sz, bool zero) {
size_t old_sz = _data.size();
if (sz > old_sz) {
resize_vec(_data, sz);
@ -223,27 +174,23 @@ void rust_vec_channel::ensure_size(size_t sz, bool zero) {
}
}
ssize_t fd_channel::read(void *buf, size_t len) {
ssize_t fd_stream::read(void *buf, size_t len) {
return ::read(fd, buf, len);
}
ssize_t fd_channel::readv(const iovec *iov, int iovcnt) {
ssize_t fd_stream::readv(const iovec *iov, int iovcnt) {
return ::readv(fd, iov, iovcnt);
}
ssize_t fd_channel::do_write(const void *buf, size_t len) {
ssize_t fd_stream::do_write(const void *buf, size_t len) {
return ::write(fd, buf, len);
}
ssize_t fd_channel::writev(const iovec *iov, int iovcnt) {
ssize_t fd_stream::writev(const iovec *iov, int iovcnt) {
return ::writev(fd, iov, iovcnt);
}
off_t fd_channel::seek(off_t off, int whence) {
return lseek(fd, off, whence);
}
bool file_channel::write(const void *buf, size_t len) {
bool file_stream::write(const void *buf, size_t len) {
size_t write_sz = 0;
ssize_t ret;
do {

View File

@ -16,14 +16,14 @@ using namespace std;
#define SHA_DIGEST_SIZE 20
static void decompress(format_t type, int fd, const void *in, size_t size) {
auto ptr = get_decoder(type, make_unique<fd_channel>(fd));
auto ptr = get_decoder(type, make_unique<fd_stream>(fd));
ptr->write(in, size);
}
static off_t compress(format_t type, int fd, const void *in, size_t size) {
auto prev = lseek(fd, 0, SEEK_CUR);
{
auto strm = get_encoder(type, make_unique<fd_channel>(fd));
auto strm = get_encoder(type, make_unique<fd_stream>(fd));
strm->write(in, size);
}
auto now = lseek(fd, 0, SEEK_CUR);

View File

@ -661,7 +661,7 @@ void decompress(char *infile, const char *outfile) {
}
FILE *out_fp = outfile == "-"sv ? stdout : xfopen(outfile, "we");
strm = get_decoder(type, make_unique<fp_channel>(out_fp));
strm = get_decoder(type, make_unique<fp_stream>(out_fp));
if (ext) *ext = '.';
}
if (!strm->write(buf, len))
@ -702,7 +702,7 @@ void compress(const char *method, const char *infile, const char *outfile) {
out_fp = outfile == "-"sv ? stdout : xfopen(outfile, "we");
}
auto strm = get_encoder(fmt, make_unique<fp_channel>(out_fp));
auto strm = get_encoder(fmt, make_unique<fp_stream>(out_fp));
char buf[4096];
size_t len;
@ -726,7 +726,7 @@ bool decompress(rust::Slice<const uint8_t> buf, int fd) {
return false;
}
auto strm = get_decoder(type, make_unique<fd_channel>(fd));
auto strm = get_decoder(type, make_unique<fd_stream>(fd));
if (!strm->write(buf.data(), buf.length())) {
return false;
}
@ -734,7 +734,7 @@ bool decompress(rust::Slice<const uint8_t> buf, int fd) {
}
bool xz(rust::Slice<const uint8_t> buf, rust::Vec<uint8_t> &out) {
auto strm = get_encoder(XZ, make_unique<rust_vec_channel>(out));
auto strm = get_encoder(XZ, make_unique<rust_vec_stream>(out));
if (!strm->write(buf.data(), buf.length())) {
return false;
}
@ -747,7 +747,7 @@ bool unxz(rust::Slice<const uint8_t> buf, rust::Vec<uint8_t> &out) {
LOGE("Input file is not in xz format!\n");
return false;
}
auto strm = get_decoder(XZ, make_unique<rust_vec_channel>(out));
auto strm = get_decoder(XZ, make_unique<rust_vec_stream>(out));
if (!strm->write(buf.data(), buf.length())) {
return false;
}

View File

@ -220,7 +220,7 @@ bool check_two_stage() {
void unxz_init(const char *init_xz, const char *init) {
LOGD("unxz %s -> %s\n", init_xz, init);
int fd = xopen(init, O_WRONLY | O_CREAT, 0777);
fd_channel ch(fd);
fd_stream ch(fd);
unxz(ch, mmap_data{init_xz});
close(fd);
clone_attr(init_xz, init);

View File

@ -196,7 +196,7 @@ static void extract_files(bool sbin) {
mmap_data magisk(m32);
unlink(m32);
int fd = xopen("magisk32", O_WRONLY | O_CREAT, 0755);
fd_channel ch(fd);
fd_stream ch(fd);
unxz(ch, magisk);
close(fd);
}
@ -204,7 +204,7 @@ static void extract_files(bool sbin) {
mmap_data magisk(m64);
unlink(m64);
int fd = xopen("magisk64", O_WRONLY | O_CREAT, 0755);
fd_channel ch(fd);
fd_stream ch(fd);
unxz(ch, magisk);
close(fd);
xsymlink("./magisk64", "magisk");
@ -215,7 +215,7 @@ static void extract_files(bool sbin) {
mmap_data stub(stub_xz);
unlink(stub_xz);
int fd = xopen("stub.apk", O_WRONLY | O_CREAT, 0);
fd_channel ch(fd);
fd_stream ch(fd);
unxz(ch, stub);
close(fd);
}

View File

@ -42,7 +42,7 @@ static void dump_preload() {
int fd = xopen("/dev/preload.so", O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC, 0644);
if (fd < 0)
return;
fd_channel ch(fd);
fd_stream ch(fd);
if (!unxz(ch, byte_view(init_ld_xz, sizeof(init_ld_xz))))
return;
close(fd);

View File

@ -241,7 +241,7 @@ bool sepolicy::to_file(const char *file) {
// No partial writes are allowed to /sys/fs/selinux/load, thus the reason why we
// first dump everything into memory, then directly call write system call
heap_data data;
auto fp = make_channel_fp<byte_channel>(data);
auto fp = make_stream_fp<byte_stream>(data);
policy_file_t pf;
policy_file_init(&pf);