Skip to content

Commit e1f60f7

Browse files
Copilotkirklandsign
andcommitted
Add auto-select model type based on file name partial matches
Co-authored-by: kirklandsign <107070759+kirklandsign@users.noreply.github.com>
1 parent 161c4bb commit e1f60f7

1 file changed

Lines changed: 43 additions & 0 deletions

File tree

llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,13 +387,56 @@ private void setupModelSelectorDialog() {
387387
(dialog, item) -> {
388388
mModelFilePath = pteFiles[item];
389389
mModelTextView.setText(getFilenameFromPath(mModelFilePath));
390+
autoSelectModelType(mModelFilePath);
390391
updateLoadModelButtonState();
391392
dialog.dismiss();
392393
});
393394

394395
modelPathBuilder.create().show();
395396
}
396397

398+
private void autoSelectModelType(String filePath) {
399+
ModelType detectedType = detectModelTypeFromFilePath(filePath);
400+
if (detectedType != null) {
401+
mModelType = detectedType;
402+
mModelTypeTextView.setText(mModelType.toString());
403+
mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType));
404+
}
405+
}
406+
407+
static ModelType detectModelTypeFromFilePath(String filePath) {
408+
if (filePath == null || filePath.isEmpty()) {
409+
return null;
410+
}
411+
// Extract just the filename from the path
412+
String fileName = filePath;
413+
int lastSeparatorIndex = filePath.lastIndexOf('/');
414+
if (lastSeparatorIndex >= 0 && lastSeparatorIndex < filePath.length() - 1) {
415+
fileName = filePath.substring(lastSeparatorIndex + 1);
416+
}
417+
String lowerFileName = fileName.toLowerCase();
418+
// Check for more specific patterns first
419+
if (lowerFileName.contains("llama_guard") || lowerFileName.contains("llama-guard") || lowerFileName.contains("llamaguard")) {
420+
return ModelType.LLAMA_GUARD_3;
421+
}
422+
if (lowerFileName.contains("llava")) {
423+
return ModelType.LLAVA_1_5;
424+
}
425+
if (lowerFileName.contains("gemma")) {
426+
return ModelType.GEMMA_3;
427+
}
428+
if (lowerFileName.contains("llama")) {
429+
return ModelType.LLAMA_3;
430+
}
431+
if (lowerFileName.contains("qwen")) {
432+
return ModelType.QWEN_3;
433+
}
434+
if (lowerFileName.contains("voxtral")) {
435+
return ModelType.VOXTRAL;
436+
}
437+
return null;
438+
}
439+
397440
private void setupDataPathSelectorDialog() {
398441
String[] dataPathFiles =
399442
listLocalFile("/data/local/tmp/llama/", new String[] {".ptd"});

0 commit comments

Comments
 (0)