Skip to content

Commit 5c21144

Browse files
authored
Merge pull request #37 from ChingEnLin/feat/cross-collection-lookup-prompt
Cosmos-safe $lookup prompting for cross-collection query generation
2 parents 1d6e972 + 9b5c976 commit 5c21144

3 files changed

Lines changed: 130 additions & 18 deletions

File tree

backend/routes/query.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
execute_mongo_query,
2424
transform_mongo_result,
2525
get_database_schema_summary,
26+
SCHEMA_FETCH_FAILED,
2627
)
2728
from models.analyze import AnalyzeRequest, AnalyzeResponse
2829
from services.analyze_service import analyze_query_result
@@ -36,9 +37,12 @@
3637
DeleteResult,
3738
)
3839
import ast
40+
import logging
3941
import re
4042
from google import genai
4143

44+
logger = logging.getLogger(__name__)
45+
4246
router = APIRouter()
4347

4448

@@ -101,6 +105,26 @@ def nl2query(prompt: QueryPrompt = Body(...), authorization: str = Header(...)):
101105

102106
collections = [col.name for col in prompt.db_context.collections]
103107

108+
relationship_context = ""
109+
if (
110+
len(collections) > 1
111+
and schema_summary
112+
and schema_summary != SCHEMA_FETCH_FAILED
113+
):
114+
try:
115+
rels = generate_schema_relationships(schema_summary, model=prompt.model)
116+
if rels.relationships:
117+
relationship_context = "\n".join(
118+
f"- {r.source_collection}.{r.source_field} -> "
119+
f"{r.target_collection}.{r.target_field} "
120+
f"(confidence={r.confidence:.2f}) — {r.description}"
121+
for r in rels.relationships
122+
)
123+
except Exception as e:
124+
logger.warning(
125+
"Relationship inference failed; continuing without it: %s", e
126+
)
127+
104128
return run_query_generator(
105129
user_input=prompt.user_input,
106130
database=prompt.db_context.name,
@@ -110,6 +134,7 @@ def nl2query(prompt: QueryPrompt = Body(...), authorization: str = Header(...)):
110134
connection_string=connection_string,
111135
max_iterations=prompt.max_iterations,
112136
model=prompt.model,
137+
relationship_context=relationship_context,
113138
)
114139

115140

backend/services/mongo_service.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from bson import ObjectId
2+
3+
SCHEMA_FETCH_FAILED = "Could not fetch schema summary."
24
import pymongo
35
from pymongo.results import (
46
UpdateResult,
@@ -26,23 +28,23 @@ def execute_mongo_query(connection_string: str, database: str, query: str):
2628
return query_result
2729

2830

31+
def _convert_object_ids(value):
32+
# Recursively stringify ObjectIds inside nested dicts/lists so FastAPI can
33+
# serialize $lookup-joined documents (which carry nested _id ObjectIds).
34+
if isinstance(value, ObjectId):
35+
return str(value)
36+
if isinstance(value, dict):
37+
return {k: _convert_object_ids(v) for k, v in value.items()}
38+
if isinstance(value, list):
39+
return [_convert_object_ids(v) for v in value]
40+
return value
41+
42+
2943
def transform_mongo_result(result):
30-
# If result is a list of dicts, convert ObjectIds
3144
if isinstance(result, list):
32-
if result and isinstance(result[0], dict):
33-
for doc in result:
34-
for k, v in doc.items():
35-
if isinstance(v, ObjectId):
36-
doc[k] = str(v)
37-
return result
38-
else:
39-
# List of primitives (e.g., from distinct)
40-
return result
45+
return _convert_object_ids(result)
4146
elif isinstance(result, dict):
42-
for k, v in result.items():
43-
if isinstance(v, ObjectId):
44-
result[k] = str(v)
45-
return result
47+
return _convert_object_ids(result)
4648
elif isinstance(result, InsertOneResult):
4749
return {"inserted_id": str(result.inserted_id)}
4850
elif isinstance(result, InsertManyResult):
@@ -90,4 +92,4 @@ def get_database_schema_summary(
9092
return "\n\n".join(summary)
9193
except Exception as e:
9294
print(f"Error fetching schema summary: {e}")
93-
return "Could not fetch schema summary."
95+
return SCHEMA_FETCH_FAILED

backend/services/react_agent_service.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class AgentState(TypedDict):
5454
database: str
5555
collections: list[str]
5656
schema_context: str
57+
relationship_context: str
5758
intermediate_context: dict
5859
connection_string: str
5960
model: str
@@ -71,25 +72,92 @@ class AgentState(TypedDict):
7172
# Define LLM clients and prompts
7273
client = genai.Client()
7374

75+
CROSS_COLLECTION_GUIDANCE = """
76+
77+
Cross-Collection Join Guidance (MULTIPLE collections selected):
78+
- Target: Azure Cosmos DB (MongoDB API). IMPORTANT — Cosmos does NOT support `let` or `pipeline` inside `$lookup`. Only the plain `localField`/`foreignField` form works. Never emit `let` or a `pipeline` array inside `$lookup` — it will fail with `CommandNotSupported`.
79+
- Produce ONE aggregate() pipeline. Pick the most-filtering collection as the driver.
80+
- Use the Inferred Relationships block above to choose join fields. Do NOT invent field names that are not in the schema or relationships.
81+
- TYPE-MISMATCH HANDLING (this is the #1 reason $lookup returns empty arrays on Cosmos): if one side is an ObjectId and the other is a string, you MUST pre-convert with an `$addFields` stage BEFORE `$lookup`, then $lookup on the converted field. Do not rely on $lookup to coerce types.
82+
- String → ObjectId: `{"$addFields": {"<field>_oid": {"$toObjectId": "$<field>"}}}` (or `$map` over an array of strings).
83+
- ObjectId → String: `{"$addFields": {"<field>_str": {"$toString": "$<field>"}}}`.
84+
- After $lookup, $unwind the joined array (use `preserveNullAndEmptyArrays: True` for LEFT-JOIN semantics) before filtering on joined fields, OR filter with `joined.field` dotted notation.
85+
- Always include a $limit (<=50) unless the user explicitly asks for all rows.
86+
87+
Example A — simple equality join (same field type on both sides):
88+
```python
89+
db['orders'].aggregate([
90+
{"$match": {"status": "paid"}},
91+
{"$lookup": {
92+
"from": "users",
93+
"localField": "userId",
94+
"foreignField": "_id",
95+
"as": "user"
96+
}},
97+
{"$unwind": {"path": "$user", "preserveNullAndEmptyArrays": True}},
98+
{"$limit": 20}
99+
])
100+
```
101+
102+
Example B — array-of-strings → ObjectId join (Cosmos-safe; pre-convert via $addFields):
103+
```python
104+
db['patient-cohort'].aggregate([
105+
{"$addFields": {
106+
"patient_oids": {"$map": {"input": "$patient_ids", "in": {"$toObjectId": "$$this"}}}
107+
}},
108+
{"$lookup": {
109+
"from": "patient",
110+
"localField": "patient_oids",
111+
"foreignField": "_id",
112+
"as": "patients_info"
113+
}},
114+
{"$match": {"patients_info.origin_ethnicity": "caucasian"}},
115+
{"$limit": 50}
116+
])
117+
```
118+
119+
Example C — single ObjectId → string join (Cosmos-safe; pre-convert the other side):
120+
```python
121+
db['orders'].aggregate([
122+
{"$addFields": {"userId_str": {"$toString": "$userId"}}},
123+
{"$lookup": {
124+
"from": "users",
125+
"localField": "userId_str",
126+
"foreignField": "external_id",
127+
"as": "user"
128+
}},
129+
{"$unwind": "$user"},
130+
{"$limit": 20}
131+
])
132+
```
133+
"""
134+
74135
GENERATE_PROMPT = """
75-
You are an expert MongoDB architect.
136+
You are an expert MongoDB architect.
76137
Your task is to generate a PyMongo query based on the user's request.
77138
78139
User Request: {user_input}
79140
Database: {database}
80141
Collections: {collections}
81142
Schema summary: {schema_context}
143+
Inferred Relationships (foreign keys between collections):
144+
{relationship_context}
82145
Intermediate Context (optional): {intermediate_context}
83146
84-
Previous Evaluation Feedback (if this is a retry):
147+
Previous Attempt (if this is a retry — DO NOT repeat the same query verbatim):
148+
{previous_query}
149+
150+
Previous Evaluation Feedback (if this is a retry):
85151
{evaluation}
86152
87153
Instructions:
88-
1. Write ONLY the PyMongo query code.
154+
0. If a Previous Attempt is shown above, your new query MUST be materially different — change the join form, fields, types, or stages in response to the critique. Producing the same query (or a trivial reformat) is a failure. If the previous $lookup returned empty arrays for every input doc, the cause is almost always a type mismatch on the join key: add an `$addFields` stage BEFORE the `$lookup` that applies `$toObjectId` (string→OID) or `$toString` (OID→string), then $lookup on the converted field. NEVER use `let` or `pipeline` inside `$lookup` — Cosmos rejects both with CommandNotSupported.
155+
1. Write ONLY the PyMongo query code.
89156
2. Use variables appropriately (e.g., db['collection_name'].find(...) or db['collection_name'].aggregate(...)).
90157
3. Do not include markdown formatting or explanations, just return the code, but you can use ```python if you must.
91158
4. If the user is asking for a visualization, ensure the query retrieves the necessary data format.
92159
5. You are allowed to use both find() and aggregate() pipelines. If using aggregate(), be mindful of the resulting data size and include $limit stages if applicable.
160+
{cross_collection_guidance}
93161
"""
94162

95163
EVALUATE_PROMPT = """
@@ -105,6 +173,12 @@ class AgentState(TypedDict):
105173
Determine if this query successfully answers the user's request based on the code and the result.
106174
If there is an error in the query result, or if it clearly does not match the intent, it is NOT valid.
107175
If it is a write action, we cannot test the result, but you should evaluate if the code looks correct for the user's intent.
176+
If the query uses $lookup, specifically verify (target is Azure Cosmos DB MongoDB API):
177+
- The query does NOT use `let` or `pipeline` inside `$lookup` — Cosmos rejects both with `CommandNotSupported`. If you see either, it is NOT valid; recommend pre-converting the type with an `$addFields` stage and then using plain `localField`/`foreignField`.
178+
- The `localField`/`foreignField` reference fields that actually exist on each side.
179+
- When the joined array is empty for every input doc, suspect a type mismatch (ObjectId vs string). The query is NOT valid; recommend an `$addFields` stage BEFORE `$lookup` that applies `$toObjectId` (string→OID) or `$toString` (OID→string), and a `$map` if the local field is an array of strings.
180+
- Do not rationalize an empty result with "maybe no such records exist" when the filter is on a field inside the joined array — empty joined arrays will make the downstream $match drop everything; treat that as a join-correctness failure unless the join itself is clearly populated.
181+
- The driving collection is the right one (smallest filtered set first).
108182
109183
Respond in JSON format:
110184
{{
@@ -117,13 +191,22 @@ class AgentState(TypedDict):
117191
def generate_query_node(state: AgentState):
118192
logger.info(f"--- GENERATE NODE (Iteration {state.get('iterations', 0)}) ---")
119193

194+
is_multi_collection = len(state.get("collections", [])) > 1
195+
is_retry = state.get("iterations", 0) > 0
196+
previous_query = state.get("generated_query", "") if is_retry else ""
120197
prompt = GENERATE_PROMPT.format(
121198
user_input=state["user_input"],
122199
database=state["database"],
123200
collections=", ".join(state["collections"]),
124201
schema_context=state["schema_context"],
202+
relationship_context=state.get("relationship_context")
203+
or "None (single collection or no inferred relationships).",
125204
intermediate_context=state.get("intermediate_context", {}),
205+
previous_query=previous_query or "None (this is the first attempt).",
126206
evaluation=state.get("evaluation", "None"),
207+
cross_collection_guidance=(
208+
CROSS_COLLECTION_GUIDANCE if is_multi_collection else ""
209+
),
127210
)
128211

129212
try:
@@ -277,13 +360,15 @@ def run_query_generator(
277360
connection_string: str,
278361
max_iterations: int = 3,
279362
model: str = "gemini-2.5-flash",
363+
relationship_context: str = "",
280364
):
281365
logger.info(f"Starting ReAct Agent workflow for input: '{user_input}'")
282366
initial_state = {
283367
"user_input": user_input,
284368
"database": database,
285369
"collections": collections,
286370
"schema_context": schema_context,
371+
"relationship_context": relationship_context,
287372
"intermediate_context": intermediate_context,
288373
"connection_string": connection_string,
289374
"model": model,

0 commit comments

Comments
 (0)