From ee4dad7a13c7181928e2588f657e25015fd1a8f0 Mon Sep 17 00:00:00 2001 From: topjohnwu Date: Wed, 7 Jun 2023 16:49:40 -0700 Subject: [PATCH] Bridge C++ bytes with Rust &[u8] --- native/src/base/files.cpp | 58 +------------------- native/src/base/files.hpp | 64 +--------------------- native/src/base/logging.cpp | 2 +- native/src/base/misc.cpp | 44 +++++++++++++++ native/src/base/misc.hpp | 103 ++++++++++++++++++++++++++++++++---- native/src/base/stream.cpp | 2 +- native/src/sepolicy/api.cpp | 4 +- 7 files changed, 145 insertions(+), 132 deletions(-) diff --git a/native/src/base/files.cpp b/native/src/base/files.cpp index 9c53e7603..fa7b3a828 100644 --- a/native/src/base/files.cpp +++ b/native/src/base/files.cpp @@ -7,13 +7,12 @@ #include #include -#include #include using namespace std; int fd_pathat(int dirfd, const char *name, char *path, size_t size) { - if (fd_path(dirfd, u8_mut_slice(path, size)) < 0) + if (fd_path(dirfd, byte_data(path, size)) < 0) return -1; auto len = strlen(path); path[len] = '/'; @@ -434,61 +433,6 @@ sFILE make_file(FILE *fp) { return sFILE(fp, [](FILE *fp){ return fp ? fclose(fp) : 1; }); } -byte_view::byte_view(string_view s, bool with_nul) -: byte_view(static_cast(s.data()), s.length()) { - if (with_nul && s[s.length()] == '\0') { - ++_sz; - } -} - -bool byte_view::contains(byte_view pattern) const { - if (_buf == nullptr) - return false; - for (uint8_t *p = _buf, *eof = _buf + _sz; p < eof; ++p) { - if (memcmp(p, pattern.buf(), pattern.sz()) == 0) { - return true; - } - } - return false; -} - -bool byte_view::equals(byte_view o) const { - return _sz == o._sz && memcmp(_buf, o._buf, _sz) == 0; -} - -void byte_view::swap(byte_view &o) { - std::swap(_buf, o._buf); - std::swap(_sz, o._sz); -} - -heap_data byte_view::clone() const { - heap_data copy(_sz); - memcpy(copy._buf, _buf, _sz); - return copy; -} - -vector byte_data::patch(byte_view from, byte_view to) { - vector v; - if (_buf == nullptr) - return v; - auto p = _buf; - auto eof = _buf + _sz; - while (p < eof) { - p = static_cast(memmem(p, eof - p, from.buf(), from.sz())); - if (p == nullptr) - return v; - memset(p, 0, from.sz()); - memcpy(p, to.buf(), to.sz()); - v.push_back(p - _buf); - p += from.sz(); - } - return v; -} - -void heap_data::realloc(size_t sz) { - _buf = static_cast(::realloc(_buf, sz)); -} - mmap_data::mmap_data(const char *name, bool rw) { int fd = xopen(name, (rw ? O_RDWR : O_RDONLY) | O_CLOEXEC); if (fd < 0) diff --git a/native/src/base/files.hpp b/native/src/base/files.hpp index bb3713ae3..41a79f636 100644 --- a/native/src/base/files.hpp +++ b/native/src/base/files.hpp @@ -6,7 +6,7 @@ #include #include -#include "xwrap.hpp" +#include "misc.hpp" template static inline T align_to(T v, int a) { @@ -41,68 +41,8 @@ struct mount_info { std::string fs_option; }; -struct heap_data; - -struct byte_view { - byte_view() : _buf(nullptr), _sz(0) {} - byte_view(const void *buf, size_t sz) : _buf((uint8_t *) buf), _sz(sz) {} - - // byte_view, or any of its sub-type, can be copied as byte_view - byte_view(const byte_view &o) : _buf(o._buf), _sz(o._sz) {} - - // String as bytes - byte_view(std::string_view s, bool with_nul = true); - byte_view(const char *s, bool with_nul = true) - : byte_view(std::string_view(s), with_nul) {} - byte_view(const std::string &s, bool with_nul = true) - : byte_view(std::string_view(s), with_nul) {} - - // Vector as bytes - byte_view(const std::vector &v) : byte_view(v.data(), v.size()) {} - - const uint8_t *buf() const { return _buf; } - size_t sz() const { return _sz; } - - bool contains(byte_view pattern) const; - bool equals(byte_view o) const; - heap_data clone() const; - -protected: - uint8_t *_buf; - size_t _sz; - - void swap(byte_view &o); -}; - -struct byte_data : public byte_view { - byte_data() = default; - byte_data(void *buf, size_t sz) : byte_view(buf, sz) {} - - using byte_view::buf; - using byte_view::sz; - uint8_t *buf() { return _buf; } - size_t &sz() { return _sz; } - - std::vector patch(byte_view from, byte_view to); -}; - -#define MOVE_ONLY(clazz) \ -clazz() = default; \ -clazz(const clazz&) = delete; \ -clazz(clazz &&o) { swap(o); } \ -clazz& operator=(clazz &&o) { swap(o); return *this; } - -struct heap_data : public byte_data { - MOVE_ONLY(heap_data) - - explicit heap_data(size_t sz) : byte_data(malloc(sz), sz) {} - ~heap_data() { free(_buf); } - - void realloc(size_t sz); -}; - struct mmap_data : public byte_data { - MOVE_ONLY(mmap_data) + ALLOW_MOVE_ONLY(mmap_data) explicit mmap_data(const char *name, bool rw = false); mmap_data(int fd, size_t sz, bool rw = false) { init(fd, sz, rw); } diff --git a/native/src/base/logging.cpp b/native/src/base/logging.cpp index 8f25df642..fa031d0ed 100644 --- a/native/src/base/logging.cpp +++ b/native/src/base/logging.cpp @@ -15,7 +15,7 @@ static int fmt_and_log_with_rs(LogLevel level, const char *fmt, va_list ap) { buf[0] = '\0'; // Fortify logs when a fatal error occurs. Do not run through fortify again int len = std::min(__call_bypassing_fortify(vsnprintf)(buf, sz, fmt, ap), sz - 1); - log_with_rs(level, u8_slice(buf, len)); + log_with_rs(level, byte_view(buf, len)); return len; } diff --git a/native/src/base/misc.cpp b/native/src/base/misc.cpp index 06289b324..12fdbcfe4 100644 --- a/native/src/base/misc.cpp +++ b/native/src/base/misc.cpp @@ -13,6 +13,50 @@ using namespace std; +bool byte_view::contains(byte_view pattern) const { + if (_buf == nullptr) + return false; + for (uint8_t *p = _buf, *eof = _buf + _sz; p < eof; ++p) { + if (memcmp(p, pattern.buf(), pattern.sz()) == 0) { + return true; + } + } + return false; +} + +bool byte_view::equals(byte_view o) const { + return _sz == o._sz && memcmp(_buf, o._buf, _sz) == 0; +} + +heap_data byte_view::clone() const { + heap_data copy(_sz); + memcpy(copy._buf, _buf, _sz); + return copy; +} + +void byte_data::swap(byte_data &o) { + std::swap(_buf, o._buf); + std::swap(_sz, o._sz); +} + +vector byte_data::patch(byte_view from, byte_view to) { + vector v; + if (_buf == nullptr) + return v; + auto p = _buf; + auto eof = _buf + _sz; + while (p < eof) { + p = static_cast(memmem(p, eof - p, from.buf(), from.sz())); + if (p == nullptr) + return v; + memset(p, 0, from.sz()); + memcpy(p, to.buf(), to.sz()); + v.push_back(p - _buf); + p += from.sz(); + } + return v; +} + int fork_dont_care() { if (int pid = xfork()) { waitpid(pid, nullptr, 0); diff --git a/native/src/base/misc.hpp b/native/src/base/misc.hpp index bf6c13f32..eae95ffb4 100644 --- a/native/src/base/misc.hpp +++ b/native/src/base/misc.hpp @@ -7,10 +7,18 @@ #include #include +#include "xwrap.hpp" + #define DISALLOW_COPY_AND_MOVE(clazz) \ -clazz(const clazz &) = delete; \ +clazz(const clazz &) = delete; \ clazz(clazz &&) = delete; +#define ALLOW_MOVE_ONLY(clazz) \ +clazz() = default; \ +clazz(const clazz&) = delete; \ +clazz(clazz &&o) { swap(o); } \ +clazz& operator=(clazz &&o) { swap(o); return *this; } + class mutex_guard { DISALLOW_COPY_AND_MOVE(mutex_guard) public: @@ -119,15 +127,92 @@ struct StringCmp { bool operator()(std::string_view a, std::string_view b) const { return a < b; } }; -template -rust::Slice u8_mut_slice(T *buf, size_t sz) { - return rust::Slice(reinterpret_cast(buf), sz); -} +struct heap_data; -template -rust::Slice u8_slice(T *buf, size_t sz) { - return rust::Slice(reinterpret_cast(buf), sz); -} +// Interchangeable as `&[u8]` in Rust +struct byte_view { + byte_view() : _buf(nullptr), _sz(0) {} + byte_view(const void *buf, size_t sz) : _buf((uint8_t *) buf), _sz(sz) {} + + // byte_view, or any of its subclass, can be copied as byte_view + byte_view(const byte_view &o) : _buf(o._buf), _sz(o._sz) {} + + // Bridging to Rust slice + byte_view(rust::Slice o) : byte_view(o.data(), o.size()) {} + operator rust::Slice() { return rust::Slice(_buf, _sz); } + + // String as bytes + byte_view(const char *s, bool with_nul = true) + : byte_view(std::string_view(s), with_nul, false) {} + byte_view(const std::string &s, bool with_nul = true) + : byte_view(std::string_view(s), with_nul, false) {} + byte_view(std::string_view s, bool with_nul = true) + : byte_view(s, with_nul, true /* string_view is not guaranteed to null terminate */ ) {} + + // Vector as bytes + byte_view(const std::vector &v) : byte_view(v.data(), v.size()) {} + + const uint8_t *buf() const { return _buf; } + size_t sz() const { return _sz; } + + bool contains(byte_view pattern) const; + bool equals(byte_view o) const; + heap_data clone() const; + +protected: + uint8_t *_buf; + size_t _sz; + +private: + byte_view(std::string_view s, bool with_nul, bool check_nul) + : byte_view(static_cast(s.data()), s.length()) { + if (with_nul) { + if (check_nul && s[s.length()] != '\0') + return; + ++_sz; + } + } +}; + +// Interchangeable as `&mut [u8]` in Rust +struct byte_data : public byte_view { + byte_data() = default; + byte_data(void *buf, size_t sz) : byte_view(buf, sz) {} + + // We don't want mutable references to be copied or moved around; pass bytes as byte_view + // Subclasses are free to implement their own constructors + byte_data(const byte_data &) = delete; + byte_data(byte_data &&) = delete; + + // Transparent conversion from common C++ types to mutable byte references + byte_data(std::string &s, bool with_nul = true) + : byte_data(s.data(), with_nul ? s.length() + 1 : s.length()) {} + byte_data(std::vector &v) : byte_data(v.data(), v.size()) {} + + // Bridging to Rust slice + byte_data(rust::Slice o) : byte_data(o.data(), o.size()) {} + operator rust::Slice() { return rust::Slice(_buf, _sz); } + + using byte_view::buf; + using byte_view::sz; + uint8_t *buf() { return _buf; } + size_t &sz() { return _sz; } + + void swap(byte_data &o); + std::vector patch(byte_view from, byte_view to); +}; + +class byte_channel; + +struct heap_data : public byte_data { + ALLOW_MOVE_ONLY(heap_data) + + explicit heap_data(size_t sz) : byte_data(malloc(sz), sz) {} + ~heap_data() { free(_buf); } + + // byte_channel needs to reallocate the internal buffer + friend byte_channel; +}; uint64_t parse_uint64_hex(std::string_view s); int parse_int(std::string_view s); diff --git a/native/src/base/stream.cpp b/native/src/base/stream.cpp index d69bde34c..20edff22c 100644 --- a/native/src/base/stream.cpp +++ b/native/src/base/stream.cpp @@ -174,7 +174,7 @@ void byte_channel::resize(size_t new_sz, bool zero) { resize = true; } if (resize) { - _data.realloc(_cap); + _data._buf = static_cast(::realloc(_data._buf, _cap)); if (zero) memset(_data.buf() + old_cap, 0, _cap - old_cap); } diff --git a/native/src/sepolicy/api.cpp b/native/src/sepolicy/api.cpp index 81de494c3..ae9550318 100644 --- a/native/src/sepolicy/api.cpp +++ b/native/src/sepolicy/api.cpp @@ -105,9 +105,9 @@ bool sepolicy::exists(const char *type) { } void sepolicy::load_rule_file(const char *file) { - rust::load_rule_file(*this, u8_slice(file, strlen(file))); + rust::load_rule_file(*this, byte_view(file, false)); } void sepolicy::load_rules(const std::string &rules) { - rust::load_rules(*this, u8_slice(rules.data(), rules.length())); + rust::load_rules(*this, byte_view(rules, false)); }