Skip to content

Commit 7378d4d

Browse files
committed
Use LiteRT CompileModel
1 parent 3d6e8d9 commit 7378d4d

7 files changed

Lines changed: 209 additions & 1 deletion

File tree

gradle/libs.versions.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ tensorflowLiteSupport = "0.4.4"
8989

9090
litert = "2.1.0"
9191
litertSupportAndMetadata = "1.0.0"
92+
undercouchDownload = "5.6.0"
9293

9394
[libraries]
9495
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" }
@@ -195,3 +196,4 @@ kotlinAndroid = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
195196
kotlinKapt = { id = "org.jetbrains.kotlin.kapt", version.ref = "kotlin" }
196197
kotlinCompose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" }
197198
ksp = { id = "com.google.devtools.ksp", version.ref = "ksp" }
199+
undercouchDownload = { id = "de.undercouch.download", version.ref = "undercouchDownload" }

subs/ai/build.gradle.kts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
plugins {
22
id("com.android.library")
33
id("org.jetbrains.kotlin.android")
4+
alias(libs.plugins.undercouchDownload)
45
}
56

67
android {
@@ -43,6 +44,9 @@ android {
4344
viewBinding = true
4445
}
4546
}
47+
// Import DownloadModels task
48+
project.ext.set("ASSET_DIR", "$projectDir/src/main/assets")
49+
apply(from = "download_model.gradle")
4650

4751
dependencies {
4852
implementation(libs.androidx.core.ktx)

subs/ai/download_model.gradle

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
tasks.register('downloadDigitClassifierModel', Download) {
2+
src 'https://storage.googleapis.com/ai-edge/interpreter-samples/digit_classifier/android/mnist_metadata.tflite'
3+
dest project.ext.ASSET_DIR + '/mnist_metadata.tflite'
4+
overwrite false
5+
}
6+
7+
preBuild.dependsOn downloadDigitClassifierModel
1.16 MB
Binary file not shown.

subs/ai/src/main/java/com/engineer/ai/DigitalClassificationActivity.kt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,17 @@ import androidx.activity.enableEdgeToEdge
1111
import androidx.appcompat.app.AppCompatActivity
1212
import com.divyanshu.draw.widget.DrawView
1313
import com.engineer.ai.util.DigitClassifier
14+
import com.engineer.ai.util.LiteTRCompileModel
1415
import com.engineer.ai.util.toast
1516

1617
class DigitalClassificationActivity : AppCompatActivity() {
1718
private var drawView: DrawView? = null
1819
private var clearButton: Button? = null
1920
private var initButton: Button? = null
21+
private var initButton2: Button? = null
2022
private var predictedTextView: TextView? = null
2123
private var digitClassifier = DigitClassifier(this)
24+
private val liteRtCompileModel = LiteTRCompileModel(this)
2225

2326
@SuppressLint("ClickableViewAccessibility")
2427
override fun onCreate(savedInstanceState: Bundle?) {
@@ -33,6 +36,7 @@ class DigitalClassificationActivity : AppCompatActivity() {
3336
drawView?.setBackgroundColor(Color.BLACK)
3437
clearButton = findViewById(R.id.clear_button)
3538
initButton = findViewById(R.id.init_model)
39+
initButton2 = findViewById(R.id.init_model2)
3640
predictedTextView = findViewById(R.id.predicted_text)
3741

3842
// Setup clear drawing button.
@@ -61,14 +65,29 @@ class DigitalClassificationActivity : AppCompatActivity() {
6165
runOnUiThread { if (it) "init success".toast(this) else "init fail".toast(this) }
6266
}
6367
}
68+
initButton2?.setOnClickListener {
69+
liteRtCompileModel.initClassifier()
70+
"init success".toast(this)
71+
}
6472
}
6573

6674
private fun classifyDrawing() {
6775
val bitmap = drawView?.getBitmap()
6876

6977
if ((bitmap != null) && (digitClassifier.isInitialized)) {
7078
digitClassifier.classifyAsync(bitmap)
71-
.addOnSuccessListener { resultText -> predictedTextView?.text = resultText }.addOnFailureListener { e ->
79+
.addOnSuccessListener { resultText -> predictedTextView?.text = resultText }
80+
.addOnFailureListener { e ->
81+
predictedTextView?.text = getString(
82+
R.string.classification_error_message, e.localizedMessage
83+
)
84+
Log.e(TAG, "Error classifying drawing.", e)
85+
}
86+
}
87+
if(bitmap != null && liteRtCompileModel.isInitialized) {
88+
liteRtCompileModel.classifyAsync(bitmap)
89+
.addOnSuccessListener { resultText -> predictedTextView?.text = resultText }
90+
.addOnFailureListener { e ->
7291
predictedTextView?.text = getString(
7392
R.string.classification_error_message, e.localizedMessage
7493
)
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
package com.engineer.ai.util
2+
3+
import android.content.Context
4+
import android.graphics.Bitmap
5+
import android.graphics.Color
6+
import android.util.Log
7+
import com.google.ai.edge.litert.Accelerator
8+
import com.google.ai.edge.litert.CompiledModel
9+
import com.google.android.gms.tasks.Task
10+
import com.google.android.gms.tasks.TaskCompletionSource
11+
import java.util.concurrent.ExecutorService
12+
import java.util.concurrent.Executors
13+
14+
/**
15+
* https://ai.google.dev/edge/litert/android?hl=zh-cn
16+
*
17+
* CompiledModel API:高性能推理的现代标准,可简化 CPU/GPU/NPU 之间的硬件加速。
18+
* 详细了解为何选择 CompiledModel API。https://ai.google.dev/edge/litert/inference?hl=zh-cn
19+
*/
20+
class LiteTRCompileModel(private val context: Context) {
21+
22+
var isInitialized = false
23+
private set
24+
private var model: CompiledModel? = null
25+
26+
/** Executor to run inference task in the background. */
27+
private val executorService: ExecutorService = Executors.newCachedThreadPool()
28+
29+
companion object {
30+
private const val TAG = "LiteTRCompileModel"
31+
32+
fun toAccelerator(acceleratorEnum: AcceleratorEnum): Accelerator {
33+
return when (acceleratorEnum) {
34+
AcceleratorEnum.CPU -> Accelerator.CPU
35+
AcceleratorEnum.GPU -> Accelerator.GPU
36+
}
37+
}
38+
}
39+
40+
enum class AcceleratorEnum {
41+
CPU, GPU,
42+
}
43+
44+
fun initClassifier(acceleratorEnum: AcceleratorEnum = AcceleratorEnum.CPU) {
45+
cleanup()
46+
try {
47+
48+
model = CompiledModel.create(
49+
context.assets,
50+
"mnist_metadata.tflite",
51+
CompiledModel.Options(toAccelerator(acceleratorEnum)),
52+
null
53+
)
54+
isInitialized = true
55+
Log.i(TAG, "Created a CompiledModel with $acceleratorEnum")
56+
57+
} catch (e: Exception) {
58+
Log.e(TAG, "Initializing CompiledModel has failed with error: ${e.message}")
59+
}
60+
}
61+
62+
fun cleanup() {
63+
model?.close()
64+
model = null
65+
}
66+
67+
fun classify(bitmap: Bitmap): String {
68+
69+
val localModel = model ?: return ""
70+
71+
try {
72+
// 1. Preprocessing
73+
// Resize to 28x28
74+
val scaledBitmap = Bitmap.createScaledBitmap(bitmap, 28, 28, true)
75+
// Convert to grayscale and normalize to 0..1
76+
val inputFloatArray = convertBitmapToFloatArray(scaledBitmap)
77+
78+
// 2. Execution
79+
val inputBuffers = localModel.createInputBuffers()
80+
val outputBuffers = localModel.createOutputBuffers()
81+
82+
inputBuffers[0].writeFloat(inputFloatArray)
83+
localModel.run(inputBuffers, outputBuffers)
84+
85+
val outputFloatArray = outputBuffers[0].readFloat()
86+
87+
// Cleanup buffers
88+
inputBuffers.forEach { it.close() }
89+
outputBuffers.forEach { it.close() }
90+
91+
// 3. Postprocessing
92+
val (digit, score) = findResult(outputFloatArray)
93+
94+
return "Prediction Result: %d\nConfidence: %2f".format(digit, score)
95+
} catch (e: Exception) {
96+
Log.e(TAG, "Error during classification: ${e.message}")
97+
return "Classification failed: ${e.message}"
98+
}
99+
return ""
100+
101+
}
102+
103+
104+
private fun convertBitmapToFloatArray(bitmap: Bitmap): FloatArray {
105+
val width = bitmap.width
106+
val height = bitmap.height
107+
val pixels = IntArray(width * height)
108+
bitmap.getPixels(pixels, 0, width, 0, 0, width, height)
109+
110+
// The original sample used ImageProcessor without grayscale conversion,
111+
// implying it sent 3 channels (RGB) to the model.
112+
// It also used NormalizeOp(0f, 1f), which effectively keeps the 0-255 range
113+
// when converting from the default Bitmap uint8 values to Float32.
114+
115+
// Target: [1, 28, 28, 3]
116+
val output = FloatArray(width * height * 3)
117+
118+
for (i in pixels.indices) {
119+
val pixel = pixels[i]
120+
121+
// Extract RGB (ignore alpha)
122+
val r = Color.red(pixel).toFloat()
123+
val g = Color.green(pixel).toFloat()
124+
val b = Color.blue(pixel).toFloat()
125+
126+
// Don't divide by 255.0f to match original 0..255 range behavior
127+
val baseIndex = i * 3
128+
output[baseIndex] = r
129+
output[baseIndex + 1] = g
130+
output[baseIndex + 2] = b
131+
}
132+
return output
133+
}
134+
135+
/**
136+
* Finds the index and value of the maximum element in a non-empty float array.
137+
*/
138+
private fun findResult(array: FloatArray): Pair<Int, Float> {
139+
if (array.isEmpty()) return Pair(-1, 0f)
140+
141+
var maxIndex = 0
142+
var maxValue = array[0]
143+
144+
for (i in array.indices) {
145+
if (array[i] > maxValue) {
146+
maxValue = array[i]
147+
maxIndex = i
148+
}
149+
}
150+
151+
return Pair(maxIndex, maxValue)
152+
}
153+
154+
155+
fun classifyAsync(bitmap: Bitmap): Task<String> {
156+
val task = TaskCompletionSource<String>()
157+
executorService.execute {
158+
val result = classify(bitmap)
159+
task.setResult(result)
160+
}
161+
return task.task
162+
}
163+
164+
fun close() {
165+
executorService.shutdown()
166+
}
167+
}

subs/ai/src/main/res/layout/activity_digital_classification.xml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,13 @@
4848
app:layout_constraintBottom_toBottomOf="parent"
4949
app:layout_constraintEnd_toEndOf="parent" />
5050

51+
<Button
52+
android:id="@+id/init_model2"
53+
android:layout_width="wrap_content"
54+
android:layout_height="wrap_content"
55+
android:layout_margin="10dp"
56+
android:text="init_model2"
57+
app:layout_constraintBottom_toBottomOf="parent"
58+
app:layout_constraintEnd_toStartOf="@id/init_model" />
59+
5160
</androidx.constraintlayout.widget.ConstraintLayout>

0 commit comments

Comments
 (0)