Update su request process

Due to changes in ec3705f2ed187863efc34af5415495e1ee7775d2, the app can
no longer communicate with the dameon through a socket opened on the
daemon side due to SELinux restrictions. The workaround here is to have
the daemon decide a socket name, send it to the app, have the app create
the socket server, then finally the daemon connects to the app through
the socket.
This commit is contained in:
topjohnwu 2020-06-19 03:52:25 -07:00
parent b71b4bd4e5
commit 43146b8316
12 changed files with 136 additions and 110 deletions

View File

@ -28,6 +28,7 @@ object Const {
const val MIN_VERCODE = 19000
const val PROVIDER_CONNECT = 20200
const val DYNAMIC_PATH = 20400
const val SEPOLICY_REDESIGN = 20416
}
object ID {

View File

@ -2,48 +2,61 @@ package com.topjohnwu.magisk.core.su
import android.content.Intent
import android.content.pm.PackageManager
import android.net.LocalServerSocket
import android.net.LocalSocket
import android.net.LocalSocketAddress
import android.os.CountDownTimer
import androidx.collection.ArrayMap
import com.topjohnwu.magisk.BuildConfig
import com.topjohnwu.magisk.core.Config
import com.topjohnwu.magisk.core.Const
import com.topjohnwu.magisk.core.Info
import com.topjohnwu.magisk.core.magiskdb.PolicyDao
import com.topjohnwu.magisk.core.model.MagiskPolicy
import com.topjohnwu.magisk.core.model.toPolicy
import com.topjohnwu.magisk.extensions.now
import com.topjohnwu.superuser.Shell
import com.topjohnwu.superuser.internal.UiThreadHandler
import timber.log.Timber
import java.io.*
import java.util.concurrent.Callable
import java.util.concurrent.TimeUnit
abstract class SuRequestHandler(
private val packageManager: PackageManager,
private val policyDB: PolicyDao
) {
private val socket: LocalSocket = LocalSocket()
private lateinit var out: DataOutputStream
private lateinit var socket: LocalSocket
private lateinit var output: DataOutputStream
private lateinit var input: DataInputStream
protected var timer: CountDownTimer = DefaultCountDown()
set(value) {
field.cancel()
field = value
field.start()
}
protected lateinit var policy: MagiskPolicy
private set
abstract fun onStart()
fun start(intent: Intent): Boolean {
val socketName = intent.getStringExtra("socket") ?: return false
val name = intent.getStringExtra("socket") ?: return false
try {
socket.connect(LocalSocketAddress(socketName, LocalSocketAddress.Namespace.ABSTRACT))
out = DataOutputStream(BufferedOutputStream(socket.outputStream))
if (Info.env.magiskVersionCode >= Const.Version.SEPOLICY_REDESIGN) {
val server = LocalServerSocket(name)
val futureSocket = Shell.EXECUTOR.submit(Callable { server.accept() })
try {
socket = futureSocket.get(1, TimeUnit.SECONDS)
} catch (e: Exception) {
// Timeout or any IO errors
throw e
} finally {
server.close()
}
} else {
socket = LocalSocket()
socket.connect(LocalSocketAddress(name, LocalSocketAddress.Namespace.ABSTRACT))
}
output = DataOutputStream(BufferedOutputStream(socket.outputStream))
input = DataInputStream(BufferedInputStream(socket.inputStream))
val map = readRequest()
val map = Shell.EXECUTOR.submit(Callable { readRequest() })
.runCatching { get(1, TimeUnit.SECONDS) }.getOrNull() ?: return false
val uid = map["uid"]?.toIntOrNull() ?: return false
policy = uid.toPolicy(packageManager)
} catch (e: Exception) {
@ -65,9 +78,7 @@ abstract class SuRequestHandler(
return true
}
}
timer.start()
onStart()
UiThreadHandler.run { onStart() }
return true
}
@ -81,23 +92,22 @@ abstract class SuRequestHandler(
policy.until = until
policy.uid = policy.uid % 100000 + Const.USER_ID * 100000
if (until >= 0)
policyDB.update(policy).blockingAwait()
Shell.EXECUTOR.submit {
try {
out.writeInt(policy.policy)
out.flush()
output.writeInt(policy.policy)
output.flush()
} catch (e: IOException) {
Timber.e(e)
} finally {
if (until >= 0)
policyDB.update(policy).blockingAwait()
runCatching {
input.close()
out.close()
output.close()
socket.close()
}
}
timer.cancel()
}
}
@Throws(IOException::class)
@ -118,11 +128,4 @@ abstract class SuRequestHandler(
return ret
}
private inner class DefaultCountDown
: CountDownTimer(TimeUnit.MINUTES.toMillis(1), TimeUnit.MINUTES.toMillis(1)) {
override fun onFinish() {
respond(MagiskPolicy.DENY, 0)
}
override fun onTick(remains: Long) {}
}
}

View File

@ -12,10 +12,12 @@ import com.topjohnwu.magisk.R
import com.topjohnwu.magisk.core.su.SuCallbackHandler
import com.topjohnwu.magisk.core.su.SuCallbackHandler.REQUEST
import com.topjohnwu.magisk.databinding.ActivityRequestBinding
import com.topjohnwu.magisk.extensions.subscribeK
import com.topjohnwu.magisk.model.events.DieEvent
import com.topjohnwu.magisk.model.events.ViewActionEvent
import com.topjohnwu.magisk.model.events.ViewEvent
import com.topjohnwu.magisk.ui.base.BaseUIActivity
import io.reactivex.Single
import org.koin.androidx.viewmodel.ext.android.viewModel
open class SuRequestActivity : BaseUIActivity<SuRequestViewModel, ActivityRequestBinding>() {
@ -37,8 +39,11 @@ open class SuRequestActivity : BaseUIActivity<SuRequestViewModel, ActivityReques
super.onCreate(savedInstanceState)
fun showRequest() {
if (!viewModel.handleRequest(intent))
finish()
Single.fromCallable {
viewModel.handleRequest(intent)
}.subscribeK {
if (!it) finish()
}
}
fun runHandler(action: String?) {

View File

@ -13,12 +13,11 @@ import com.topjohnwu.magisk.core.model.MagiskPolicy.Companion.ALLOW
import com.topjohnwu.magisk.core.model.MagiskPolicy.Companion.DENY
import com.topjohnwu.magisk.core.su.SuRequestHandler
import com.topjohnwu.magisk.core.utils.BiometricHelper
import com.topjohnwu.magisk.databinding.ComparableRvItem
import com.topjohnwu.magisk.model.entity.recycler.SpinnerRvItem
import com.topjohnwu.magisk.model.events.DieEvent
import com.topjohnwu.magisk.ui.base.BaseViewModel
import com.topjohnwu.magisk.utils.DiffObservableList
import com.topjohnwu.magisk.utils.KObservableField
import com.topjohnwu.superuser.internal.UiThreadHandler
import me.tatarka.bindingcollectionadapter2.BindingListViewAdapter
import me.tatarka.bindingcollectionadapter2.ItemBinding
import java.util.concurrent.TimeUnit.SECONDS
@ -41,13 +40,11 @@ class SuRequestViewModel(
val grantEnabled = KObservableField(false)
private val items = DiffObservableList(ComparableRvItem.callback)
private val itemBinding = ItemBinding.of<ComparableRvItem<*>> { binding, _, item ->
private val items = res.getStringArray(R.array.allow_timeout).map { SpinnerRvItem(it) }
val adapter = BindingListViewAdapter<SpinnerRvItem>(1).apply {
itemBinding = ItemBinding.of { binding, _, item ->
item.bind(binding)
}
val adapter = BindingListViewAdapter<ComparableRvItem<*>>(1).apply {
itemBinding = this@SuRequestViewModel.itemBinding
setItems(items)
}
@ -81,7 +78,11 @@ class SuRequestViewModel(
private inner class Handler : SuRequestHandler(pm, policyDB) {
private lateinit var timer: CountDownTimer
fun respond(action: Int) {
timer.cancel()
val pos = selectedItemPosition.value
timeoutPrefs.edit().putInt(policy.packageName, pos).apply()
respond(action, Config.Value.TIMEOUT_LIST[pos])
@ -96,20 +97,26 @@ class SuRequestViewModel(
}
override fun onStart() {
res.getStringArray(R.array.allow_timeout)
.map { SpinnerRvItem(it) }
.let { items.update(it) }
icon.value = policy.applicationInfo.loadIcon(pm)
title.value = policy.appName
packageName.value = policy.packageName
UiThreadHandler.handler.post {
// Delay is required to properly do selection
selectedItemPosition.value = timeoutPrefs.getInt(policy.packageName, 0)
}
// Override timer
// Set timer
val millis = SECONDS.toMillis(Config.suDefaultTimeout.toLong())
timer = object : CountDownTimer(millis, 1000) {
timer = SuTimer(millis, 1000).apply { start() }
}
private inner class SuTimer(
private val millis: Long,
interval: Long
) : CountDownTimer(millis, interval) {
override fun onTick(remains: Long) {
if (remains <= millis - 1000) {
if (!grantEnabled.value && remains <= millis - 1000) {
grantEnabled.value = true
}
denyText.value = "${res.getString(R.string.deny)} (${(remains / 1000) + 1})"
@ -119,7 +126,7 @@ class SuRequestViewModel(
denyText.value = res.getString(R.string.deny)
respond(DENY)
}
}
}
}

View File

@ -86,8 +86,7 @@ open class DiffObservableList<T>(
@MainThread
fun update(newItems: List<T>) {
val diffResult = doCalculateDiff(list, newItems)
list = newItems.toMutableList()
diffResult.dispatchUpdatesTo(listCallback)
update(newItems, diffResult)
}
override fun addOnListChangedCallback(listener: ObservableList.OnListChangedCallback<out ObservableList<T>>) {

View File

@ -24,7 +24,7 @@
android:minWidth="350dp"
android:orientation="vertical">
<androidx.appcompat.widget.AppCompatTextView
<TextView
android:id="@+id/request_title"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
@ -45,7 +45,7 @@
android:paddingStart="10dp"
android:paddingEnd="10dp">
<androidx.appcompat.widget.AppCompatImageView
<ImageView
android:id="@+id/app_icon"
style="@style/WidgetFoundation.Icon"
android:layout_gravity="center_vertical"
@ -65,7 +65,7 @@
android:gravity="center_vertical"
android:orientation="vertical">
<androidx.appcompat.widget.AppCompatTextView
<TextView
android:id="@+id/app_name"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
@ -78,7 +78,7 @@
android:textStyle="bold"
tools:text="Magisk" />
<androidx.appcompat.widget.AppCompatTextView
<TextView
android:id="@+id/package_name"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
@ -92,16 +92,17 @@
</LinearLayout>
</LinearLayout>
<androidx.appcompat.widget.AppCompatSpinner
<Spinner
android:id="@+id/timeout"
onTouch="@{() -> viewModel.spinnerTouched()}"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center_horizontal"
android:enabled="@{viewModel.grantEnabled}"
android:adapter="@{viewModel.adapter}"
android:selection="@={viewModel.selectedItemPosition}" />
<androidx.appcompat.widget.AppCompatTextView
<TextView
android:id="@+id/warning"
android:layout_width="wrap_content"
android:layout_height="wrap_content"

View File

@ -6,23 +6,18 @@
#include <socket.hpp>
#include <utils.hpp>
#define ABS_SOCKET_LEN(sun) (sizeof(sa_family_t) + strlen(sun->sun_path + 1) + 1)
static size_t socket_len(sockaddr_un *sun) {
if (sun->sun_path[0])
return sizeof(sa_family_t) + strlen(sun->sun_path) + 1;
else
return sizeof(sa_family_t) + strlen(sun->sun_path + 1) + 1;
}
socklen_t setup_sockaddr(struct sockaddr_un *sun, const char *name) {
socklen_t setup_sockaddr(sockaddr_un *sun, const char *name) {
memset(sun, 0, sizeof(*sun));
sun->sun_family = AF_LOCAL;
strcpy(sun->sun_path + 1, name);
return ABS_SOCKET_LEN(sun);
}
int create_rand_socket(struct sockaddr_un *sun) {
memset(sun, 0, sizeof(*sun));
sun->sun_family = AF_LOCAL;
gen_rand_str(sun->sun_path + 1, sizeof(sun->sun_path) - 1);
int fd = xsocket(AF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC, 0);
xbind(fd, (struct sockaddr*) sun, ABS_SOCKET_LEN(sun));
xlisten(fd, 1);
return fd;
return socket_len(sun);
}
int socket_accept(int sockfd, int timeout) {

View File

@ -3,8 +3,8 @@
#include <sys/un.h>
#include <sys/socket.h>
socklen_t setup_sockaddr(struct sockaddr_un *sun, const char *name);
int create_rand_socket(struct sockaddr_un *sun);
socklen_t setup_sockaddr(sockaddr_un *sun, const char *name);
int create_app_socket(sockaddr_un *sun);
int socket_accept(int sockfd, int timeout);
void get_client_cred(int fd, struct ucred *cred);
int recv_fd(int sockfd);

View File

@ -1,9 +1,5 @@
#include <sys/types.h>
#include <sys/wait.h>
#include <unistd.h>
#include <stdlib.h>
#include <fcntl.h>
#include <stdio.h>
#include <daemon.hpp>
#include <utils.hpp>
@ -92,9 +88,9 @@ public:
}
};
static bool check_error(int fd) {
static bool check_no_error(int fd) {
char buf[1024];
unique_ptr<FILE, decltype(&fclose)> out(xfdopen(fd, "r"), fclose);
auto out = xopen_file(fd, "r");
while (fgets(buf, sizeof(buf), out.get())) {
if (strncmp(buf, "Error", 5) == 0)
return false;
@ -123,7 +119,7 @@ static void exec_cmd(const char *action, vector<Extra> &data,
.argv = args.data()
};
exec_command_sync(exec);
if (check_error(exec.fd))
if (check_no_error(exec.fd))
return;
}
@ -143,7 +139,7 @@ static void exec_cmd(const char *action, vector<Extra> &data,
// Then try start activity without component name
strcpy(target, info->str[SU_MANAGER].data());
exec_command_sync(exec);
if (check_error(exec.fd))
if (check_no_error(exec.fd))
return;
}
@ -183,12 +179,31 @@ void app_notify(const su_context &ctx) {
}
}
void app_socket(const char *socket, const shared_ptr<su_info> &info) {
int app_socket(const char *name, const shared_ptr<su_info> &info) {
vector<Extra> extras;
extras.reserve(1);
extras.emplace_back("socket", socket);
extras.emplace_back("socket", name);
exec_cmd("request", extras, info, PKG_ACTIVITY);
sockaddr_un addr;
size_t len = setup_sockaddr(&addr, name);
int fd = xsocket(AF_LOCAL, SOCK_STREAM | SOCK_CLOEXEC, 0);
bool connected = false;
// Try at most 60 seconds
for (int i = 0; i < 600; ++i) {
if (connect(fd, reinterpret_cast<sockaddr *>(&addr), len) == 0) {
connected = true;
break;
}
usleep(100000); // 100ms
}
if (connected) {
return fd;
} else {
close(fd);
return -1;
}
}
void socket_send_request(int fd, const shared_ptr<su_info> &info) {

View File

@ -68,5 +68,5 @@ struct su_context {
void app_log(const su_context &ctx);
void app_notify(const su_context &ctx);
void app_socket(const char *socket, const std::shared_ptr<su_info> &info);
int app_socket(const char *name, const std::shared_ptr<su_info> &info);
void socket_send_request(int fd, const std::shared_ptr<su_info> &info);

View File

@ -139,12 +139,9 @@ static shared_ptr<su_info> get_su_info(unsigned uid) {
}
// If still not determined, ask manager
struct sockaddr_un addr;
int sockfd = create_rand_socket(&addr);
// Connect manager
app_socket(addr.sun_path + 1, info);
int fd = socket_accept(sockfd, 60);
char socket_name[32];
gen_rand_str(socket_name, sizeof(socket_name));
int fd = app_socket(socket_name, info);
if (fd < 0) {
info->access.policy = DENY;
} else {
@ -153,7 +150,6 @@ static shared_ptr<su_info> get_su_info(unsigned uid) {
info->access.policy = ret < 0 ? DENY : static_cast<policy_t>(ret);
close(fd);
}
close(sockfd);
return info;
}

View File

@ -133,3 +133,7 @@ static inline sFILE open_file(const char *path, const char *mode) {
static inline sFILE xopen_file(const char *path, const char *mode) {
return sFILE(xfopen(path, mode), fclose);
}
static inline sFILE xopen_file(int fd, const char *mode) {
return sFILE(xfdopen(fd, mode), fclose);
}