Cleanup code for su request

This commit is contained in:
topjohnwu
2020-06-17 03:47:12 -07:00
parent 56602cb9a3
commit e7f1c03151
4 changed files with 76 additions and 109 deletions

View File

@@ -20,8 +20,8 @@ import com.topjohnwu.magisk.extensions.get
import com.topjohnwu.magisk.extensions.startActivity import com.topjohnwu.magisk.extensions.startActivity
import com.topjohnwu.magisk.extensions.startActivityWithRoot import com.topjohnwu.magisk.extensions.startActivityWithRoot
import com.topjohnwu.magisk.extensions.subscribeK import com.topjohnwu.magisk.extensions.subscribeK
import com.topjohnwu.magisk.ui.surequest.SuRequestActivity
import com.topjohnwu.magisk.model.entity.toLog import com.topjohnwu.magisk.model.entity.toLog
import com.topjohnwu.magisk.ui.surequest.SuRequestActivity
import com.topjohnwu.superuser.Shell import com.topjohnwu.superuser.Shell
import timber.log.Timber import timber.log.Timber
@@ -51,20 +51,8 @@ object SuCallbackHandler : ProviderCallHandler {
} }
when (action) { when (action) {
REQUEST -> { REQUEST -> handleRequest(context, data)
val intent = context.intent<SuRequestActivity>() LOG -> handleLogging(context, data)
.setAction(action)
.putExtras(data)
.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK)
.addFlags(Intent.FLAG_ACTIVITY_MULTIPLE_TASK)
if (Build.VERSION.SDK_INT >= 29) {
// Android Q does not allow starting activity from background
intent.startActivityWithRoot()
} else {
intent.startActivity(context)
}
}
LOG -> handleLogs(context, data)
NOTIFY -> handleNotify(context, data) NOTIFY -> handleNotify(context, data)
TEST -> { TEST -> {
val mode = data.getInt("mode", 2) val mode = data.getInt("mode", 2)
@@ -78,13 +66,26 @@ object SuCallbackHandler : ProviderCallHandler {
private fun Any?.toInt(): Int? { private fun Any?.toInt(): Int? {
return when (this) { return when (this) {
is Int -> this is Number -> this.toInt()
is Long -> this.toInt()
else -> null else -> null
} }
} }
private fun handleLogs(context: Context, data: Bundle) { private fun handleRequest(context: Context, data: Bundle) {
val intent = context.intent<SuRequestActivity>()
.setAction(REQUEST)
.putExtras(data)
.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK)
.addFlags(Intent.FLAG_ACTIVITY_MULTIPLE_TASK)
if (Build.VERSION.SDK_INT >= 29) {
// Android Q does not allow starting activity from background
intent.startActivityWithRoot()
} else {
intent.startActivity(context)
}
}
private fun handleLogging(context: Context, data: Bundle) {
val fromUid = data["from.uid"].toInt() ?: return val fromUid = data["from.uid"].toInt() ?: return
if (fromUid == Process.myUid()) if (fromUid == Process.myUid())
return return

View File

@@ -1,57 +0,0 @@
package com.topjohnwu.magisk.core.su
import android.net.LocalSocket
import android.net.LocalSocketAddress
import androidx.collection.ArrayMap
import timber.log.Timber
import java.io.*
abstract class SuConnector @Throws(IOException::class)
protected constructor(name: String) {
private val socket: LocalSocket = LocalSocket()
protected var out: DataOutputStream
protected var input: DataInputStream
init {
socket.connect(LocalSocketAddress(name, LocalSocketAddress.Namespace.ABSTRACT))
out = DataOutputStream(BufferedOutputStream(socket.outputStream))
input = DataInputStream(BufferedInputStream(socket.inputStream))
}
private fun readString(): String {
val len = input.readInt()
val buf = ByteArray(len)
input.readFully(buf)
return String(buf, Charsets.UTF_8)
}
@Throws(IOException::class)
fun readRequest(): Map<String, String> {
val ret = ArrayMap<String, String>()
while (true) {
val name = readString()
if (name == "eof")
break
ret[name] = readString()
}
return ret
}
fun response() {
runCatching {
onResponse()
out.flush()
}.onFailure { Timber.e(it) }
runCatching {
input.close()
out.close()
socket.close()
}
}
@Throws(IOException::class)
protected abstract fun onResponse()
}

View File

@@ -2,7 +2,10 @@ package com.topjohnwu.magisk.core.su
import android.content.Intent import android.content.Intent
import android.content.pm.PackageManager import android.content.pm.PackageManager
import android.net.LocalSocket
import android.net.LocalSocketAddress
import android.os.CountDownTimer import android.os.CountDownTimer
import androidx.collection.ArrayMap
import com.topjohnwu.magisk.BuildConfig import com.topjohnwu.magisk.BuildConfig
import com.topjohnwu.magisk.core.Config import com.topjohnwu.magisk.core.Config
import com.topjohnwu.magisk.core.Const import com.topjohnwu.magisk.core.Const
@@ -11,43 +14,36 @@ import com.topjohnwu.magisk.core.model.MagiskPolicy
import com.topjohnwu.magisk.core.model.toPolicy import com.topjohnwu.magisk.core.model.toPolicy
import com.topjohnwu.magisk.extensions.now import com.topjohnwu.magisk.extensions.now
import timber.log.Timber import timber.log.Timber
import java.io.*
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
abstract class SuRequestHandler( abstract class SuRequestHandler(
private val packageManager: PackageManager, private val packageManager: PackageManager,
private val policyDB: PolicyDao private val policyDB: PolicyDao
) { ) {
protected var timer: CountDownTimer = object : CountDownTimer( private val socket: LocalSocket = LocalSocket()
TimeUnit.MINUTES.toMillis(1), TimeUnit.MINUTES.toMillis(1)) { private lateinit var out: DataOutputStream
override fun onFinish() { private lateinit var input: DataInputStream
respond(MagiskPolicy.DENY, 0)
} protected var timer: CountDownTimer = DefaultCountDown()
override fun onTick(remains: Long) {}
}
set(value) { set(value) {
field.cancel() field.cancel()
field = value field = value
field.start() field.start()
} }
protected lateinit var policy: MagiskPolicy protected lateinit var policy: MagiskPolicy
private set
private val cleanupTasks = mutableListOf<() -> Unit>()
private lateinit var connector: SuConnector
abstract fun onStart() abstract fun onStart()
abstract fun onRespond()
fun start(intent: Intent): Boolean { fun start(intent: Intent): Boolean {
val socketName = intent.getStringExtra("socket") ?: return false val socketName = intent.getStringExtra("socket") ?: return false
try { try {
connector = object : SuConnector(socketName) { socket.connect(LocalSocketAddress(socketName, LocalSocketAddress.Namespace.ABSTRACT))
override fun onResponse() { out = DataOutputStream(BufferedOutputStream(socket.outputStream))
out.writeInt(policy.policy) input = DataInputStream(BufferedInputStream(socket.inputStream))
} val map = readRequest()
}
val map = connector.readRequest()
val uid = map["uid"]?.toIntOrNull() ?: return false val uid = map["uid"]?.toIntOrNull() ?: return false
policy = uid.toPolicy(packageManager) policy = uid.toPolicy(packageManager)
} catch (e: Exception) { } catch (e: Exception) {
@@ -71,20 +67,10 @@ abstract class SuRequestHandler(
} }
timer.start() timer.start()
cleanupTasks.add {
timer.cancel()
}
onStart() onStart()
return true return true
} }
private fun respond() {
connector.response()
cleanupTasks.forEach { it() }
onRespond()
}
fun respond(action: Int, time: Int) { fun respond(action: Int, time: Int) {
val until = if (time > 0) val until = if (time > 0)
TimeUnit.MILLISECONDS.toSeconds(now) + TimeUnit.MINUTES.toSeconds(time.toLong()) TimeUnit.MILLISECONDS.toSeconds(now) + TimeUnit.MINUTES.toSeconds(time.toLong())
@@ -98,6 +84,45 @@ abstract class SuRequestHandler(
if (until >= 0) if (until >= 0)
policyDB.update(policy).blockingAwait() policyDB.update(policy).blockingAwait()
respond() try {
out.writeInt(policy.policy)
out.flush()
} catch (e: IOException) {
Timber.e(e)
} finally {
runCatching {
input.close()
out.close()
socket.close()
}
}
timer.cancel()
}
@Throws(IOException::class)
private fun readRequest(): Map<String, String> {
fun readString(): String {
val len = input.readInt()
val buf = ByteArray(len)
input.readFully(buf)
return String(buf, Charsets.UTF_8)
}
val ret = ArrayMap<String, String>()
while (true) {
val name = readString()
if (name == "eof")
break
ret[name] = readString()
}
return ret
}
private inner class DefaultCountDown
: CountDownTimer(TimeUnit.MINUTES.toMillis(1), TimeUnit.MINUTES.toMillis(1)) {
override fun onFinish() {
respond(MagiskPolicy.DENY, 0)
}
override fun onTick(remains: Long) {}
} }
} }

View File

@@ -85,6 +85,9 @@ class SuRequestViewModel(
val pos = selectedItemPosition.value val pos = selectedItemPosition.value
timeoutPrefs.edit().putInt(policy.packageName, pos).apply() timeoutPrefs.edit().putInt(policy.packageName, pos).apply()
respond(action, Config.Value.TIMEOUT_LIST[pos]) respond(action, Config.Value.TIMEOUT_LIST[pos])
// Kill activity after response
DieEvent().publish()
} }
fun cancelTimer() { fun cancelTimer() {
@@ -118,11 +121,6 @@ class SuRequestViewModel(
} }
} }
} }
override fun onRespond() {
// Kill activity after response
DieEvent().publish()
}
} }
} }