Cleanup byte_channel implementation

This commit is contained in:
topjohnwu 2023-05-20 14:19:40 -07:00
parent f5aaff2b1e
commit a5768e02ea
4 changed files with 55 additions and 59 deletions

View File

@ -26,19 +26,6 @@ struct file_attr {
char con[128]; char con[128];
}; };
struct byte_data {
using str_pairs = std::initializer_list<std::pair<std::string_view, std::string_view>>;
uint8_t *buf = nullptr;
size_t sz = 0;
int patch(str_pairs list) { return patch(true, list); }
int patch(bool log, str_pairs list);
bool contains(std::string_view pattern, bool log = true) const;
protected:
void swap(byte_data &o);
};
struct mount_info { struct mount_info {
unsigned int id; unsigned int id;
unsigned int parent; unsigned int parent;
@ -56,13 +43,37 @@ struct mount_info {
std::string fs_option; std::string fs_option;
}; };
struct byte_data {
using str_pairs = std::initializer_list<std::pair<std::string_view, std::string_view>>;
uint8_t *buf = nullptr;
size_t sz = 0;
int patch(str_pairs list) { return patch(true, list); }
int patch(bool log, str_pairs list);
bool contains(std::string_view pattern, bool log = true) const;
protected:
void swap(byte_data &o);
};
#define MOVE_ONLY(clazz) \
clazz() = default; \
clazz(const clazz&) = delete; \
clazz(clazz &&o) { swap(o); } \
clazz& operator=(clazz &&o) { swap(o); return *this; }
struct heap_data : public byte_data {
MOVE_ONLY(heap_data)
explicit heap_data(size_t sz) { this->sz = sz; buf = new uint8_t[sz]; }
~heap_data() { free(buf); }
};
struct mmap_data : public byte_data { struct mmap_data : public byte_data {
mmap_data() = default; MOVE_ONLY(mmap_data)
mmap_data(const mmap_data&) = delete;
mmap_data(mmap_data &&o) { swap(o); }
mmap_data(const char *name, bool rw = false); mmap_data(const char *name, bool rw = false);
~mmap_data() { if (buf) munmap(buf, sz); } ~mmap_data() { if (buf) munmap(buf, sz); }
mmap_data& operator=(mmap_data &&other) { swap(other); return *this; }
}; };
extern "C" { extern "C" {

View File

@ -27,26 +27,23 @@ protected:
class chunk_out_stream : public filter_out_stream { class chunk_out_stream : public filter_out_stream {
public: public:
chunk_out_stream(out_strm_ptr &&base, size_t buf_sz, size_t chunk_sz) chunk_out_stream(out_strm_ptr &&base, size_t buf_sz, size_t chunk_sz)
: filter_out_stream(std::move(base)), chunk_sz(chunk_sz), buf_sz(buf_sz) {} : filter_out_stream(std::move(base)), chunk_sz(chunk_sz), data(buf_sz) {}
chunk_out_stream(out_strm_ptr &&base, size_t buf_sz = 4096) chunk_out_stream(out_strm_ptr &&base, size_t buf_sz = 4096)
: chunk_out_stream(std::move(base), buf_sz, buf_sz) {} : chunk_out_stream(std::move(base), buf_sz, buf_sz) {}
~chunk_out_stream() override { delete[] _buf; }
bool write(const void *buf, size_t len) final; bool write(const void *buf, size_t len) final;
protected: protected:
// Classes inheriting this class has to call finalize() in its destructor // Classes inheriting this class has to call finalize() in its destructor
void finalize(); void finalize();
virtual bool write_chunk(const void *buf, size_t len, bool final) = 0; virtual bool write_chunk(const void *buf, size_t len, bool final);
size_t chunk_sz; size_t chunk_sz;
private: private:
size_t buf_sz;
size_t buf_off = 0; size_t buf_off = 0;
uint8_t *_buf = nullptr; heap_data data;
}; };
struct in_stream { struct in_stream {
@ -67,21 +64,18 @@ using channel_ptr = std::unique_ptr<channel>;
// Byte channel that dynamically allocates memory // Byte channel that dynamically allocates memory
class byte_channel : public channel { class byte_channel : public channel {
public: public:
byte_channel(uint8_t *&buf, size_t &len); byte_channel(heap_data &data) : _data(data) {}
template <class Byte>
byte_channel(Byte *&buf, size_t &len) : byte_channel(reinterpret_cast<uint8_t *&>(buf), len) {}
ssize_t read(void *buf, size_t len) override; ssize_t read(void *buf, size_t len) override;
bool write(const void *buf, size_t len) override; bool write(const void *buf, size_t len) override;
off_t seek(off_t off, int whence) override; off_t seek(off_t off, int whence) override;
private: private:
uint8_t *&_buf; heap_data &_data;
size_t &_len;
size_t _pos = 0; size_t _pos = 0;
size_t _cap = 0; size_t _cap = 0;
void resize(size_t new_pos, bool zero = false); void resize(size_t new_sz, bool zero = false);
}; };
class file_channel : public channel { class file_channel : public channel {

View File

@ -95,9 +95,9 @@ bool chunk_out_stream::write(const void *_in, size_t len) {
// Enough input for a chunk // Enough input for a chunk
const uint8_t *src; const uint8_t *src;
if (buf_off) { if (buf_off) {
src = _buf; src = data.buf;
auto copy = chunk_sz - buf_off; auto copy = chunk_sz - buf_off;
memcpy(_buf + buf_off, in, copy); memcpy(data.buf + buf_off, in, copy);
in += copy; in += copy;
len -= copy; len -= copy;
buf_off = 0; buf_off = 0;
@ -110,10 +110,7 @@ bool chunk_out_stream::write(const void *_in, size_t len) {
return false; return false;
} else { } else {
// Buffer internally // Buffer internally
if (!_buf) { memcpy(data.buf + buf_off, in, len);
_buf = new uint8_t[buf_sz];
}
memcpy(_buf + buf_off, in, len);
buf_off += len; buf_off += len;
break; break;
} }
@ -121,33 +118,31 @@ bool chunk_out_stream::write(const void *_in, size_t len) {
return true; return true;
} }
bool chunk_out_stream::write_chunk(const void *buf, size_t len, bool) {
return base->write(buf, len);
}
void chunk_out_stream::finalize() { void chunk_out_stream::finalize() {
if (buf_off) { if (buf_off) {
if (!write_chunk(_buf, buf_off, true)) { if (!write_chunk(data.buf, buf_off, true)) {
LOGE("Error in finalize, file truncated\n"); LOGE("Error in finalize, file truncated\n");
} }
delete[] _buf;
_buf = nullptr;
buf_off = 0; buf_off = 0;
} }
} }
byte_channel::byte_channel(uint8_t *&buf, size_t &len) : _buf(buf), _len(len) {
buf = nullptr;
len = 0;
}
ssize_t byte_channel::read(void *buf, size_t len) { ssize_t byte_channel::read(void *buf, size_t len) {
len = std::min((size_t) len, _len - _pos); len = std::min((size_t) len, _data.sz - _pos);
memcpy(buf, _buf + _pos, len); memcpy(buf, _data.buf + _pos, len);
_pos += len;
return len; return len;
} }
bool byte_channel::write(const void *buf, size_t len) { bool byte_channel::write(const void *buf, size_t len) {
resize(_pos + len); resize(_pos + len);
memcpy(_buf + _pos, buf, len); memcpy(_data.buf + _pos, buf, len);
_pos += len; _pos += len;
_len = std::max(_len, _pos); _data.sz = std::max(_data.sz, _pos);
return true; return true;
} }
@ -158,7 +153,7 @@ off_t byte_channel::seek(off_t off, int whence) {
np = _pos + off; np = _pos + off;
break; break;
case SEEK_END: case SEEK_END:
np = _len + off; np = _data.sz + off;
break; break;
case SEEK_SET: case SEEK_SET:
np = off; np = off;
@ -171,17 +166,17 @@ off_t byte_channel::seek(off_t off, int whence) {
return np; return np;
} }
void byte_channel::resize(size_t new_pos, bool zero) { void byte_channel::resize(size_t new_sz, bool zero) {
bool resize = false; bool resize = false;
size_t old_cap = _cap; size_t old_cap = _cap;
while (new_pos > _cap) { while (new_sz > _cap) {
_cap = _cap ? (_cap << 1) - (_cap >> 1) : 1 << 12; _cap = _cap ? (_cap << 1) - (_cap >> 1) : 1 << 12;
resize = true; resize = true;
} }
if (resize) { if (resize) {
_buf = (uint8_t *) realloc(_buf, _cap); _data.buf = (uint8_t *) realloc(_data.buf, _cap);
if (zero) if (zero)
memset(_buf + old_cap, 0, _cap - old_cap); memset(_data.buf + old_cap, 0, _cap - old_cap);
} }
} }

View File

@ -234,14 +234,10 @@ sepol_impl::~sepol_impl() {
} }
bool sepolicy::to_file(const char *file) { bool sepolicy::to_file(const char *file) {
uint8_t *data;
size_t len;
// No partial writes are allowed to /sys/fs/selinux/load, thus the reason why we // No partial writes are allowed to /sys/fs/selinux/load, thus the reason why we
// first dump everything into memory, then directly call write system call // first dump everything into memory, then directly call write system call
heap_data data;
auto fp = make_channel_fp<byte_channel>(data, len); auto fp = make_channel_fp<byte_channel>(data);
run_finally fin([=]{ free(data); });
policy_file_t pf; policy_file_t pf;
policy_file_init(&pf); policy_file_init(&pf);
@ -255,7 +251,7 @@ bool sepolicy::to_file(const char *file) {
int fd = xopen(file, O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC, 0644); int fd = xopen(file, O_WRONLY | O_CREAT | O_TRUNC | O_CLOEXEC, 0644);
if (fd < 0) if (fd < 0)
return false; return false;
xwrite(fd, data, len); xwrite(fd, data.buf, data.sz);
close(fd); close(fd);
return true; return true;