Skip to content

Commit 32147f9

Browse files
ChingEnLinclaude
andcommitted
feat: add dynamic model selection for LLM queries
Adds a GET /query/models endpoint that fetches available Gemini models from the Google GenAI API, and threads a model parameter through all AI service functions and request schemas so the user can switch models at runtime. Two model selector dropdowns are added to the frontend UI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8cc3d18 commit 32147f9

10 files changed

Lines changed: 162 additions & 39 deletions

File tree

backend/models/analyze.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
class AnalyzeRequest(BaseModel):
66
query_result: List[Dict[str, Any]]
7+
model: str = "gemini-2.5-flash"
78

89

910
class AnalyzeResponse(BaseModel):

backend/models/schemas.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class QueryPrompt(BaseModel):
3939
max_iterations: int = Field(
4040
default=3, ge=1, le=10
4141
) # Server-enforced agent iteration cap
42+
model: str = "gemini-2.5-flash"
4243

4344

4445
class GeneratedCode(BaseModel):
@@ -54,6 +55,7 @@ class ExecuteInput(BaseModel):
5455
class DebugQueryRequest(BaseModel):
5556
query: str
5657
error_message: str
58+
model: str = "gemini-2.5-flash"
5759

5860

5961
class DebugSuggestionResponse(BaseModel):
@@ -64,6 +66,7 @@ class SchemaRelationshipsRequest(BaseModel):
6466
account_id: str
6567
database_name: str
6668
collection_names: list[str]
69+
model: str = "gemini-2.5-flash"
6770

6871

6972
class Relationship(BaseModel):
@@ -85,6 +88,7 @@ class EvaluateWriteRequest(BaseModel):
8588
write_result: dict
8689
account_id: str
8790
database_name: str
91+
model: str = "gemini-2.5-flash"
8892

8993

9094
class EvaluateWriteResponse(BaseModel):

backend/routes/query.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from fastapi import APIRouter, Header, Body, HTTPException
2+
from typing import List
23
from services.azure_auth import exchange_token_obo
34
from services.azure_cosmos_resources import (
45
get_connection_string,
@@ -107,6 +108,7 @@ def nl2query(prompt: QueryPrompt = Body(...), authorization: str = Header(...)):
107108
intermediate_context=prompt.intermediate_context,
108109
connection_string=connection_string,
109110
max_iterations=prompt.max_iterations,
111+
model=prompt.model,
110112
)
111113

112114

@@ -129,7 +131,7 @@ def infer_relationships(
129131
collection_filter=request.collection_names,
130132
)
131133

132-
return generate_schema_relationships(schema_summary)
134+
return generate_schema_relationships(schema_summary, model=request.model)
133135

134136
except Exception as e:
135137
print(f"Error inferring relationships: {e}")
@@ -202,15 +204,17 @@ def debug(body: DebugQueryRequest = Body(...)):
202204
"""
203205
Sends a failed query and error message to Gemini for debugging suggestion.
204206
"""
205-
return generate_suggestion_from_query_error(body.query, body.error_message)
207+
return generate_suggestion_from_query_error(
208+
body.query, body.error_message, model=body.model
209+
)
206210

207211

208212
@router.post("/analyze", response_model=AnalyzeResponse)
209213
def analyze(body: AnalyzeRequest = Body(...)):
210214
"""
211215
Sends a query result to the AI for analysis and visualization suggestions.
212216
"""
213-
return analyze_query_result(body.query_result)
217+
return analyze_query_result(body.query_result, model=body.model)
214218

215219

216220
@router.post("/evaluate-write", response_model=EvaluateWriteResponse)
@@ -240,4 +244,35 @@ def evaluate_write(
240244
write_result=body.write_result,
241245
connection_string=connection_string,
242246
database_name=body.database_name,
247+
model=body.model,
243248
)
249+
250+
251+
@router.get("/models", response_model=List[str])
252+
def list_models():
253+
"""
254+
Returns available Gemini model IDs filtered to generative models.
255+
"""
256+
try:
257+
from google import genai
258+
259+
client = genai.Client()
260+
models = client.models.list()
261+
model_ids = [
262+
m.name.replace("models/", "")
263+
for m in models
264+
if m.name
265+
and "gemini" in m.name.lower()
266+
and hasattr(m, "supported_actions")
267+
and "generateContent" in (m.supported_actions or [])
268+
]
269+
if not model_ids:
270+
# Fallback: return all gemini models even if supported_actions is missing
271+
model_ids = [
272+
m.name.replace("models/", "")
273+
for m in client.models.list()
274+
if m.name and "gemini" in m.name.lower()
275+
]
276+
return sorted(set(model_ids))
277+
except Exception as e:
278+
raise HTTPException(status_code=500, detail=f"Failed to list models: {str(e)}")

backend/services/analyze_service.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from models.analyze import AnalyzeResponse
55

66

7-
def analyze_query_result(query_result: list[dict]) -> AnalyzeResponse:
7+
def analyze_query_result(
8+
query_result: list[dict], model: str = "gemini-2.5-flash"
9+
) -> AnalyzeResponse:
810
prompt = f"""
911
You are a data analyst assistant. Given the following MongoDB query result, provide:
1012
1. A concise textual insight or summary of the data.
@@ -18,7 +20,7 @@ def analyze_query_result(query_result: list[dict]) -> AnalyzeResponse:
1820
"""
1921
client = genai.Client()
2022
response = client.models.generate_content(
21-
model="gemini-2.5-flash",
23+
model=model,
2224
contents=prompt,
2325
config=types.GenerateContentConfig(
2426
thinking_config=types.ThinkingConfig(thinking_budget=0)

backend/services/evaluate_write_service.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def evaluate_write_result(
1111
write_result: dict,
1212
connection_string: str = "",
1313
database_name: str = "",
14+
model: str = "gemini-2.5-flash",
1415
) -> EvaluateWriteResponse:
1516
prompt = f"""
1617
You are an expert MongoDB database administrator and assistant.
@@ -63,7 +64,7 @@ def query_database(query: str) -> str:
6364
tools = [query_database] if connection_string else None
6465

6566
response = client.models.generate_content(
66-
model="gemini-2.5-flash",
67+
model=model,
6768
contents=prompt,
6869
config=types.GenerateContentConfig(
6970
tools=tools, thinking_config=types.ThinkingConfig(thinking_budget=0)
@@ -84,7 +85,7 @@ def query_database(query: str) -> str:
8485

8586
# Send the tool result back to the model
8687
response = client.models.generate_content(
87-
model="gemini-2.5-flash",
88+
model=model,
8889
contents=[
8990
prompt,
9091
response.candidates[0].content,

backend/services/gemini_service.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def generate_query_from_prompt(
119119
collection_context: CollectionContext = None,
120120
intermediate_context: dict = None,
121121
all_collections_schema: str = "",
122+
model: str = "gemini-2.5-flash",
122123
) -> GeneratedCode:
123124
# Prune intermediate_context to remove image/large data
124125
safe_intermediate_context = (
@@ -136,7 +137,7 @@ def generate_query_from_prompt(
136137
)
137138
client = genai.Client()
138139
response = client.models.generate_content(
139-
model="gemini-2.5-flash",
140+
model=model,
140141
contents=full_prompt,
141142
config=types.GenerateContentConfig(
142143
thinking_config=types.ThinkingConfig(thinking_budget=0) # Disables thinking
@@ -146,14 +147,16 @@ def generate_query_from_prompt(
146147
return GeneratedCode(generated_code=code)
147148

148149

149-
def generate_suggestion_from_query_error(query: str, error_message: str) -> str:
150+
def generate_suggestion_from_query_error(
151+
query: str, error_message: str, model: str = "gemini-2.5-flash"
152+
) -> str:
150153
"""
151154
Sends a failed query and error message to Gemini for debugging suggestion.
152155
"""
153156
full_prompt = PROMPT_TEMPLATE_DEBUG.format(query=query, error_message=error_message)
154157
client = genai.Client()
155158
response = client.models.generate_content(
156-
model="gemini-2.5-flash",
159+
model=model,
157160
contents=full_prompt,
158161
config=types.GenerateContentConfig(
159162
thinking_config=types.ThinkingConfig(thinking_budget=0) # Disables thinking
@@ -202,11 +205,11 @@ def generate_suggestion_from_query_error(query: str, error_message: str) -> str:
202205
"""
203206

204207

205-
def generate_audit_sql(user_input: str) -> str:
208+
def generate_audit_sql(user_input: str, model: str = "gemini-2.5-flash") -> str:
206209
full_prompt = PROMPT_TEMPLATE_AUDIT_SQL.format(user_input=user_input)
207210
client = genai.Client()
208211
response = client.models.generate_content(
209-
model="gemini-2.5-flash",
212+
model=model,
210213
contents=full_prompt,
211214
config=types.GenerateContentConfig(
212215
thinking_config=types.ThinkingConfig(thinking_budget=0)
@@ -220,7 +223,7 @@ def generate_audit_sql(user_input: str) -> str:
220223

221224

222225
def summarize_audit_results(
223-
user_input: str, sql_query: str, results: list
226+
user_input: str, sql_query: str, results: list, model: str = "gemini-2.5-flash"
224227
) -> AuditSummaryResponse:
225228
# Truncate results if too large to avoid token limits
226229
results_str = str(results)[:10000]
@@ -229,7 +232,7 @@ def summarize_audit_results(
229232
)
230233
client = genai.Client()
231234
response = client.models.generate_content(
232-
model="gemini-2.5-flash",
235+
model=model,
233236
contents=full_prompt,
234237
config=types.GenerateContentConfig(
235238
response_mime_type="application/json",
@@ -281,13 +284,15 @@ def summarize_audit_results(
281284
"""
282285

283286

284-
def generate_schema_relationships(schema_summary: str) -> SchemaRelationshipsResponse:
287+
def generate_schema_relationships(
288+
schema_summary: str, model: str = "gemini-2.5-flash"
289+
) -> SchemaRelationshipsResponse:
285290
from models.schemas import SchemaRelationshipsResponse
286291

287292
full_prompt = PROMPT_TEMPLATE_RELATIONSHIPS.format(schema_summary=schema_summary)
288293
client = genai.Client()
289294
response = client.models.generate_content(
290-
model="gemini-2.5-flash",
295+
model=model,
291296
contents=full_prompt,
292297
config=types.GenerateContentConfig(
293298
response_mime_type="application/json",

backend/services/react_agent_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class AgentState(TypedDict):
5656
schema_context: str
5757
intermediate_context: dict
5858
connection_string: str
59+
model: str
5960

6061
generated_query: str
6162
is_write_action: bool
@@ -127,7 +128,7 @@ def generate_query_node(state: AgentState):
127128

128129
try:
129130
response = client.models.generate_content(
130-
model="gemini-2.5-flash",
131+
model=state.get("model", "gemini-2.5-flash"),
131132
contents=prompt,
132133
config=types.GenerateContentConfig(
133134
thinking_config=types.ThinkingConfig(thinking_budget=0)
@@ -210,7 +211,7 @@ def evaluate_node(state: AgentState):
210211

211212
try:
212213
response = client.models.generate_content(
213-
model="gemini-2.5-flash",
214+
model=state.get("model", "gemini-2.5-flash"),
214215
contents=prompt,
215216
config=types.GenerateContentConfig(
216217
response_mime_type="application/json",
@@ -271,6 +272,7 @@ def run_query_generator(
271272
intermediate_context: dict,
272273
connection_string: str,
273274
max_iterations: int = 3,
275+
model: str = "gemini-2.5-flash",
274276
):
275277
logger.info(f"Starting ReAct Agent workflow for input: '{user_input}'")
276278
initial_state = {
@@ -280,6 +282,7 @@ def run_query_generator(
280282
"schema_context": schema_context,
281283
"intermediate_context": intermediate_context,
282284
"connection_string": connection_string,
285+
"model": model,
283286
"iterations": 0,
284287
"max_iterations": max_iterations,
285288
}

backend/tests/test_query_routes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ def test_debug_query(client):
163163
data = response.json()
164164
assert data["suggestion"] == "Check collection name"
165165

166-
mock_debug.assert_called_once_with("db.users.find({})", "Collection not found")
166+
mock_debug.assert_called_once_with(
167+
"db.users.find({})", "Collection not found", model="gemini-2.5-flash"
168+
)
167169

168170

169171
def test_analyze_query(client):

0 commit comments

Comments
 (0)