Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/models/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

class AnalyzeRequest(BaseModel):
query_result: List[Dict[str, Any]]
model: str = "gemini-2.5-flash"


class AnalyzeResponse(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions backend/models/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class QueryPrompt(BaseModel):
max_iterations: int = Field(
default=3, ge=1, le=10
) # Server-enforced agent iteration cap
model: str = "gemini-2.5-flash"


class GeneratedCode(BaseModel):
Expand All @@ -54,6 +55,7 @@ class ExecuteInput(BaseModel):
class DebugQueryRequest(BaseModel):
query: str
error_message: str
model: str = "gemini-2.5-flash"


class DebugSuggestionResponse(BaseModel):
Expand All @@ -64,6 +66,7 @@ class SchemaRelationshipsRequest(BaseModel):
account_id: str
database_name: str
collection_names: list[str]
model: str = "gemini-2.5-flash"


class Relationship(BaseModel):
Expand All @@ -85,6 +88,7 @@ class EvaluateWriteRequest(BaseModel):
write_result: dict
account_id: str
database_name: str
model: str = "gemini-2.5-flash"


class EvaluateWriteResponse(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion backend/routes/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class AuditQueryRequest(BaseModel):
question: str
model: str = "gemini-2.5-flash"


class AuditQueryResponse(BaseModel):
Expand All @@ -35,7 +36,7 @@ def query_audit_log(
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid token format")

response = process_audit_question(body.question)
response = process_audit_question(body.question, model=body.model)
return AuditQueryResponse(**response)


Expand Down
42 changes: 39 additions & 3 deletions backend/routes/query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import APIRouter, Header, Body, HTTPException
from typing import List
from services.azure_auth import exchange_token_obo
from services.azure_cosmos_resources import (
get_connection_string,
Expand Down Expand Up @@ -36,6 +37,7 @@
)
import ast
import re
from google import genai

router = APIRouter()

Expand Down Expand Up @@ -107,6 +109,7 @@ def nl2query(prompt: QueryPrompt = Body(...), authorization: str = Header(...)):
intermediate_context=prompt.intermediate_context,
connection_string=connection_string,
max_iterations=prompt.max_iterations,
model=prompt.model,
)


Expand All @@ -129,7 +132,7 @@ def infer_relationships(
collection_filter=request.collection_names,
)

return generate_schema_relationships(schema_summary)
return generate_schema_relationships(schema_summary, model=request.model)

except Exception as e:
print(f"Error inferring relationships: {e}")
Expand Down Expand Up @@ -202,15 +205,17 @@ def debug(body: DebugQueryRequest = Body(...)):
"""
Sends a failed query and error message to Gemini for debugging suggestion.
"""
return generate_suggestion_from_query_error(body.query, body.error_message)
return generate_suggestion_from_query_error(
body.query, body.error_message, model=body.model
)


@router.post("/analyze", response_model=AnalyzeResponse)
def analyze(body: AnalyzeRequest = Body(...)):
"""
Sends a query result to the AI for analysis and visualization suggestions.
"""
return analyze_query_result(body.query_result)
return analyze_query_result(body.query_result, model=body.model)


@router.post("/evaluate-write", response_model=EvaluateWriteResponse)
Expand Down Expand Up @@ -240,4 +245,35 @@ def evaluate_write(
write_result=body.write_result,
connection_string=connection_string,
database_name=body.database_name,
model=body.model,
)


@router.get("/models", response_model=List[str])
def list_models():
"""
Returns available Gemini model IDs filtered to generative models.
Intentionally unauthenticated — model names are non-sensitive and
this endpoint is called on page load before auth completes.
"""
try:
client = genai.Client()
models = list(client.models.list())
model_ids = [
m.name.replace("models/", "")
for m in models
if m.name
and "gemini" in m.name.lower()
and hasattr(m, "supported_actions")
and "generateContent" in (m.supported_actions or [])
]
if not model_ids:
# Fallback: return all gemini models if supported_actions is absent
model_ids = [
m.name.replace("models/", "")
for m in models
if m.name and "gemini" in m.name.lower()
]
return sorted(set(model_ids))
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to list models: {str(e)}")
6 changes: 4 additions & 2 deletions backend/services/analyze_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from models.analyze import AnalyzeResponse


def analyze_query_result(query_result: list[dict]) -> AnalyzeResponse:
def analyze_query_result(
query_result: list[dict], model: str = "gemini-2.5-flash"
) -> AnalyzeResponse:
prompt = f"""
You are a data analyst assistant. Given the following MongoDB query result, provide:
1. A concise textual insight or summary of the data.
Expand All @@ -18,7 +20,7 @@ def analyze_query_result(query_result: list[dict]) -> AnalyzeResponse:
"""
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.5-flash",
model=model,
contents=prompt,
config=types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(thinking_budget=0)
Expand Down
10 changes: 7 additions & 3 deletions backend/services/audit_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,16 @@ def get_recent_activity(user_email: str, limit: int = 10) -> List[Dict[str, Any]
return []


def process_audit_question(question: str) -> Dict[str, Any]:
def process_audit_question(
question: str, model: str = "gemini-2.5-flash"
) -> Dict[str, Any]:
"""
Orchestrates the process of answering a user's audit question:
1. Generate SQL from NL question (via Gemini)
2. Execute SQL
3. Summarize results (via Gemini)
"""
sql_query = generate_audit_sql(question)
sql_query = generate_audit_sql(question, model=model)

# If the generator returned an error query or invalid SQL, return it
if "Error:" in sql_query:
Expand All @@ -93,7 +95,9 @@ def process_audit_question(question: str) -> Dict[str, Any]:
"summary": f"Error executing query: {results[0]['error']}",
}

summary_response = summarize_audit_results(question, sql_query, results)
summary_response = summarize_audit_results(
question, sql_query, results, model=model
)

return {
"sql_query": sql_query,
Expand Down
5 changes: 3 additions & 2 deletions backend/services/evaluate_write_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def evaluate_write_result(
write_result: dict,
connection_string: str = "",
database_name: str = "",
model: str = "gemini-2.5-flash",
) -> EvaluateWriteResponse:
prompt = f"""
You are an expert MongoDB database administrator and assistant.
Expand Down Expand Up @@ -63,7 +64,7 @@ def query_database(query: str) -> str:
tools = [query_database] if connection_string else None

response = client.models.generate_content(
model="gemini-2.5-flash",
model=model,
contents=prompt,
config=types.GenerateContentConfig(
tools=tools, thinking_config=types.ThinkingConfig(thinking_budget=0)
Expand All @@ -84,7 +85,7 @@ def query_database(query: str) -> str:

# Send the tool result back to the model
response = client.models.generate_content(
model="gemini-2.5-flash",
model=model,
contents=[
prompt,
response.candidates[0].content,
Expand Down
23 changes: 14 additions & 9 deletions backend/services/gemini_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def generate_query_from_prompt(
collection_context: CollectionContext = None,
intermediate_context: dict = None,
all_collections_schema: str = "",
model: str = "gemini-2.5-flash",
) -> GeneratedCode:
# Prune intermediate_context to remove image/large data
safe_intermediate_context = (
Expand All @@ -136,7 +137,7 @@ def generate_query_from_prompt(
)
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.5-flash",
model=model,
contents=full_prompt,
config=types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(thinking_budget=0) # Disables thinking
Expand All @@ -146,14 +147,16 @@ def generate_query_from_prompt(
return GeneratedCode(generated_code=code)


def generate_suggestion_from_query_error(query: str, error_message: str) -> str:
def generate_suggestion_from_query_error(
query: str, error_message: str, model: str = "gemini-2.5-flash"
) -> str:
"""
Sends a failed query and error message to Gemini for debugging suggestion.
"""
full_prompt = PROMPT_TEMPLATE_DEBUG.format(query=query, error_message=error_message)
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.5-flash",
model=model,
contents=full_prompt,
config=types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(thinking_budget=0) # Disables thinking
Expand Down Expand Up @@ -202,11 +205,11 @@ def generate_suggestion_from_query_error(query: str, error_message: str) -> str:
"""


def generate_audit_sql(user_input: str) -> str:
def generate_audit_sql(user_input: str, model: str = "gemini-2.5-flash") -> str:
full_prompt = PROMPT_TEMPLATE_AUDIT_SQL.format(user_input=user_input)
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.5-flash",
model=model,
contents=full_prompt,
config=types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(thinking_budget=0)
Expand All @@ -220,7 +223,7 @@ def generate_audit_sql(user_input: str) -> str:


def summarize_audit_results(
user_input: str, sql_query: str, results: list
user_input: str, sql_query: str, results: list, model: str = "gemini-2.5-flash"
) -> AuditSummaryResponse:
# Truncate results if too large to avoid token limits
results_str = str(results)[:10000]
Expand All @@ -229,7 +232,7 @@ def summarize_audit_results(
)
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.5-flash",
model=model,
contents=full_prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
Expand Down Expand Up @@ -281,13 +284,15 @@ def summarize_audit_results(
"""


def generate_schema_relationships(schema_summary: str) -> SchemaRelationshipsResponse:
def generate_schema_relationships(
schema_summary: str, model: str = "gemini-2.5-flash"
) -> SchemaRelationshipsResponse:
from models.schemas import SchemaRelationshipsResponse

full_prompt = PROMPT_TEMPLATE_RELATIONSHIPS.format(schema_summary=schema_summary)
client = genai.Client()
response = client.models.generate_content(
model="gemini-2.5-flash",
model=model,
contents=full_prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
Expand Down
7 changes: 5 additions & 2 deletions backend/services/react_agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class AgentState(TypedDict):
schema_context: str
intermediate_context: dict
connection_string: str
model: str

generated_query: str
is_write_action: bool
Expand Down Expand Up @@ -127,7 +128,7 @@ def generate_query_node(state: AgentState):

try:
response = client.models.generate_content(
model="gemini-2.5-flash",
model=state.get("model", "gemini-2.5-flash"),
contents=prompt,
config=types.GenerateContentConfig(
thinking_config=types.ThinkingConfig(thinking_budget=0)
Expand Down Expand Up @@ -210,7 +211,7 @@ def evaluate_node(state: AgentState):

try:
response = client.models.generate_content(
model="gemini-2.5-flash",
model=state.get("model", "gemini-2.5-flash"),
contents=prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
Expand Down Expand Up @@ -271,6 +272,7 @@ def run_query_generator(
intermediate_context: dict,
connection_string: str,
max_iterations: int = 3,
model: str = "gemini-2.5-flash",
):
logger.info(f"Starting ReAct Agent workflow for input: '{user_input}'")
initial_state = {
Expand All @@ -280,6 +282,7 @@ def run_query_generator(
"schema_context": schema_context,
"intermediate_context": intermediate_context,
"connection_string": connection_string,
"model": model,
"iterations": 0,
"max_iterations": max_iterations,
}
Expand Down
Loading
Loading