@@ -25,8 +25,6 @@ import java.util.concurrent.Executors
2525
2626class DigitClassifier (private val context : Context ) {
2727
28- // private var interpreter: Interpreter? = null
29-
3028 var isInitialized = false
3129 private set
3230
@@ -40,37 +38,34 @@ class DigitClassifier(private val context: Context) {
4038 private var interpreter: InterpreterApi ? = null
4139
4240 fun initialize (cb : (Boolean ) -> Unit ) {
43- TensorFlowLiteHelper .init (context) { playServicesOk ->
44- try {
45- interpreter = TensorFlowLiteHelper .createInterpreterApi(
46- context = context,
47- modelName = " mnist.tflite" ,
48- preferPlayServices = playServicesOk
49- )
50-
51- val inter = interpreter
52- if (inter == null ) {
53- isInitialized = false
54- cb(false )
55- return @init
56- }
57-
58- val inputShape = inter.getInputTensor(0 ).shape()
59- Log .d(TAG , " input shape = ${inputShape.contentToString()} " )
60- Log .d(TAG , " elem shape = ${inter.getInputTensor(0 ).numElements()} " )
61- Log .d(TAG , " output shape = ${inter.getOutputTensor(0 ).shape().contentToString()} " )
62-
63- inputImageWidth = inputShape[1 ]
64- inputImageHeight = inputShape[2 ]
65- modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
66- isInitialized = true
67- cb(true )
68- } catch (t: Throwable ) {
69- Log .e(TAG , " Failed to initialize DigitClassifier." , t)
41+ try {
42+ interpreter = TensorFlowLiteHelper .createInterpreterApi(
43+ context = context, modelName = " mnist.tflite"
44+ )
45+
46+ val inter = interpreter
47+ if (inter == null ) {
7048 isInitialized = false
7149 cb(false )
50+ return
7251 }
52+
53+ val inputShape = inter.getInputTensor(0 ).shape()
54+ Log .d(TAG , " input shape = ${inputShape.contentToString()} " )
55+ Log .d(TAG , " elem shape = ${inter.getInputTensor(0 ).numElements()} " )
56+ Log .d(TAG , " output shape = ${inter.getOutputTensor(0 ).shape().contentToString()} " )
57+
58+ inputImageWidth = inputShape[1 ]
59+ inputImageHeight = inputShape[2 ]
60+ modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * PIXEL_SIZE
61+ isInitialized = true
62+ cb(true )
63+ } catch (t: Throwable ) {
64+ Log .e(TAG , " Failed to initialize DigitClassifier." , t)
65+ isInitialized = false
66+ cb(false )
7367 }
68+
7469 }
7570
7671
@@ -91,7 +86,8 @@ class DigitClassifier(private val context: Context) {
9186 val result = output[0 ]
9287 Log .d(TAG , " result = ${result.contentToString()} " )
9388 val maxIndex = result.indices.maxBy { result[it] }
94- val resultString = " Prediction Result: %d\n Confidence: %2f" .format(maxIndex, result[maxIndex])
89+ val resultString =
90+ " Prediction Result: %d\n Confidence: %2f" .format(maxIndex, result[maxIndex])
9591
9692 return resultString
9793
0 commit comments