diff --git a/app/src/main/java/com/topjohnwu/magisk/core/Const.kt b/app/src/main/java/com/topjohnwu/magisk/core/Const.kt index d75620d30..bcd96628e 100644 --- a/app/src/main/java/com/topjohnwu/magisk/core/Const.kt +++ b/app/src/main/java/com/topjohnwu/magisk/core/Const.kt @@ -28,6 +28,7 @@ object Const { const val MIN_VERCODE = 19000 const val PROVIDER_CONNECT = 20200 const val DYNAMIC_PATH = 20400 + const val SEPOLICY_REDESIGN = 20416 } object ID { diff --git a/app/src/main/java/com/topjohnwu/magisk/core/su/SuRequestHandler.kt b/app/src/main/java/com/topjohnwu/magisk/core/su/SuRequestHandler.kt index 4267b1206..642960e56 100644 --- a/app/src/main/java/com/topjohnwu/magisk/core/su/SuRequestHandler.kt +++ b/app/src/main/java/com/topjohnwu/magisk/core/su/SuRequestHandler.kt @@ -2,48 +2,61 @@ package com.topjohnwu.magisk.core.su import android.content.Intent import android.content.pm.PackageManager +import android.net.LocalServerSocket import android.net.LocalSocket import android.net.LocalSocketAddress -import android.os.CountDownTimer import androidx.collection.ArrayMap import com.topjohnwu.magisk.BuildConfig import com.topjohnwu.magisk.core.Config import com.topjohnwu.magisk.core.Const +import com.topjohnwu.magisk.core.Info import com.topjohnwu.magisk.core.magiskdb.PolicyDao import com.topjohnwu.magisk.core.model.MagiskPolicy import com.topjohnwu.magisk.core.model.toPolicy import com.topjohnwu.magisk.extensions.now +import com.topjohnwu.superuser.Shell +import com.topjohnwu.superuser.internal.UiThreadHandler import timber.log.Timber import java.io.* +import java.util.concurrent.Callable import java.util.concurrent.TimeUnit abstract class SuRequestHandler( private val packageManager: PackageManager, private val policyDB: PolicyDao ) { - private val socket: LocalSocket = LocalSocket() - private lateinit var out: DataOutputStream + private lateinit var socket: LocalSocket + private lateinit var output: DataOutputStream private lateinit var input: DataInputStream - protected var timer: CountDownTimer = DefaultCountDown() - set(value) { - field.cancel() - field = value - field.start() - } protected lateinit var policy: MagiskPolicy private set abstract fun onStart() fun start(intent: Intent): Boolean { - val socketName = intent.getStringExtra("socket") ?: return false + val name = intent.getStringExtra("socket") ?: return false try { - socket.connect(LocalSocketAddress(socketName, LocalSocketAddress.Namespace.ABSTRACT)) - out = DataOutputStream(BufferedOutputStream(socket.outputStream)) + if (Info.env.magiskVersionCode >= Const.Version.SEPOLICY_REDESIGN) { + val server = LocalServerSocket(name) + val futureSocket = Shell.EXECUTOR.submit(Callable { server.accept() }) + try { + socket = futureSocket.get(1, TimeUnit.SECONDS) + } catch (e: Exception) { + // Timeout or any IO errors + throw e + } finally { + server.close() + } + } else { + socket = LocalSocket() + socket.connect(LocalSocketAddress(name, LocalSocketAddress.Namespace.ABSTRACT)) + } + output = DataOutputStream(BufferedOutputStream(socket.outputStream)) input = DataInputStream(BufferedInputStream(socket.inputStream)) - val map = readRequest() + val map = Shell.EXECUTOR.submit(Callable { readRequest() }) + .runCatching { get(1, TimeUnit.SECONDS) }.getOrNull() ?: return false val uid = map["uid"]?.toIntOrNull() ?: return false policy = uid.toPolicy(packageManager) } catch (e: Exception) { @@ -65,9 +78,7 @@ abstract class SuRequestHandler( return true } } - - timer.start() - onStart() + UiThreadHandler.run { onStart() } return true } @@ -81,23 +92,22 @@ abstract class SuRequestHandler( policy.until = until policy.uid = policy.uid % 100000 + Const.USER_ID * 100000 - if (until >= 0) - policyDB.update(policy).blockingAwait() - - try { - out.writeInt(policy.policy) - out.flush() - } catch (e: IOException) { - Timber.e(e) - } finally { - runCatching { - input.close() - out.close() - socket.close() + Shell.EXECUTOR.submit { + try { + output.writeInt(policy.policy) + output.flush() + } catch (e: IOException) { + Timber.e(e) + } finally { + if (until >= 0) + policyDB.update(policy).blockingAwait() + runCatching { + input.close() + output.close() + socket.close() + } } } - - timer.cancel() } @Throws(IOException::class) @@ -118,11 +128,4 @@ abstract class SuRequestHandler( return ret } - private inner class DefaultCountDown - : CountDownTimer(TimeUnit.MINUTES.toMillis(1), TimeUnit.MINUTES.toMillis(1)) { - override fun onFinish() { - respond(MagiskPolicy.DENY, 0) - } - override fun onTick(remains: Long) {} - } } diff --git a/app/src/main/java/com/topjohnwu/magisk/ui/surequest/SuRequestActivity.kt b/app/src/main/java/com/topjohnwu/magisk/ui/surequest/SuRequestActivity.kt index 93f2eb886..e6d3a0c5b 100644 --- a/app/src/main/java/com/topjohnwu/magisk/ui/surequest/SuRequestActivity.kt +++ b/app/src/main/java/com/topjohnwu/magisk/ui/surequest/SuRequestActivity.kt @@ -12,10 +12,12 @@ import com.topjohnwu.magisk.R import com.topjohnwu.magisk.core.su.SuCallbackHandler import com.topjohnwu.magisk.core.su.SuCallbackHandler.REQUEST import com.topjohnwu.magisk.databinding.ActivityRequestBinding +import com.topjohnwu.magisk.extensions.subscribeK import com.topjohnwu.magisk.model.events.DieEvent import com.topjohnwu.magisk.model.events.ViewActionEvent import com.topjohnwu.magisk.model.events.ViewEvent import com.topjohnwu.magisk.ui.base.BaseUIActivity +import io.reactivex.Single import org.koin.androidx.viewmodel.ext.android.viewModel open class SuRequestActivity : BaseUIActivity() { @@ -37,8 +39,11 @@ open class SuRequestActivity : BaseUIActivity> { binding, _, item -> - item.bind(binding) - } - - val adapter = BindingListViewAdapter>(1).apply { - itemBinding = this@SuRequestViewModel.itemBinding + private val items = res.getStringArray(R.array.allow_timeout).map { SpinnerRvItem(it) } + val adapter = BindingListViewAdapter(1).apply { + itemBinding = ItemBinding.of { binding, _, item -> + item.bind(binding) + } setItems(items) } @@ -81,7 +78,11 @@ class SuRequestViewModel( private inner class Handler : SuRequestHandler(pm, policyDB) { + private lateinit var timer: CountDownTimer + fun respond(action: Int) { + timer.cancel() + val pos = selectedItemPosition.value timeoutPrefs.edit().putInt(policy.packageName, pos).apply() respond(action, Config.Value.TIMEOUT_LIST[pos]) @@ -96,30 +97,36 @@ class SuRequestViewModel( } override fun onStart() { - res.getStringArray(R.array.allow_timeout) - .map { SpinnerRvItem(it) } - .let { items.update(it) } - icon.value = policy.applicationInfo.loadIcon(pm) title.value = policy.appName packageName.value = policy.packageName - selectedItemPosition.value = timeoutPrefs.getInt(policy.packageName, 0) - - // Override timer - val millis = SECONDS.toMillis(Config.suDefaultTimeout.toLong()) - timer = object : CountDownTimer(millis, 1000) { - override fun onTick(remains: Long) { - if (remains <= millis - 1000) { - grantEnabled.value = true - } - denyText.value = "${res.getString(R.string.deny)} (${(remains / 1000) + 1})" - } - - override fun onFinish() { - denyText.value = res.getString(R.string.deny) - respond(DENY) - } + UiThreadHandler.handler.post { + // Delay is required to properly do selection + selectedItemPosition.value = timeoutPrefs.getInt(policy.packageName, 0) } + + // Set timer + val millis = SECONDS.toMillis(Config.suDefaultTimeout.toLong()) + timer = SuTimer(millis, 1000).apply { start() } + } + + private inner class SuTimer( + private val millis: Long, + interval: Long + ) : CountDownTimer(millis, interval) { + + override fun onTick(remains: Long) { + if (!grantEnabled.value && remains <= millis - 1000) { + grantEnabled.value = true + } + denyText.value = "${res.getString(R.string.deny)} (${(remains / 1000) + 1})" + } + + override fun onFinish() { + denyText.value = res.getString(R.string.deny) + respond(DENY) + } + } } diff --git a/app/src/main/java/com/topjohnwu/magisk/utils/DiffObservableList.kt b/app/src/main/java/com/topjohnwu/magisk/utils/DiffObservableList.kt index b658d08fe..ada4c908a 100644 --- a/app/src/main/java/com/topjohnwu/magisk/utils/DiffObservableList.kt +++ b/app/src/main/java/com/topjohnwu/magisk/utils/DiffObservableList.kt @@ -86,8 +86,7 @@ open class DiffObservableList( @MainThread fun update(newItems: List) { val diffResult = doCalculateDiff(list, newItems) - list = newItems.toMutableList() - diffResult.dispatchUpdatesTo(listCallback) + update(newItems, diffResult) } override fun addOnListChangedCallback(listener: ObservableList.OnListChangedCallback>) { diff --git a/app/src/main/res/layout/activity_request.xml b/app/src/main/res/layout/activity_request.xml index e05d15dd0..051d9fcc7 100644 --- a/app/src/main/res/layout/activity_request.xml +++ b/app/src/main/res/layout/activity_request.xml @@ -24,7 +24,7 @@ android:minWidth="350dp" android:orientation="vertical"> - - - - - - #include -#define ABS_SOCKET_LEN(sun) (sizeof(sa_family_t) + strlen(sun->sun_path + 1) + 1) +static size_t socket_len(sockaddr_un *sun) { + if (sun->sun_path[0]) + return sizeof(sa_family_t) + strlen(sun->sun_path) + 1; + else + return sizeof(sa_family_t) + strlen(sun->sun_path + 1) + 1; +} -socklen_t setup_sockaddr(struct sockaddr_un *sun, const char *name) { +socklen_t setup_sockaddr(sockaddr_un *sun, const char *name) { memset(sun, 0, sizeof(*sun)); sun->sun_family = AF_LOCAL; strcpy(sun->sun_path + 1, name); - return ABS_SOCKET_LEN(sun); -} - -int create_rand_socket(struct sockaddr_un *sun) { - memset(sun, 0, sizeof(*sun)); - sun->sun_family = AF_LOCAL; - gen_rand_str(sun->sun_path + 1, sizeof(sun->sun_path) - 1); - int fd = xsocket(AF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC, 0); - xbind(fd, (struct sockaddr*) sun, ABS_SOCKET_LEN(sun)); - xlisten(fd, 1); - return fd; + return socket_len(sun); } int socket_accept(int sockfd, int timeout) { diff --git a/native/jni/include/socket.hpp b/native/jni/include/socket.hpp index 8d15c7d6f..eb4d6a69d 100644 --- a/native/jni/include/socket.hpp +++ b/native/jni/include/socket.hpp @@ -3,8 +3,8 @@ #include #include -socklen_t setup_sockaddr(struct sockaddr_un *sun, const char *name); -int create_rand_socket(struct sockaddr_un *sun); +socklen_t setup_sockaddr(sockaddr_un *sun, const char *name); +int create_app_socket(sockaddr_un *sun); int socket_accept(int sockfd, int timeout); void get_client_cred(int fd, struct ucred *cred); int recv_fd(int sockfd); diff --git a/native/jni/su/connect.cpp b/native/jni/su/connect.cpp index 9ff6fc5fa..8e26607d2 100644 --- a/native/jni/su/connect.cpp +++ b/native/jni/su/connect.cpp @@ -1,9 +1,5 @@ #include #include -#include -#include -#include -#include #include #include @@ -92,9 +88,9 @@ public: } }; -static bool check_error(int fd) { +static bool check_no_error(int fd) { char buf[1024]; - unique_ptr out(xfdopen(fd, "r"), fclose); + auto out = xopen_file(fd, "r"); while (fgets(buf, sizeof(buf), out.get())) { if (strncmp(buf, "Error", 5) == 0) return false; @@ -123,7 +119,7 @@ static void exec_cmd(const char *action, vector &data, .argv = args.data() }; exec_command_sync(exec); - if (check_error(exec.fd)) + if (check_no_error(exec.fd)) return; } @@ -143,7 +139,7 @@ static void exec_cmd(const char *action, vector &data, // Then try start activity without component name strcpy(target, info->str[SU_MANAGER].data()); exec_command_sync(exec); - if (check_error(exec.fd)) + if (check_no_error(exec.fd)) return; } @@ -183,12 +179,31 @@ void app_notify(const su_context &ctx) { } } -void app_socket(const char *socket, const shared_ptr &info) { +int app_socket(const char *name, const shared_ptr &info) { vector extras; extras.reserve(1); - extras.emplace_back("socket", socket); + extras.emplace_back("socket", name); exec_cmd("request", extras, info, PKG_ACTIVITY); + + sockaddr_un addr; + size_t len = setup_sockaddr(&addr, name); + int fd = xsocket(AF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC, 0); + bool connected = false; + // Try at most 60 seconds + for (int i = 0; i < 600; ++i) { + if (connect(fd, reinterpret_cast(&addr), len) == 0) { + connected = true; + break; + } + usleep(100000); // 100ms + } + if (connected) { + return fd; + } else { + close(fd); + return -1; + } } void socket_send_request(int fd, const shared_ptr &info) { diff --git a/native/jni/su/su.hpp b/native/jni/su/su.hpp index 280affd1e..470fcd611 100644 --- a/native/jni/su/su.hpp +++ b/native/jni/su/su.hpp @@ -68,5 +68,5 @@ struct su_context { void app_log(const su_context &ctx); void app_notify(const su_context &ctx); -void app_socket(const char *socket, const std::shared_ptr &info); +int app_socket(const char *name, const std::shared_ptr &info); void socket_send_request(int fd, const std::shared_ptr &info); diff --git a/native/jni/su/su_daemon.cpp b/native/jni/su/su_daemon.cpp index a46e5ed2f..d3cb2888c 100644 --- a/native/jni/su/su_daemon.cpp +++ b/native/jni/su/su_daemon.cpp @@ -139,12 +139,9 @@ static shared_ptr get_su_info(unsigned uid) { } // If still not determined, ask manager - struct sockaddr_un addr; - int sockfd = create_rand_socket(&addr); - - // Connect manager - app_socket(addr.sun_path + 1, info); - int fd = socket_accept(sockfd, 60); + char socket_name[32]; + gen_rand_str(socket_name, sizeof(socket_name)); + int fd = app_socket(socket_name, info); if (fd < 0) { info->access.policy = DENY; } else { @@ -153,7 +150,6 @@ static shared_ptr get_su_info(unsigned uid) { info->access.policy = ret < 0 ? DENY : static_cast(ret); close(fd); } - close(sockfd); return info; } diff --git a/native/jni/utils/files.hpp b/native/jni/utils/files.hpp index b8f15f9bc..2e490792f 100644 --- a/native/jni/utils/files.hpp +++ b/native/jni/utils/files.hpp @@ -133,3 +133,7 @@ static inline sFILE open_file(const char *path, const char *mode) { static inline sFILE xopen_file(const char *path, const char *mode) { return sFILE(xfopen(path, mode), fclose); } + +static inline sFILE xopen_file(int fd, const char *mode) { + return sFILE(xfdopen(fd, mode), fclose); +}