Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,37 @@ package com.example.executorchllamademo

import android.content.Context
import android.util.Log
import androidx.compose.ui.semantics.SemanticsProperties
import androidx.compose.ui.test.assertIsEnabled
import androidx.compose.ui.test.hasContentDescription
import androidx.compose.ui.test.junit4.createAndroidComposeRule
import androidx.compose.ui.test.onAllNodesWithText
import androidx.compose.ui.test.onNodeWithContentDescription
import androidx.compose.ui.test.onNodeWithTag
import androidx.compose.ui.test.onNodeWithText
import androidx.compose.ui.test.performClick
import androidx.compose.ui.test.performTextInput
import androidx.test.core.app.ApplicationProvider
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.filters.LargeTest
import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Ignore
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith

/**
* Preset model sanity test that validates the preset model download and chat workflow.
* Preset model sanity test that validates the preset model UI workflow.
*
* This test validates:
* 1. Navigate from Welcome screen to Preset model screen
* 2. Select Stories 110M and download it
* 3. After download completes, tap to load and enter chat view
* 4. Type "Once upon a time" and generate a response
* 2. Load preset config from URL
* 3. Verify Stories 110M model is displayed
*
* Note: Download and chat steps are skipped in CI due to network constraints.
*/
@RunWith(AndroidJUnit4::class)
@LargeTest
class PresetSanityTest {

companion object {
private const val TAG = "PresetSanityTest"
private const val RESPONSE_TAG = "LLAMA_RESPONSE"
private const val DEFAULT_CONFIG_URL = "https://raw.githubusercontent.com/meta-pytorch/executorch-examples/889ccc6e88813cbf03775889beed29b793d0c8db/llm/android/LlamaDemo/app/src/main/assets/preset_models.json"
}

@get:Rule
Expand All @@ -60,113 +55,46 @@ class PresetSanityTest {
Context.MODE_PRIVATE
)
prefs.edit().clear().commit()
}

/**
* Types text into the chat input field using testTag.
*/
private fun typeInChatInput(text: String) {
composeTestRule.onNodeWithTag("chat_input_field").performClick()
composeTestRule.waitForIdle()
composeTestRule.onNodeWithTag("chat_input_field").performTextInput(text)
composeTestRule.waitForIdle()
// Also clear the preset config preferences
val configPrefs = context.getSharedPreferences("preset_config_prefs", Context.MODE_PRIVATE)
configPrefs.edit().clear().commit()
}

/**
* Waits for generation to complete by checking for tokens-per-second metrics.
* Loads the preset config from URL.
* This is needed because the bundled preset_models.json is empty by default.
*/
private fun waitForGenerationComplete(timeoutMs: Long = 120000): Boolean {
return try {
composeTestRule.waitUntil(timeoutMillis = timeoutMs) {
val tpsNodes = composeTestRule.onAllNodesWithText("t/s", substring = true)
.fetchSemanticsNodes()
val tokpsNodes = composeTestRule.onAllNodesWithText("tok/s", substring = true)
.fetchSemanticsNodes()
tpsNodes.isNotEmpty() || tokpsNodes.isNotEmpty()
}
Log.i(TAG, "Generation complete - found generation metrics")
true
} catch (e: Exception) {
Log.e(TAG, "Generation timed out after ${timeoutMs}ms")
false
}
}
private fun loadPresetConfigFromUrl() {
Log.i(TAG, "Loading preset config from URL")

/**
* Waits for the model to be loaded by checking for success or error messages.
*/
private fun waitForModelLoaded(timeoutMs: Long = 60000): Boolean {
return try {
var wasSuccess = false
composeTestRule.waitUntil(timeoutMillis = timeoutMs) {
val successNodes = composeTestRule.onAllNodesWithText("Successfully loaded", substring = true)
.fetchSemanticsNodes()
val errorNodes = composeTestRule.onAllNodesWithText("Model load failure", substring = true)
.fetchSemanticsNodes()
wasSuccess = successNodes.isNotEmpty()
successNodes.isNotEmpty() || errorNodes.isNotEmpty()
}
if (wasSuccess) {
Log.i(TAG, "Model loaded successfully")
} else {
Log.e(TAG, "Model load failed")
}
wasSuccess
} catch (e: Exception) {
Log.e(TAG, "Model loading timed out after ${timeoutMs}ms: ${e.message}")
false
}
}
// Type the URL into the config URL field (it's empty by default)
composeTestRule.onNodeWithTag("config_url_field").performClick()
composeTestRule.waitForIdle()
composeTestRule.onNodeWithTag("config_url_field").performTextInput(DEFAULT_CONFIG_URL)

/**
* Verifies that the model generated a non-empty response.
*/
private fun assertModelResponseNotEmpty(timeoutMs: Long = 10000) {
try {
composeTestRule.waitUntil(timeoutMillis = timeoutMs) {
val tpsNodes = composeTestRule.onAllNodesWithText("t/s", substring = true)
.fetchSemanticsNodes()
val tokpsNodes = composeTestRule.onAllNodesWithText("tok/s", substring = true)
.fetchSemanticsNodes()
tpsNodes.isNotEmpty() || tokpsNodes.isNotEmpty()
}
Log.i(TAG, "Model response verified - found generation metrics")
} catch (e: Exception) {
throw AssertionError("Model response appears to be empty - no generation metrics found after ${timeoutMs}ms")
}
}
// Small delay to ensure text is entered
Thread.sleep(500)

/**
* Logs the model response text for CI output.
*/
private fun logModelResponse() {
try {
Log.i(RESPONSE_TAG, "BEGIN_RESPONSE")
val responseNodes = composeTestRule.onAllNodesWithText("t/s", substring = true)
.fetchSemanticsNodes()
for (node in responseNodes) {
val text = node.config.getOrElse(SemanticsProperties.Text) { emptyList() }
.joinToString(" ") { it.text }
if (text.isNotBlank()) {
Log.i(RESPONSE_TAG, text)
}
}
Log.i(RESPONSE_TAG, "END_RESPONSE")
} catch (e: Exception) {
Log.d(TAG, "Could not log model response: ${e.message}")
// Click the Load button
composeTestRule.onNodeWithText("Load").performClick()

// Wait for config to load (models should appear)
// Don't use waitForIdle here as the loading spinner animation keeps Compose busy
composeTestRule.waitUntil(timeoutMillis = 60000) {
composeTestRule.onAllNodesWithText("Stories 110M").fetchSemanticsNodes().isNotEmpty()
}
Log.i(TAG, "Preset config loaded successfully")
}

/**
* Tests the complete preset model download and chat workflow:
* Tests the preset model UI workflow:
* 1. From Welcome screen, tap "Preset model" card
* 2. Find Stories 110M and tap Download
* 3. Wait for download to complete
* 4. Tap the card to load model and enter chat
* 5. Type "Once upon a time" and send
* 6. Verify response is generated
* 2. Load preset config from URL (since bundled JSON is empty)
* 3. Verify Stories 110M model is displayed
*
* Note: Download and chat steps are skipped in CI due to network constraints.
*/
@Ignore("Temporarily disabled")
@Test
fun testPresetModelDownloadAndChat() {
composeTestRule.waitForIdle()
Expand All @@ -178,68 +106,15 @@ class PresetSanityTest {
composeTestRule.onAllNodesWithText("Download Preset Model").fetchSemanticsNodes().isNotEmpty()
}

// Step 2: Find Stories 110M and tap Download
Log.i(TAG, "Step 2: Finding Stories 110M and starting download")
composeTestRule.onNodeWithText("Stories 110M").assertExists()

// Check if already downloaded (Ready to use) or needs download
val readyNodes = composeTestRule.onAllNodesWithText("Ready to use", substring = true)
.fetchSemanticsNodes()

if (readyNodes.isEmpty()) {
// Need to download - click Download button
composeTestRule.onNodeWithText("Download").performClick()

// Step 3: Wait for download to complete (may take a while for large files)
Log.i(TAG, "Step 3: Waiting for download to complete")
composeTestRule.waitUntil(timeoutMillis = 300000) { // 5 minutes timeout for download
composeTestRule.onAllNodesWithText("Ready to use", substring = true)
.fetchSemanticsNodes().isNotEmpty()
}
Log.i(TAG, "Download completed")
} else {
Log.i(TAG, "Model already downloaded, skipping download step")
}

// Step 4: Tap the card to load model and enter chat
Log.i(TAG, "Step 4: Tapping card to load model")
composeTestRule.onNodeWithText("Stories 110M").performClick()

// Wait for Activity transition - MainActivity needs time to launch and set content
// The SelectPresetModelActivity calls finish() after starting MainActivity
Thread.sleep(2000)

// Wait for model to load and chat screen to appear
Log.i(TAG, "Waiting for model to load")
val modelLoaded = waitForModelLoaded(90000)
assertTrue("Model should be loaded successfully", modelLoaded)
Log.i(TAG, "Model loaded successfully")
// Step 2: Load preset config from URL
Log.i(TAG, "Step 2: Loading preset config from URL")
loadPresetConfigFromUrl()

// Step 5: Type "Once upon a time" and send
Log.i(TAG, "Step 5: Typing prompt and sending")
typeInChatInput("Once upon a time")

// Wait for send button to be enabled
composeTestRule.waitUntil(timeoutMillis = 5000) {
try {
composeTestRule.onNodeWithContentDescription("Send").assertIsEnabled()
true
} catch (e: AssertionError) {
false
}
}

composeTestRule.onNodeWithContentDescription("Send").performClick()
composeTestRule.waitForIdle()

// Step 6: Wait for generation to complete and verify response
Log.i(TAG, "Step 6: Waiting for generation to complete")
val generationComplete = waitForGenerationComplete(120000)
assertTrue("Generation should complete", generationComplete)

assertModelResponseNotEmpty()
logModelResponse()
// Step 3: Find Stories 110M and verify it exists
Log.i(TAG, "Step 3: Verifying Stories 110M is displayed")
composeTestRule.onNodeWithText("Stories 110M").assertExists()
Log.i(TAG, "Stories 110M found - preset model screen is working correctly")

Log.i(TAG, "Preset model sanity test completed successfully")
// Note: Download and chat steps are skipped in CI due to network constraints
}
}
4 changes: 4 additions & 0 deletions llm/android/LlamaDemo/app/src/main/assets/preset_models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"models": {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

package com.example.executorchllamademo

import android.content.Context

/**
* Represents a downloadable model with its associated files.
*/
Expand All @@ -24,21 +26,51 @@ data class ModelInfo(

/**
* Configuration class that maps model display names to their download URLs.
* Models are loaded from JSON configuration at runtime via PresetConfigManager.
*/
object ModelDownloadConfig {

private val AVAILABLE_MODELS: LinkedHashMap<String, ModelInfo> = linkedMapOf(
)
private var configManager: PresetConfigManager? = null
private var cachedModels: Map<String, ModelInfo> = emptyMap()

/**
* Initializes the config with a context. Must be called before accessing models.
*/
fun initialize(context: Context) {
if (configManager == null) {
configManager = PresetConfigManager(context.applicationContext)
reloadModels()
}
}

/**
* Reloads models from the current configuration source.
*/
fun reloadModels() {
cachedModels = configManager?.loadModels() ?: emptyMap()
}

/**
* Updates the models with a new map (used after loading from URL).
*/
fun updateModels(models: Map<String, ModelInfo>) {
cachedModels = models
}

/**
* Returns the PresetConfigManager instance for advanced operations.
*/
fun getConfigManager(): PresetConfigManager? = configManager

fun getAvailableModels(): Map<String, ModelInfo> = AVAILABLE_MODELS
fun getAvailableModels(): Map<String, ModelInfo> = cachedModels

fun getDisplayNames(): Array<String> =
AVAILABLE_MODELS.values.map { it.displayName }.toTypedArray()
cachedModels.values.map { it.displayName }.toTypedArray()

fun getModelKeys(): Array<String> = AVAILABLE_MODELS.keys.toTypedArray()
fun getModelKeys(): Array<String> = cachedModels.keys.toTypedArray()

fun getByDisplayName(displayName: String): ModelInfo? =
AVAILABLE_MODELS.values.find { it.displayName == displayName }
cachedModels.values.find { it.displayName == displayName }

fun getByKey(key: String): ModelInfo? = AVAILABLE_MODELS[key]
fun getByKey(key: String): ModelInfo? = cachedModels[key]
}
Loading