diff --git a/backend/models/analyze.py b/backend/models/analyze.py index c40a2a9..b8b9310 100644 --- a/backend/models/analyze.py +++ b/backend/models/analyze.py @@ -4,6 +4,7 @@ class AnalyzeRequest(BaseModel): query_result: List[Dict[str, Any]] + model: str = "gemini-2.5-flash" class AnalyzeResponse(BaseModel): diff --git a/backend/models/schemas.py b/backend/models/schemas.py index 2241b7d..ee926ba 100644 --- a/backend/models/schemas.py +++ b/backend/models/schemas.py @@ -39,6 +39,7 @@ class QueryPrompt(BaseModel): max_iterations: int = Field( default=3, ge=1, le=10 ) # Server-enforced agent iteration cap + model: str = "gemini-2.5-flash" class GeneratedCode(BaseModel): @@ -54,6 +55,7 @@ class ExecuteInput(BaseModel): class DebugQueryRequest(BaseModel): query: str error_message: str + model: str = "gemini-2.5-flash" class DebugSuggestionResponse(BaseModel): @@ -64,6 +66,7 @@ class SchemaRelationshipsRequest(BaseModel): account_id: str database_name: str collection_names: list[str] + model: str = "gemini-2.5-flash" class Relationship(BaseModel): @@ -85,6 +88,7 @@ class EvaluateWriteRequest(BaseModel): write_result: dict account_id: str database_name: str + model: str = "gemini-2.5-flash" class EvaluateWriteResponse(BaseModel): diff --git a/backend/routes/audit.py b/backend/routes/audit.py index c221fc2..ffde038 100644 --- a/backend/routes/audit.py +++ b/backend/routes/audit.py @@ -10,6 +10,7 @@ class AuditQueryRequest(BaseModel): question: str + model: str = "gemini-2.5-flash" class AuditQueryResponse(BaseModel): @@ -35,7 +36,7 @@ def query_audit_log( if not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid token format") - response = process_audit_question(body.question) + response = process_audit_question(body.question, model=body.model) return AuditQueryResponse(**response) diff --git a/backend/routes/query.py b/backend/routes/query.py index 5dde88b..386e75a 100644 --- a/backend/routes/query.py +++ b/backend/routes/query.py @@ -1,4 +1,5 @@ from fastapi import APIRouter, Header, Body, HTTPException +from typing import List from services.azure_auth import exchange_token_obo from services.azure_cosmos_resources import ( get_connection_string, @@ -36,6 +37,7 @@ ) import ast import re +from google import genai router = APIRouter() @@ -107,6 +109,7 @@ def nl2query(prompt: QueryPrompt = Body(...), authorization: str = Header(...)): intermediate_context=prompt.intermediate_context, connection_string=connection_string, max_iterations=prompt.max_iterations, + model=prompt.model, ) @@ -129,7 +132,7 @@ def infer_relationships( collection_filter=request.collection_names, ) - return generate_schema_relationships(schema_summary) + return generate_schema_relationships(schema_summary, model=request.model) except Exception as e: print(f"Error inferring relationships: {e}") @@ -202,7 +205,9 @@ def debug(body: DebugQueryRequest = Body(...)): """ Sends a failed query and error message to Gemini for debugging suggestion. """ - return generate_suggestion_from_query_error(body.query, body.error_message) + return generate_suggestion_from_query_error( + body.query, body.error_message, model=body.model + ) @router.post("/analyze", response_model=AnalyzeResponse) @@ -210,7 +215,7 @@ def analyze(body: AnalyzeRequest = Body(...)): """ Sends a query result to the AI for analysis and visualization suggestions. """ - return analyze_query_result(body.query_result) + return analyze_query_result(body.query_result, model=body.model) @router.post("/evaluate-write", response_model=EvaluateWriteResponse) @@ -240,4 +245,35 @@ def evaluate_write( write_result=body.write_result, connection_string=connection_string, database_name=body.database_name, + model=body.model, ) + + +@router.get("/models", response_model=List[str]) +def list_models(): + """ + Returns available Gemini model IDs filtered to generative models. + Intentionally unauthenticated — model names are non-sensitive and + this endpoint is called on page load before auth completes. + """ + try: + client = genai.Client() + models = list(client.models.list()) + model_ids = [ + m.name.replace("models/", "") + for m in models + if m.name + and "gemini" in m.name.lower() + and hasattr(m, "supported_actions") + and "generateContent" in (m.supported_actions or []) + ] + if not model_ids: + # Fallback: return all gemini models if supported_actions is absent + model_ids = [ + m.name.replace("models/", "") + for m in models + if m.name and "gemini" in m.name.lower() + ] + return sorted(set(model_ids)) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to list models: {str(e)}") diff --git a/backend/services/analyze_service.py b/backend/services/analyze_service.py index 56ef9d6..50a0571 100644 --- a/backend/services/analyze_service.py +++ b/backend/services/analyze_service.py @@ -4,7 +4,9 @@ from models.analyze import AnalyzeResponse -def analyze_query_result(query_result: list[dict]) -> AnalyzeResponse: +def analyze_query_result( + query_result: list[dict], model: str = "gemini-2.5-flash" +) -> AnalyzeResponse: prompt = f""" You are a data analyst assistant. Given the following MongoDB query result, provide: 1. A concise textual insight or summary of the data. @@ -18,7 +20,7 @@ def analyze_query_result(query_result: list[dict]) -> AnalyzeResponse: """ client = genai.Client() response = client.models.generate_content( - model="gemini-2.5-flash", + model=model, contents=prompt, config=types.GenerateContentConfig( thinking_config=types.ThinkingConfig(thinking_budget=0) diff --git a/backend/services/audit_service.py b/backend/services/audit_service.py index 30fa08f..17ce97e 100644 --- a/backend/services/audit_service.py +++ b/backend/services/audit_service.py @@ -66,14 +66,16 @@ def get_recent_activity(user_email: str, limit: int = 10) -> List[Dict[str, Any] return [] -def process_audit_question(question: str) -> Dict[str, Any]: +def process_audit_question( + question: str, model: str = "gemini-2.5-flash" +) -> Dict[str, Any]: """ Orchestrates the process of answering a user's audit question: 1. Generate SQL from NL question (via Gemini) 2. Execute SQL 3. Summarize results (via Gemini) """ - sql_query = generate_audit_sql(question) + sql_query = generate_audit_sql(question, model=model) # If the generator returned an error query or invalid SQL, return it if "Error:" in sql_query: @@ -93,7 +95,9 @@ def process_audit_question(question: str) -> Dict[str, Any]: "summary": f"Error executing query: {results[0]['error']}", } - summary_response = summarize_audit_results(question, sql_query, results) + summary_response = summarize_audit_results( + question, sql_query, results, model=model + ) return { "sql_query": sql_query, diff --git a/backend/services/evaluate_write_service.py b/backend/services/evaluate_write_service.py index a063586..39bc142 100644 --- a/backend/services/evaluate_write_service.py +++ b/backend/services/evaluate_write_service.py @@ -11,6 +11,7 @@ def evaluate_write_result( write_result: dict, connection_string: str = "", database_name: str = "", + model: str = "gemini-2.5-flash", ) -> EvaluateWriteResponse: prompt = f""" You are an expert MongoDB database administrator and assistant. @@ -63,7 +64,7 @@ def query_database(query: str) -> str: tools = [query_database] if connection_string else None response = client.models.generate_content( - model="gemini-2.5-flash", + model=model, contents=prompt, config=types.GenerateContentConfig( tools=tools, thinking_config=types.ThinkingConfig(thinking_budget=0) @@ -84,7 +85,7 @@ def query_database(query: str) -> str: # Send the tool result back to the model response = client.models.generate_content( - model="gemini-2.5-flash", + model=model, contents=[ prompt, response.candidates[0].content, diff --git a/backend/services/gemini_service.py b/backend/services/gemini_service.py index 0040c6d..d09b96d 100644 --- a/backend/services/gemini_service.py +++ b/backend/services/gemini_service.py @@ -119,6 +119,7 @@ def generate_query_from_prompt( collection_context: CollectionContext = None, intermediate_context: dict = None, all_collections_schema: str = "", + model: str = "gemini-2.5-flash", ) -> GeneratedCode: # Prune intermediate_context to remove image/large data safe_intermediate_context = ( @@ -136,7 +137,7 @@ def generate_query_from_prompt( ) client = genai.Client() response = client.models.generate_content( - model="gemini-2.5-flash", + model=model, contents=full_prompt, config=types.GenerateContentConfig( thinking_config=types.ThinkingConfig(thinking_budget=0) # Disables thinking @@ -146,14 +147,16 @@ def generate_query_from_prompt( return GeneratedCode(generated_code=code) -def generate_suggestion_from_query_error(query: str, error_message: str) -> str: +def generate_suggestion_from_query_error( + query: str, error_message: str, model: str = "gemini-2.5-flash" +) -> str: """ Sends a failed query and error message to Gemini for debugging suggestion. """ full_prompt = PROMPT_TEMPLATE_DEBUG.format(query=query, error_message=error_message) client = genai.Client() response = client.models.generate_content( - model="gemini-2.5-flash", + model=model, contents=full_prompt, config=types.GenerateContentConfig( thinking_config=types.ThinkingConfig(thinking_budget=0) # Disables thinking @@ -202,11 +205,11 @@ def generate_suggestion_from_query_error(query: str, error_message: str) -> str: """ -def generate_audit_sql(user_input: str) -> str: +def generate_audit_sql(user_input: str, model: str = "gemini-2.5-flash") -> str: full_prompt = PROMPT_TEMPLATE_AUDIT_SQL.format(user_input=user_input) client = genai.Client() response = client.models.generate_content( - model="gemini-2.5-flash", + model=model, contents=full_prompt, config=types.GenerateContentConfig( thinking_config=types.ThinkingConfig(thinking_budget=0) @@ -220,7 +223,7 @@ def generate_audit_sql(user_input: str) -> str: def summarize_audit_results( - user_input: str, sql_query: str, results: list + user_input: str, sql_query: str, results: list, model: str = "gemini-2.5-flash" ) -> AuditSummaryResponse: # Truncate results if too large to avoid token limits results_str = str(results)[:10000] @@ -229,7 +232,7 @@ def summarize_audit_results( ) client = genai.Client() response = client.models.generate_content( - model="gemini-2.5-flash", + model=model, contents=full_prompt, config=types.GenerateContentConfig( response_mime_type="application/json", @@ -281,13 +284,15 @@ def summarize_audit_results( """ -def generate_schema_relationships(schema_summary: str) -> SchemaRelationshipsResponse: +def generate_schema_relationships( + schema_summary: str, model: str = "gemini-2.5-flash" +) -> SchemaRelationshipsResponse: from models.schemas import SchemaRelationshipsResponse full_prompt = PROMPT_TEMPLATE_RELATIONSHIPS.format(schema_summary=schema_summary) client = genai.Client() response = client.models.generate_content( - model="gemini-2.5-flash", + model=model, contents=full_prompt, config=types.GenerateContentConfig( response_mime_type="application/json", diff --git a/backend/services/react_agent_service.py b/backend/services/react_agent_service.py index 15d3118..a0c3400 100644 --- a/backend/services/react_agent_service.py +++ b/backend/services/react_agent_service.py @@ -56,6 +56,7 @@ class AgentState(TypedDict): schema_context: str intermediate_context: dict connection_string: str + model: str generated_query: str is_write_action: bool @@ -127,7 +128,7 @@ def generate_query_node(state: AgentState): try: response = client.models.generate_content( - model="gemini-2.5-flash", + model=state.get("model", "gemini-2.5-flash"), contents=prompt, config=types.GenerateContentConfig( thinking_config=types.ThinkingConfig(thinking_budget=0) @@ -210,7 +211,7 @@ def evaluate_node(state: AgentState): try: response = client.models.generate_content( - model="gemini-2.5-flash", + model=state.get("model", "gemini-2.5-flash"), contents=prompt, config=types.GenerateContentConfig( response_mime_type="application/json", @@ -271,6 +272,7 @@ def run_query_generator( intermediate_context: dict, connection_string: str, max_iterations: int = 3, + model: str = "gemini-2.5-flash", ): logger.info(f"Starting ReAct Agent workflow for input: '{user_input}'") initial_state = { @@ -280,6 +282,7 @@ def run_query_generator( "schema_context": schema_context, "intermediate_context": intermediate_context, "connection_string": connection_string, + "model": model, "iterations": 0, "max_iterations": max_iterations, } diff --git a/backend/tests/test_query_routes.py b/backend/tests/test_query_routes.py index 116af58..3e85fd0 100644 --- a/backend/tests/test_query_routes.py +++ b/backend/tests/test_query_routes.py @@ -163,7 +163,9 @@ def test_debug_query(client): data = response.json() assert data["suggestion"] == "Check collection name" - mock_debug.assert_called_once_with("db.users.find({})", "Collection not found") + mock_debug.assert_called_once_with( + "db.users.find({})", "Collection not found", model="gemini-2.5-flash" + ) def test_analyze_query(client): @@ -189,7 +191,10 @@ def test_analyze_query(client): assert "chartData" in data assert "chartOptions" in data - mock_analyze.assert_called_once() + mock_analyze.assert_called_once_with( + [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}], + model="gemini-2.5-flash", + ) from models.schemas import EvaluateWriteRequest @@ -289,3 +294,56 @@ def test_execute_query_write_logging(client): call_kwargs["after_data"]["query"] == "db.users.update_one({'_id': '123'}, {'$set': {'age': 30}})" ) + + +def test_list_models_with_supported_actions(client): + """Test /models returns gemini models filtered by generateContent support.""" + + class MockModel: + def __init__(self, name, supported_actions=None): + self.name = name + self.supported_actions = supported_actions + + mock_models = [ + MockModel("models/gemini-2.5-flash", ["generateContent", "countTokens"]), + MockModel("models/gemini-2.0-flash", ["generateContent"]), + MockModel("models/gemini-1.0-pro", ["countTokens"]), # no generateContent + MockModel("models/text-embedding-004", ["embedContent"]), # not gemini + MockModel(None), # name is None + ] + + with patch("routes.query.genai") as mock_genai: + mock_genai.Client.return_value.models.list.return_value = mock_models + response = client.get("/query/models") + + assert response.status_code == 200 + data = response.json() + assert "gemini-2.5-flash" in data + assert "gemini-2.0-flash" in data + assert "gemini-1.0-pro" not in data + assert "text-embedding-004" not in data + assert data == sorted(data) + + +def test_list_models_fallback_no_supported_actions(client): + """Test /models fallback when supported_actions is absent on all models.""" + + class MockModelNoActions: + def __init__(self, name): + self.name = name + + mock_models = [ + MockModelNoActions("models/gemini-2.5-flash"), + MockModelNoActions("models/gemini-2.0-flash"), + MockModelNoActions("models/text-embedding-004"), + ] + + with patch("routes.query.genai") as mock_genai: + mock_genai.Client.return_value.models.list.return_value = mock_models + response = client.get("/query/models") + + assert response.status_code == 200 + data = response.json() + assert "gemini-2.5-flash" in data + assert "gemini-2.0-flash" in data + assert "text-embedding-004" not in data diff --git a/frontend/pages/QueryGeneratorPage.tsx b/frontend/pages/QueryGeneratorPage.tsx index 0555ece..34e85a4 100644 --- a/frontend/pages/QueryGeneratorPage.tsx +++ b/frontend/pages/QueryGeneratorPage.tsx @@ -2,7 +2,7 @@ import React, { useState, useCallback, useEffect, useMemo, useRef } from 'react' import { createPortal } from 'react-dom'; import { useSearchParams, useNavigate } from 'react-router-dom'; import { parseQueryForHandover } from '../utils/queryHandover'; -import { generateMongoQuery, debugMongoQuery, analyzeQueryResult, inferSchemaRelationships, evaluateWriteResult } from '../services/geminiService'; +import { generateMongoQuery, debugMongoQuery, analyzeQueryResult, inferSchemaRelationships, evaluateWriteResult, getAvailableModels } from '../services/geminiService'; import { getAzureCosmosAccounts, getDatabasesForAccount, runMongoQuery, getCollectionInfo, clearSystemCache } from '../services/dbService'; import { getSavedQueries, saveQuery, updateSavedQuery, deleteSavedQuery } from '../services/userDataService'; import { generateIpynbContent, downloadFile } from '../services/notebookService'; @@ -507,6 +507,10 @@ const QueryGeneratorPage: React.FC = ({ name, email, on // Agent configuration const [maxIterations, setMaxIterations] = useState(3); + const [selectedModel, setSelectedModel] = useState( + () => localStorage.getItem('qp_selected_model') ?? 'gemini-2.5-flash' + ); + const [availableModels, setAvailableModels] = useState(['gemini-2.5-flash']); // State for collection details const [selectedCollections, setSelectedCollections] = useState([]); @@ -665,6 +669,9 @@ const QueryGeneratorPage: React.FC = ({ name, email, on useEffect(() => { fetchAccounts(); fetchSavedQueries(); + getAvailableModels().then(models => { + if (models.length > 0) setAvailableModels(models); + }); // Check if the user has seen the tutorial before const hasSeenTutorial = localStorage.getItem('hasSeenTutorial'); if (!hasSeenTutorial) { @@ -865,7 +872,8 @@ const QueryGeneratorPage: React.FC = ({ name, email, on collectionCtx, intermediateContext?.data, selectedCollectionInfos, - maxIterations + maxIterations, + selectedModel ); setQueryResult(result); setIntermediateContext(null); // Clear context after use @@ -901,7 +909,7 @@ const QueryGeneratorPage: React.FC = ({ name, email, on } finally { setIsLoading(false); } - }, [connectedDbInfo, codeHistory, historyIndex, intermediateContext, connectedResource, selectedAccountId, selectedCollections, collectionDetailsMap]); + }, [connectedDbInfo, codeHistory, historyIndex, intermediateContext, connectedResource, selectedAccountId, selectedCollections, collectionDetailsMap, selectedModel]); const handleGenerateQueryClick = useCallback(() => { if (intermediateContext) { @@ -975,7 +983,7 @@ const QueryGeneratorPage: React.FC = ({ name, email, on setDebuggingResult(null); try { - const result = await debugMongoQuery(editableCode, executionError); + const result = await debugMongoQuery(editableCode, executionError, selectedModel); setDebuggingResult(result); } catch (e) { if (e instanceof Error) setDebugError(e.message); @@ -983,7 +991,7 @@ const QueryGeneratorPage: React.FC = ({ name, email, on } finally { setIsDebugging(false); } - }, [editableCode, executionError]); + }, [editableCode, executionError, selectedModel]); const handleAnalyzeQuery = useCallback(async (dataToAnalyze: any) => { if (!dataToAnalyze) return; @@ -993,7 +1001,7 @@ const QueryGeneratorPage: React.FC = ({ name, email, on setAnalysisResult(null); try { - const result = await analyzeQueryResult(dataToAnalyze); + const result = await analyzeQueryResult(dataToAnalyze, selectedModel); setAnalysisResult(result); } catch (e) { if (e instanceof Error) setAnalysisError(e.message); @@ -1001,7 +1009,7 @@ const QueryGeneratorPage: React.FC = ({ name, email, on } finally { setIsAnalyzing(false); } - }, []); + }, [selectedModel]); const handleEvaluateWrite = useCallback(async () => { if (!editableCode || !executionResult || !lastSuccessfulPrompt || !selectedAccountId || !connectedDbInfo) return; @@ -1011,7 +1019,7 @@ const QueryGeneratorPage: React.FC = ({ name, email, on setWriteEvaluationResult(null); try { - const result = await evaluateWriteResult(lastSuccessfulPrompt, editableCode, executionResult, selectedAccountId, connectedDbInfo.name); + const result = await evaluateWriteResult(lastSuccessfulPrompt, editableCode, executionResult, selectedAccountId, connectedDbInfo.name, selectedModel); setWriteEvaluationResult(result); } catch (e) { if (e instanceof Error) setWriteEvaluationError(e.message); @@ -1019,7 +1027,7 @@ const QueryGeneratorPage: React.FC = ({ name, email, on } finally { setIsEvaluatingWrite(false); } - }, [lastSuccessfulPrompt, editableCode, executionResult, selectedAccountId, connectedDbInfo]); + }, [lastSuccessfulPrompt, editableCode, executionResult, selectedAccountId, connectedDbInfo, selectedModel]); const handleCollectionClick = useCallback(async (collectionName: string, event?: React.MouseEvent) => { @@ -1098,7 +1106,8 @@ const QueryGeneratorPage: React.FC = ({ name, email, on const result = await inferSchemaRelationships( connectedResource.accountId, connectedResource.databaseName, - selectedCollections + selectedCollections, + selectedModel ); setRelationships(result); } catch (e: any) { @@ -1106,7 +1115,7 @@ const QueryGeneratorPage: React.FC = ({ name, email, on } finally { setIsAnalyzingRelationships(false); } - }, [connectedResource, selectedCollections]); + }, [connectedResource, selectedCollections, selectedModel]); // Debounced effect to trigger analysis when selections change useEffect(() => { @@ -1844,6 +1853,33 @@ const QueryGeneratorPage: React.FC = ({ name, email, on aria-label={`Agent iterations: ${maxIterations}`} /> + {/* Model selector */} +
+ + +