Skip to content

Commit b92535f

Browse files
authored
Merge pull request #19 from ChingEnLin/dev
Release
2 parents b72ebee + 254e838 commit b92535f

File tree

13 files changed

+2440
-14
lines changed

13 files changed

+2440
-14
lines changed

backend/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import uvicorn
33
from fastapi import FastAPI
44
from fastapi.middleware.cors import CORSMiddleware
5-
from routes import query, azure, system, user_queries, data_documents
5+
from routes import query, azure, system, user_queries, data_documents, audit
66

77
app = FastAPI()
88

@@ -55,6 +55,7 @@ async def health_check():
5555
app.include_router(system.router, prefix="/system", tags=["System"])
5656
app.include_router(user_queries.router, prefix="/user", tags=["User Queries"])
5757
app.include_router(data_documents.router, prefix="/data", tags=["Data Documents"])
58+
app.include_router(audit.router, prefix="/audit", tags=["Audit"])
5859

5960
if __name__ == "__main__":
6061
uvicorn.run(app, host="0.0.0.0", port=8000)

backend/routes/audit.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from fastapi import APIRouter, Header, Body, HTTPException
2+
from pydantic import BaseModel
3+
from typing import List, Dict, Any, Optional
4+
from services.audit_service import process_audit_question
5+
from services.gemini_service import VisualizationConfig
6+
7+
router = APIRouter()
8+
9+
10+
class AuditQueryRequest(BaseModel):
11+
question: str
12+
13+
14+
class AuditQueryResponse(BaseModel):
15+
sql_query: str
16+
results: List[Dict[str, Any]]
17+
summary: str
18+
visualization: Optional[VisualizationConfig] = None
19+
20+
21+
@router.post("/query", response_model=AuditQueryResponse)
22+
def query_audit_log(
23+
body: AuditQueryRequest = Body(...), authorization: str = Header(...)
24+
):
25+
if not authorization.startswith("Bearer "):
26+
raise HTTPException(status_code=401, detail="Invalid token format")
27+
28+
# We might want to validate the token here even if we don't use it for the pg connection directly yet
29+
# user_token = authorization.replace("Bearer ", "")
30+
# access_token = exchange_token_obo(user_token)
31+
32+
response = process_audit_question(body.question)
33+
return AuditQueryResponse(**response)

backend/services/audit_service.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Dict, Any
2+
from services.pg_connection import get_connection
3+
from services.gemini_service import generate_audit_sql, summarize_audit_results
4+
5+
6+
def execute_audit_query(sql_query: str) -> list:
7+
"""
8+
Executes a read-only SQL query against the audit database.
9+
"""
10+
if not sql_query.lower().strip().startswith("select"):
11+
return [{"error": "Only SELECT queries are allowed."}]
12+
13+
try:
14+
conn = get_connection()
15+
cur = conn.cursor()
16+
cur.execute(sql_query)
17+
18+
# Get column names
19+
columns = [desc[0] for desc in cur.description]
20+
results = [dict(zip(columns, row)) for row in cur.fetchall()]
21+
22+
# Serialize datetime and json objects
23+
for row in results:
24+
for key, value in row.items():
25+
if hasattr(value, "isoformat"):
26+
row[key] = value.isoformat()
27+
elif isinstance(value, dict):
28+
# Ensure dicts (like diff_data) are kept as dicts for the frontend
29+
pass
30+
31+
cur.close()
32+
conn.close()
33+
return results
34+
except Exception as e:
35+
print(f"Error executing audit query: {e}")
36+
return [{"error": str(e)}]
37+
38+
39+
def process_audit_question(question: str) -> Dict[str, Any]:
40+
"""
41+
Orchestrates the process of answering a user's audit question:
42+
1. Generate SQL from NL question (via Gemini)
43+
2. Execute SQL
44+
3. Summarize results (via Gemini)
45+
"""
46+
sql_query = generate_audit_sql(question)
47+
48+
# If the generator returned an error query or invalid SQL, return it
49+
if "Error:" in sql_query:
50+
return {
51+
"sql_query": sql_query,
52+
"results": [],
53+
"summary": "Could not generate a valid query for your request.",
54+
}
55+
56+
results = execute_audit_query(sql_query)
57+
58+
# If execution failed
59+
if results and "error" in results[0]:
60+
return {
61+
"sql_query": sql_query,
62+
"results": [],
63+
"summary": f"Error executing query: {results[0]['error']}",
64+
}
65+
66+
summary_response = summarize_audit_results(question, sql_query, results)
67+
68+
return {
69+
"sql_query": sql_query,
70+
"results": results,
71+
"summary": summary_response.summary,
72+
"visualization": summary_response.visualization,
73+
}

backend/services/gemini_service.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,29 @@
11
from google import genai
22
from google.genai import types
33
from models.schemas import GeneratedCode, CollectionContext, DebugSuggestionResponse
4+
from pydantic import BaseModel, Field
5+
from typing import Optional, List
6+
7+
8+
class VisualizationConfig(BaseModel):
9+
available: bool = Field(description="Whether a chart is recommended for this data")
10+
type: Optional[str] = Field(
11+
description="Type of chart: 'bar', 'line', 'pie', 'scatter'"
12+
)
13+
x_key: Optional[str] = Field(description="Key for X-axis data")
14+
y_key: Optional[str] = Field(description="Key for Y-axis data")
15+
title: Optional[str] = Field(description="Title for the chart")
16+
data_keys: Optional[List[str]] = Field(
17+
description="Keys to include in the chart data points (e.g. ['count', 'date'])"
18+
)
19+
20+
21+
class AuditSummaryResponse(BaseModel):
22+
summary: str = Field(description="Markdown summary of the results")
23+
visualization: VisualizationConfig = Field(
24+
description="Configuration for data visualization"
25+
)
26+
427

528
PROMPT_TEMPLATE_QUERY = """
629
You are an assistant that converts user requests into MongoDB query code.
@@ -132,3 +155,92 @@ def generate_suggestion_from_query_error(query: str, error_message: str) -> str:
132155
response.text.strip() if hasattr(response, "text") else str(response).strip
133156
)
134157
return DebugSuggestionResponse(suggestion=suggestion)
158+
159+
160+
PROMPT_TEMPLATE_AUDIT_SQL = """
161+
You are a PostgreSQL expert. Convert the user's natural language question into a read-only SQL query for the `write_audit_log` table.
162+
Table Schema:
163+
- user_email (text): Email of the user who performed the operation.
164+
- operation (text): 'insert', 'update', or 'delete'.
165+
- database_name (text): Name of the database (format: account.database).
166+
- collection_name (text): Name of the collection.
167+
- document_id (text): ID of the affected document.
168+
- diff_data (jsonb): JSON containing the changes (for updates, it has 'before' and 'after' fields).
169+
- timestamp_utc (timestamptz): When the operation occurred.
170+
171+
User Question: "{user_input}"
172+
173+
Rules:
174+
1. Return ONLY the SQL query. No markdown, no explanations.
175+
2. The query MUST be a SELECT statement.
176+
3. Use LIMIT 100 if no limit is specified.
177+
4. If the user asks for "recent", order by timestamp_utc DESC.
178+
"""
179+
180+
PROMPT_TEMPLATE_AUDIT_SUMMARY = """
181+
You are a data analyst. Analyze the following SQL query and its results.
182+
183+
User Question: "{user_input}"
184+
SQL Query: "{sql_query}"
185+
Results:
186+
{results}
187+
188+
Tasks:
189+
1. Provide a concise markdown summary identifying patterns or answering the specific question.
190+
2. Determine if the data is suitable for visualization (e.g., time series, counts, comparisons).
191+
3. If suitable, structure a visualization configuration (type, keys, title).
192+
- For time series, prefer 'line' or 'bar'.
193+
- For categorical counts, use 'bar' or 'pie'.
194+
"""
195+
196+
197+
def generate_audit_sql(user_input: str) -> str:
198+
full_prompt = PROMPT_TEMPLATE_AUDIT_SQL.format(user_input=user_input)
199+
client = genai.Client()
200+
response = client.models.generate_content(
201+
model="gemini-2.5-flash",
202+
contents=full_prompt,
203+
config=types.GenerateContentConfig(
204+
thinking_config=types.ThinkingConfig(thinking_budget=0)
205+
),
206+
)
207+
sql = extract_python_code(response.text)
208+
# Basic safety check
209+
if not sql.lower().startswith("select"):
210+
return "SELECT 'Error: Generated query was not a SELECT statement' as error;"
211+
return sql
212+
213+
214+
def summarize_audit_results(
215+
user_input: str, sql_query: str, results: list
216+
) -> AuditSummaryResponse:
217+
# Truncate results if too large to avoid token limits
218+
results_str = str(results)[:10000]
219+
full_prompt = PROMPT_TEMPLATE_AUDIT_SUMMARY.format(
220+
user_input=user_input, sql_query=sql_query, results=results_str
221+
)
222+
client = genai.Client()
223+
response = client.models.generate_content(
224+
model="gemini-2.5-flash",
225+
contents=full_prompt,
226+
config=types.GenerateContentConfig(
227+
response_mime_type="application/json",
228+
response_schema=AuditSummaryResponse,
229+
thinking_config=types.ThinkingConfig(thinking_budget=0),
230+
),
231+
)
232+
233+
if hasattr(response, "parsed") and response.parsed:
234+
return response.parsed
235+
236+
import json
237+
238+
try:
239+
data = json.loads(response.text)
240+
return AuditSummaryResponse(**data)
241+
except Exception as e:
242+
print(f"Error parsing Gemini response: {e}")
243+
return AuditSummaryResponse(
244+
summary="Could not generate summary due to parsing error.",
245+
visualization=VisualizationConfig(available=False),
246+
)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import unittest
2+
from unittest.mock import patch, MagicMock
3+
from services.audit_service import process_audit_question
4+
5+
6+
class TestAuditService(unittest.TestCase):
7+
8+
@patch("services.audit_service.generate_audit_sql")
9+
@patch("services.audit_service.get_connection")
10+
@patch("services.audit_service.summarize_audit_results")
11+
def test_process_audit_question_success(
12+
self, mock_summarize, mock_get_conn, mock_generate_sql
13+
):
14+
# Setup mocks
15+
mock_generate_sql.return_value = "SELECT * FROM write_audit_log LIMIT 5"
16+
17+
mock_conn = MagicMock()
18+
mock_cursor = MagicMock()
19+
mock_get_conn.return_value = mock_conn
20+
mock_conn.cursor.return_value = mock_cursor
21+
22+
# Mock DB results
23+
mock_cursor.description = [("user_email",), ("operation",)]
24+
mock_cursor.fetchall.return_value = [("test@example.com", "insert")]
25+
26+
# Mock Summary Response
27+
mock_response = MagicMock()
28+
mock_response.summary = "Summary of results."
29+
mock_response.visualization = None
30+
mock_summarize.return_value = mock_response
31+
32+
# Execute
33+
result = process_audit_question("Show me inserts")
34+
35+
# Assertions
36+
self.assertEqual(result["sql_query"], "SELECT * FROM write_audit_log LIMIT 5")
37+
self.assertEqual(len(result["results"]), 1)
38+
self.assertEqual(result["results"][0]["user_email"], "test@example.com")
39+
self.assertEqual(result["summary"], "Summary of results.")
40+
41+
mock_generate_sql.assert_called_once()
42+
mock_cursor.execute.assert_called_with("SELECT * FROM write_audit_log LIMIT 5")
43+
mock_summarize.assert_called_once()
44+
45+
@patch("services.audit_service.generate_audit_sql")
46+
def test_process_audit_question_invalid_sql(self, mock_generate_sql):
47+
mock_generate_sql.return_value = "DELETE FROM write_audit_log"
48+
49+
result = process_audit_question("Delete everything")
50+
51+
self.assertIn("Error executing query", result["summary"])
52+
self.assertIn("Only SELECT", result["summary"])
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main()

0 commit comments

Comments
 (0)