Skip to content

Commit c6ef8db

Browse files
committed
feat: Implement schema relationship inference and enhance NL2Query with cross-collection schema context.
1 parent a7fea1b commit c6ef8db

7 files changed

Lines changed: 540 additions & 82 deletions

File tree

backend/models/schemas.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class CollectionContext(BaseModel):
2828

2929
class QueryPrompt(BaseModel):
3030
user_input: str
31+
account_id: str # Added for cross-collection schema fetching
3132
db_context: DBContext
3233
collection_context: CollectionContext | None = (
3334
None # Optional context for specific collection
@@ -54,3 +55,22 @@ class DebugQueryRequest(BaseModel):
5455

5556
class DebugSuggestionResponse(BaseModel):
5657
suggestion: str
58+
59+
60+
class SchemaRelationshipsRequest(BaseModel):
61+
account_id: str
62+
database_name: str
63+
collection_names: list[str]
64+
65+
66+
class Relationship(BaseModel):
67+
source_collection: str
68+
source_field: str
69+
target_collection: str
70+
target_field: str
71+
description: str
72+
confidence: float # 0.0 to 1.0
73+
74+
75+
class SchemaRelationshipsResponse(BaseModel):
76+
relationships: list[Relationship]

backend/routes/query.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,72 @@
88
ExecuteInput,
99
DebugQueryRequest,
1010
DebugSuggestionResponse,
11+
SchemaRelationshipsRequest,
12+
SchemaRelationshipsResponse,
1113
)
1214
from services.gemini_service import (
1315
generate_query_from_prompt,
1416
generate_suggestion_from_query_error,
17+
generate_schema_relationships,
1518
)
16-
from services.mongo_service import execute_mongo_query, transform_mongo_result
19+
from services.mongo_service import execute_mongo_query, transform_mongo_result, get_database_schema_summary
1720
from models.analyze import AnalyzeRequest, AnalyzeResponse
1821
from services.analyze_service import analyze_query_result
1922

2023
router = APIRouter()
2124

2225

2326
@router.post("/nl2query")
24-
def nl2query(prompt: QueryPrompt = Body(...)):
27+
def nl2query(prompt: QueryPrompt = Body(...), authorization: str = Header(...)):
28+
if not authorization.startswith("Bearer "):
29+
raise HTTPException(status_code=401, detail="Invalid token format")
30+
31+
try:
32+
user_token = authorization.replace("Bearer ", "")
33+
access_token = exchange_token_obo(user_token)
34+
# Fetch schema summary
35+
# Note: We need the connection string to fetch the schema
36+
# Ideally we might cache this, but for now we fetch it live
37+
schema_summary = get_database_schema_summary(prompt.account_id, prompt.db_context.name, access_token)
38+
except Exception as e:
39+
print(f"Error fetching schema context: {e}")
40+
schema_summary = "Could not fetch schema summary."
41+
2542
collections = [col.name for col in prompt.db_context.collections]
2643
return generate_query_from_prompt(
2744
prompt.user_input,
2845
collections,
2946
prompt.db_context.name,
3047
prompt.collection_context,
3148
prompt.intermediate_context,
49+
all_collections_schema=schema_summary
3250
)
3351

3452

53+
@router.post("/infer-relationships", response_model=SchemaRelationshipsResponse)
54+
def infer_relationships(request: SchemaRelationshipsRequest = Body(...), authorization: str = Header(...)):
55+
if not authorization.startswith("Bearer "):
56+
raise HTTPException(status_code=401, detail="Invalid token format")
57+
58+
try:
59+
user_token = authorization.replace("Bearer ", "")
60+
access_token = exchange_token_obo(user_token)
61+
62+
# Fetch schema summary ONLY for correct collections
63+
schema_summary = get_database_schema_summary(
64+
request.account_id,
65+
request.database_name,
66+
access_token,
67+
collection_filter=request.collection_names
68+
)
69+
70+
return generate_schema_relationships(schema_summary)
71+
72+
except Exception as e:
73+
print(f"Error inferring relationships: {e}")
74+
raise HTTPException(status_code=500, detail=str(e))
75+
76+
3577
@router.post("/execute")
3678
def execute(query: ExecuteInput = Body(...), authorization: str = Header(...)):
3779
if not authorization.startswith("Bearer "):

backend/services/gemini_service.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from google import genai
22
from google.genai import types
3-
from models.schemas import GeneratedCode, CollectionContext, DebugSuggestionResponse
3+
from models.schemas import GeneratedCode, CollectionContext, DebugSuggestionResponse, SchemaRelationshipsResponse
44
from pydantic import BaseModel, Field
55
from typing import Optional, List
66

@@ -31,6 +31,7 @@ class AuditSummaryResponse(BaseModel):
3131
Database: {database}
3232
Available collections: {collections}
3333
Sample collection document (optional): {collection_context}
34+
Schema summary for ALL collections (for JOINs/lookups): {all_collections_schema}
3435
Intermediate context (optional): {intermediate_context}
3536
Return:
3637
only one line of pure pymongo query code (e.g., db["collection"].find(...))
@@ -112,6 +113,7 @@ def generate_query_from_prompt(
112113
database: str,
113114
collection_context: CollectionContext = None,
114115
intermediate_context: dict = None,
116+
all_collections_schema: str = ""
115117
) -> GeneratedCode:
116118
# Prune intermediate_context to remove image/large data
117119
safe_intermediate_context = (
@@ -124,6 +126,7 @@ def generate_query_from_prompt(
124126
collection_context=(
125127
collection_context.sampleDocument if collection_context else ""
126128
),
129+
all_collections_schema=all_collections_schema,
127130
intermediate_context=safe_intermediate_context,
128131
)
129132
client = genai.Client()
@@ -244,3 +247,57 @@ def summarize_audit_results(
244247
summary="Could not generate summary due to parsing error.",
245248
visualization=VisualizationConfig(available=False),
246249
)
250+
251+
252+
PROMPT_TEMPLATE_RELATIONSHIPS = """
253+
You are a database architect. Analyze the provided MongoDB document samples to identify likely foreign key relationships and JOIN conditions between collections.
254+
255+
Schema/Samples:
256+
{schema_summary}
257+
258+
Tasks:
259+
1. Identify likely relationships (e.g., `userId` in `orders` -> `_id` in `users`).
260+
2. Provide a confidence score (0.0 - 1.0) and a brief description for each.
261+
3. Return a JSON object with a "relationships" key containing a list of these findings.
262+
263+
Output Format (Json):
264+
{{
265+
"relationships": [
266+
{{
267+
"source_collection": "orders",
268+
"source_field": "userId",
269+
"target_collection": "users",
270+
"target_field": "_id",
271+
"description": "Orders belong to Users",
272+
"confidence": 0.95
273+
}}
274+
]
275+
}}
276+
"""
277+
278+
279+
def generate_schema_relationships(schema_summary: str) -> SchemaRelationshipsResponse:
280+
from models.schemas import SchemaRelationshipsResponse
281+
282+
full_prompt = PROMPT_TEMPLATE_RELATIONSHIPS.format(schema_summary=schema_summary)
283+
client = genai.Client()
284+
response = client.models.generate_content(
285+
model="gemini-2.5-flash",
286+
contents=full_prompt,
287+
config=types.GenerateContentConfig(
288+
response_mime_type="application/json",
289+
response_schema=SchemaRelationshipsResponse,
290+
thinking_config=types.ThinkingConfig(thinking_budget=0),
291+
),
292+
)
293+
294+
if hasattr(response, "parsed") and response.parsed:
295+
return response.parsed
296+
297+
import json
298+
try:
299+
data = json.loads(response.text)
300+
return SchemaRelationshipsResponse(**data)
301+
except Exception as e:
302+
print(f"Error parsing Gemini relationship response: {e}")
303+
return SchemaRelationshipsResponse(relationships=[])

backend/services/mongo_service.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,33 @@ def transform_mongo_result(result):
5656
elif isinstance(result, DeleteResult):
5757
return {"deleted_count": result.deleted_count}
5858
return result
59+
60+
61+
def get_database_schema_summary(account_id: str, database: str, access_token: str, collection_filter: list[str] = None) -> str:
62+
from services.azure_cosmos_resources import get_connection_string
63+
64+
try:
65+
connection_string = get_connection_string(account_id, access_token)
66+
client = pymongo.MongoClient(connection_string)
67+
db = client[database]
68+
summary = []
69+
70+
# Determine which collections to scan
71+
if collection_filter:
72+
target_collections = collection_filter
73+
else:
74+
target_collections = db.list_collection_names()
75+
76+
for collection_name in target_collections:
77+
# Skip system collections if scanning all (if explicit filter, try to fetch)
78+
if not collection_filter and collection_name.startswith("system."):
79+
continue
80+
81+
doc = db[collection_name].find_one()
82+
doc_str = str(doc) if doc else "No documents found"
83+
summary.append(f"Collection: {collection_name}\nSample Document: {doc_str}")
84+
85+
return "\n\n".join(summary)
86+
except Exception as e:
87+
print(f"Error fetching schema summary: {e}")
88+
return "Could not fetch schema summary."

0 commit comments

Comments
 (0)