Skip to content

Commit 1f4d417

Browse files
authored
Merge pull request #23 from ChingEnLin/dev
Release
2 parents 6b865ca + 0cd1f0a commit 1f4d417

File tree

10 files changed

+1370
-594
lines changed

10 files changed

+1370
-594
lines changed

backend/models/schemas.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ class CollectionContext(BaseModel):
2828

2929
class QueryPrompt(BaseModel):
3030
user_input: str
31+
account_id: str # Added for cross-collection schema fetching
3132
db_context: DBContext
32-
collection_context: CollectionContext | None = (
33-
None # Optional context for specific collection
34-
)
33+
collection_context: list[CollectionContext] = (
34+
[]
35+
) # List of contexts for selected collections
3536
intermediate_context: object | None = (
3637
None # Optional intermediate context for complex queries
3738
)
@@ -54,3 +55,22 @@ class DebugQueryRequest(BaseModel):
5455

5556
class DebugSuggestionResponse(BaseModel):
5657
suggestion: str
58+
59+
60+
class SchemaRelationshipsRequest(BaseModel):
61+
account_id: str
62+
database_name: str
63+
collection_names: list[str]
64+
65+
66+
class Relationship(BaseModel):
67+
source_collection: str
68+
source_field: str
69+
target_collection: str
70+
target_field: str
71+
description: str
72+
confidence: float # 0.0 to 1.0
73+
74+
75+
class SchemaRelationshipsResponse(BaseModel):
76+
relationships: list[Relationship]

backend/routes/query.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,90 @@
88
ExecuteInput,
99
DebugQueryRequest,
1010
DebugSuggestionResponse,
11+
SchemaRelationshipsRequest,
12+
SchemaRelationshipsResponse,
1113
)
1214
from services.gemini_service import (
1315
generate_query_from_prompt,
1416
generate_suggestion_from_query_error,
17+
generate_schema_relationships,
18+
)
19+
from services.mongo_service import (
20+
execute_mongo_query,
21+
transform_mongo_result,
22+
get_database_schema_summary,
1523
)
16-
from services.mongo_service import execute_mongo_query, transform_mongo_result
1724
from models.analyze import AnalyzeRequest, AnalyzeResponse
1825
from services.analyze_service import analyze_query_result
1926

2027
router = APIRouter()
2128

2229

2330
@router.post("/nl2query")
24-
def nl2query(prompt: QueryPrompt = Body(...)):
31+
def nl2query(prompt: QueryPrompt = Body(...), authorization: str = Header(...)):
32+
if not authorization.startswith("Bearer "):
33+
raise HTTPException(status_code=401, detail="Invalid token format")
34+
35+
try:
36+
user_token = authorization.replace("Bearer ", "")
37+
access_token = exchange_token_obo(user_token)
38+
# Use provided contexts if available to avoid re-fetch and ensure consistency
39+
if prompt.collection_context:
40+
summary = []
41+
for ctx in prompt.collection_context:
42+
doc_str = (
43+
str(ctx.sampleDocument)
44+
if ctx.sampleDocument
45+
else "No documents found"
46+
)
47+
summary.append(f"Collection: {ctx.name}\nSample Document: {doc_str}")
48+
schema_summary = "\n\n".join(summary)
49+
# Fallback: fetch schema summary from DB
50+
else:
51+
schema_summary = get_database_schema_summary(
52+
prompt.account_id, prompt.db_context.name, access_token
53+
)
54+
except Exception as e:
55+
print(f"Error fetching schema context: {e}")
56+
schema_summary = "Could not fetch schema summary."
57+
2558
collections = [col.name for col in prompt.db_context.collections]
2659
return generate_query_from_prompt(
2760
prompt.user_input,
2861
collections,
2962
prompt.db_context.name,
30-
prompt.collection_context,
31-
prompt.intermediate_context,
63+
collection_context=None,
64+
intermediate_context=prompt.intermediate_context,
65+
all_collections_schema=schema_summary,
3266
)
3367

3468

69+
@router.post("/infer-relationships", response_model=SchemaRelationshipsResponse)
70+
def infer_relationships(
71+
request: SchemaRelationshipsRequest = Body(...), authorization: str = Header(...)
72+
):
73+
if not authorization.startswith("Bearer "):
74+
raise HTTPException(status_code=401, detail="Invalid token format")
75+
76+
try:
77+
user_token = authorization.replace("Bearer ", "")
78+
access_token = exchange_token_obo(user_token)
79+
80+
# Fetch schema summary ONLY for correct collections
81+
schema_summary = get_database_schema_summary(
82+
request.account_id,
83+
request.database_name,
84+
access_token,
85+
collection_filter=request.collection_names,
86+
)
87+
88+
return generate_schema_relationships(schema_summary)
89+
90+
except Exception as e:
91+
print(f"Error inferring relationships: {e}")
92+
raise HTTPException(status_code=500, detail=str(e))
93+
94+
3595
@router.post("/execute")
3696
def execute(query: ExecuteInput = Body(...), authorization: str = Header(...)):
3797
if not authorization.startswith("Bearer "):

backend/services/gemini_service.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from google import genai
22
from google.genai import types
3-
from models.schemas import GeneratedCode, CollectionContext, DebugSuggestionResponse
3+
from models.schemas import (
4+
GeneratedCode,
5+
CollectionContext,
6+
DebugSuggestionResponse,
7+
SchemaRelationshipsResponse,
8+
)
49
from pydantic import BaseModel, Field
510
from typing import Optional, List
611

@@ -31,6 +36,7 @@ class AuditSummaryResponse(BaseModel):
3136
Database: {database}
3237
Available collections: {collections}
3338
Sample collection document (optional): {collection_context}
39+
Schema summary for ALL collections (for JOINs/lookups): {all_collections_schema}
3440
Intermediate context (optional): {intermediate_context}
3541
Return:
3642
only one line of pure pymongo query code (e.g., db["collection"].find(...))
@@ -112,6 +118,7 @@ def generate_query_from_prompt(
112118
database: str,
113119
collection_context: CollectionContext = None,
114120
intermediate_context: dict = None,
121+
all_collections_schema: str = "",
115122
) -> GeneratedCode:
116123
# Prune intermediate_context to remove image/large data
117124
safe_intermediate_context = (
@@ -124,6 +131,7 @@ def generate_query_from_prompt(
124131
collection_context=(
125132
collection_context.sampleDocument if collection_context else ""
126133
),
134+
all_collections_schema=all_collections_schema,
127135
intermediate_context=safe_intermediate_context,
128136
)
129137
client = genai.Client()
@@ -244,3 +252,58 @@ def summarize_audit_results(
244252
summary="Could not generate summary due to parsing error.",
245253
visualization=VisualizationConfig(available=False),
246254
)
255+
256+
257+
PROMPT_TEMPLATE_RELATIONSHIPS = """
258+
You are a database architect. Analyze the provided MongoDB document samples to identify likely foreign key relationships and JOIN conditions between collections.
259+
260+
Schema/Samples:
261+
{schema_summary}
262+
263+
Tasks:
264+
1. Identify likely relationships (e.g., `userId` in `orders` -> `_id` in `users`).
265+
2. Provide a confidence score (0.0 - 1.0) and a brief description for each.
266+
3. Return a JSON object with a "relationships" key containing a list of these findings.
267+
268+
Output Format (Json):
269+
{{
270+
"relationships": [
271+
{{
272+
"source_collection": "orders",
273+
"source_field": "userId",
274+
"target_collection": "users",
275+
"target_field": "_id",
276+
"description": "Orders belong to Users",
277+
"confidence": 0.95
278+
}}
279+
]
280+
}}
281+
"""
282+
283+
284+
def generate_schema_relationships(schema_summary: str) -> SchemaRelationshipsResponse:
285+
from models.schemas import SchemaRelationshipsResponse
286+
287+
full_prompt = PROMPT_TEMPLATE_RELATIONSHIPS.format(schema_summary=schema_summary)
288+
client = genai.Client()
289+
response = client.models.generate_content(
290+
model="gemini-2.5-flash",
291+
contents=full_prompt,
292+
config=types.GenerateContentConfig(
293+
response_mime_type="application/json",
294+
response_schema=SchemaRelationshipsResponse,
295+
thinking_config=types.ThinkingConfig(thinking_budget=0),
296+
),
297+
)
298+
299+
if hasattr(response, "parsed") and response.parsed:
300+
return response.parsed
301+
302+
import json
303+
304+
try:
305+
data = json.loads(response.text)
306+
return SchemaRelationshipsResponse(**data)
307+
except Exception as e:
308+
print(f"Error parsing Gemini relationship response: {e}")
309+
return SchemaRelationshipsResponse(relationships=[])

backend/services/mongo_service.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,38 @@ def transform_mongo_result(result):
5656
elif isinstance(result, DeleteResult):
5757
return {"deleted_count": result.deleted_count}
5858
return result
59+
60+
61+
def get_database_schema_summary(
62+
account_id: str,
63+
database: str,
64+
access_token: str,
65+
collection_filter: list[str] = None,
66+
) -> str:
67+
from services.azure_cosmos_resources import get_connection_string
68+
69+
try:
70+
connection_string = get_connection_string(account_id, access_token)
71+
client = pymongo.MongoClient(connection_string)
72+
db = client[database]
73+
summary = []
74+
75+
# Determine which collections to scan
76+
if collection_filter:
77+
target_collections = collection_filter
78+
else:
79+
target_collections = db.list_collection_names()
80+
81+
for collection_name in target_collections:
82+
# Skip system collections if scanning all (if explicit filter, try to fetch)
83+
if not collection_filter and collection_name.startswith("system."):
84+
continue
85+
86+
doc = db[collection_name].find_one()
87+
doc_str = str(doc) if doc else "No documents found"
88+
summary.append(f"Collection: {collection_name}\nSample Document: {doc_str}")
89+
90+
return "\n\n".join(summary)
91+
except Exception as e:
92+
print(f"Error fetching schema summary: {e}")
93+
return "Could not fetch schema summary."

backend/tests/test_query_routes.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
def test_nl2query(client):
1616
"""Test natural language to query conversion."""
1717
# Mock dependencies
18-
with patch("routes.query.generate_query_from_prompt") as mock_generate:
18+
with (
19+
patch("routes.query.generate_query_from_prompt") as mock_generate,
20+
patch("routes.query.exchange_token_obo") as mock_exchange,
21+
):
1922
mock_generate.return_value = {"generated_code": "db.users.find({})"}
2023

2124
# Create test data
@@ -25,11 +28,15 @@ def test_nl2query(client):
2528

2629
prompt = QueryPrompt(
2730
user_input="Find all users",
31+
account_id="test-account",
2832
db_context=db_context,
29-
collection_context=collection_context,
33+
collection_context=[collection_context],
3034
)
3135

32-
response = client.post("/query/nl2query", json=prompt.model_dump())
36+
headers = {"authorization": "Bearer valid-token"}
37+
response = client.post(
38+
"/query/nl2query", json=prompt.model_dump(), headers=headers
39+
)
3340

3441
assert response.status_code == 200
3542
data = response.json()

backend/tests/test_schemas.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,22 @@ def test_query_prompt():
9191

9292
prompt = QueryPrompt(
9393
user_input="Find all users",
94+
account_id="test-account",
9495
db_context=db_context,
95-
collection_context=collection_context,
96+
collection_context=[collection_context],
9697
intermediate_context={"key": "value"},
9798
)
9899

99100
assert prompt.user_input == "Find all users"
100101
assert prompt.db_context.name == "test-db"
101-
assert prompt.collection_context.name == "users"
102+
assert prompt.collection_context[0].name == "users"
102103
assert prompt.intermediate_context == {"key": "value"}
103104

104105
# Test with minimal required fields
105-
minimal_prompt = QueryPrompt(user_input="Find all users", db_context=db_context)
106-
assert minimal_prompt.collection_context is None
106+
minimal_prompt = QueryPrompt(
107+
user_input="Find all users", account_id="test-account", db_context=db_context
108+
)
109+
assert minimal_prompt.collection_context == []
107110
assert minimal_prompt.intermediate_context is None
108111

109112

0 commit comments

Comments
 (0)