Skip to content

Commit 4ccc896

Browse files
committed
[Android] Enable multi-model selection and fix Gemma image prefill
1 parent 1cb22f1 commit 4ccc896

5 files changed

Lines changed: 92 additions & 13 deletions

File tree

llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModuleSettings.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ package com.example.executorchllamademo
1212
* Holds module-specific settings for the current model/tokenizer configuration.
1313
*/
1414
data class ModuleSettings(
15-
val modelFilePath: String = "",
15+
val modelFilePath: List<String> = emptyList(),
1616
val tokenizerFilePath: String = "",
1717
val dataPath: String = "",
1818
val temperature: Double = DEFAULT_TEMPERATURE,

llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.kt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,21 @@ object PromptFormat {
9393
return "$USER_PLACEHOLDER ASSISTANT:"
9494
}
9595

96+
@JvmStatic
97+
fun getGemmaPreImagePrompt(): String {
98+
return "<start_of_turn>user\n"
99+
}
100+
101+
@JvmStatic
102+
fun getGemmaPostImagePrompt(): String {
103+
return "<end_of_image>"
104+
}
105+
106+
@JvmStatic
107+
fun getGemmaMultimodalUserPrompt(): String {
108+
return "$USER_PLACEHOLDER<end_of_turn>\n<start_of_turn>model"
109+
}
110+
96111
@JvmStatic
97112
fun getFormattedLlamaGuardPrompt(userPrompt: String): String {
98113
return getUserPromptTemplate(ModelType.LLAMA_GUARD_3)

llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/screens/ModelSettingsScreen.kt

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,12 +439,12 @@ private fun ModelDialog(viewModel: ModelSettingsViewModel) {
439439
}
440440
)
441441
} else {
442-
SingleChoiceDialog(
442+
MultiChoiceDialog(
443443
title = "Select model path",
444444
options = viewModel.modelFiles.toList(),
445+
selectedOptions = viewModel.moduleSettings.modelFilePath,
445446
onSelect = { selected ->
446-
viewModel.selectModel(selected)
447-
viewModel.showModelDialog = false
447+
viewModel.toggleModel(selected)
448448
},
449449
onDismiss = { viewModel.showModelDialog = false }
450450
)
@@ -658,6 +658,51 @@ private fun InvalidPromptDialog(viewModel: ModelSettingsViewModel) {
658658
}
659659
}
660660

661+
@Composable
662+
private fun MultiChoiceDialog(
663+
title: String,
664+
options: List<String>,
665+
selectedOptions: List<String>,
666+
onSelect: (String) -> Unit,
667+
onDismiss: () -> Unit
668+
) {
669+
AlertDialog(
670+
onDismissRequest = onDismiss,
671+
title = { Text(title) },
672+
text = {
673+
Column {
674+
options.forEach { option ->
675+
val isSelected = selectedOptions.contains(option)
676+
Row(
677+
modifier = Modifier
678+
.fillMaxWidth()
679+
.clickable { onSelect(option) }
680+
.padding(vertical = 8.dp),
681+
verticalAlignment = Alignment.CenterVertically
682+
) {
683+
androidx.compose.material3.Checkbox(
684+
checked = isSelected,
685+
onCheckedChange = { onSelect(option) }
686+
)
687+
Text(
688+
text = option.substringAfterLast('/'),
689+
modifier = Modifier
690+
.padding(start = 8.dp)
691+
.weight(1f),
692+
fontSize = 14.sp
693+
)
694+
}
695+
}
696+
}
697+
},
698+
confirmButton = {
699+
TextButton(onClick = onDismiss) {
700+
Text("Done")
701+
}
702+
}
703+
)
704+
}
705+
661706
@Composable
662707
private fun SingleChoiceDialog(
663708
title: String,

llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/viewmodel/ChatViewModel.kt

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,22 +182,23 @@ class ChatViewModel(application: Application) : AndroidViewModel(application), L
182182
}
183183

184184
private fun loadLocalModelAndParameters(
185-
modelFilePath: String,
185+
modelFilePaths: List<String>,
186186
tokenizerFilePath: String,
187187
dataPath: String,
188188
temperature: Float
189189
) {
190190
Thread {
191-
setLocalModel(modelFilePath, tokenizerFilePath, dataPath, temperature)
191+
setLocalModel(modelFilePaths, tokenizerFilePath, dataPath, temperature)
192192
}.start()
193193
}
194194

195195
private fun setLocalModel(
196-
modelPath: String,
196+
modelPaths: List<String>,
197197
tokenizerPath: String,
198198
dataPath: String,
199199
temperature: Float
200200
) {
201+
val modelPath = modelPaths.joinToString(separator = ",")
201202
val modelLoadingMessage = Message("Loading model...", false, MessageType.SYSTEM, 0)
202203
ETLogging.getInstance().log(
203204
"Loading model $modelPath with tokenizer $tokenizerPath data path $dataPath"
@@ -317,13 +318,13 @@ class ChatViewModel(application: Application) : AndroidViewModel(application), L
317318
val processedImageList = getProcessedImagesForModel(_selectedImages)
318319
if (processedImageList.isNotEmpty()) {
319320
_messages.add(
320-
Message("Llava - Starting image Prefill.", false, MessageType.SYSTEM, 0)
321+
Message("Llava/Gemma - Starting image Prefill.", false, MessageType.SYSTEM, 0)
321322
)
322323
executor.execute {
323324
android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_MORE_FAVORABLE)
324325
ETLogging.getInstance().log("Starting runnable prefill image")
325326
val img = processedImageList[0]
326-
ETLogging.getInstance().log("Llava start prefill image")
327+
ETLogging.getInstance().log("Start prefill image")
327328
if (currentSettingsFields.modelType == ModelType.LLAVA_1_5) {
328329
module?.prefillImages(
329330
img.getInts(),
@@ -332,12 +333,14 @@ class ChatViewModel(application: Application) : AndroidViewModel(application), L
332333
ModelUtils.VISION_MODEL_IMAGE_CHANNELS
333334
)
334335
} else if (currentSettingsFields.modelType == ModelType.GEMMA_3) {
336+
module?.prefillPrompt(PromptFormat.getGemmaPreImagePrompt())
335337
module?.prefillImages(
336338
img.getFloats(),
337339
img.width,
338340
img.height,
339341
ModelUtils.VISION_MODEL_IMAGE_CHANNELS
340342
)
343+
module?.prefillPrompt(PromptFormat.getGemmaPostImagePrompt())
341344
}
342345
}
343346
}
@@ -375,6 +378,9 @@ class ChatViewModel(application: Application) : AndroidViewModel(application), L
375378
if (currentSettingsFields.modelType == ModelType.LLAVA_1_5 && shouldAddSystemPrompt) {
376379
finalPrompt = PromptFormat.getLlavaFirstTurnUserPrompt()
377380
.replace(PromptFormat.USER_PLACEHOLDER, rawPrompt)
381+
} else if (currentSettingsFields.modelType == ModelType.GEMMA_3 && _selectedImages.isNotEmpty()) {
382+
finalPrompt = PromptFormat.getGemmaMultimodalUserPrompt()
383+
.replace(PromptFormat.USER_PLACEHOLDER, rawPrompt)
378384
} else {
379385
finalPrompt = (if (shouldAddSystemPrompt) currentSettingsFields.getFormattedSystemPrompt() else "") +
380386
currentSettingsFields.getFormattedUserPrompt(rawPrompt, thinkMode)

llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/viewmodel/ModelSettingsViewModel.kt

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class ModelSettingsViewModel : ViewModel() {
8686
private fun applyBackendDefaults(settings: ModuleSettings): ModuleSettings {
8787
return if (settings.backendType == BackendType.MEDIATEK) {
8888
settings.copy(
89-
modelFilePath = settings.modelFilePath.ifEmpty { "/in/mtk/llama/runner" },
89+
modelFilePath = if (settings.modelFilePath.isEmpty()) listOf("/in/mtk/llama/runner") else settings.modelFilePath,
9090
tokenizerFilePath = settings.tokenizerFilePath.ifEmpty { "/in/mtk/llama/runner" }
9191
)
9292
} else {
@@ -95,9 +95,18 @@ class ModelSettingsViewModel : ViewModel() {
9595
}
9696

9797
// Model selection
98-
fun selectModel(modelPath: String) {
99-
var newSettings = moduleSettings.copy(modelFilePath = modelPath)
100-
newSettings = autoSelectModelType(newSettings, modelPath)
98+
fun toggleModel(modelPath: String) {
99+
val currentModels = moduleSettings.modelFilePath.toMutableList()
100+
if (currentModels.contains(modelPath)) {
101+
currentModels.remove(modelPath)
102+
} else {
103+
currentModels.add(modelPath)
104+
}
105+
var newSettings = moduleSettings.copy(modelFilePath = currentModels)
106+
// Auto-select type based on the last added model or the first one
107+
if (currentModels.isNotEmpty()) {
108+
newSettings = autoSelectModelType(newSettings, modelPath)
109+
}
101110
moduleSettings = newSettings
102111
}
103112

@@ -193,6 +202,10 @@ class ModelSettingsViewModel : ViewModel() {
193202
return if (path.isEmpty()) "" else path.substringAfterLast('/')
194203
}
195204

205+
fun getFilenameFromPath(paths: List<String>): String {
206+
return if (paths.isEmpty()) "" else paths.joinToString(", ") { it.substringAfterLast('/') }
207+
}
208+
196209
// Appearance mode selection (app-wide setting)
197210
fun selectAppearanceMode(mode: AppearanceMode) {
198211
appSettings = appSettings.copy(appearanceMode = mode)

0 commit comments

Comments
 (0)