Skip to content

Commit e18d33c

Browse files
authored
Make preset model dynamic from json (#189)
Convert static ModelDownloadConfig to runtime JSON configuration: - Add preset_models.json in assets with default presets - Add PresetConfigManager for loading/parsing JSON configs - Support loading custom config from URL with caching - Add URL input, Load, and Reset buttons to preset UI - Add unit tests for JSON parsing logic
1 parent e19c3d0 commit e18d33c

8 files changed

Lines changed: 802 additions & 174 deletions

File tree

llm/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PresetSanityTest.kt

Lines changed: 40 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -10,42 +10,37 @@ package com.example.executorchllamademo
1010

1111
import android.content.Context
1212
import android.util.Log
13-
import androidx.compose.ui.semantics.SemanticsProperties
14-
import androidx.compose.ui.test.assertIsEnabled
15-
import androidx.compose.ui.test.hasContentDescription
1613
import androidx.compose.ui.test.junit4.createAndroidComposeRule
1714
import androidx.compose.ui.test.onAllNodesWithText
18-
import androidx.compose.ui.test.onNodeWithContentDescription
1915
import androidx.compose.ui.test.onNodeWithTag
2016
import androidx.compose.ui.test.onNodeWithText
2117
import androidx.compose.ui.test.performClick
2218
import androidx.compose.ui.test.performTextInput
2319
import androidx.test.core.app.ApplicationProvider
2420
import androidx.test.ext.junit.runners.AndroidJUnit4
2521
import androidx.test.filters.LargeTest
26-
import org.junit.Assert.assertTrue
2722
import org.junit.Before
28-
import org.junit.Ignore
2923
import org.junit.Rule
3024
import org.junit.Test
3125
import org.junit.runner.RunWith
3226

3327
/**
34-
* Preset model sanity test that validates the preset model download and chat workflow.
28+
* Preset model sanity test that validates the preset model UI workflow.
3529
*
3630
* This test validates:
3731
* 1. Navigate from Welcome screen to Preset model screen
38-
* 2. Select Stories 110M and download it
39-
* 3. After download completes, tap to load and enter chat view
40-
* 4. Type "Once upon a time" and generate a response
32+
* 2. Load preset config from URL
33+
* 3. Verify Stories 110M model is displayed
34+
*
35+
* Note: Download and chat steps are skipped in CI due to network constraints.
4136
*/
4237
@RunWith(AndroidJUnit4::class)
4338
@LargeTest
4439
class PresetSanityTest {
4540

4641
companion object {
4742
private const val TAG = "PresetSanityTest"
48-
private const val RESPONSE_TAG = "LLAMA_RESPONSE"
43+
private const val DEFAULT_CONFIG_URL = "https://raw.githubusercontent.com/meta-pytorch/executorch-examples/889ccc6e88813cbf03775889beed29b793d0c8db/llm/android/LlamaDemo/app/src/main/assets/preset_models.json"
4944
}
5045

5146
@get:Rule
@@ -60,113 +55,46 @@ class PresetSanityTest {
6055
Context.MODE_PRIVATE
6156
)
6257
prefs.edit().clear().commit()
63-
}
6458

65-
/**
66-
* Types text into the chat input field using testTag.
67-
*/
68-
private fun typeInChatInput(text: String) {
69-
composeTestRule.onNodeWithTag("chat_input_field").performClick()
70-
composeTestRule.waitForIdle()
71-
composeTestRule.onNodeWithTag("chat_input_field").performTextInput(text)
72-
composeTestRule.waitForIdle()
59+
// Also clear the preset config preferences
60+
val configPrefs = context.getSharedPreferences("preset_config_prefs", Context.MODE_PRIVATE)
61+
configPrefs.edit().clear().commit()
7362
}
7463

7564
/**
76-
* Waits for generation to complete by checking for tokens-per-second metrics.
65+
* Loads the preset config from URL.
66+
* This is needed because the bundled preset_models.json is empty by default.
7767
*/
78-
private fun waitForGenerationComplete(timeoutMs: Long = 120000): Boolean {
79-
return try {
80-
composeTestRule.waitUntil(timeoutMillis = timeoutMs) {
81-
val tpsNodes = composeTestRule.onAllNodesWithText("t/s", substring = true)
82-
.fetchSemanticsNodes()
83-
val tokpsNodes = composeTestRule.onAllNodesWithText("tok/s", substring = true)
84-
.fetchSemanticsNodes()
85-
tpsNodes.isNotEmpty() || tokpsNodes.isNotEmpty()
86-
}
87-
Log.i(TAG, "Generation complete - found generation metrics")
88-
true
89-
} catch (e: Exception) {
90-
Log.e(TAG, "Generation timed out after ${timeoutMs}ms")
91-
false
92-
}
93-
}
68+
private fun loadPresetConfigFromUrl() {
69+
Log.i(TAG, "Loading preset config from URL")
9470

95-
/**
96-
* Waits for the model to be loaded by checking for success or error messages.
97-
*/
98-
private fun waitForModelLoaded(timeoutMs: Long = 60000): Boolean {
99-
return try {
100-
var wasSuccess = false
101-
composeTestRule.waitUntil(timeoutMillis = timeoutMs) {
102-
val successNodes = composeTestRule.onAllNodesWithText("Successfully loaded", substring = true)
103-
.fetchSemanticsNodes()
104-
val errorNodes = composeTestRule.onAllNodesWithText("Model load failure", substring = true)
105-
.fetchSemanticsNodes()
106-
wasSuccess = successNodes.isNotEmpty()
107-
successNodes.isNotEmpty() || errorNodes.isNotEmpty()
108-
}
109-
if (wasSuccess) {
110-
Log.i(TAG, "Model loaded successfully")
111-
} else {
112-
Log.e(TAG, "Model load failed")
113-
}
114-
wasSuccess
115-
} catch (e: Exception) {
116-
Log.e(TAG, "Model loading timed out after ${timeoutMs}ms: ${e.message}")
117-
false
118-
}
119-
}
71+
// Type the URL into the config URL field (it's empty by default)
72+
composeTestRule.onNodeWithTag("config_url_field").performClick()
73+
composeTestRule.waitForIdle()
74+
composeTestRule.onNodeWithTag("config_url_field").performTextInput(DEFAULT_CONFIG_URL)
12075

121-
/**
122-
* Verifies that the model generated a non-empty response.
123-
*/
124-
private fun assertModelResponseNotEmpty(timeoutMs: Long = 10000) {
125-
try {
126-
composeTestRule.waitUntil(timeoutMillis = timeoutMs) {
127-
val tpsNodes = composeTestRule.onAllNodesWithText("t/s", substring = true)
128-
.fetchSemanticsNodes()
129-
val tokpsNodes = composeTestRule.onAllNodesWithText("tok/s", substring = true)
130-
.fetchSemanticsNodes()
131-
tpsNodes.isNotEmpty() || tokpsNodes.isNotEmpty()
132-
}
133-
Log.i(TAG, "Model response verified - found generation metrics")
134-
} catch (e: Exception) {
135-
throw AssertionError("Model response appears to be empty - no generation metrics found after ${timeoutMs}ms")
136-
}
137-
}
76+
// Small delay to ensure text is entered
77+
Thread.sleep(500)
13878

139-
/**
140-
* Logs the model response text for CI output.
141-
*/
142-
private fun logModelResponse() {
143-
try {
144-
Log.i(RESPONSE_TAG, "BEGIN_RESPONSE")
145-
val responseNodes = composeTestRule.onAllNodesWithText("t/s", substring = true)
146-
.fetchSemanticsNodes()
147-
for (node in responseNodes) {
148-
val text = node.config.getOrElse(SemanticsProperties.Text) { emptyList() }
149-
.joinToString(" ") { it.text }
150-
if (text.isNotBlank()) {
151-
Log.i(RESPONSE_TAG, text)
152-
}
153-
}
154-
Log.i(RESPONSE_TAG, "END_RESPONSE")
155-
} catch (e: Exception) {
156-
Log.d(TAG, "Could not log model response: ${e.message}")
79+
// Click the Load button
80+
composeTestRule.onNodeWithText("Load").performClick()
81+
82+
// Wait for config to load (models should appear)
83+
// Don't use waitForIdle here as the loading spinner animation keeps Compose busy
84+
composeTestRule.waitUntil(timeoutMillis = 60000) {
85+
composeTestRule.onAllNodesWithText("Stories 110M").fetchSemanticsNodes().isNotEmpty()
15786
}
87+
Log.i(TAG, "Preset config loaded successfully")
15888
}
15989

16090
/**
161-
* Tests the complete preset model download and chat workflow:
91+
* Tests the preset model UI workflow:
16292
* 1. From Welcome screen, tap "Preset model" card
163-
* 2. Find Stories 110M and tap Download
164-
* 3. Wait for download to complete
165-
* 4. Tap the card to load model and enter chat
166-
* 5. Type "Once upon a time" and send
167-
* 6. Verify response is generated
93+
* 2. Load preset config from URL (since bundled JSON is empty)
94+
* 3. Verify Stories 110M model is displayed
95+
*
96+
* Note: Download and chat steps are skipped in CI due to network constraints.
16897
*/
169-
@Ignore("Temporarily disabled")
17098
@Test
17199
fun testPresetModelDownloadAndChat() {
172100
composeTestRule.waitForIdle()
@@ -178,68 +106,15 @@ class PresetSanityTest {
178106
composeTestRule.onAllNodesWithText("Download Preset Model").fetchSemanticsNodes().isNotEmpty()
179107
}
180108

181-
// Step 2: Find Stories 110M and tap Download
182-
Log.i(TAG, "Step 2: Finding Stories 110M and starting download")
183-
composeTestRule.onNodeWithText("Stories 110M").assertExists()
184-
185-
// Check if already downloaded (Ready to use) or needs download
186-
val readyNodes = composeTestRule.onAllNodesWithText("Ready to use", substring = true)
187-
.fetchSemanticsNodes()
188-
189-
if (readyNodes.isEmpty()) {
190-
// Need to download - click Download button
191-
composeTestRule.onNodeWithText("Download").performClick()
192-
193-
// Step 3: Wait for download to complete (may take a while for large files)
194-
Log.i(TAG, "Step 3: Waiting for download to complete")
195-
composeTestRule.waitUntil(timeoutMillis = 300000) { // 5 minutes timeout for download
196-
composeTestRule.onAllNodesWithText("Ready to use", substring = true)
197-
.fetchSemanticsNodes().isNotEmpty()
198-
}
199-
Log.i(TAG, "Download completed")
200-
} else {
201-
Log.i(TAG, "Model already downloaded, skipping download step")
202-
}
203-
204-
// Step 4: Tap the card to load model and enter chat
205-
Log.i(TAG, "Step 4: Tapping card to load model")
206-
composeTestRule.onNodeWithText("Stories 110M").performClick()
207-
208-
// Wait for Activity transition - MainActivity needs time to launch and set content
209-
// The SelectPresetModelActivity calls finish() after starting MainActivity
210-
Thread.sleep(2000)
211-
212-
// Wait for model to load and chat screen to appear
213-
Log.i(TAG, "Waiting for model to load")
214-
val modelLoaded = waitForModelLoaded(90000)
215-
assertTrue("Model should be loaded successfully", modelLoaded)
216-
Log.i(TAG, "Model loaded successfully")
109+
// Step 2: Load preset config from URL
110+
Log.i(TAG, "Step 2: Loading preset config from URL")
111+
loadPresetConfigFromUrl()
217112

218-
// Step 5: Type "Once upon a time" and send
219-
Log.i(TAG, "Step 5: Typing prompt and sending")
220-
typeInChatInput("Once upon a time")
221-
222-
// Wait for send button to be enabled
223-
composeTestRule.waitUntil(timeoutMillis = 5000) {
224-
try {
225-
composeTestRule.onNodeWithContentDescription("Send").assertIsEnabled()
226-
true
227-
} catch (e: AssertionError) {
228-
false
229-
}
230-
}
231-
232-
composeTestRule.onNodeWithContentDescription("Send").performClick()
233-
composeTestRule.waitForIdle()
234-
235-
// Step 6: Wait for generation to complete and verify response
236-
Log.i(TAG, "Step 6: Waiting for generation to complete")
237-
val generationComplete = waitForGenerationComplete(120000)
238-
assertTrue("Generation should complete", generationComplete)
239-
240-
assertModelResponseNotEmpty()
241-
logModelResponse()
113+
// Step 3: Find Stories 110M and verify it exists
114+
Log.i(TAG, "Step 3: Verifying Stories 110M is displayed")
115+
composeTestRule.onNodeWithText("Stories 110M").assertExists()
116+
Log.i(TAG, "Stories 110M found - preset model screen is working correctly")
242117

243-
Log.i(TAG, "Preset model sanity test completed successfully")
118+
// Note: Download and chat steps are skipped in CI due to network constraints
244119
}
245120
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"models": {
3+
}
4+
}

llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelDownloadConfig.kt

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
package com.example.executorchllamademo
1010

11+
import android.content.Context
12+
1113
/**
1214
* Represents a downloadable model with its associated files.
1315
*/
@@ -24,21 +26,51 @@ data class ModelInfo(
2426

2527
/**
2628
* Configuration class that maps model display names to their download URLs.
29+
* Models are loaded from JSON configuration at runtime via PresetConfigManager.
2730
*/
2831
object ModelDownloadConfig {
2932

30-
private val AVAILABLE_MODELS: LinkedHashMap<String, ModelInfo> = linkedMapOf(
31-
)
33+
private var configManager: PresetConfigManager? = null
34+
private var cachedModels: Map<String, ModelInfo> = emptyMap()
35+
36+
/**
37+
* Initializes the config with a context. Must be called before accessing models.
38+
*/
39+
fun initialize(context: Context) {
40+
if (configManager == null) {
41+
configManager = PresetConfigManager(context.applicationContext)
42+
reloadModels()
43+
}
44+
}
45+
46+
/**
47+
* Reloads models from the current configuration source.
48+
*/
49+
fun reloadModels() {
50+
cachedModels = configManager?.loadModels() ?: emptyMap()
51+
}
52+
53+
/**
54+
* Updates the models with a new map (used after loading from URL).
55+
*/
56+
fun updateModels(models: Map<String, ModelInfo>) {
57+
cachedModels = models
58+
}
59+
60+
/**
61+
* Returns the PresetConfigManager instance for advanced operations.
62+
*/
63+
fun getConfigManager(): PresetConfigManager? = configManager
3264

33-
fun getAvailableModels(): Map<String, ModelInfo> = AVAILABLE_MODELS
65+
fun getAvailableModels(): Map<String, ModelInfo> = cachedModels
3466

3567
fun getDisplayNames(): Array<String> =
36-
AVAILABLE_MODELS.values.map { it.displayName }.toTypedArray()
68+
cachedModels.values.map { it.displayName }.toTypedArray()
3769

38-
fun getModelKeys(): Array<String> = AVAILABLE_MODELS.keys.toTypedArray()
70+
fun getModelKeys(): Array<String> = cachedModels.keys.toTypedArray()
3971

4072
fun getByDisplayName(displayName: String): ModelInfo? =
41-
AVAILABLE_MODELS.values.find { it.displayName == displayName }
73+
cachedModels.values.find { it.displayName == displayName }
4274

43-
fun getByKey(key: String): ModelInfo? = AVAILABLE_MODELS[key]
75+
fun getByKey(key: String): ModelInfo? = cachedModels[key]
4476
}

0 commit comments

Comments
 (0)