From 8e1a44e7eb245fe2deae8eb00f1ce11dbf526820 Mon Sep 17 00:00:00 2001 From: topjohnwu Date: Thu, 2 Jan 2025 01:04:44 -0800 Subject: [PATCH] Use argument binding for query --- native/src/core/db.cpp | 136 ++++++++++++++--------------- native/src/core/deny/utils.cpp | 2 +- native/src/core/include/db.hpp | 42 ++++++--- native/src/core/include/sqlite.hpp | 13 ++- native/src/core/sqlite.cpp | 12 ++- native/src/core/su/su_daemon.cpp | 25 +++--- 6 files changed, 115 insertions(+), 115 deletions(-) diff --git a/native/src/core/db.cpp b/native/src/core/db.cpp index d40c2b36f..8a57d4070 100644 --- a/native/src/core/db.cpp +++ b/native/src/core/db.cpp @@ -7,41 +7,30 @@ #include #include -#define DB_VERSION 12 +#define DB_VERSION 12 +#define DB_VERSION_STR "12" using namespace std; #define DBLOGV(...) //#define DBLOGV(...) LOGD("magiskdb: " __VA_ARGS__) -struct db_result { - db_result() = default; - db_result(const char *s) : err(s) {} - db_result(int code) : err(code == SQLITE_OK ? "" : (sqlite3_errstr(code) ?: "")) {} - operator bool() { - if (!err.empty()) { - LOGE("sqlite3: %s\n", err.data()); - return false; - } - return true; - } -private: - string err; -}; - -static int sql_exec(sqlite3 *db, const char *sql, sql_exec_callback callback = nullptr, void *v = nullptr) { - return sql_exec(db, sql, nullptr, nullptr, callback, v); +#define sql_chk_log(fn, ...) if (int rc = fn(__VA_ARGS__); rc != SQLITE_OK) { \ + LOGE("sqlite3(db.cpp:%d): %s\n", __LINE__, sqlite3_errstr(rc)); \ + return false; \ } -static db_result open_and_init_db_impl(sqlite3 **dbOut) { - if (!load_sqlite()) - return "Cannot load libsqlite.so"; +static bool open_and_init_db_impl(sqlite3 **dbOut) { + if (!load_sqlite()) { + LOGE("sqlite3: Cannot load libsqlite.so\n"); + return false; + } unique_ptr db(nullptr, sqlite3_close); { sqlite3 *sql; - fn_run_ret(sqlite3_open_v2, MAGISKDB, &sql, - SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX, nullptr); + sql_chk_log(sqlite3_open_v2, MAGISKDB, &sql, + SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX, nullptr); db.reset(sql); } @@ -50,10 +39,11 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) { auto ver_cb = [](void *ver, auto, DbValues &data) { *static_cast(ver) = data.get_int(0); }; - fn_run_ret(sql_exec, db.get(), "PRAGMA user_version", ver_cb, &ver); + sql_chk_log(sql_exec, db.get(), "PRAGMA user_version", nullptr, nullptr, ver_cb, &ver); if (ver > DB_VERSION) { // Don't support downgrading database - return "Downgrading database is not supported"; + LOGE("sqlite3: Downgrading database is not supported\n"); + return false; } auto create_policy = [&] { @@ -90,17 +80,17 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) { // 12: rebuild table `policies` to drop column `package_name` if (/* 0, 1, 2, 3, 4, 5, 6 */ ver <= 6) { - fn_run_ret(create_policy); - fn_run_ret(create_settings); - fn_run_ret(create_strings); - fn_run_ret(create_denylist); + sql_chk_log(create_policy); + sql_chk_log(create_settings); + sql_chk_log(create_strings); + sql_chk_log(create_denylist); // Directly jump to latest ver = DB_VERSION; upgrade = true; } if (ver == 7) { - fn_run_ret(sql_exec, db.get(), + sql_chk_log(sql_exec, db.get(), "BEGIN TRANSACTION;" "ALTER TABLE hidelist RENAME TO hidelist_tmp;" "CREATE TABLE IF NOT EXISTS hidelist " @@ -113,7 +103,7 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) { upgrade = true; } if (ver == 8) { - fn_run_ret(sql_exec, db.get(), + sql_chk_log(sql_exec, db.get(), "BEGIN TRANSACTION;" "ALTER TABLE hidelist RENAME TO hidelist_tmp;" "CREATE TABLE IF NOT EXISTS hidelist " @@ -125,20 +115,20 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) { upgrade = true; } if (ver == 9) { - fn_run_ret(sql_exec, db.get(), "DROP TABLE IF EXISTS logs", nullptr, nullptr); + sql_chk_log(sql_exec, db.get(), "DROP TABLE IF EXISTS logs", nullptr, nullptr); ver = 10; upgrade = true; } if (ver == 10) { - fn_run_ret(sql_exec, db.get(), + sql_chk_log(sql_exec, db.get(), "DROP TABLE IF EXISTS hidelist;" "DELETE FROM settings WHERE key='magiskhide';"); - fn_run_ret(create_denylist); + sql_chk_log(create_denylist); ver = 11; upgrade = true; } if (ver == 11) { - fn_run_ret(sql_exec, db.get(), + sql_chk_log(sql_exec, db.get(), "BEGIN TRANSACTION;" "ALTER TABLE policies RENAME TO policies_tmp;" "CREATE TABLE IF NOT EXISTS policies " @@ -154,20 +144,16 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) { if (upgrade) { // Set version - char query[32]; - sprintf(query, "PRAGMA user_version=%d", ver); - fn_run_ret(sql_exec, db.get(), query); + sql_chk_log(sql_exec, db.get(), "PRAGMA user_version=" DB_VERSION_STR); } *dbOut = db.release(); - return {}; + return true; } sqlite3 *open_and_init_db() { sqlite3 *db = nullptr; - if (!open_and_init_db_impl(&db)) - return nullptr; - return db; + return open_and_init_db_impl(&db) ? db : nullptr; } static sqlite3 *get_db() { @@ -183,13 +169,17 @@ static sqlite3 *get_db() { return db; } -bool db_exec(const char *sql, db_bind_callback bind_fn, db_exec_callback exec_fn) { +bool db_exec(const char *sql, DbArgs args, db_exec_callback exec_fn) { + using db_bind_callback = std::function; + if (sqlite3 *db = get_db()) { + db_bind_callback bind_fn = {}; sql_bind_callback bind_cb = nullptr; - if (bind_fn) { - bind_cb = [](void *v, int index, DbStatement &stmt) { + if (!args.empty()) { + bind_fn = std::ref(args); + bind_cb = [](void *v, int index, DbStatement &stmt) -> int { auto fn = static_cast(v); - fn->operator()(index, stmt); + return fn->operator()(index, stmt); }; } sql_exec_callback exec_cb = nullptr; @@ -199,52 +189,41 @@ bool db_exec(const char *sql, db_bind_callback bind_fn, db_exec_callback exec_fn fn->operator()(columns, data); }; } - db_result res = sql_exec(db, sql, bind_cb, &bind_fn, exec_cb, &exec_fn); - return res; + sql_chk_log(sql_exec, db, sql, bind_cb, &bind_fn, exec_cb, &exec_fn); + return true; } return false; } -int get_db_settings(db_settings &cfg, int key) { - bool res; +bool get_db_settings(db_settings &cfg, int key) { if (key >= 0) { - char query[128]; - ssprintf(query, sizeof(query), "SELECT * FROM settings WHERE key='%s'", DB_SETTING_KEYS[key]); - res = db_exec(query, cfg); + return db_exec("SELECT * FROM settings WHERE key=?", { DB_SETTING_KEYS[key] }, cfg); } else { - res = db_exec("SELECT * FROM settings", cfg); + return db_exec("SELECT * FROM settings", {}, cfg); } - return res ? 0 : 1; } -int set_db_settings(int key, int value) { - char sql[128]; - ssprintf(sql, sizeof(sql), "INSERT OR REPLACE INTO settings VALUES ('%s', %d)", - DB_SETTING_KEYS[key], value); - return db_exec(sql) ? 0 : 1; +bool set_db_settings(int key, int value) { + return db_exec( + "INSERT OR REPLACE INTO settings (key,value) VALUES(?,?)", + { DB_SETTING_KEYS[key], value }); } -int get_db_strings(db_strings &str, int key) { - bool res; +bool get_db_strings(db_strings &str, int key) { if (key >= 0) { - char query[128]; - ssprintf(query, sizeof(query), "SELECT * FROM strings WHERE key='%s'", DB_STRING_KEYS[key]); - res = db_exec(query, str); + return db_exec("SELECT * FROM strings WHERE key=?", { DB_STRING_KEYS[key] }, str); } else { - res = db_exec("SELECT * FROM strings", str); + return db_exec("SELECT * FROM strings", {}, str); } - return res ? 0 : 1; } -void rm_db_strings(int key) { - char query[128]; - ssprintf(query, sizeof(query), "DELETE FROM strings WHERE key == '%s'", DB_STRING_KEYS[key]); - db_exec(query); +bool rm_db_strings(int key) { + return db_exec("DELETE FROM strings WHERE key=?", { DB_STRING_KEYS[key] }); } void exec_sql(owned_fd client) { string sql = read_string(client); - db_exec(sql.data(), [fd = (int) client](StringSlice columns, DbValues &data) { + db_exec(sql.data(), {}, [fd = (int) client](StringSlice columns, DbValues &data) { string out; for (int i = 0; i < columns.size(); ++i) { if (i != 0) out += '|'; @@ -306,3 +285,16 @@ void db_strings::operator()(StringSlice columns, DbValues &data) { su_manager = val; } } + +int DbArgs::operator()(int index, DbStatement &stmt) { + if (curr < args.size()) { + const auto &arg = args[curr++]; + switch (arg.type) { + case DbArg::INT: + return stmt.bind_int64(index, arg.int_val); + case DbArg::TEXT: + return stmt.bind_text(index, arg.str_val); + } + } + return SQLITE_OK; +} diff --git a/native/src/core/deny/utils.cpp b/native/src/core/deny/utils.cpp index 6eefaba78..3a907aa1c 100644 --- a/native/src/core/deny/utils.cpp +++ b/native/src/core/deny/utils.cpp @@ -222,7 +222,7 @@ static bool ensure_data() { LOGI("denylist: initializing internal data structures\n"); default_new(pkg_to_procs_); - bool res = db_exec("SELECT * FROM denylist", [](StringSlice columns, DbValues &data) { + bool res = db_exec("SELECT * FROM denylist", {}, [](StringSlice columns, DbValues &data) { const char *package_name; const char *process; for (int i = 0; i < columns.size(); ++i) { diff --git a/native/src/core/include/db.hpp b/native/src/core/include/db.hpp index 74ad57bdc..244b8a7f6 100644 --- a/native/src/core/include/db.hpp +++ b/native/src/core/include/db.hpp @@ -88,23 +88,41 @@ struct db_strings { ********************/ using db_exec_callback = std::function; -using db_bind_callback = std::function; -int get_db_settings(db_settings &cfg, int key = -1); -int set_db_settings(int key, int value); -int get_db_strings(db_strings &str, int key = -1); -void rm_db_strings(int key); +struct DbArg { + enum { + INT, + TEXT, + } type; + union { + int64_t int_val; + rust::Str str_val; + }; + DbArg(int64_t v) : type(INT), int_val(v) {} + DbArg(const char *v) : type(TEXT), str_val(v) {} +}; + +struct DbArgs { + DbArgs() : curr(0) {} + DbArgs(std::initializer_list list) : args(list), curr(0) {} + int operator()(int index, DbStatement &stmt); + bool empty() const { return args.empty(); } +private: + std::vector args; + size_t curr; +}; + +bool get_db_settings(db_settings &cfg, int key = -1); +bool set_db_settings(int key, int value); +bool get_db_strings(db_strings &str, int key = -1); +bool rm_db_strings(int key); void exec_sql(owned_fd client); -bool db_exec(const char *sql, db_bind_callback bind_fn = {}, db_exec_callback exec_fn = {}); - -static inline bool db_exec(const char *sql, db_exec_callback exec_fn) { - return db_exec(sql, {}, std::move(exec_fn)); -} +bool db_exec(const char *sql, DbArgs args = {}, db_exec_callback exec_fn = {}); template concept DbData = requires(T t, StringSlice s, DbValues &v) { t(s, v); }; template -bool db_exec(const char *sql, T &data) { - return db_exec(sql, (db_exec_callback) std::ref(data)); +bool db_exec(const char *sql, DbArgs args, T &data) { + return db_exec(sql, std::move(args), (db_exec_callback) std::ref(data)); } diff --git a/native/src/core/include/sqlite.hpp b/native/src/core/include/sqlite.hpp index be4cc34f7..8eec3338c 100644 --- a/native/src/core/include/sqlite.hpp +++ b/native/src/core/include/sqlite.hpp @@ -17,27 +17,24 @@ extern int (*sqlite3_open_v2)(const char *filename, sqlite3 **ppDb, int flags, c extern int (*sqlite3_close)(sqlite3 *db); extern const char *(*sqlite3_errstr)(int); -// Transparent wrapper of sqlite3_stmt +// Transparent wrappers of sqlite3_stmt struct DbValues { const char *get_text(int index); int get_int(int index); ~DbValues() = delete; }; - struct DbStatement { - int bind_text(int index, const char *val); int bind_text(int index, rust::Str val); int bind_int64(int index, int64_t val); + ~DbStatement() = delete; }; using StringSlice = rust::Slice; -using sql_bind_callback = void(*)(void*, int, DbStatement&); +using sql_bind_callback = int(*)(void*, int, DbStatement&); using sql_exec_callback = void(*)(void*, StringSlice, DbValues&); -#define fn_run_ret(fn, ...) if (int rc = fn(__VA_ARGS__); rc != SQLITE_OK) return rc - bool load_sqlite(); sqlite3 *open_and_init_db(); int sql_exec(sqlite3 *db, rust::Str zSql, - sql_bind_callback bind_cb, void *bind_cookie, - sql_exec_callback exec_cb, void *exec_cookie); + sql_bind_callback bind_cb = nullptr, void *bind_cookie = nullptr, + sql_exec_callback exec_cb = nullptr, void *exec_cookie = nullptr); diff --git a/native/src/core/sqlite.cpp b/native/src/core/sqlite.cpp index 2d39d96a3..926a954a4 100644 --- a/native/src/core/sqlite.cpp +++ b/native/src/core/sqlite.cpp @@ -91,8 +91,10 @@ bool load_sqlite() { } using StringVec = rust::Vec; +using sql_bind_callback_real = int(*)(void*, int, sqlite3_stmt*); using sql_exec_callback_real = void(*)(void*, StringSlice, sqlite3_stmt*); -using sql_bind_callback_real = void(*)(void*, int, sqlite3_stmt*); + +#define sql_chk(fn, ...) if (int rc = fn(__VA_ARGS__); rc != SQLITE_OK) return rc int sql_exec(sqlite3 *db, rust::Str zSql, sql_bind_callback bind_cb, void *bind_cookie, sql_exec_callback exec_cb, void *exec_cookie) { const char *sql = zSql.begin(); @@ -102,7 +104,7 @@ int sql_exec(sqlite3 *db, rust::Str zSql, sql_bind_callback bind_cb, void *bind_ // Step 1: prepare statement { sqlite3_stmt *st = nullptr; - fn_run_ret(sqlite3_prepare_v2, db, sql, zSql.end() - sql, &st, &sql); + sql_chk(sqlite3_prepare_v2, db, sql, zSql.end() - sql, &st, &sql); if (st == nullptr) continue; stmt.reset(st); } @@ -112,7 +114,7 @@ int sql_exec(sqlite3 *db, rust::Str zSql, sql_bind_callback bind_cb, void *bind_ if (int count = sqlite3_bind_parameter_count(stmt.get())) { auto real_cb = reinterpret_cast(bind_cb); for (int i = 1; i <= count; ++i) { - real_cb(bind_cookie, i, stmt.get()); + sql_chk(real_cb, bind_cookie, i, stmt.get()); } } } @@ -155,7 +157,3 @@ int DbStatement::bind_int64(int index, int64_t val) { int DbStatement::bind_text(int index, rust::Str val) { return sqlite3_bind_text(reinterpret_cast(this), index, val.data(), val.size(), nullptr); } - -int DbStatement::bind_text(int index, const char *val) { - return sqlite3_bind_text(reinterpret_cast(this), index, val, -1, nullptr); -} diff --git a/native/src/core/su/su_daemon.cpp b/native/src/core/su/su_daemon.cpp index 0913943f3..2e966f2aa 100644 --- a/native/src/core/su/su_daemon.cpp +++ b/native/src/core/su/su_daemon.cpp @@ -76,11 +76,10 @@ void su_info::check_db() { } if (eval_uid > 0) { - char query[256]; - ssprintf(query, sizeof(query), - "SELECT policy, logging, notification FROM policies " - "WHERE uid=%d AND (until=0 OR until>%li)", eval_uid, time(nullptr)); - if (!db_exec(query, access)) + bool res = db_exec( + "SELECT policy, logging, notification FROM policies " + "WHERE uid=? AND (until=0 OR until>?)", { eval_uid, time(nullptr) }, access); + if (!res) return; } @@ -127,15 +126,11 @@ bool uid_granted_root(int uid) { break; } - char query[256]; - ssprintf(query, sizeof(query), - "SELECT policy FROM policies WHERE uid=%d AND (until=0 OR until>%li)", - uid, time(nullptr)); - su_access access; - access.policy = QUERY; - if (!db_exec(query, access)) - return false; - return access.policy == ALLOW; + bool granted = false; + db_exec("SELECT policy FROM policies WHERE uid=? AND (until=0 OR until>?)", + { uid, time(nullptr) }, + [&](auto, DbValues &data) { granted = data.get_int(0) == ALLOW; }); + return granted; } struct policy_uid_list : public vector { @@ -147,7 +142,7 @@ struct policy_uid_list : public vector { void prune_su_access() { cached.reset(); policy_uid_list uids; - if (!db_exec("SELECT uid FROM policies", uids)) + if (!db_exec("SELECT uid FROM policies", {}, uids)) return; vector app_no_list = get_app_no_list(); vector rm_uids;