Skip to content

Commit 4758c73

Browse files
authored
Merge pull request #35 from ChingEnLin/worktree-multiple_model
feat: add dynamic model selection for LLM queries
2 parents 8cc3d18 + f83e159 commit 4758c73

12 files changed

Lines changed: 234 additions & 44 deletions

File tree

backend/models/analyze.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
class AnalyzeRequest(BaseModel):
66
query_result: List[Dict[str, Any]]
7+
model: str = "gemini-2.5-flash"
78

89

910
class AnalyzeResponse(BaseModel):

backend/models/schemas.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class QueryPrompt(BaseModel):
3939
max_iterations: int = Field(
4040
default=3, ge=1, le=10
4141
) # Server-enforced agent iteration cap
42+
model: str = "gemini-2.5-flash"
4243

4344

4445
class GeneratedCode(BaseModel):
@@ -54,6 +55,7 @@ class ExecuteInput(BaseModel):
5455
class DebugQueryRequest(BaseModel):
5556
query: str
5657
error_message: str
58+
model: str = "gemini-2.5-flash"
5759

5860

5961
class DebugSuggestionResponse(BaseModel):
@@ -64,6 +66,7 @@ class SchemaRelationshipsRequest(BaseModel):
6466
account_id: str
6567
database_name: str
6668
collection_names: list[str]
69+
model: str = "gemini-2.5-flash"
6770

6871

6972
class Relationship(BaseModel):
@@ -85,6 +88,7 @@ class EvaluateWriteRequest(BaseModel):
8588
write_result: dict
8689
account_id: str
8790
database_name: str
91+
model: str = "gemini-2.5-flash"
8892

8993

9094
class EvaluateWriteResponse(BaseModel):

backend/routes/audit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
class AuditQueryRequest(BaseModel):
1212
question: str
13+
model: str = "gemini-2.5-flash"
1314

1415

1516
class AuditQueryResponse(BaseModel):
@@ -35,7 +36,7 @@ def query_audit_log(
3536
if not authorization.startswith("Bearer "):
3637
raise HTTPException(status_code=401, detail="Invalid token format")
3738

38-
response = process_audit_question(body.question)
39+
response = process_audit_question(body.question, model=body.model)
3940
return AuditQueryResponse(**response)
4041

4142

backend/routes/query.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from fastapi import APIRouter, Header, Body, HTTPException
2+
from typing import List
23
from services.azure_auth import exchange_token_obo
34
from services.azure_cosmos_resources import (
45
get_connection_string,
@@ -36,6 +37,7 @@
3637
)
3738
import ast
3839
import re
40+
from google import genai
3941

4042
router = APIRouter()
4143

@@ -107,6 +109,7 @@ def nl2query(prompt: QueryPrompt = Body(...), authorization: str = Header(...)):
107109
intermediate_context=prompt.intermediate_context,
108110
connection_string=connection_string,
109111
max_iterations=prompt.max_iterations,
112+
model=prompt.model,
110113
)
111114

112115

@@ -129,7 +132,7 @@ def infer_relationships(
129132
collection_filter=request.collection_names,
130133
)
131134

132-
return generate_schema_relationships(schema_summary)
135+
return generate_schema_relationships(schema_summary, model=request.model)
133136

134137
except Exception as e:
135138
print(f"Error inferring relationships: {e}")
@@ -202,15 +205,17 @@ def debug(body: DebugQueryRequest = Body(...)):
202205
"""
203206
Sends a failed query and error message to Gemini for debugging suggestion.
204207
"""
205-
return generate_suggestion_from_query_error(body.query, body.error_message)
208+
return generate_suggestion_from_query_error(
209+
body.query, body.error_message, model=body.model
210+
)
206211

207212

208213
@router.post("/analyze", response_model=AnalyzeResponse)
209214
def analyze(body: AnalyzeRequest = Body(...)):
210215
"""
211216
Sends a query result to the AI for analysis and visualization suggestions.
212217
"""
213-
return analyze_query_result(body.query_result)
218+
return analyze_query_result(body.query_result, model=body.model)
214219

215220

216221
@router.post("/evaluate-write", response_model=EvaluateWriteResponse)
@@ -240,4 +245,35 @@ def evaluate_write(
240245
write_result=body.write_result,
241246
connection_string=connection_string,
242247
database_name=body.database_name,
248+
model=body.model,
243249
)
250+
251+
252+
@router.get("/models", response_model=List[str])
253+
def list_models():
254+
"""
255+
Returns available Gemini model IDs filtered to generative models.
256+
Intentionally unauthenticated — model names are non-sensitive and
257+
this endpoint is called on page load before auth completes.
258+
"""
259+
try:
260+
client = genai.Client()
261+
models = list(client.models.list())
262+
model_ids = [
263+
m.name.replace("models/", "")
264+
for m in models
265+
if m.name
266+
and "gemini" in m.name.lower()
267+
and hasattr(m, "supported_actions")
268+
and "generateContent" in (m.supported_actions or [])
269+
]
270+
if not model_ids:
271+
# Fallback: return all gemini models if supported_actions is absent
272+
model_ids = [
273+
m.name.replace("models/", "")
274+
for m in models
275+
if m.name and "gemini" in m.name.lower()
276+
]
277+
return sorted(set(model_ids))
278+
except Exception as e:
279+
raise HTTPException(status_code=500, detail=f"Failed to list models: {str(e)}")

backend/services/analyze_service.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from models.analyze import AnalyzeResponse
55

66

7-
def analyze_query_result(query_result: list[dict]) -> AnalyzeResponse:
7+
def analyze_query_result(
8+
query_result: list[dict], model: str = "gemini-2.5-flash"
9+
) -> AnalyzeResponse:
810
prompt = f"""
911
You are a data analyst assistant. Given the following MongoDB query result, provide:
1012
1. A concise textual insight or summary of the data.
@@ -18,7 +20,7 @@ def analyze_query_result(query_result: list[dict]) -> AnalyzeResponse:
1820
"""
1921
client = genai.Client()
2022
response = client.models.generate_content(
21-
model="gemini-2.5-flash",
23+
model=model,
2224
contents=prompt,
2325
config=types.GenerateContentConfig(
2426
thinking_config=types.ThinkingConfig(thinking_budget=0)

backend/services/audit_service.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,16 @@ def get_recent_activity(user_email: str, limit: int = 10) -> List[Dict[str, Any]
6666
return []
6767

6868

69-
def process_audit_question(question: str) -> Dict[str, Any]:
69+
def process_audit_question(
70+
question: str, model: str = "gemini-2.5-flash"
71+
) -> Dict[str, Any]:
7072
"""
7173
Orchestrates the process of answering a user's audit question:
7274
1. Generate SQL from NL question (via Gemini)
7375
2. Execute SQL
7476
3. Summarize results (via Gemini)
7577
"""
76-
sql_query = generate_audit_sql(question)
78+
sql_query = generate_audit_sql(question, model=model)
7779

7880
# If the generator returned an error query or invalid SQL, return it
7981
if "Error:" in sql_query:
@@ -93,7 +95,9 @@ def process_audit_question(question: str) -> Dict[str, Any]:
9395
"summary": f"Error executing query: {results[0]['error']}",
9496
}
9597

96-
summary_response = summarize_audit_results(question, sql_query, results)
98+
summary_response = summarize_audit_results(
99+
question, sql_query, results, model=model
100+
)
97101

98102
return {
99103
"sql_query": sql_query,

backend/services/evaluate_write_service.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def evaluate_write_result(
1111
write_result: dict,
1212
connection_string: str = "",
1313
database_name: str = "",
14+
model: str = "gemini-2.5-flash",
1415
) -> EvaluateWriteResponse:
1516
prompt = f"""
1617
You are an expert MongoDB database administrator and assistant.
@@ -63,7 +64,7 @@ def query_database(query: str) -> str:
6364
tools = [query_database] if connection_string else None
6465

6566
response = client.models.generate_content(
66-
model="gemini-2.5-flash",
67+
model=model,
6768
contents=prompt,
6869
config=types.GenerateContentConfig(
6970
tools=tools, thinking_config=types.ThinkingConfig(thinking_budget=0)
@@ -84,7 +85,7 @@ def query_database(query: str) -> str:
8485

8586
# Send the tool result back to the model
8687
response = client.models.generate_content(
87-
model="gemini-2.5-flash",
88+
model=model,
8889
contents=[
8990
prompt,
9091
response.candidates[0].content,

backend/services/gemini_service.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def generate_query_from_prompt(
119119
collection_context: CollectionContext = None,
120120
intermediate_context: dict = None,
121121
all_collections_schema: str = "",
122+
model: str = "gemini-2.5-flash",
122123
) -> GeneratedCode:
123124
# Prune intermediate_context to remove image/large data
124125
safe_intermediate_context = (
@@ -136,7 +137,7 @@ def generate_query_from_prompt(
136137
)
137138
client = genai.Client()
138139
response = client.models.generate_content(
139-
model="gemini-2.5-flash",
140+
model=model,
140141
contents=full_prompt,
141142
config=types.GenerateContentConfig(
142143
thinking_config=types.ThinkingConfig(thinking_budget=0) # Disables thinking
@@ -146,14 +147,16 @@ def generate_query_from_prompt(
146147
return GeneratedCode(generated_code=code)
147148

148149

149-
def generate_suggestion_from_query_error(query: str, error_message: str) -> str:
150+
def generate_suggestion_from_query_error(
151+
query: str, error_message: str, model: str = "gemini-2.5-flash"
152+
) -> str:
150153
"""
151154
Sends a failed query and error message to Gemini for debugging suggestion.
152155
"""
153156
full_prompt = PROMPT_TEMPLATE_DEBUG.format(query=query, error_message=error_message)
154157
client = genai.Client()
155158
response = client.models.generate_content(
156-
model="gemini-2.5-flash",
159+
model=model,
157160
contents=full_prompt,
158161
config=types.GenerateContentConfig(
159162
thinking_config=types.ThinkingConfig(thinking_budget=0) # Disables thinking
@@ -202,11 +205,11 @@ def generate_suggestion_from_query_error(query: str, error_message: str) -> str:
202205
"""
203206

204207

205-
def generate_audit_sql(user_input: str) -> str:
208+
def generate_audit_sql(user_input: str, model: str = "gemini-2.5-flash") -> str:
206209
full_prompt = PROMPT_TEMPLATE_AUDIT_SQL.format(user_input=user_input)
207210
client = genai.Client()
208211
response = client.models.generate_content(
209-
model="gemini-2.5-flash",
212+
model=model,
210213
contents=full_prompt,
211214
config=types.GenerateContentConfig(
212215
thinking_config=types.ThinkingConfig(thinking_budget=0)
@@ -220,7 +223,7 @@ def generate_audit_sql(user_input: str) -> str:
220223

221224

222225
def summarize_audit_results(
223-
user_input: str, sql_query: str, results: list
226+
user_input: str, sql_query: str, results: list, model: str = "gemini-2.5-flash"
224227
) -> AuditSummaryResponse:
225228
# Truncate results if too large to avoid token limits
226229
results_str = str(results)[:10000]
@@ -229,7 +232,7 @@ def summarize_audit_results(
229232
)
230233
client = genai.Client()
231234
response = client.models.generate_content(
232-
model="gemini-2.5-flash",
235+
model=model,
233236
contents=full_prompt,
234237
config=types.GenerateContentConfig(
235238
response_mime_type="application/json",
@@ -281,13 +284,15 @@ def summarize_audit_results(
281284
"""
282285

283286

284-
def generate_schema_relationships(schema_summary: str) -> SchemaRelationshipsResponse:
287+
def generate_schema_relationships(
288+
schema_summary: str, model: str = "gemini-2.5-flash"
289+
) -> SchemaRelationshipsResponse:
285290
from models.schemas import SchemaRelationshipsResponse
286291

287292
full_prompt = PROMPT_TEMPLATE_RELATIONSHIPS.format(schema_summary=schema_summary)
288293
client = genai.Client()
289294
response = client.models.generate_content(
290-
model="gemini-2.5-flash",
295+
model=model,
291296
contents=full_prompt,
292297
config=types.GenerateContentConfig(
293298
response_mime_type="application/json",

backend/services/react_agent_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class AgentState(TypedDict):
5656
schema_context: str
5757
intermediate_context: dict
5858
connection_string: str
59+
model: str
5960

6061
generated_query: str
6162
is_write_action: bool
@@ -127,7 +128,7 @@ def generate_query_node(state: AgentState):
127128

128129
try:
129130
response = client.models.generate_content(
130-
model="gemini-2.5-flash",
131+
model=state.get("model", "gemini-2.5-flash"),
131132
contents=prompt,
132133
config=types.GenerateContentConfig(
133134
thinking_config=types.ThinkingConfig(thinking_budget=0)
@@ -210,7 +211,7 @@ def evaluate_node(state: AgentState):
210211

211212
try:
212213
response = client.models.generate_content(
213-
model="gemini-2.5-flash",
214+
model=state.get("model", "gemini-2.5-flash"),
214215
contents=prompt,
215216
config=types.GenerateContentConfig(
216217
response_mime_type="application/json",
@@ -271,6 +272,7 @@ def run_query_generator(
271272
intermediate_context: dict,
272273
connection_string: str,
273274
max_iterations: int = 3,
275+
model: str = "gemini-2.5-flash",
274276
):
275277
logger.info(f"Starting ReAct Agent workflow for input: '{user_input}'")
276278
initial_state = {
@@ -280,6 +282,7 @@ def run_query_generator(
280282
"schema_context": schema_context,
281283
"intermediate_context": intermediate_context,
282284
"connection_string": connection_string,
285+
"model": model,
283286
"iterations": 0,
284287
"max_iterations": max_iterations,
285288
}

0 commit comments

Comments
 (0)