Skip to content

Commit 6128a45

Browse files
authored
Convert minibench Java files to Kotlin (#19760)
Convert BenchmarkActivity, BenchmarkMetric, LlmBenchmark, LlmModelRunner, and ModelRunner from Java to Kotlin. Differential Revision: D106195816
1 parent 0bf018f commit 6128a45

11 files changed

Lines changed: 449 additions & 550 deletions

File tree

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java

Lines changed: 0 additions & 136 deletions
This file was deleted.
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.minibench
10+
11+
import android.app.Activity
12+
import android.os.Bundle
13+
import android.os.Handler
14+
import android.os.HandlerThread
15+
import android.os.Looper
16+
import android.os.Message
17+
import android.system.Os
18+
import com.google.gson.Gson
19+
import java.io.File
20+
import java.io.FileWriter
21+
import java.io.IOException
22+
23+
class BenchmarkActivity : Activity() {
24+
25+
lateinit var model: File
26+
var numIter: Int = 0
27+
var numWarmupIter: Int = 0
28+
var tokenizerPath: String? = null
29+
var temperature: Float = 0.8f
30+
var prompt: String = "The ultimate answer"
31+
32+
private lateinit var handlerThread: HandlerThread
33+
private lateinit var handler: BenchmarkHandler
34+
35+
val results: MutableList<BenchmarkMetric> = mutableListOf()
36+
37+
override fun onCreate(savedInstanceState: Bundle?) {
38+
super.onCreate(savedInstanceState)
39+
40+
try {
41+
Os.setenv("ADSP_LIBRARY_PATH", applicationInfo.nativeLibraryDir, true)
42+
} catch (e: android.system.ErrnoException) {
43+
finish()
44+
return
45+
}
46+
47+
val intent = intent
48+
val modelDir = File(intent.getStringExtra("model_dir")!!)
49+
model = modelDir.listFiles()!!.first { it.name.endsWith(".pte") }
50+
51+
numIter = intent.getIntExtra("num_iter", 50)
52+
numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10)
53+
tokenizerPath = intent.getStringExtra("tokenizer_path")
54+
temperature = intent.getFloatExtra("temperature", 0.8f)
55+
prompt = intent.getStringExtra("prompt") ?: "The ultimate answer"
56+
57+
handlerThread = HandlerThread("ModelRunner")
58+
handlerThread.start()
59+
handler = BenchmarkHandler(handlerThread.looper, this)
60+
61+
handler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK)
62+
}
63+
64+
fun writeResult() {
65+
try {
66+
FileWriter("${filesDir}/benchmark_results.json").use { writer ->
67+
writer.write(Gson().toJson(results))
68+
}
69+
} catch (e: IOException) {
70+
e.printStackTrace()
71+
} finally {
72+
finish()
73+
}
74+
}
75+
}
76+
77+
private class BenchmarkHandler(
78+
looper: Looper,
79+
private val activity: BenchmarkActivity,
80+
) : Handler(looper) {
81+
82+
private val modelRunner = ModelRunner()
83+
84+
override fun handleMessage(msg: Message) {
85+
when (msg.what) {
86+
MESSAGE_RUN_BENCHMARK -> {
87+
modelRunner.runBenchmark(
88+
activity.model,
89+
activity.numWarmupIter,
90+
activity.numIter,
91+
activity.results,
92+
)
93+
if (activity.tokenizerPath == null) {
94+
activity.writeResult()
95+
} else {
96+
sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK)
97+
}
98+
}
99+
MESSAGE_LLM_RUN_BENCHMARK -> {
100+
LlmBenchmark(
101+
activity,
102+
activity.model.path,
103+
activity.tokenizerPath!!,
104+
activity.prompt,
105+
activity.temperature,
106+
activity.results,
107+
)
108+
}
109+
}
110+
}
111+
112+
companion object {
113+
const val MESSAGE_RUN_BENCHMARK = 1
114+
const val MESSAGE_LLM_RUN_BENCHMARK = 2
115+
}
116+
}

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java

Lines changed: 0 additions & 74 deletions
This file was deleted.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.minibench
10+
11+
import android.app.ActivityManager
12+
import android.os.Build
13+
14+
class BenchmarkMetric(
15+
val benchmarkModel: BenchmarkModel,
16+
val metric: String,
17+
val actualValue: Double,
18+
val targetValue: Double,
19+
) {
20+
data class BenchmarkModel(
21+
val name: String,
22+
val backend: String,
23+
val quantization: String,
24+
)
25+
26+
class DeviceInfo {
27+
val device: String = Build.BRAND
28+
val arch: String = Build.MODEL
29+
val os: String = "Android ${Build.VERSION.RELEASE}"
30+
val totalMem: Long = ActivityManager.MemoryInfo().totalMem
31+
val availMem: Long = ActivityManager.MemoryInfo().availMem
32+
}
33+
34+
val deviceInfo: DeviceInfo = DeviceInfo()
35+
36+
companion object {
37+
// TODO (huydhn): Figure out a way to extract the backend and quantization information from
38+
// the .pte model itself instead of parsing its name
39+
@JvmStatic
40+
fun extractBackendAndQuantization(model: String): BenchmarkModel {
41+
val pattern = Regex("(?<name>\\w+)_(?<backend>[\\w+]+)_(?<quantization>\\w+)")
42+
val match = pattern.matchEntire(model)
43+
return if (match != null) {
44+
BenchmarkModel(
45+
match.groups["name"]!!.value,
46+
match.groups["backend"]!!.value,
47+
match.groups["quantization"]!!.value,
48+
)
49+
} else {
50+
BenchmarkModel(model, "", "")
51+
}
52+
}
53+
}
54+
}

0 commit comments

Comments
 (0)