|
1 | 1 | from google import genai |
2 | 2 | from google.genai import types |
3 | 3 | from models.schemas import GeneratedCode, CollectionContext, DebugSuggestionResponse |
| 4 | +from pydantic import BaseModel, Field |
| 5 | +from typing import Optional, List |
| 6 | + |
| 7 | + |
| 8 | +class VisualizationConfig(BaseModel): |
| 9 | + available: bool = Field(description="Whether a chart is recommended for this data") |
| 10 | + type: Optional[str] = Field( |
| 11 | + description="Type of chart: 'bar', 'line', 'pie', 'scatter'" |
| 12 | + ) |
| 13 | + x_key: Optional[str] = Field(description="Key for X-axis data") |
| 14 | + y_key: Optional[str] = Field(description="Key for Y-axis data") |
| 15 | + title: Optional[str] = Field(description="Title for the chart") |
| 16 | + data_keys: Optional[List[str]] = Field( |
| 17 | + description="Keys to include in the chart data points (e.g. ['count', 'date'])" |
| 18 | + ) |
| 19 | + |
| 20 | + |
| 21 | +class AuditSummaryResponse(BaseModel): |
| 22 | + summary: str = Field(description="Markdown summary of the results") |
| 23 | + visualization: VisualizationConfig = Field( |
| 24 | + description="Configuration for data visualization" |
| 25 | + ) |
| 26 | + |
4 | 27 |
|
5 | 28 | PROMPT_TEMPLATE_QUERY = """ |
6 | 29 | You are an assistant that converts user requests into MongoDB query code. |
@@ -132,3 +155,92 @@ def generate_suggestion_from_query_error(query: str, error_message: str) -> str: |
132 | 155 | response.text.strip() if hasattr(response, "text") else str(response).strip |
133 | 156 | ) |
134 | 157 | return DebugSuggestionResponse(suggestion=suggestion) |
| 158 | + |
| 159 | + |
| 160 | +PROMPT_TEMPLATE_AUDIT_SQL = """ |
| 161 | +You are a PostgreSQL expert. Convert the user's natural language question into a read-only SQL query for the `write_audit_log` table. |
| 162 | +Table Schema: |
| 163 | +- user_email (text): Email of the user who performed the operation. |
| 164 | +- operation (text): 'insert', 'update', or 'delete'. |
| 165 | +- database_name (text): Name of the database (format: account.database). |
| 166 | +- collection_name (text): Name of the collection. |
| 167 | +- document_id (text): ID of the affected document. |
| 168 | +- diff_data (jsonb): JSON containing the changes (for updates, it has 'before' and 'after' fields). |
| 169 | +- timestamp_utc (timestamptz): When the operation occurred. |
| 170 | +
|
| 171 | +User Question: "{user_input}" |
| 172 | +
|
| 173 | +Rules: |
| 174 | +1. Return ONLY the SQL query. No markdown, no explanations. |
| 175 | +2. The query MUST be a SELECT statement. |
| 176 | +3. Use LIMIT 100 if no limit is specified. |
| 177 | +4. If the user asks for "recent", order by timestamp_utc DESC. |
| 178 | +""" |
| 179 | + |
| 180 | +PROMPT_TEMPLATE_AUDIT_SUMMARY = """ |
| 181 | +You are a data analyst. Analyze the following SQL query and its results. |
| 182 | +
|
| 183 | +User Question: "{user_input}" |
| 184 | +SQL Query: "{sql_query}" |
| 185 | +Results: |
| 186 | +{results} |
| 187 | +
|
| 188 | +Tasks: |
| 189 | +1. Provide a concise markdown summary identifying patterns or answering the specific question. |
| 190 | +2. Determine if the data is suitable for visualization (e.g., time series, counts, comparisons). |
| 191 | +3. If suitable, structure a visualization configuration (type, keys, title). |
| 192 | + - For time series, prefer 'line' or 'bar'. |
| 193 | + - For categorical counts, use 'bar' or 'pie'. |
| 194 | +""" |
| 195 | + |
| 196 | + |
| 197 | +def generate_audit_sql(user_input: str) -> str: |
| 198 | + full_prompt = PROMPT_TEMPLATE_AUDIT_SQL.format(user_input=user_input) |
| 199 | + client = genai.Client() |
| 200 | + response = client.models.generate_content( |
| 201 | + model="gemini-2.5-flash", |
| 202 | + contents=full_prompt, |
| 203 | + config=types.GenerateContentConfig( |
| 204 | + thinking_config=types.ThinkingConfig(thinking_budget=0) |
| 205 | + ), |
| 206 | + ) |
| 207 | + sql = extract_python_code(response.text) |
| 208 | + # Basic safety check |
| 209 | + if not sql.lower().startswith("select"): |
| 210 | + return "SELECT 'Error: Generated query was not a SELECT statement' as error;" |
| 211 | + return sql |
| 212 | + |
| 213 | + |
| 214 | +def summarize_audit_results( |
| 215 | + user_input: str, sql_query: str, results: list |
| 216 | +) -> AuditSummaryResponse: |
| 217 | + # Truncate results if too large to avoid token limits |
| 218 | + results_str = str(results)[:10000] |
| 219 | + full_prompt = PROMPT_TEMPLATE_AUDIT_SUMMARY.format( |
| 220 | + user_input=user_input, sql_query=sql_query, results=results_str |
| 221 | + ) |
| 222 | + client = genai.Client() |
| 223 | + response = client.models.generate_content( |
| 224 | + model="gemini-2.5-flash", |
| 225 | + contents=full_prompt, |
| 226 | + config=types.GenerateContentConfig( |
| 227 | + response_mime_type="application/json", |
| 228 | + response_schema=AuditSummaryResponse, |
| 229 | + thinking_config=types.ThinkingConfig(thinking_budget=0), |
| 230 | + ), |
| 231 | + ) |
| 232 | + |
| 233 | + if hasattr(response, "parsed") and response.parsed: |
| 234 | + return response.parsed |
| 235 | + |
| 236 | + import json |
| 237 | + |
| 238 | + try: |
| 239 | + data = json.loads(response.text) |
| 240 | + return AuditSummaryResponse(**data) |
| 241 | + except Exception as e: |
| 242 | + print(f"Error parsing Gemini response: {e}") |
| 243 | + return AuditSummaryResponse( |
| 244 | + summary="Could not generate summary due to parsing error.", |
| 245 | + visualization=VisualizationConfig(available=False), |
| 246 | + ) |
0 commit comments