Skip to content

Commit 20aa8e9

Browse files
authored
LlmDemo java->kt migration: Tests, data classes, util, runner (#165)
1 parent ee7e57e commit 20aa8e9

27 files changed

Lines changed: 1299 additions & 1445 deletions

llm/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/SanityCheck.java

Lines changed: 0 additions & 82 deletions
This file was deleted.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package com.example.executorchllamademo
10+
11+
import android.util.Log
12+
import androidx.test.ext.junit.runners.AndroidJUnit4
13+
import androidx.test.platform.app.InstrumentationRegistry
14+
import org.junit.Assert.assertEquals
15+
import org.junit.Assert.assertFalse
16+
import org.junit.Before
17+
import org.junit.Test
18+
import org.junit.runner.RunWith
19+
import org.pytorch.executorch.extension.llm.LlmCallback
20+
import org.pytorch.executorch.extension.llm.LlmModule
21+
import java.io.File
22+
23+
/**
24+
* Sanity check test for model loading and generation.
25+
*
26+
* Model filenames can be configured via instrumentation arguments:
27+
* - modelFile: name of the .pte file (default: stories110M.pte)
28+
* - tokenizerFile: name of the tokenizer file (default: tokenizer.model)
29+
*/
30+
@RunWith(AndroidJUnit4::class)
31+
class SanityCheck : LlmCallback {
32+
33+
companion object {
34+
private const val RESOURCE_PATH = "/data/local/tmp/llama/"
35+
private const val DEFAULT_MODEL_FILE = "stories110M.pte"
36+
private const val DEFAULT_TOKENIZER_FILE = "tokenizer.model"
37+
private const val TAG = "SanityCheck"
38+
}
39+
40+
private lateinit var modelFile: String
41+
private lateinit var tokenizerFile: String
42+
private val results = mutableListOf<String>()
43+
44+
@Before
45+
fun setUp() {
46+
// Read model filenames from instrumentation arguments
47+
val args = InstrumentationRegistry.getArguments()
48+
modelFile = args.getString("modelFile", DEFAULT_MODEL_FILE) ?: DEFAULT_MODEL_FILE
49+
tokenizerFile = args.getString("tokenizerFile", DEFAULT_TOKENIZER_FILE) ?: DEFAULT_TOKENIZER_FILE
50+
Log.i(TAG, "Using model: $modelFile, tokenizer: $tokenizerFile")
51+
}
52+
53+
@Test
54+
fun testLoadAndGenerate() {
55+
val tokenizerPath = RESOURCE_PATH + tokenizerFile
56+
val model = File(RESOURCE_PATH + modelFile)
57+
val module = LlmModule(model.path, tokenizerPath, 0.8f)
58+
59+
val loadResult = module.load()
60+
// Check that the model can be loaded successfully
61+
assertEquals(0, loadResult)
62+
63+
// Run a testing prompt
64+
module.generate("How do you do! I'm testing llm on mobile device", this)
65+
66+
// Verify we got some response
67+
assertFalse("Should receive at least one result token", results.isEmpty())
68+
}
69+
70+
override fun onResult(result: String) {
71+
results.add(result)
72+
}
73+
74+
override fun onStats(result: String) {
75+
// Not measuring performance for now
76+
}
77+
}

0 commit comments

Comments
 (0)