diff --git a/llm/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PresetSanityTest.kt b/llm/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PresetSanityTest.kt new file mode 100644 index 0000000000..28c3449be6 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PresetSanityTest.kt @@ -0,0 +1,245 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo + +import android.content.Context +import android.util.Log +import androidx.compose.ui.semantics.SemanticsProperties +import androidx.compose.ui.test.assertIsEnabled +import androidx.compose.ui.test.hasContentDescription +import androidx.compose.ui.test.junit4.createAndroidComposeRule +import androidx.compose.ui.test.onAllNodesWithText +import androidx.compose.ui.test.onNodeWithContentDescription +import androidx.compose.ui.test.onNodeWithTag +import androidx.compose.ui.test.onNodeWithText +import androidx.compose.ui.test.performClick +import androidx.compose.ui.test.performTextInput +import androidx.test.core.app.ApplicationProvider +import androidx.test.ext.junit.runners.AndroidJUnit4 +import androidx.test.filters.LargeTest +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Ignore +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith + +/** + * Preset model sanity test that validates the preset model download and chat workflow. + * + * This test validates: + * 1. Navigate from Welcome screen to Preset model screen + * 2. Select Stories 110M and download it + * 3. After download completes, tap to load and enter chat view + * 4. Type "Once upon a time" and generate a response + */ +@RunWith(AndroidJUnit4::class) +@LargeTest +class PresetSanityTest { + + companion object { + private const val TAG = "PresetSanityTest" + private const val RESPONSE_TAG = "LLAMA_RESPONSE" + } + + @get:Rule + val composeTestRule = createAndroidComposeRule() + + @Before + fun setUp() { + // Clear SharedPreferences before test to ensure a clean state + val context = ApplicationProvider.getApplicationContext() + val prefs = context.getSharedPreferences( + context.getString(R.string.demo_pref_file_key), + Context.MODE_PRIVATE + ) + prefs.edit().clear().commit() + } + + /** + * Types text into the chat input field using testTag. + */ + private fun typeInChatInput(text: String) { + composeTestRule.onNodeWithTag("chat_input_field").performClick() + composeTestRule.waitForIdle() + composeTestRule.onNodeWithTag("chat_input_field").performTextInput(text) + composeTestRule.waitForIdle() + } + + /** + * Waits for generation to complete by checking for tokens-per-second metrics. + */ + private fun waitForGenerationComplete(timeoutMs: Long = 120000): Boolean { + return try { + composeTestRule.waitUntil(timeoutMillis = timeoutMs) { + val tpsNodes = composeTestRule.onAllNodesWithText("t/s", substring = true) + .fetchSemanticsNodes() + val tokpsNodes = composeTestRule.onAllNodesWithText("tok/s", substring = true) + .fetchSemanticsNodes() + tpsNodes.isNotEmpty() || tokpsNodes.isNotEmpty() + } + Log.i(TAG, "Generation complete - found generation metrics") + true + } catch (e: Exception) { + Log.e(TAG, "Generation timed out after ${timeoutMs}ms") + false + } + } + + /** + * Waits for the model to be loaded by checking for success or error messages. + */ + private fun waitForModelLoaded(timeoutMs: Long = 60000): Boolean { + return try { + var wasSuccess = false + composeTestRule.waitUntil(timeoutMillis = timeoutMs) { + val successNodes = composeTestRule.onAllNodesWithText("Successfully loaded", substring = true) + .fetchSemanticsNodes() + val errorNodes = composeTestRule.onAllNodesWithText("Model load failure", substring = true) + .fetchSemanticsNodes() + wasSuccess = successNodes.isNotEmpty() + successNodes.isNotEmpty() || errorNodes.isNotEmpty() + } + if (wasSuccess) { + Log.i(TAG, "Model loaded successfully") + } else { + Log.e(TAG, "Model load failed") + } + wasSuccess + } catch (e: Exception) { + Log.e(TAG, "Model loading timed out after ${timeoutMs}ms: ${e.message}") + false + } + } + + /** + * Verifies that the model generated a non-empty response. + */ + private fun assertModelResponseNotEmpty(timeoutMs: Long = 10000) { + try { + composeTestRule.waitUntil(timeoutMillis = timeoutMs) { + val tpsNodes = composeTestRule.onAllNodesWithText("t/s", substring = true) + .fetchSemanticsNodes() + val tokpsNodes = composeTestRule.onAllNodesWithText("tok/s", substring = true) + .fetchSemanticsNodes() + tpsNodes.isNotEmpty() || tokpsNodes.isNotEmpty() + } + Log.i(TAG, "Model response verified - found generation metrics") + } catch (e: Exception) { + throw AssertionError("Model response appears to be empty - no generation metrics found after ${timeoutMs}ms") + } + } + + /** + * Logs the model response text for CI output. + */ + private fun logModelResponse() { + try { + Log.i(RESPONSE_TAG, "BEGIN_RESPONSE") + val responseNodes = composeTestRule.onAllNodesWithText("t/s", substring = true) + .fetchSemanticsNodes() + for (node in responseNodes) { + val text = node.config.getOrElse(SemanticsProperties.Text) { emptyList() } + .joinToString(" ") { it.text } + if (text.isNotBlank()) { + Log.i(RESPONSE_TAG, text) + } + } + Log.i(RESPONSE_TAG, "END_RESPONSE") + } catch (e: Exception) { + Log.d(TAG, "Could not log model response: ${e.message}") + } + } + + /** + * Tests the complete preset model download and chat workflow: + * 1. From Welcome screen, tap "Preset model" card + * 2. Find Stories 110M and tap Download + * 3. Wait for download to complete + * 4. Tap the card to load model and enter chat + * 5. Type "Once upon a time" and send + * 6. Verify response is generated + */ + @Ignore("Temporarily disabled") + @Test + fun testPresetModelDownloadAndChat() { + composeTestRule.waitForIdle() + + // Step 1: From Welcome screen, tap "Preset model" card + Log.i(TAG, "Step 1: Navigating to Preset model screen") + composeTestRule.onNodeWithText("Preset model").performClick() + composeTestRule.waitUntil(timeoutMillis = 5000) { + composeTestRule.onAllNodesWithText("Download Preset Model").fetchSemanticsNodes().isNotEmpty() + } + + // Step 2: Find Stories 110M and tap Download + Log.i(TAG, "Step 2: Finding Stories 110M and starting download") + composeTestRule.onNodeWithText("Stories 110M").assertExists() + + // Check if already downloaded (Ready to use) or needs download + val readyNodes = composeTestRule.onAllNodesWithText("Ready to use", substring = true) + .fetchSemanticsNodes() + + if (readyNodes.isEmpty()) { + // Need to download - click Download button + composeTestRule.onNodeWithText("Download").performClick() + + // Step 3: Wait for download to complete (may take a while for large files) + Log.i(TAG, "Step 3: Waiting for download to complete") + composeTestRule.waitUntil(timeoutMillis = 300000) { // 5 minutes timeout for download + composeTestRule.onAllNodesWithText("Ready to use", substring = true) + .fetchSemanticsNodes().isNotEmpty() + } + Log.i(TAG, "Download completed") + } else { + Log.i(TAG, "Model already downloaded, skipping download step") + } + + // Step 4: Tap the card to load model and enter chat + Log.i(TAG, "Step 4: Tapping card to load model") + composeTestRule.onNodeWithText("Stories 110M").performClick() + + // Wait for Activity transition - MainActivity needs time to launch and set content + // The SelectPresetModelActivity calls finish() after starting MainActivity + Thread.sleep(2000) + + // Wait for model to load and chat screen to appear + Log.i(TAG, "Waiting for model to load") + val modelLoaded = waitForModelLoaded(90000) + assertTrue("Model should be loaded successfully", modelLoaded) + Log.i(TAG, "Model loaded successfully") + + // Step 5: Type "Once upon a time" and send + Log.i(TAG, "Step 5: Typing prompt and sending") + typeInChatInput("Once upon a time") + + // Wait for send button to be enabled + composeTestRule.waitUntil(timeoutMillis = 5000) { + try { + composeTestRule.onNodeWithContentDescription("Send").assertIsEnabled() + true + } catch (e: AssertionError) { + false + } + } + + composeTestRule.onNodeWithContentDescription("Send").performClick() + composeTestRule.waitForIdle() + + // Step 6: Wait for generation to complete and verify response + Log.i(TAG, "Step 6: Waiting for generation to complete") + val generationComplete = waitForGenerationComplete(120000) + assertTrue("Generation should complete", generationComplete) + + assertModelResponseNotEmpty() + logModelResponse() + + Log.i(TAG, "Preset model sanity test completed successfully") + } +} diff --git a/llm/android/LlamaDemo/app/src/main/AndroidManifest.xml b/llm/android/LlamaDemo/app/src/main/AndroidManifest.xml index f06c4dd616..ecb2b52292 100644 --- a/llm/android/LlamaDemo/app/src/main/AndroidManifest.xml +++ b/llm/android/LlamaDemo/app/src/main/AndroidManifest.xml @@ -5,6 +5,7 @@ + @@ -27,6 +28,9 @@ + = linkedMapOf( + ) + + fun getAvailableModels(): Map = AVAILABLE_MODELS + + fun getDisplayNames(): Array = + AVAILABLE_MODELS.values.map { it.displayName }.toTypedArray() + + fun getModelKeys(): Array = AVAILABLE_MODELS.keys.toTypedArray() + + fun getByDisplayName(displayName: String): ModelInfo? = + AVAILABLE_MODELS.values.find { it.displayName == displayName } + + fun getByKey(key: String): ModelInfo? = AVAILABLE_MODELS[key] +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SelectPresetModelActivity.kt b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SelectPresetModelActivity.kt new file mode 100644 index 0000000000..4d63428fde --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SelectPresetModelActivity.kt @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo + +import android.content.Intent +import android.os.Build +import android.os.Bundle +import androidx.activity.ComponentActivity +import androidx.activity.compose.setContent +import androidx.compose.foundation.isSystemInDarkTheme +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.material3.Surface +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.setValue +import androidx.compose.ui.Modifier +import androidx.core.content.ContextCompat +import androidx.lifecycle.viewmodel.compose.viewModel +import com.example.executorchllamademo.ui.screens.SelectPresetModelScreen +import com.example.executorchllamademo.ui.theme.LlamaDemoTheme +import com.example.executorchllamademo.ui.viewmodel.SelectPresetModelViewModel + +class SelectPresetModelActivity : ComponentActivity() { + + private var appearanceMode by mutableStateOf(AppearanceMode.SYSTEM) + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + + if (Build.VERSION.SDK_INT >= 21) { + window.statusBarColor = ContextCompat.getColor(this, R.color.status_bar) + window.navigationBarColor = ContextCompat.getColor(this, R.color.nav_bar) + } + + loadAppearanceMode() + + setContent { + val isDarkTheme = when (appearanceMode) { + AppearanceMode.LIGHT -> false + AppearanceMode.DARK -> true + AppearanceMode.SYSTEM -> isSystemInDarkTheme() + } + + LlamaDemoTheme(darkTheme = isDarkTheme) { + Surface(modifier = Modifier.fillMaxSize()) { + val viewModel: SelectPresetModelViewModel = viewModel() + + LaunchedEffect(Unit) { + viewModel.initialize(this@SelectPresetModelActivity) + } + + SelectPresetModelScreen( + availableModels = viewModel.availableModels, + modelStates = viewModel.modelStates, + onBackPressed = { finish() }, + onDownloadClick = { key -> + viewModel.downloadModel(key) + }, + onDeleteClick = { key -> + viewModel.deleteModel(key) + }, + onModelClick = { key -> + if (viewModel.loadModelAndStartChat(key)) { + // Navigate to MainActivity (conversation) after loading model + startActivity(Intent(this@SelectPresetModelActivity, MainActivity::class.java)) + finish() + } + } + ) + } + } + } + } + + private fun loadAppearanceMode() { + val prefs = DemoSharedPreferences(this) + appearanceMode = prefs.getAppSettings().appearanceMode + } + + override fun onResume() { + super.onResume() + loadAppearanceMode() + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/WelcomeActivity.kt b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/WelcomeActivity.kt index 90af656e9b..b830e12cda 100644 --- a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/WelcomeActivity.kt +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/WelcomeActivity.kt @@ -51,6 +51,9 @@ class WelcomeActivity : ComponentActivity() { onLoadModelClick = { startActivity(Intent(this@WelcomeActivity, ModelSettingsActivity::class.java)) }, + onDownloadModelClick = { + startActivity(Intent(this@WelcomeActivity, SelectPresetModelActivity::class.java)) + }, onAppSettingsClick = { startActivity(Intent(this@WelcomeActivity, AppSettingsActivity::class.java)) }, diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/screens/SelectPresetModelScreen.kt b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/screens/SelectPresetModelScreen.kt new file mode 100644 index 0000000000..e42795a217 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/screens/SelectPresetModelScreen.kt @@ -0,0 +1,275 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo.ui.screens + +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.layout.width +import androidx.compose.foundation.rememberScrollState +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.verticalScroll +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.ArrowBack +import androidx.compose.material.icons.filled.Check +import androidx.compose.material.icons.filled.Delete +import androidx.compose.material.icons.filled.Download +import androidx.compose.material3.Button +import androidx.compose.material3.ButtonDefaults +import androidx.compose.material3.Card +import androidx.compose.material3.CardDefaults +import androidx.compose.material3.CircularProgressIndicator +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.LinearProgressIndicator +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp +import com.example.executorchllamademo.ModelInfo +import com.example.executorchllamademo.ui.theme.LocalAppColors +import com.example.executorchllamademo.ui.viewmodel.ModelDownloadState + +@Composable +fun SelectPresetModelScreen( + availableModels: Map, + modelStates: Map, + onBackPressed: () -> Unit, + onDownloadClick: (String) -> Unit, + onDeleteClick: (String) -> Unit, + onModelClick: (String) -> Unit +) { + val appColors = LocalAppColors.current + val scrollState = rememberScrollState() + + Column( + modifier = Modifier + .fillMaxSize() + .background(appColors.settingsBackground) + ) { + // Top banner with back button + Row( + modifier = Modifier + .fillMaxWidth() + .background(appColors.navBar) + .padding(horizontal = 8.dp, vertical = 12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + IconButton(onClick = onBackPressed) { + Icon( + imageVector = Icons.Filled.ArrowBack, + contentDescription = "Back", + tint = appColors.textOnNavBar + ) + } + Text( + text = "Download Preset Model", + color = appColors.textOnNavBar, + fontSize = 18.sp, + fontWeight = FontWeight.Bold, + modifier = Modifier.weight(1f) + ) + } + + // Scrollable content + Column( + modifier = Modifier + .fillMaxSize() + .verticalScroll(scrollState) + .padding(16.dp), + verticalArrangement = Arrangement.spacedBy(12.dp) + ) { + if (availableModels.isEmpty()) { + Text( + text = "No preset models available. Stay tuned!", + fontSize = 14.sp, + color = appColors.settingsSecondaryText + ) + } else { + Text( + text = "Select a model to download and use", + fontSize = 14.sp, + color = appColors.settingsSecondaryText + ) + + Spacer(modifier = Modifier.height(8.dp)) + + availableModels.forEach { (key, modelInfo) -> + val state = modelStates[key] ?: ModelDownloadState() + val isReady = state.isModelDownloaded && state.isTokenizerDownloaded + + PresetModelCard( + modelInfo = modelInfo, + state = state, + isReady = isReady, + onDownloadClick = { onDownloadClick(key) }, + onDeleteClick = { onDeleteClick(key) }, + onCardClick = { if (isReady) onModelClick(key) } + ) + } + } + + Spacer(modifier = Modifier.height(16.dp)) + } + } +} + +@Composable +private fun PresetModelCard( + modelInfo: ModelInfo, + state: ModelDownloadState, + isReady: Boolean, + onDownloadClick: () -> Unit, + onDeleteClick: () -> Unit, + onCardClick: () -> Unit +) { + val appColors = LocalAppColors.current + + Card( + modifier = Modifier + .fillMaxWidth() + .then( + if (isReady) { + Modifier.clickable(onClick = onCardClick) + } else { + Modifier + } + ), + shape = RoundedCornerShape(12.dp), + colors = CardDefaults.cardColors( + containerColor = appColors.settingsRowBackground + ), + elevation = CardDefaults.cardElevation( + defaultElevation = 2.dp + ) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(16.dp) + ) { + Row( + modifier = Modifier.fillMaxWidth(), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween + ) { + Column(modifier = Modifier.weight(1f)) { + Text( + text = modelInfo.displayName, + fontSize = 16.sp, + fontWeight = FontWeight.Bold, + color = appColors.settingsText + ) + Spacer(modifier = Modifier.height(4.dp)) + Text( + text = getStatusText(state), + fontSize = 12.sp, + color = if (isReady) Color(0xFF4CAF50) else appColors.settingsSecondaryText + ) + } + + Spacer(modifier = Modifier.width(8.dp)) + + if (state.isDownloading) { + CircularProgressIndicator( + modifier = Modifier.size(36.dp), + strokeWidth = 3.dp + ) + } else if (isReady) { + Row( + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Filled.Check, + contentDescription = "Ready", + tint = Color(0xFF4CAF50), + modifier = Modifier.size(36.dp) + ) + Spacer(modifier = Modifier.width(8.dp)) + IconButton(onClick = onDeleteClick) { + Icon( + imageVector = Icons.Filled.Delete, + contentDescription = "Delete", + tint = Color.Red, + modifier = Modifier.size(24.dp) + ) + } + } + } else { + Button( + onClick = onDownloadClick, + colors = ButtonDefaults.buttonColors( + containerColor = appColors.navBar + ) + ) { + Icon( + imageVector = Icons.Filled.Download, + contentDescription = "Download", + modifier = Modifier.size(18.dp) + ) + Spacer(modifier = Modifier.width(4.dp)) + Text("Download") + } + } + } + + // Show progress bar when downloading + if (state.isDownloading && state.downloadProgress > 0) { + Spacer(modifier = Modifier.height(8.dp)) + @Suppress("DEPRECATION") + LinearProgressIndicator( + progress = state.downloadProgress, + modifier = Modifier.fillMaxWidth(), + ) + } + + // Show error if any + state.downloadError?.let { error -> + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = error, + fontSize = 12.sp, + color = Color.Red + ) + } + + // Show hint for ready models + if (isReady) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = "Tap to load model and start chat", + fontSize = 12.sp, + color = appColors.settingsSecondaryText + ) + } + } + } +} + +private fun getStatusText(state: ModelDownloadState): String { + return when { + state.isDownloading -> "Downloading..." + state.isModelDownloaded && state.isTokenizerDownloaded -> "Ready to use" + state.isModelDownloaded -> "Model downloaded, tokenizer missing" + state.isTokenizerDownloaded -> "Tokenizer downloaded, model missing" + else -> "Not downloaded" + } +} diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/screens/WelcomeScreen.kt b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/screens/WelcomeScreen.kt index d85ebf8c28..9a51282be0 100644 --- a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/screens/WelcomeScreen.kt +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/screens/WelcomeScreen.kt @@ -23,6 +23,7 @@ import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.verticalScroll import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.CloudDownload import androidx.compose.material.icons.filled.PlayArrow import androidx.compose.material.icons.filled.Settings import androidx.compose.material3.Card @@ -41,6 +42,7 @@ import com.example.executorchllamademo.ui.theme.LocalAppColors @Composable fun WelcomeScreen( onLoadModelClick: () -> Unit = {}, + onDownloadModelClick: () -> Unit = {}, onAppSettingsClick: () -> Unit = {}, onStartChatClick: () -> Unit = {} ) { @@ -100,6 +102,14 @@ fun WelcomeScreen( onClick = onLoadModelClick ) + // Download Preset Model Card + WelcomeCard( + title = "Preset model", + description = "Download a pre-configured model from the cloud.", + icon = Icons.Filled.CloudDownload, + onClick = onDownloadModelClick + ) + // App Settings Card WelcomeCardNoDescription( title = "App Settings", diff --git a/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/viewmodel/SelectPresetModelViewModel.kt b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/viewmodel/SelectPresetModelViewModel.kt new file mode 100644 index 0000000000..e5065f0b96 --- /dev/null +++ b/llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ui/viewmodel/SelectPresetModelViewModel.kt @@ -0,0 +1,219 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.example.executorchllamademo.ui.viewmodel + +import android.content.Context +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateMapOf +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.setValue +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.example.executorchllamademo.DemoSharedPreferences +import com.example.executorchllamademo.ModelDownloadConfig +import com.example.executorchllamademo.ModelInfo +import com.example.executorchllamademo.ModuleSettings +import com.example.executorchllamademo.PromptFormat +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import java.io.File +import java.io.FileOutputStream +import java.net.HttpURLConnection +import java.net.URL + +data class ModelDownloadState( + val isModelDownloaded: Boolean = false, + val isTokenizerDownloaded: Boolean = false, + val isDownloading: Boolean = false, + val downloadProgress: Float = 0f, + val downloadError: String? = null +) + +class SelectPresetModelViewModel : ViewModel() { + + private var context: Context? = null + private var demoSharedPreferences: DemoSharedPreferences? = null + + val availableModels: Map = ModelDownloadConfig.getAvailableModels() + + // Track download state for each model + val modelStates = mutableStateMapOf() + + var selectedModelKey by mutableStateOf(null) + private set + + fun initialize(context: Context) { + this.context = context + demoSharedPreferences = DemoSharedPreferences(context) + checkDownloadedFiles() + } + + private fun getModelsDirectory(): File { + val dir = File(context?.filesDir, "models") + if (!dir.exists()) { + dir.mkdirs() + } + return dir + } + + fun checkDownloadedFiles() { + val modelsDir = getModelsDirectory() + + availableModels.forEach { (key, modelInfo) -> + val modelFile = File(modelsDir, modelInfo.modelFilename) + val tokenizerFile = File(modelsDir, modelInfo.tokenizerFilename) + + val currentState = modelStates[key] ?: ModelDownloadState() + modelStates[key] = currentState.copy( + isModelDownloaded = modelFile.exists(), + isTokenizerDownloaded = tokenizerFile.exists() + ) + } + } + + fun isModelReady(key: String): Boolean { + val state = modelStates[key] ?: return false + return state.isModelDownloaded && state.isTokenizerDownloaded + } + + fun needsDownload(key: String): Boolean { + val state = modelStates[key] ?: return true + return !state.isModelDownloaded || !state.isTokenizerDownloaded + } + + fun isDownloading(key: String): Boolean { + return modelStates[key]?.isDownloading == true + } + + fun downloadModel(key: String) { + val modelInfo = availableModels[key] ?: return + val modelsDir = getModelsDirectory() + + val currentState = modelStates[key] ?: ModelDownloadState() + modelStates[key] = currentState.copy(isDownloading = true, downloadError = null, downloadProgress = 0f) + + viewModelScope.launch(Dispatchers.IO) { + try { + // Download model file if needed + if (!currentState.isModelDownloaded) { + val modelFile = File(modelsDir, modelInfo.modelFilename) + downloadFile(modelInfo.modelUrl, modelFile) { progress -> + val state = modelStates[key] ?: ModelDownloadState() + modelStates[key] = state.copy(downloadProgress = progress * 0.5f) + } + withContext(Dispatchers.Main) { + val state = modelStates[key] ?: ModelDownloadState() + modelStates[key] = state.copy(isModelDownloaded = true) + } + } + + // Download tokenizer file if needed + if (!currentState.isTokenizerDownloaded && modelInfo.hasTokenizer()) { + val tokenizerFile = File(modelsDir, modelInfo.tokenizerFilename) + downloadFile(modelInfo.tokenizerUrl, tokenizerFile) { progress -> + val state = modelStates[key] ?: ModelDownloadState() + modelStates[key] = state.copy(downloadProgress = 0.5f + progress * 0.5f) + } + withContext(Dispatchers.Main) { + val state = modelStates[key] ?: ModelDownloadState() + modelStates[key] = state.copy(isTokenizerDownloaded = true) + } + } + + withContext(Dispatchers.Main) { + val state = modelStates[key] ?: ModelDownloadState() + modelStates[key] = state.copy(isDownloading = false, downloadProgress = 1f) + } + } catch (e: Exception) { + withContext(Dispatchers.Main) { + val state = modelStates[key] ?: ModelDownloadState() + modelStates[key] = state.copy( + isDownloading = false, + downloadError = e.message ?: "Download failed" + ) + } + } + } + } + + private fun downloadFile(urlString: String, outputFile: File, onProgress: (Float) -> Unit) { + val url = URL(urlString) + val connection = url.openConnection() as HttpURLConnection + connection.connectTimeout = 30000 + connection.readTimeout = 30000 + connection.connect() + + val contentLength = connection.contentLength + var downloadedBytes = 0L + + connection.inputStream.use { input -> + FileOutputStream(outputFile).use { output -> + val buffer = ByteArray(8192) + var bytesRead: Int + while (input.read(buffer).also { bytesRead = it } != -1) { + output.write(buffer, 0, bytesRead) + downloadedBytes += bytesRead + if (contentLength > 0) { + onProgress(downloadedBytes.toFloat() / contentLength) + } + } + } + } + } + + fun loadModelAndStartChat(key: String): Boolean { + val modelInfo = availableModels[key] ?: return false + val modelsDir = getModelsDirectory() + + val modelFile = File(modelsDir, modelInfo.modelFilename) + val tokenizerFile = File(modelsDir, modelInfo.tokenizerFilename) + + if (!modelFile.exists() || !tokenizerFile.exists()) { + return false + } + + // Save the module settings + val moduleSettings = ModuleSettings( + modelFilePath = modelFile.absolutePath, + tokenizerFilePath = tokenizerFile.absolutePath, + modelType = modelInfo.modelType, + userPrompt = PromptFormat.getUserPromptTemplate(modelInfo.modelType), + isLoadModel = true + ) + + demoSharedPreferences?.saveModuleSettings(moduleSettings) + return true + } + + fun deleteModel(key: String) { + val modelInfo = availableModels[key] ?: return + val modelsDir = getModelsDirectory() + + val modelFile = File(modelsDir, modelInfo.modelFilename) + val tokenizerFile = File(modelsDir, modelInfo.tokenizerFilename) + + // Delete both files + if (modelFile.exists()) { + modelFile.delete() + } + if (tokenizerFile.exists()) { + tokenizerFile.delete() + } + + // Update state + modelStates[key] = ModelDownloadState( + isModelDownloaded = false, + isTokenizerDownloaded = false, + isDownloading = false, + downloadProgress = 0f, + downloadError = null + ) + } +}