Skip to content

Commit f5d6b01

Browse files
committed
feat(ai-logic): add hybrid on-device inference sample
1 parent 7da89ea commit f5d6b01

File tree

9 files changed

+370
-3
lines changed

9 files changed

+370
-3
lines changed

firebase-ai/app/build.gradle.kts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ android {
1212

1313
defaultConfig {
1414
applicationId = "com.google.firebase.quickstart.ai"
15-
minSdk = 23
15+
minSdk = 26
1616
targetSdk = 36
1717
versionCode = 1
1818
versionName = "1.0"
@@ -73,6 +73,7 @@ dependencies {
7373
// Firebase
7474
implementation(platform(libs.firebase.bom))
7575
implementation(libs.firebase.ai)
76+
implementation(libs.firebase.ai.ondevice)
7677

7778
// Image loading
7879
implementation(libs.coil.compose)

firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/MainActivity.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import androidx.navigation.compose.NavHost
2727
import androidx.navigation.compose.composable
2828
import androidx.navigation.compose.rememberNavController
2929
import com.google.firebase.quickstart.ai.feature.live.BidiViewModel
30+
import com.google.firebase.quickstart.ai.feature.hybrid.HybridInferenceViewModel
3031
import com.google.firebase.quickstart.ai.feature.media.imagen.ImagenViewModel
3132
import com.google.firebase.quickstart.ai.feature.text.ChatViewModel
3233
import com.google.firebase.quickstart.ai.feature.text.ServerPromptTemplateViewModel
@@ -36,6 +37,7 @@ import com.google.firebase.quickstart.ai.ui.ImagenScreen
3637
import com.google.firebase.quickstart.ai.ui.ServerPromptScreen
3738
import com.google.firebase.quickstart.ai.ui.StreamRealtimeScreen
3839
import com.google.firebase.quickstart.ai.ui.StreamRealtimeVideoScreen
40+
import com.google.firebase.quickstart.ai.ui.HybridInferenceScreen
3941
import com.google.firebase.quickstart.ai.ui.SvgScreen
4042
import com.google.firebase.quickstart.ai.ui.navigation.FIREBASE_AI_SAMPLES
4143
import com.google.firebase.quickstart.ai.ui.navigation.MainMenuScreen
@@ -123,6 +125,10 @@ class MainActivity : ComponentActivity() {
123125
StreamRealtimeVideoScreen(it)
124126
}
125127
}
128+
129+
ScreenType.HYBRID -> {
130+
(vm as? HybridInferenceViewModel)?.let { HybridInferenceScreen(it) }
131+
}
126132
}
127133
}
128134
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package com.google.firebase.quickstart.ai.feature.hybrid
2+
3+
import kotlinx.serialization.Serializable
4+
5+
@Serializable
6+
data class Expense(
7+
val id: String,
8+
val name: String,
9+
val price: Double
10+
)
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package com.google.firebase.quickstart.ai.feature.hybrid
2+
3+
import android.graphics.Bitmap
4+
import android.util.Log
5+
import androidx.lifecycle.ViewModel
6+
import androidx.lifecycle.viewModelScope
7+
import com.google.firebase.Firebase
8+
import com.google.firebase.ai.InferenceMode
9+
import com.google.firebase.ai.OnDeviceConfig
10+
import com.google.firebase.ai.ai
11+
import com.google.firebase.ai.ondevice.DownloadStatus
12+
import com.google.firebase.ai.ondevice.FirebaseAIOnDevice
13+
import com.google.firebase.ai.ondevice.OnDeviceModelStatus
14+
import com.google.firebase.ai.type.GenerativeBackend
15+
import com.google.firebase.ai.type.PublicPreviewAPI
16+
import com.google.firebase.ai.type.content
17+
import com.google.firebase.quickstart.ai.ui.HybridInferenceUiState
18+
import kotlinx.coroutines.flow.MutableStateFlow
19+
import kotlinx.coroutines.flow.StateFlow
20+
import kotlinx.coroutines.flow.asStateFlow
21+
import kotlinx.coroutines.flow.update
22+
import kotlinx.coroutines.launch
23+
import kotlinx.serialization.Serializable
24+
import java.util.UUID
25+
26+
@Serializable
27+
object HybridInferenceRoute
28+
29+
@OptIn(PublicPreviewAPI::class)
30+
class HybridInferenceViewModel : ViewModel() {
31+
private val _uiState = MutableStateFlow(HybridInferenceUiState(
32+
expenses = listOf(
33+
Expense(UUID.randomUUID().toString(), "Lunch", 15.50),
34+
Expense(UUID.randomUUID().toString(), "Coffee", 4.75)
35+
)
36+
))
37+
val uiState: StateFlow<HybridInferenceUiState> = _uiState.asStateFlow()
38+
39+
private val model = Firebase.ai(backend = GenerativeBackend.googleAI())
40+
.generativeModel(
41+
modelName = "gemini-3.1-flash-lite-preview",
42+
onDeviceConfig = OnDeviceConfig(mode = InferenceMode.PREFER_ON_DEVICE)
43+
)
44+
45+
init {
46+
checkAndDownloadModel()
47+
}
48+
49+
private fun checkAndDownloadModel() {
50+
viewModelScope.launch {
51+
try {
52+
val status = FirebaseAIOnDevice.checkStatus()
53+
updateStatus(status)
54+
55+
if (status == OnDeviceModelStatus.DOWNLOADABLE) {
56+
FirebaseAIOnDevice.download().collect { downloadStatus ->
57+
when (downloadStatus) {
58+
is DownloadStatus.DownloadStarted -> {
59+
_uiState.update { it.copy(modelStatus = "Downloading model...") }
60+
}
61+
is DownloadStatus.DownloadInProgress -> {
62+
val progress = downloadStatus.totalBytesDownloaded
63+
_uiState.update { it.copy(modelStatus = "Downloading: $progress bytes downloaded") }
64+
}
65+
is DownloadStatus.DownloadCompleted -> {
66+
_uiState.update { it.copy(modelStatus = "Model ready") }
67+
}
68+
is DownloadStatus.DownloadFailed -> {
69+
_uiState.update { it.copy(modelStatus = "Download failed", errorMessage = "Model download failed") }
70+
}
71+
}
72+
}
73+
}
74+
} catch (e: Exception) {
75+
Log.e("HybridVM", "Error checking model status", e)
76+
_uiState.update { it.copy(modelStatus = "Error checking status", errorMessage = e.message) }
77+
}
78+
}
79+
}
80+
81+
private fun updateStatus(status: OnDeviceModelStatus) {
82+
val statusText = when (status) {
83+
OnDeviceModelStatus.AVAILABLE -> "Model available"
84+
OnDeviceModelStatus.DOWNLOADABLE -> "Model downloadable"
85+
OnDeviceModelStatus.DOWNLOADING -> "Model downloading..."
86+
OnDeviceModelStatus.UNAVAILABLE -> "On-device model unavailable"
87+
else -> "Unknown"
88+
}
89+
_uiState.update { it.copy(modelStatus = statusText) }
90+
}
91+
92+
fun scanReceipt(bitmap: Bitmap) {
93+
viewModelScope.launch {
94+
_uiState.update { it.copy(isScanning = true, errorMessage = null) }
95+
try {
96+
val prompt = content {
97+
image(bitmap)
98+
text("Extract the store name and the total price from this receipt. Output only in CSV format like 'Store,Price'. Example: 'Starbucks,5.50'")
99+
}
100+
101+
val response = model.generateContent(prompt)
102+
val text = response.text
103+
Log.d("HybridVM", "Response is: $text")
104+
if (text != null) {
105+
parseAndAddExpense(text)
106+
} else {
107+
_uiState.update { it.copy(errorMessage = "Could not extract data") }
108+
}
109+
} catch (e: Exception) {
110+
Log.e("HybridVM", "Error scanning receipt", e)
111+
_uiState.update { it.copy(errorMessage = "Error: ${e.message}") }
112+
} finally {
113+
_uiState.update { it.copy(isScanning = false) }
114+
}
115+
}
116+
}
117+
118+
private fun parseAndAddExpense(text: String) {
119+
// Simple parsing: "Store, Price"
120+
val parts = text
121+
// Sometimes the output contains single quotes
122+
.replace("'", "")
123+
.split(",", limit = 2)
124+
if (parts.size >= 2) {
125+
val name = parts[0].trim()
126+
val priceStr = parts[1].trim()
127+
.replace("$", "")
128+
.replace(",", "")
129+
val price = priceStr.toDoubleOrNull() ?: 0.0
130+
131+
val newExpense = Expense(UUID.randomUUID().toString(), name, price)
132+
_uiState.update { it.copy(expenses = it.expenses + newExpense) }
133+
} else {
134+
_uiState.update { it.copy(errorMessage = "Unexpected AI output format: $text") }
135+
}
136+
}
137+
}
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
package com.google.firebase.quickstart.ai.ui
2+
3+
import android.graphics.BitmapFactory
4+
import androidx.activity.compose.rememberLauncherForActivityResult
5+
import androidx.activity.result.PickVisualMediaRequest
6+
import androidx.activity.result.contract.ActivityResultContracts
7+
import androidx.compose.foundation.layout.Arrangement
8+
import androidx.compose.foundation.layout.Box
9+
import androidx.compose.foundation.layout.Column
10+
import androidx.compose.foundation.layout.Row
11+
import androidx.compose.foundation.layout.Spacer
12+
import androidx.compose.foundation.layout.fillMaxSize
13+
import androidx.compose.foundation.layout.fillMaxWidth
14+
import androidx.compose.foundation.layout.height
15+
import androidx.compose.foundation.layout.padding
16+
import androidx.compose.foundation.layout.size
17+
import androidx.compose.foundation.lazy.LazyColumn
18+
import androidx.compose.foundation.lazy.items
19+
import androidx.compose.material.icons.Icons
20+
import androidx.compose.material.icons.filled.Add
21+
import androidx.compose.material.icons.filled.CameraAlt
22+
import androidx.compose.material.icons.filled.ReceiptLong
23+
import androidx.compose.material3.Card
24+
import androidx.compose.material3.CardDefaults
25+
import androidx.compose.material3.CircularProgressIndicator
26+
import androidx.compose.material3.FloatingActionButton
27+
import androidx.compose.material3.HorizontalDivider
28+
import androidx.compose.material3.Icon
29+
import androidx.compose.material3.MaterialTheme
30+
import androidx.compose.material3.Scaffold
31+
import androidx.compose.material3.Text
32+
import androidx.compose.runtime.Composable
33+
import androidx.compose.runtime.collectAsState
34+
import androidx.compose.runtime.getValue
35+
import androidx.compose.ui.Alignment
36+
import androidx.compose.ui.Modifier
37+
import androidx.compose.ui.graphics.Color
38+
import androidx.compose.ui.platform.LocalContext
39+
import androidx.compose.ui.text.font.FontWeight
40+
import androidx.compose.ui.unit.dp
41+
import androidx.compose.ui.unit.sp
42+
import androidx.lifecycle.viewmodel.compose.viewModel
43+
import com.google.firebase.quickstart.ai.feature.hybrid.HybridInferenceViewModel
44+
45+
@Composable
46+
fun HybridInferenceScreen(
47+
viewModel: HybridInferenceViewModel = viewModel()
48+
) {
49+
val uiState by viewModel.uiState.collectAsState()
50+
val context = LocalContext.current
51+
52+
val launcher = rememberLauncherForActivityResult(
53+
contract = ActivityResultContracts.PickVisualMedia(),
54+
onResult = { uri ->
55+
uri?.let {
56+
try {
57+
context.contentResolver.openInputStream(it)?.use { stream ->
58+
val bitmap = BitmapFactory.decodeStream(stream)
59+
bitmap?.let { viewModel.scanReceipt(it) }
60+
}
61+
} catch (e: Exception) {
62+
// Handle error
63+
}
64+
}
65+
}
66+
)
67+
68+
Scaffold(
69+
floatingActionButton = {
70+
FloatingActionButton(
71+
onClick = {
72+
launcher.launch(PickVisualMediaRequest(ActivityResultContracts.PickVisualMedia.ImageOnly))
73+
},
74+
containerColor = MaterialTheme.colorScheme.primary,
75+
contentColor = MaterialTheme.colorScheme.onPrimary
76+
) {
77+
if (uiState.isScanning) {
78+
CircularProgressIndicator(
79+
modifier = Modifier.size(24.dp),
80+
color = MaterialTheme.colorScheme.onPrimary,
81+
strokeWidth = 2.dp
82+
)
83+
} else {
84+
Icon(Icons.Default.CameraAlt, contentDescription = "Scan Receipt")
85+
}
86+
}
87+
}
88+
) { padding ->
89+
Column(
90+
modifier = Modifier
91+
.fillMaxSize()
92+
.padding(padding)
93+
.padding(16.dp)
94+
) {
95+
// Model Status Card
96+
Card(
97+
modifier = Modifier.fillMaxWidth(),
98+
colors = CardDefaults.cardColors(
99+
containerColor = MaterialTheme.colorScheme.secondaryContainer
100+
)
101+
) {
102+
Row(
103+
modifier = Modifier.padding(12.dp),
104+
verticalAlignment = Alignment.CenterVertically
105+
) {
106+
Icon(
107+
Icons.Default.ReceiptLong,
108+
contentDescription = null,
109+
tint = MaterialTheme.colorScheme.onSecondaryContainer
110+
)
111+
Spacer(modifier = Modifier.size(12.dp))
112+
Column {
113+
Text(
114+
"Hybrid AI Status",
115+
style = MaterialTheme.typography.labelMedium,
116+
color = MaterialTheme.colorScheme.onSecondaryContainer
117+
)
118+
Text(
119+
uiState.modelStatus,
120+
style = MaterialTheme.typography.bodyMedium,
121+
fontWeight = FontWeight.Bold,
122+
color = MaterialTheme.colorScheme.onSecondaryContainer
123+
)
124+
}
125+
}
126+
}
127+
128+
Spacer(modifier = Modifier.height(16.dp))
129+
130+
Text(
131+
"Expenses",
132+
style = MaterialTheme.typography.headlineSmall,
133+
fontWeight = FontWeight.Bold
134+
)
135+
136+
Spacer(modifier = Modifier.height(8.dp))
137+
138+
if (uiState.expenses.isEmpty()) {
139+
Box(modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center) {
140+
Text("No expenses yet. Scan a receipt to add one.", color = Color.Gray)
141+
}
142+
} else {
143+
LazyColumn(
144+
modifier = Modifier.weight(1f),
145+
verticalArrangement = Arrangement.spacedBy(8.dp)
146+
) {
147+
items(uiState.expenses) { expense ->
148+
ExpenseItem(expense.name, expense.price)
149+
}
150+
}
151+
}
152+
153+
if (uiState.errorMessage != null) {
154+
Spacer(modifier = Modifier.height(8.dp))
155+
Text(
156+
text = uiState.errorMessage!!,
157+
color = MaterialTheme.colorScheme.error,
158+
style = MaterialTheme.typography.bodySmall
159+
)
160+
}
161+
}
162+
}
163+
}
164+
165+
@Composable
166+
fun ExpenseItem(name: String, price: Double) {
167+
Card(
168+
modifier = Modifier.fillMaxWidth(),
169+
elevation = CardDefaults.cardElevation(defaultElevation = 2.dp)
170+
) {
171+
Row(
172+
modifier = Modifier
173+
.padding(16.dp)
174+
.fillMaxWidth(),
175+
horizontalArrangement = Arrangement.SpaceBetween,
176+
verticalAlignment = Alignment.CenterVertically
177+
) {
178+
Text(
179+
name,
180+
style = MaterialTheme.typography.bodyLarge,
181+
fontWeight = FontWeight.Medium
182+
)
183+
Text(
184+
"$${String.format("%.2f", price)}",
185+
style = MaterialTheme.typography.bodyLarge,
186+
fontWeight = FontWeight.Bold,
187+
color = MaterialTheme.colorScheme.primary
188+
)
189+
}
190+
}
191+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package com.google.firebase.quickstart.ai.ui
2+
3+
import com.google.firebase.quickstart.ai.feature.hybrid.Expense
4+
5+
data class HybridInferenceUiState(
6+
val expenses: List<Expense> = emptyList(),
7+
val isScanning: Boolean = false,
8+
val modelStatus: String = "Checking model status...",
9+
val errorMessage: String? = null
10+
)

0 commit comments

Comments
 (0)