|
29 | 29 | public class PerfTest implements LlmCallback { |
30 | 30 |
|
31 | 31 | private static final String RESOURCE_PATH = "/data/local/tmp/llama/"; |
32 | | - private static final String TOKENIZER_BIN = "tokenizer.json"; |
| 32 | + private static final String TOKENIZER_PATH = "tokenizer.json"; |
| 33 | + private static final String MODEL_PATH = "model.pte"; |
33 | 34 |
|
34 | 35 | private final List<String> results = new ArrayList<>(); |
35 | 36 | private final List<Float> tokensPerSecond = new ArrayList<>(); |
36 | | - @Test public void testSanity() {} |
37 | 37 |
|
| 38 | + @Test |
| 39 | + public void testLoadAndGenerate() { |
| 40 | + String tokenizerPath = RESOURCE_PATH + TOKENIZER_PATH; |
| 41 | + // Find out the model name |
| 42 | + File model = new File(RESOURCE_PATH + MODEL_PATH); |
| 43 | + LlmModule mModule = new LlmModule(model.getPath(), tokenizerPath, 0.8f); |
| 44 | + // Print the model name because there might be more than one of them |
| 45 | + report("ModelName", model.getName()); |
| 46 | + |
| 47 | + int loadResult = mModule.load(); |
| 48 | + // Check that the model can be load successfully |
| 49 | + assertEquals(0, loadResult); |
| 50 | + |
| 51 | + // Run a testing prompt |
| 52 | + mModule.generate("How do you do! I'm testing llm on mobile device", PerfTest.this); |
| 53 | + } |
| 54 | + |
| 55 | + @Test |
| 56 | + public void testTokensPerSecond() { |
| 57 | + String tokenizerPath = RESOURCE_PATH + TOKENIZER_PATH; |
| 58 | + // Find out the model name |
| 59 | + File model = new File(RESOURCE_PATH + MODEL_PATH); |
| 60 | + LlmModule mModule = new LlmModule(model.getPath(), tokenizerPath, 0.8f); |
| 61 | + // Print the model name because there might be more than one of them |
| 62 | + report("ModelName", model.getName()); |
| 63 | + |
| 64 | + int loadResult = mModule.load(); |
| 65 | + // Check that the model can be load successfully |
| 66 | + assertEquals(0, loadResult); |
| 67 | + |
| 68 | + // Run a testing prompt |
| 69 | + mModule.generate("How do you do! I'm testing llm on mobile device", PerfTest.this); |
| 70 | + assertFalse(tokensPerSecond.isEmpty()); |
| 71 | + |
| 72 | + final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); |
| 73 | + report("TPS", tps); |
| 74 | + } |
38 | 75 |
|
39 | 76 | @Override |
40 | 77 | public void onResult(String result) { |
|
0 commit comments