Skip to content

Commit 8a84947

Browse files
authored
refactor(firebase-ai): use explicit backing fields and the new onDeviceExtension (#2804)
* refactor(firebase-ai): use explicit backing fields Stable in Kotlin 2.4.0 * refactor(firebase-ai): use new onDeviceExtension in hybrid sample
1 parent 40e343c commit 8a84947

7 files changed

Lines changed: 66 additions & 65 deletions

File tree

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,40 @@
11
package com.google.firebase.quickstart.ai.feature.hybrid
22

33
import android.graphics.Bitmap
4-
import android.util.Log
54
import androidx.lifecycle.ViewModel
65
import androidx.lifecycle.viewModelScope
76
import com.google.firebase.Firebase
7+
import com.google.firebase.ai.DownloadStatus
88
import com.google.firebase.ai.InferenceMode
99
import com.google.firebase.ai.InferenceSource
1010
import com.google.firebase.ai.OnDeviceConfig
11+
import com.google.firebase.ai.OnDeviceModelStatus
1112
import com.google.firebase.ai.ai
12-
import com.google.firebase.ai.ondevice.DownloadStatus
13-
import com.google.firebase.ai.ondevice.FirebaseAIOnDevice
14-
import com.google.firebase.ai.ondevice.OnDeviceModelStatus
1513
import com.google.firebase.ai.type.GenerativeBackend
1614
import com.google.firebase.ai.type.PublicPreviewAPI
1715
import com.google.firebase.ai.type.content
1816
import com.google.firebase.quickstart.ai.ui.HybridInferenceUiState
1917
import kotlinx.coroutines.flow.MutableStateFlow
2018
import kotlinx.coroutines.flow.StateFlow
21-
import kotlinx.coroutines.flow.asStateFlow
2219
import kotlinx.coroutines.flow.update
2320
import kotlinx.coroutines.launch
2421
import kotlinx.serialization.Serializable
2522
import kotlinx.serialization.json.Json
26-
import java.util.UUID
2723

2824
@Serializable
2925
object HybridInferenceRoute
3026

3127
@OptIn(PublicPreviewAPI::class)
3228
class HybridInferenceViewModel : ViewModel() {
33-
private val _uiState = MutableStateFlow(
34-
HybridInferenceUiState(
35-
expenses = listOf(
36-
Expense("Lunch", 15.50, "Example data"),
37-
Expense("Coffee", 4.75, "Example data")
29+
val uiState: StateFlow<HybridInferenceUiState>
30+
field = MutableStateFlow(
31+
HybridInferenceUiState(
32+
expenses = listOf(
33+
Expense("Lunch", 15.50, "Example data"),
34+
Expense("Coffee", 4.75, "Example data")
35+
)
3836
)
3937
)
40-
)
41-
val uiState: StateFlow<HybridInferenceUiState> = _uiState.asStateFlow()
4238

4339
private val model = Firebase.ai(backend = GenerativeBackend.googleAI()).generativeModel(
4440
modelName = "gemini-3.1-flash-lite",
@@ -52,27 +48,27 @@ class HybridInferenceViewModel : ViewModel() {
5248
private fun checkAndDownloadModel() {
5349
viewModelScope.launch {
5450
try {
55-
val status = FirebaseAIOnDevice.checkStatus()
51+
val status = model.onDeviceExtension?.checkStatus()
5652
updateStatus(status)
5753

5854
if (status == OnDeviceModelStatus.DOWNLOADABLE) {
59-
FirebaseAIOnDevice.download().collect { downloadStatus ->
55+
model.onDeviceExtension?.download()?.collect { downloadStatus ->
6056
when (downloadStatus) {
6157
is DownloadStatus.DownloadStarted -> {
62-
_uiState.update { it.copy(modelStatus = "Downloading model...") }
58+
uiState.update { it.copy(modelStatus = "Downloading model...") }
6359
}
6460

6561
is DownloadStatus.DownloadInProgress -> {
6662
val progress = downloadStatus.totalBytesDownloaded
67-
_uiState.update { it.copy(modelStatus = "Downloading: $progress bytes downloaded") }
63+
uiState.update { it.copy(modelStatus = "Downloading: $progress bytes downloaded") }
6864
}
6965

7066
is DownloadStatus.DownloadCompleted -> {
71-
_uiState.update { it.copy(modelStatus = "Model ready") }
67+
uiState.update { it.copy(modelStatus = "Model ready") }
7268
}
7369

7470
is DownloadStatus.DownloadFailed -> {
75-
_uiState.update {
71+
uiState.update {
7672
it.copy(
7773
modelStatus = "Download failed", errorMessage = "Model download failed"
7874
)
@@ -82,25 +78,24 @@ class HybridInferenceViewModel : ViewModel() {
8278
}
8379
}
8480
} catch (e: Exception) {
85-
_uiState.update { it.copy(modelStatus = "Error checking status", errorMessage = e.message) }
81+
uiState.update { it.copy(modelStatus = "Error checking status", errorMessage = e.message) }
8682
}
8783
}
8884
}
8985

90-
private fun updateStatus(status: OnDeviceModelStatus) {
86+
private fun updateStatus(status: OnDeviceModelStatus?) {
9187
val statusText = when (status) {
9288
OnDeviceModelStatus.AVAILABLE -> "Model available"
9389
OnDeviceModelStatus.DOWNLOADABLE -> "Model downloadable"
9490
OnDeviceModelStatus.DOWNLOADING -> "Model downloading..."
95-
OnDeviceModelStatus.UNAVAILABLE -> "On-device model unavailable"
96-
else -> "Unknown"
91+
else -> "On-device model unavailable"
9792
}
98-
_uiState.update { it.copy(modelStatus = statusText) }
93+
uiState.update { it.copy(modelStatus = statusText) }
9994
}
10095

10196
fun scanReceipt(bitmap: Bitmap) {
10297
viewModelScope.launch {
103-
_uiState.update { it.copy(isScanning = true, errorMessage = null) }
98+
uiState.update { it.copy(isScanning = true, errorMessage = null) }
10499
try {
105100
val prompt = content {
106101
image(bitmap)
@@ -124,16 +119,15 @@ class HybridInferenceViewModel : ViewModel() {
124119
} else {
125120
"Cloud"
126121
}
127-
Log.d("HybridVM", "$inferenceMode response: $text")
128122
if (text != null) {
129123
parseAndAddExpense(text, inferenceMode)
130124
} else {
131-
_uiState.update { it.copy(errorMessage = "Could not extract data") }
125+
uiState.update { it.copy(errorMessage = "Could not extract data") }
132126
}
133127
} catch (e: Exception) {
134-
_uiState.update { it.copy(errorMessage = "Error: ${e.message}") }
128+
uiState.update { it.copy(errorMessage = "Error: ${e.message}") }
135129
} finally {
136-
_uiState.update { it.copy(isScanning = false) }
130+
uiState.update { it.copy(isScanning = false) }
137131
}
138132
}
139133
}
@@ -145,9 +139,9 @@ class HybridInferenceViewModel : ViewModel() {
145139
.replace("```", "")
146140
try {
147141
val newExpense = Json.decodeFromString<Expense>(json).copy(inferenceMode = inferenceMode)
148-
_uiState.update { it.copy(expenses = it.expenses + newExpense) }
142+
uiState.update { it.copy(expenses = it.expenses + newExpense) }
149143
} catch (e: Exception) {
150-
_uiState.update { it.copy(errorMessage = e.localizedMessage) }
144+
uiState.update { it.copy(errorMessage = e.localizedMessage) }
151145
}
152146
}
153147
}

firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/text/AudioSummarizationViewModel.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class AudioSummarizationViewModel : ChatViewModel() {
4040
)
4141
}
4242
))
43-
_messages.value = chat.history.map { UiChatMessage(it) }
44-
_uiState.value = ChatUiState.Success
43+
updateMessages(chat.history.map { UiChatMessage(it) })
44+
updateUiState(ChatUiState.Success)
4545
}
4646

4747
override suspend fun performSendMessage(prompt: Content, currentMessages: List<UiChatMessage>) {

firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/text/ChatViewModel.kt

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,27 @@ import com.google.firebase.quickstart.ai.ui.ChatUiState
1111
import com.google.firebase.quickstart.ai.ui.UiChatMessage
1212
import kotlinx.coroutines.flow.MutableStateFlow
1313
import kotlinx.coroutines.flow.StateFlow
14-
import kotlinx.coroutines.flow.asStateFlow
1514
import kotlinx.coroutines.launch
1615

1716
@OptIn(PublicPreviewAPI::class)
1817
abstract class ChatViewModel : ViewModel() {
1918

20-
protected val _uiState = MutableStateFlow<ChatUiState>(ChatUiState.Success)
21-
val uiState: StateFlow<ChatUiState> = _uiState.asStateFlow()
19+
val uiState: StateFlow<ChatUiState>
20+
field = MutableStateFlow<ChatUiState>(ChatUiState.Success)
2221

23-
protected val _messages = MutableStateFlow<List<UiChatMessage>>(emptyList())
24-
val messages: StateFlow<List<UiChatMessage>> = _messages.asStateFlow()
22+
val messages: StateFlow<List<UiChatMessage>>
23+
field = MutableStateFlow<List<UiChatMessage>>(emptyList())
2524

26-
protected val _attachments = MutableStateFlow<List<Attachment>>(emptyList())
27-
val attachments: StateFlow<List<Attachment>> = _attachments.asStateFlow()
25+
val attachments: StateFlow<List<Attachment>>
26+
field = MutableStateFlow<List<Attachment>>(emptyList())
27+
28+
protected fun updateUiState(state: ChatUiState) {
29+
uiState.value = state
30+
}
31+
32+
protected fun updateMessages(list: List<UiChatMessage>) {
33+
messages.value = list
34+
}
2835

2936
abstract val initialPrompt: String
3037

@@ -40,14 +47,14 @@ abstract class ChatViewModel : ViewModel() {
4047
.text(userMessage)
4148
.build()
4249

43-
_messages.value = _messages.value + UiChatMessage(prompt)
50+
messages.value = messages.value + UiChatMessage(prompt)
4451

4552
viewModelScope.launch {
46-
_uiState.value = ChatUiState.Loading
53+
uiState.value = ChatUiState.Loading
4754
try {
48-
performSendMessage(prompt, _messages.value)
55+
performSendMessage(prompt, messages.value)
4956
} catch (e: Exception) {
50-
_uiState.value = ChatUiState.Error(e.localizedMessage ?: "Unknown error")
57+
uiState.value = ChatUiState.Error(e.localizedMessage ?: "Unknown error")
5158
} finally {
5259
contentBuilder = Content.Builder() // reset the builder
5360
}
@@ -76,13 +83,13 @@ abstract class ChatViewModel : ViewModel() {
7683
&& candidate.groundingMetadata?.groundingChunks?.isNotEmpty() == true
7784
&& candidate.groundingMetadata?.searchEntryPoint == null
7885
) {
79-
_uiState.value = ChatUiState.Error(
86+
uiState.value = ChatUiState.Error(
8087
"Could not display the response because it was missing required attribution components."
8188
)
8289
} else {
83-
_messages.value = currentMessages + UiChatMessage(candidate.content, candidate.groundingMetadata)
84-
_attachments.value = emptyList()
85-
_uiState.value = ChatUiState.Success
90+
messages.value = currentMessages + UiChatMessage(candidate.content, candidate.groundingMetadata)
91+
attachments.value = emptyList()
92+
uiState.value = ChatUiState.Success
8693
}
8794
}
8895

@@ -98,7 +105,7 @@ abstract class ChatViewModel : ViewModel() {
98105
contentBuilder.inlineData(fileInBytes, mimeType ?: "text/plain")
99106
}
100107

101-
_attachments.value = _attachments.value + Attachment(fileName ?: "Unnamed file")
108+
attachments.value = attachments.value + Attachment(fileName ?: "Unnamed file")
102109
}
103110

104111
protected fun decodeBitmapFromImage(input: ByteArray) =

firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/text/ServerPromptTemplateViewModel.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class ServerPromptTemplateViewModel : ViewModel() {
2222
val initialPrompt = "Jane Doe"
2323
val allowEmptyPrompt = false
2424

25-
private val _uiState = MutableStateFlow<ServerPromptUiState>(ServerPromptUiState.Success())
26-
val uiState: StateFlow<ServerPromptUiState> = _uiState.asStateFlow()
25+
val uiState: StateFlow<ServerPromptUiState>
26+
field = MutableStateFlow<ServerPromptUiState>(ServerPromptUiState.Success())
2727

2828
private var templateGenerativeModel: TemplateGenerativeModel
2929

@@ -35,13 +35,13 @@ class ServerPromptTemplateViewModel : ViewModel() {
3535

3636
fun generate(inputText: String) {
3737
viewModelScope.launch {
38-
_uiState.value = ServerPromptUiState.Loading
38+
uiState.value = ServerPromptUiState.Loading
3939
try {
4040
val response = templateGenerativeModel
4141
.generateContent("input-system-instructions", mapOf("customerName" to inputText))
42-
_uiState.value = ServerPromptUiState.Success(response.text)
42+
uiState.value = ServerPromptUiState.Success(response.text)
4343
} catch (e: Exception) {
44-
_uiState.value = ServerPromptUiState.Error(
44+
uiState.value = ServerPromptUiState.Error(
4545
if (e.localizedMessage?.contains("not found") == true) {
4646
"""
4747
Template was not found, please verify that your project contains a template

firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/text/SvgViewModel.kt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ import kotlinx.coroutines.launch
2323
object SvgRoute
2424

2525
class SvgViewModel : ViewModel() {
26-
private val _uiState = MutableStateFlow<SvgUiState>(SvgUiState.Success())
27-
val uiState: StateFlow<SvgUiState> = _uiState.asStateFlow()
26+
val uiState: StateFlow<SvgUiState>
27+
field = MutableStateFlow<SvgUiState>(SvgUiState.Success())
2828

2929
private val generativeModel: GenerativeModel
3030

@@ -53,8 +53,8 @@ class SvgViewModel : ViewModel() {
5353
}
5454

5555
fun generateSVG(prompt: String) {
56-
val currentSvgs = (_uiState.value as? SvgUiState.Success)?.svgs ?: emptyList()
57-
_uiState.value = SvgUiState.Loading
56+
val currentSvgs = (uiState.value as? SvgUiState.Success)?.svgs ?: emptyList()
57+
uiState.value = SvgUiState.Loading
5858
viewModelScope.launch(Dispatchers.IO) {
5959
try {
6060
val response = generativeModel.generateContent(prompt)
@@ -64,12 +64,12 @@ class SvgViewModel : ViewModel() {
6464
?.removeSuffix("```")
6565
?.trimIndent()
6666
if (newSvg != null) {
67-
_uiState.value = SvgUiState.Success(listOf(newSvg) + currentSvgs)
67+
uiState.value = SvgUiState.Success(listOf(newSvg) + currentSvgs)
6868
} else {
69-
_uiState.value = SvgUiState.Success(currentSvgs)
69+
uiState.value = SvgUiState.Success(currentSvgs)
7070
}
7171
} catch (e: Exception) {
72-
_uiState.value = SvgUiState.Error(e.localizedMessage ?: "Unknown error")
72+
uiState.value = SvgUiState.Error(e.localizedMessage ?: "Unknown error")
7373
}
7474
}
7575
}

firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/text/TravelTipsViewModel.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ class TravelTipsViewModel : ChatViewModel() {
5959
)
6060
)
6161

62-
_messages.value = chat.history.map { UiChatMessage(it) }
63-
_uiState.value = ChatUiState.Success
62+
updateMessages(chat.history.map { UiChatMessage(it) })
63+
updateUiState(ChatUiState.Success)
6464
}
6565

6666
override suspend fun performSendMessage(prompt: Content, currentMessages: List<UiChatMessage>) {

firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/text/VideoSummarizationViewModel.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class VideoSummarizationViewModel : ChatViewModel() {
3333
}
3434
)
3535

36-
_messages.value = chatHistory.map { UiChatMessage(it) }
37-
_uiState.value = ChatUiState.Success
36+
updateMessages(chatHistory.map { UiChatMessage(it) })
37+
updateUiState(ChatUiState.Success)
3838

3939
val generativeModel = Firebase.ai(
4040
backend = GenerativeBackend.googleAI()

0 commit comments

Comments
 (0)