Use argument binding for query

This commit is contained in:
topjohnwu 2025-01-02 01:04:44 -08:00 committed by John Wu
parent 2722875190
commit 8e1a44e7eb
6 changed files with 115 additions and 115 deletions

View File

@ -7,41 +7,30 @@
#include <sqlite.hpp>
#include <core.hpp>
#define DB_VERSION 12
#define DB_VERSION 12
#define DB_VERSION_STR "12"
using namespace std;
#define DBLOGV(...)
//#define DBLOGV(...) LOGD("magiskdb: " __VA_ARGS__)
struct db_result {
db_result() = default;
db_result(const char *s) : err(s) {}
db_result(int code) : err(code == SQLITE_OK ? "" : (sqlite3_errstr(code) ?: "")) {}
operator bool() {
if (!err.empty()) {
LOGE("sqlite3: %s\n", err.data());
return false;
}
return true;
}
private:
string err;
};
static int sql_exec(sqlite3 *db, const char *sql, sql_exec_callback callback = nullptr, void *v = nullptr) {
return sql_exec(db, sql, nullptr, nullptr, callback, v);
#define sql_chk_log(fn, ...) if (int rc = fn(__VA_ARGS__); rc != SQLITE_OK) { \
LOGE("sqlite3(db.cpp:%d): %s\n", __LINE__, sqlite3_errstr(rc)); \
return false; \
}
static db_result open_and_init_db_impl(sqlite3 **dbOut) {
if (!load_sqlite())
return "Cannot load libsqlite.so";
static bool open_and_init_db_impl(sqlite3 **dbOut) {
if (!load_sqlite()) {
LOGE("sqlite3: Cannot load libsqlite.so\n");
return false;
}
unique_ptr<sqlite3, decltype(sqlite3_close)> db(nullptr, sqlite3_close);
{
sqlite3 *sql;
fn_run_ret(sqlite3_open_v2, MAGISKDB, &sql,
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX, nullptr);
sql_chk_log(sqlite3_open_v2, MAGISKDB, &sql,
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_FULLMUTEX, nullptr);
db.reset(sql);
}
@ -50,10 +39,11 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) {
auto ver_cb = [](void *ver, auto, DbValues &data) {
*static_cast<int *>(ver) = data.get_int(0);
};
fn_run_ret(sql_exec, db.get(), "PRAGMA user_version", ver_cb, &ver);
sql_chk_log(sql_exec, db.get(), "PRAGMA user_version", nullptr, nullptr, ver_cb, &ver);
if (ver > DB_VERSION) {
// Don't support downgrading database
return "Downgrading database is not supported";
LOGE("sqlite3: Downgrading database is not supported\n");
return false;
}
auto create_policy = [&] {
@ -90,17 +80,17 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) {
// 12: rebuild table `policies` to drop column `package_name`
if (/* 0, 1, 2, 3, 4, 5, 6 */ ver <= 6) {
fn_run_ret(create_policy);
fn_run_ret(create_settings);
fn_run_ret(create_strings);
fn_run_ret(create_denylist);
sql_chk_log(create_policy);
sql_chk_log(create_settings);
sql_chk_log(create_strings);
sql_chk_log(create_denylist);
// Directly jump to latest
ver = DB_VERSION;
upgrade = true;
}
if (ver == 7) {
fn_run_ret(sql_exec, db.get(),
sql_chk_log(sql_exec, db.get(),
"BEGIN TRANSACTION;"
"ALTER TABLE hidelist RENAME TO hidelist_tmp;"
"CREATE TABLE IF NOT EXISTS hidelist "
@ -113,7 +103,7 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) {
upgrade = true;
}
if (ver == 8) {
fn_run_ret(sql_exec, db.get(),
sql_chk_log(sql_exec, db.get(),
"BEGIN TRANSACTION;"
"ALTER TABLE hidelist RENAME TO hidelist_tmp;"
"CREATE TABLE IF NOT EXISTS hidelist "
@ -125,20 +115,20 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) {
upgrade = true;
}
if (ver == 9) {
fn_run_ret(sql_exec, db.get(), "DROP TABLE IF EXISTS logs", nullptr, nullptr);
sql_chk_log(sql_exec, db.get(), "DROP TABLE IF EXISTS logs", nullptr, nullptr);
ver = 10;
upgrade = true;
}
if (ver == 10) {
fn_run_ret(sql_exec, db.get(),
sql_chk_log(sql_exec, db.get(),
"DROP TABLE IF EXISTS hidelist;"
"DELETE FROM settings WHERE key='magiskhide';");
fn_run_ret(create_denylist);
sql_chk_log(create_denylist);
ver = 11;
upgrade = true;
}
if (ver == 11) {
fn_run_ret(sql_exec, db.get(),
sql_chk_log(sql_exec, db.get(),
"BEGIN TRANSACTION;"
"ALTER TABLE policies RENAME TO policies_tmp;"
"CREATE TABLE IF NOT EXISTS policies "
@ -154,20 +144,16 @@ static db_result open_and_init_db_impl(sqlite3 **dbOut) {
if (upgrade) {
// Set version
char query[32];
sprintf(query, "PRAGMA user_version=%d", ver);
fn_run_ret(sql_exec, db.get(), query);
sql_chk_log(sql_exec, db.get(), "PRAGMA user_version=" DB_VERSION_STR);
}
*dbOut = db.release();
return {};
return true;
}
sqlite3 *open_and_init_db() {
sqlite3 *db = nullptr;
if (!open_and_init_db_impl(&db))
return nullptr;
return db;
return open_and_init_db_impl(&db) ? db : nullptr;
}
static sqlite3 *get_db() {
@ -183,13 +169,17 @@ static sqlite3 *get_db() {
return db;
}
bool db_exec(const char *sql, db_bind_callback bind_fn, db_exec_callback exec_fn) {
bool db_exec(const char *sql, DbArgs args, db_exec_callback exec_fn) {
using db_bind_callback = std::function<int(int, DbStatement&)>;
if (sqlite3 *db = get_db()) {
db_bind_callback bind_fn = {};
sql_bind_callback bind_cb = nullptr;
if (bind_fn) {
bind_cb = [](void *v, int index, DbStatement &stmt) {
if (!args.empty()) {
bind_fn = std::ref(args);
bind_cb = [](void *v, int index, DbStatement &stmt) -> int {
auto fn = static_cast<db_bind_callback*>(v);
fn->operator()(index, stmt);
return fn->operator()(index, stmt);
};
}
sql_exec_callback exec_cb = nullptr;
@ -199,52 +189,41 @@ bool db_exec(const char *sql, db_bind_callback bind_fn, db_exec_callback exec_fn
fn->operator()(columns, data);
};
}
db_result res = sql_exec(db, sql, bind_cb, &bind_fn, exec_cb, &exec_fn);
return res;
sql_chk_log(sql_exec, db, sql, bind_cb, &bind_fn, exec_cb, &exec_fn);
return true;
}
return false;
}
int get_db_settings(db_settings &cfg, int key) {
bool res;
bool get_db_settings(db_settings &cfg, int key) {
if (key >= 0) {
char query[128];
ssprintf(query, sizeof(query), "SELECT * FROM settings WHERE key='%s'", DB_SETTING_KEYS[key]);
res = db_exec(query, cfg);
return db_exec("SELECT * FROM settings WHERE key=?", { DB_SETTING_KEYS[key] }, cfg);
} else {
res = db_exec("SELECT * FROM settings", cfg);
return db_exec("SELECT * FROM settings", {}, cfg);
}
return res ? 0 : 1;
}
int set_db_settings(int key, int value) {
char sql[128];
ssprintf(sql, sizeof(sql), "INSERT OR REPLACE INTO settings VALUES ('%s', %d)",
DB_SETTING_KEYS[key], value);
return db_exec(sql) ? 0 : 1;
bool set_db_settings(int key, int value) {
return db_exec(
"INSERT OR REPLACE INTO settings (key,value) VALUES(?,?)",
{ DB_SETTING_KEYS[key], value });
}
int get_db_strings(db_strings &str, int key) {
bool res;
bool get_db_strings(db_strings &str, int key) {
if (key >= 0) {
char query[128];
ssprintf(query, sizeof(query), "SELECT * FROM strings WHERE key='%s'", DB_STRING_KEYS[key]);
res = db_exec(query, str);
return db_exec("SELECT * FROM strings WHERE key=?", { DB_STRING_KEYS[key] }, str);
} else {
res = db_exec("SELECT * FROM strings", str);
return db_exec("SELECT * FROM strings", {}, str);
}
return res ? 0 : 1;
}
void rm_db_strings(int key) {
char query[128];
ssprintf(query, sizeof(query), "DELETE FROM strings WHERE key == '%s'", DB_STRING_KEYS[key]);
db_exec(query);
bool rm_db_strings(int key) {
return db_exec("DELETE FROM strings WHERE key=?", { DB_STRING_KEYS[key] });
}
void exec_sql(owned_fd client) {
string sql = read_string(client);
db_exec(sql.data(), [fd = (int) client](StringSlice columns, DbValues &data) {
db_exec(sql.data(), {}, [fd = (int) client](StringSlice columns, DbValues &data) {
string out;
for (int i = 0; i < columns.size(); ++i) {
if (i != 0) out += '|';
@ -306,3 +285,16 @@ void db_strings::operator()(StringSlice columns, DbValues &data) {
su_manager = val;
}
}
int DbArgs::operator()(int index, DbStatement &stmt) {
if (curr < args.size()) {
const auto &arg = args[curr++];
switch (arg.type) {
case DbArg::INT:
return stmt.bind_int64(index, arg.int_val);
case DbArg::TEXT:
return stmt.bind_text(index, arg.str_val);
}
}
return SQLITE_OK;
}

View File

@ -222,7 +222,7 @@ static bool ensure_data() {
LOGI("denylist: initializing internal data structures\n");
default_new(pkg_to_procs_);
bool res = db_exec("SELECT * FROM denylist", [](StringSlice columns, DbValues &data) {
bool res = db_exec("SELECT * FROM denylist", {}, [](StringSlice columns, DbValues &data) {
const char *package_name;
const char *process;
for (int i = 0; i < columns.size(); ++i) {

View File

@ -88,23 +88,41 @@ struct db_strings {
********************/
using db_exec_callback = std::function<void(StringSlice, DbValues&)>;
using db_bind_callback = std::function<void(int, DbStatement&)>;
int get_db_settings(db_settings &cfg, int key = -1);
int set_db_settings(int key, int value);
int get_db_strings(db_strings &str, int key = -1);
void rm_db_strings(int key);
struct DbArg {
enum {
INT,
TEXT,
} type;
union {
int64_t int_val;
rust::Str str_val;
};
DbArg(int64_t v) : type(INT), int_val(v) {}
DbArg(const char *v) : type(TEXT), str_val(v) {}
};
struct DbArgs {
DbArgs() : curr(0) {}
DbArgs(std::initializer_list<DbArg> list) : args(list), curr(0) {}
int operator()(int index, DbStatement &stmt);
bool empty() const { return args.empty(); }
private:
std::vector<DbArg> args;
size_t curr;
};
bool get_db_settings(db_settings &cfg, int key = -1);
bool set_db_settings(int key, int value);
bool get_db_strings(db_strings &str, int key = -1);
bool rm_db_strings(int key);
void exec_sql(owned_fd client);
bool db_exec(const char *sql, db_bind_callback bind_fn = {}, db_exec_callback exec_fn = {});
static inline bool db_exec(const char *sql, db_exec_callback exec_fn) {
return db_exec(sql, {}, std::move(exec_fn));
}
bool db_exec(const char *sql, DbArgs args = {}, db_exec_callback exec_fn = {});
template<typename T>
concept DbData = requires(T t, StringSlice s, DbValues &v) { t(s, v); };
template<DbData T>
bool db_exec(const char *sql, T &data) {
return db_exec(sql, (db_exec_callback) std::ref(data));
bool db_exec(const char *sql, DbArgs args, T &data) {
return db_exec(sql, std::move(args), (db_exec_callback) std::ref(data));
}

View File

@ -17,27 +17,24 @@ extern int (*sqlite3_open_v2)(const char *filename, sqlite3 **ppDb, int flags, c
extern int (*sqlite3_close)(sqlite3 *db);
extern const char *(*sqlite3_errstr)(int);
// Transparent wrapper of sqlite3_stmt
// Transparent wrappers of sqlite3_stmt
struct DbValues {
const char *get_text(int index);
int get_int(int index);
~DbValues() = delete;
};
struct DbStatement {
int bind_text(int index, const char *val);
int bind_text(int index, rust::Str val);
int bind_int64(int index, int64_t val);
~DbStatement() = delete;
};
using StringSlice = rust::Slice<rust::String>;
using sql_bind_callback = void(*)(void*, int, DbStatement&);
using sql_bind_callback = int(*)(void*, int, DbStatement&);
using sql_exec_callback = void(*)(void*, StringSlice, DbValues&);
#define fn_run_ret(fn, ...) if (int rc = fn(__VA_ARGS__); rc != SQLITE_OK) return rc
bool load_sqlite();
sqlite3 *open_and_init_db();
int sql_exec(sqlite3 *db, rust::Str zSql,
sql_bind_callback bind_cb, void *bind_cookie,
sql_exec_callback exec_cb, void *exec_cookie);
sql_bind_callback bind_cb = nullptr, void *bind_cookie = nullptr,
sql_exec_callback exec_cb = nullptr, void *exec_cookie = nullptr);

View File

@ -91,8 +91,10 @@ bool load_sqlite() {
}
using StringVec = rust::Vec<rust::String>;
using sql_bind_callback_real = int(*)(void*, int, sqlite3_stmt*);
using sql_exec_callback_real = void(*)(void*, StringSlice, sqlite3_stmt*);
using sql_bind_callback_real = void(*)(void*, int, sqlite3_stmt*);
#define sql_chk(fn, ...) if (int rc = fn(__VA_ARGS__); rc != SQLITE_OK) return rc
int sql_exec(sqlite3 *db, rust::Str zSql, sql_bind_callback bind_cb, void *bind_cookie, sql_exec_callback exec_cb, void *exec_cookie) {
const char *sql = zSql.begin();
@ -102,7 +104,7 @@ int sql_exec(sqlite3 *db, rust::Str zSql, sql_bind_callback bind_cb, void *bind_
// Step 1: prepare statement
{
sqlite3_stmt *st = nullptr;
fn_run_ret(sqlite3_prepare_v2, db, sql, zSql.end() - sql, &st, &sql);
sql_chk(sqlite3_prepare_v2, db, sql, zSql.end() - sql, &st, &sql);
if (st == nullptr) continue;
stmt.reset(st);
}
@ -112,7 +114,7 @@ int sql_exec(sqlite3 *db, rust::Str zSql, sql_bind_callback bind_cb, void *bind_
if (int count = sqlite3_bind_parameter_count(stmt.get())) {
auto real_cb = reinterpret_cast<sql_bind_callback_real>(bind_cb);
for (int i = 1; i <= count; ++i) {
real_cb(bind_cookie, i, stmt.get());
sql_chk(real_cb, bind_cookie, i, stmt.get());
}
}
}
@ -155,7 +157,3 @@ int DbStatement::bind_int64(int index, int64_t val) {
int DbStatement::bind_text(int index, rust::Str val) {
return sqlite3_bind_text(reinterpret_cast<sqlite3_stmt*>(this), index, val.data(), val.size(), nullptr);
}
int DbStatement::bind_text(int index, const char *val) {
return sqlite3_bind_text(reinterpret_cast<sqlite3_stmt*>(this), index, val, -1, nullptr);
}

View File

@ -76,11 +76,10 @@ void su_info::check_db() {
}
if (eval_uid > 0) {
char query[256];
ssprintf(query, sizeof(query),
"SELECT policy, logging, notification FROM policies "
"WHERE uid=%d AND (until=0 OR until>%li)", eval_uid, time(nullptr));
if (!db_exec(query, access))
bool res = db_exec(
"SELECT policy, logging, notification FROM policies "
"WHERE uid=? AND (until=0 OR until>?)", { eval_uid, time(nullptr) }, access);
if (!res)
return;
}
@ -127,15 +126,11 @@ bool uid_granted_root(int uid) {
break;
}
char query[256];
ssprintf(query, sizeof(query),
"SELECT policy FROM policies WHERE uid=%d AND (until=0 OR until>%li)",
uid, time(nullptr));
su_access access;
access.policy = QUERY;
if (!db_exec(query, access))
return false;
return access.policy == ALLOW;
bool granted = false;
db_exec("SELECT policy FROM policies WHERE uid=? AND (until=0 OR until>?)",
{ uid, time(nullptr) },
[&](auto, DbValues &data) { granted = data.get_int(0) == ALLOW; });
return granted;
}
struct policy_uid_list : public vector<int> {
@ -147,7 +142,7 @@ struct policy_uid_list : public vector<int> {
void prune_su_access() {
cached.reset();
policy_uid_list uids;
if (!db_exec("SELECT uid FROM policies", uids))
if (!db_exec("SELECT uid FROM policies", {}, uids))
return;
vector<bool> app_no_list = get_app_no_list();
vector<int> rm_uids;