1+ package com.engineer.ai.util
2+
3+ import android.content.Context
4+ import android.graphics.Bitmap
5+ import android.graphics.Color
6+ import android.util.Log
7+ import com.google.ai.edge.litert.Accelerator
8+ import com.google.ai.edge.litert.CompiledModel
9+ import com.google.android.gms.tasks.Task
10+ import com.google.android.gms.tasks.TaskCompletionSource
11+ import java.util.concurrent.ExecutorService
12+ import java.util.concurrent.Executors
13+
14+ /* *
15+ * https://ai.google.dev/edge/litert/android?hl=zh-cn
16+ *
17+ * CompiledModel API:高性能推理的现代标准,可简化 CPU/GPU/NPU 之间的硬件加速。
18+ * 详细了解为何选择 CompiledModel API。https://ai.google.dev/edge/litert/inference?hl=zh-cn
19+ */
20+ class LiteTRCompileModel (private val context : Context ) {
21+
22+ var isInitialized = false
23+ private set
24+ private var model: CompiledModel ? = null
25+
26+ /* * Executor to run inference task in the background. */
27+ private val executorService: ExecutorService = Executors .newCachedThreadPool()
28+
29+ companion object {
30+ private const val TAG = " LiteTRCompileModel"
31+
32+ fun toAccelerator (acceleratorEnum : AcceleratorEnum ): Accelerator {
33+ return when (acceleratorEnum) {
34+ AcceleratorEnum .CPU -> Accelerator .CPU
35+ AcceleratorEnum .GPU -> Accelerator .GPU
36+ }
37+ }
38+ }
39+
40+ enum class AcceleratorEnum {
41+ CPU , GPU ,
42+ }
43+
44+ fun initClassifier (acceleratorEnum : AcceleratorEnum = AcceleratorEnum .CPU ) {
45+ cleanup()
46+ try {
47+
48+ model = CompiledModel .create(
49+ context.assets,
50+ " mnist_metadata.tflite" ,
51+ CompiledModel .Options (toAccelerator(acceleratorEnum)),
52+ null
53+ )
54+ isInitialized = true
55+ Log .i(TAG , " Created a CompiledModel with $acceleratorEnum " )
56+
57+ } catch (e: Exception ) {
58+ Log .e(TAG , " Initializing CompiledModel has failed with error: ${e.message} " )
59+ }
60+ }
61+
62+ fun cleanup () {
63+ model?.close()
64+ model = null
65+ }
66+
67+ fun classify (bitmap : Bitmap ): String {
68+
69+ val localModel = model ? : return " "
70+
71+ try {
72+ // 1. Preprocessing
73+ // Resize to 28x28
74+ val scaledBitmap = Bitmap .createScaledBitmap(bitmap, 28 , 28 , true )
75+ // Convert to grayscale and normalize to 0..1
76+ val inputFloatArray = convertBitmapToFloatArray(scaledBitmap)
77+
78+ // 2. Execution
79+ val inputBuffers = localModel.createInputBuffers()
80+ val outputBuffers = localModel.createOutputBuffers()
81+
82+ inputBuffers[0 ].writeFloat(inputFloatArray)
83+ localModel.run (inputBuffers, outputBuffers)
84+
85+ val outputFloatArray = outputBuffers[0 ].readFloat()
86+
87+ // Cleanup buffers
88+ inputBuffers.forEach { it.close() }
89+ outputBuffers.forEach { it.close() }
90+
91+ // 3. Postprocessing
92+ val (digit, score) = findResult(outputFloatArray)
93+
94+ return " Prediction Result: %d\n Confidence: %2f" .format(digit, score)
95+ } catch (e: Exception ) {
96+ Log .e(TAG , " Error during classification: ${e.message} " )
97+ return " Classification failed: ${e.message} "
98+ }
99+ return " "
100+
101+ }
102+
103+
104+ private fun convertBitmapToFloatArray (bitmap : Bitmap ): FloatArray {
105+ val width = bitmap.width
106+ val height = bitmap.height
107+ val pixels = IntArray (width * height)
108+ bitmap.getPixels(pixels, 0 , width, 0 , 0 , width, height)
109+
110+ // The original sample used ImageProcessor without grayscale conversion,
111+ // implying it sent 3 channels (RGB) to the model.
112+ // It also used NormalizeOp(0f, 1f), which effectively keeps the 0-255 range
113+ // when converting from the default Bitmap uint8 values to Float32.
114+
115+ // Target: [1, 28, 28, 3]
116+ val output = FloatArray (width * height * 3 )
117+
118+ for (i in pixels.indices) {
119+ val pixel = pixels[i]
120+
121+ // Extract RGB (ignore alpha)
122+ val r = Color .red(pixel).toFloat()
123+ val g = Color .green(pixel).toFloat()
124+ val b = Color .blue(pixel).toFloat()
125+
126+ // Don't divide by 255.0f to match original 0..255 range behavior
127+ val baseIndex = i * 3
128+ output[baseIndex] = r
129+ output[baseIndex + 1 ] = g
130+ output[baseIndex + 2 ] = b
131+ }
132+ return output
133+ }
134+
135+ /* *
136+ * Finds the index and value of the maximum element in a non-empty float array.
137+ */
138+ private fun findResult (array : FloatArray ): Pair <Int , Float > {
139+ if (array.isEmpty()) return Pair (- 1 , 0f )
140+
141+ var maxIndex = 0
142+ var maxValue = array[0 ]
143+
144+ for (i in array.indices) {
145+ if (array[i] > maxValue) {
146+ maxValue = array[i]
147+ maxIndex = i
148+ }
149+ }
150+
151+ return Pair (maxIndex, maxValue)
152+ }
153+
154+
155+ fun classifyAsync (bitmap : Bitmap ): Task <String > {
156+ val task = TaskCompletionSource <String >()
157+ executorService.execute {
158+ val result = classify(bitmap)
159+ task.setResult(result)
160+ }
161+ return task.task
162+ }
163+
164+ fun close () {
165+ executorService.shutdown()
166+ }
167+ }
0 commit comments