Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llm/android/LlamaDemo/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ android {
versionName = "1.0"

testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"

// Automatically set instrumentation arguments based on model preset
val preset = modelPresets[modelPreset]
if (preset != null) {
testInstrumentationRunnerArguments["modelFile"] = preset["pteFile"] as String
testInstrumentationRunnerArguments["tokenizerFile"] = preset["tokenizerFile"] as String
}

vectorDrawables { useSupportLibrary = true }
externalNativeBuild { cmake { cppFlags += "" } }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import static androidx.test.espresso.matcher.ViewMatchers.withId;
import static androidx.test.espresso.matcher.ViewMatchers.withText;
import static org.hamcrest.Matchers.anything;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.Matchers.not;
Expand Down Expand Up @@ -118,12 +118,12 @@ public void testModelLoadingWorkflow() throws Exception {
// Step 3: Click model selection button and select the model file
onView(withId(R.id.modelImageButton)).perform(click());
// Select the model file matching the configured filename
onData(hasToString(containsString(modelFile))).inRoot(isDialog()).perform(click());
onData(hasToString(endsWith(modelFile))).inRoot(isDialog()).perform(click());

// Step 4: Click tokenizer selection button and select the tokenizer file
onView(withId(R.id.tokenizerImageButton)).perform(click());
// Select the tokenizer file matching the configured filename
onData(hasToString(containsString(tokenizerFile))).inRoot(isDialog()).perform(click());
onData(hasToString(endsWith(tokenizerFile))).inRoot(isDialog()).perform(click());

// Step 5: Click load model button
onView(withId(R.id.loadModelButton)).perform(click());
Expand Down Expand Up @@ -165,13 +165,13 @@ public void testSendMessageAndReceiveResponse() throws Exception {
// Select model - choose the configured model file
onView(withId(R.id.modelImageButton)).perform(click());
Thread.sleep(300); // Wait for dialog to appear
onData(hasToString(containsString(modelFile))).inRoot(isDialog()).perform(click());
onData(hasToString(endsWith(modelFile))).inRoot(isDialog()).perform(click());
Thread.sleep(300); // Wait for dialog to dismiss and UI to update

// Select tokenizer - choose the configured tokenizer file
onView(withId(R.id.tokenizerImageButton)).perform(click());
Thread.sleep(300); // Wait for dialog to appear
onData(hasToString(containsString(tokenizerFile))).inRoot(isDialog()).perform(click());
onData(hasToString(endsWith(tokenizerFile))).inRoot(isDialog()).perform(click());
Thread.sleep(300); // Wait for dialog to dismiss and UI to update

// Verify load button is now enabled
Expand All @@ -198,9 +198,10 @@ public void testSendMessageAndReceiveResponse() throws Exception {
// Click send button
onView(withId(R.id.sendButton)).perform(click());

// --- Wait for response and validate ---
// Wait 50 seconds for model to generate response
Thread.sleep(50000);
// --- Wait for response ---
// Poll until we have some response text (at least 50 characters)
boolean hasResponse = waitForResponseLength(scenario, 50, 60000);
assertTrue("Model should generate a response", hasResponse);

// Extract all messages from the list
AtomicInteger messageCount = new AtomicInteger(0);
Expand Down Expand Up @@ -265,6 +266,143 @@ private boolean waitForModelLoaded(ActivityScenario<MainActivity> scenario, long
return false;
}

/**
* Tests stopping generation mid-way:
* 1. Load model
* 2. Send a message to start generation
* 3. Wait for generation to start (button changes to stop mode)
* 4. Click stop button
* 5. Verify generation stops (button returns to send mode)
* 6. Verify partial response was received
*/
@Test
public void testStopGeneration() throws Exception {
try (ActivityScenario<MainActivity> scenario = ActivityScenario.launch(MainActivity.class)) {
// Wait for activity to fully load
Thread.sleep(1000);

// Dismiss the "Please Select a Model" dialog
onView(withText(android.R.string.ok)).inRoot(isDialog()).perform(click());

// --- Load model ---
onView(withId(R.id.settings)).perform(click());
Thread.sleep(500);

// Select model
onView(withId(R.id.modelImageButton)).perform(click());
Thread.sleep(300);
onData(hasToString(endsWith(modelFile))).inRoot(isDialog()).perform(click());
Thread.sleep(300);

// Select tokenizer
onView(withId(R.id.tokenizerImageButton)).perform(click());
Thread.sleep(300);
onData(hasToString(endsWith(tokenizerFile))).inRoot(isDialog()).perform(click());
Thread.sleep(300);

// Load model
onView(withId(R.id.loadModelButton)).perform(click());
onView(withText(android.R.string.yes)).inRoot(isDialog()).perform(click());

// Wait for model to load
boolean modelLoaded = waitForModelLoaded(scenario, 60000);
assertTrue("Model should be loaded successfully", modelLoaded);

// --- Send a message to start generation ---
onView(withId(R.id.editTextMessage)).perform(typeText("Write a very long story about a brave knight"), ViewActions.closeSoftKeyboard());
onView(withId(R.id.sendButton)).perform(click());

// --- Wait for generation to start (some response text appears) ---
boolean generationStarted = waitForResponseStarted(scenario, 30000);
assertTrue("Generation should start (some response text should appear)", generationStarted);

// --- Wait for some text to generate (at least 20 characters) ---
boolean hasEnoughText = waitForResponseLength(scenario, 20, 30000);
assertTrue("Should generate some text before stopping", hasEnoughText);

// --- Click stop button ---
onView(withId(R.id.sendButton)).perform(click());

// --- Wait for generation to stop ---
// Give it a moment to process the stop
Thread.sleep(1000);

// --- Verify we got a partial response ---
AtomicReference<String> responseText = new AtomicReference<>("");
scenario.onActivity(activity -> {
ListView messagesView = activity.findViewById(R.id.messages_view);
if (messagesView != null && messagesView.getAdapter() != null) {
for (int i = 0; i < messagesView.getAdapter().getCount(); i++) {
Object item = messagesView.getAdapter().getItem(i);
if (item instanceof Message) {
Message message = (Message) item;
// Find the model response (not sent by user, not system message)
if (!message.getIsSent() && !message.getText().contains("Successfully loaded")) {
responseText.set(message.getText());
}
}
}
}
});

// Log the partial response
android.util.Log.i("STOP_TEST", "Partial response after stop: " + responseText.get());

// We should have received some tokens before stopping
assertTrue("Should have received some response before stopping",
responseText.get() != null && !responseText.get().isEmpty());
}
}

/**
* Waits for generation to start by checking for model response text.
*
* @param scenario the activity scenario
* @param timeoutMs maximum time to wait in milliseconds
* @return true if response text appeared, false if timeout
*/
private boolean waitForResponseStarted(ActivityScenario<MainActivity> scenario, long timeoutMs) throws InterruptedException {
return waitForResponseLength(scenario, 1, timeoutMs);
}

/**
* Waits for the model response to reach a minimum length.
*
* @param scenario the activity scenario
* @param minLength minimum response length in characters
* @param timeoutMs maximum time to wait in milliseconds
* @return true if response reached minimum length, false if timeout
*/
private boolean waitForResponseLength(ActivityScenario<MainActivity> scenario, int minLength, long timeoutMs) throws InterruptedException {
long startTime = System.currentTimeMillis();
while (System.currentTimeMillis() - startTime < timeoutMs) {
AtomicInteger responseLength = new AtomicInteger(0);
scenario.onActivity(activity -> {
ListView messagesView = activity.findViewById(R.id.messages_view);
if (messagesView != null && messagesView.getAdapter() != null) {
for (int i = 0; i < messagesView.getAdapter().getCount(); i++) {
Object item = messagesView.getAdapter().getItem(i);
if (item instanceof Message) {
Message message = (Message) item;
// Look for a model response (not sent, not system message)
if (!message.getIsSent()
&& !message.getText().contains("Successfully loaded")
&& !message.getText().contains("Loading model")
&& !message.getText().contains("To get started")) {
responseLength.set(message.getText().length());
}
}
}
}
});
if (responseLength.get() >= minLength) {
return true;
}
Thread.sleep(200); // Poll every 200ms
}
return false;
}

/**
* Writes the model response to logcat with a special tag for extraction.
* The response can be extracted from logcat using: grep "LLAMA_RESPONSE"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ private void addSelectedImagesToChatThread(List<Uri> selectedImageUri) {
}

private void onModelRunStarted() {
mSendButton.setClickable(false);
mSendButton.setClickable(true);
mSendButton.setImageResource(R.drawable.baseline_stop_24);
mSendButton.setOnClickListener(
view -> {
Expand Down
4 changes: 1 addition & 3 deletions llm/android/LlamaDemo/scripts/run-ci-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ LOGCAT_PID=$!
echo "=== Starting Gradle ==="
./gradlew connectedCheck \
-PskipModelDownload=true \
-PmodelPreset="$MODEL_PRESET" \
-Pandroid.testInstrumentationRunnerArguments.modelFile="$MODEL_FILE" \
-Pandroid.testInstrumentationRunnerArguments.tokenizerFile="$TOKENIZER_FILE"
-PmodelPreset="$MODEL_PRESET"
TEST_EXIT_CODE=$?

echo "=== Model directory after Gradle ==="
Expand Down