Skip to content

Commit 943cf44

Browse files
committed
feat: Enable multi-collection query generation by updating context handling to pass a list of selected collections to the backend.
1 parent 605119d commit 943cf44

4 files changed

Lines changed: 38 additions & 13 deletions

File tree

backend/models/schemas.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ class QueryPrompt(BaseModel):
3030
user_input: str
3131
account_id: str # Added for cross-collection schema fetching
3232
db_context: DBContext
33-
collection_context: CollectionContext | None = (
34-
None # Optional context for specific collection
35-
)
33+
collection_context: list[CollectionContext] = [] # List of contexts for selected collections
3634
intermediate_context: object | None = (
3735
None # Optional intermediate context for complex queries
3836
)

backend/routes/query.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,20 @@ def nl2query(prompt: QueryPrompt = Body(...), authorization: str = Header(...)):
3131
try:
3232
user_token = authorization.replace("Bearer ", "")
3333
access_token = exchange_token_obo(user_token)
34-
# Fetch schema summary
35-
# Note: We need the connection string to fetch the schema
36-
# Ideally we might cache this, but for now we fetch it live
37-
schema_summary = get_database_schema_summary(prompt.account_id, prompt.db_context.name, access_token)
34+
# Use provided contexts if available to avoid re-fetch and ensure consistency
35+
if prompt.collection_context:
36+
summary = []
37+
for ctx in prompt.collection_context:
38+
doc_str = str(ctx.sampleDocument) if ctx.sampleDocument else "No documents found"
39+
summary.append(f"Collection: {ctx.name}\nSample Document: {doc_str}")
40+
schema_summary = "\n\n".join(summary)
41+
# Fallback: fetch schema summary from DB
42+
else:
43+
schema_summary = get_database_schema_summary(
44+
prompt.account_id,
45+
prompt.db_context.name,
46+
access_token
47+
)
3848
except Exception as e:
3949
print(f"Error fetching schema context: {e}")
4050
schema_summary = "Could not fetch schema summary."
@@ -44,8 +54,8 @@ def nl2query(prompt: QueryPrompt = Body(...), authorization: str = Header(...)):
4454
prompt.user_input,
4555
collections,
4656
prompt.db_context.name,
47-
prompt.collection_context,
48-
prompt.intermediate_context,
57+
collection_context=None,
58+
intermediate_context=prompt.intermediate_context,
4959
all_collections_schema=schema_summary
5060
)
5161

frontend/pages/QueryGeneratorPage.tsx

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,19 @@ const QueryGeneratorPage: React.FC<QueryGeneratorPageProps> = ({ name, email, on
619619
throw new Error("No account ID available for query generation.");
620620
}
621621

622-
const result = await generateMongoQuery(prompt, accountId, connectedDbInfo ?? undefined, collectionCtx, intermediateContext?.data);
622+
// Map selected collection names to their full info objects
623+
const selectedCollectionInfos = selectedCollections
624+
.map(name => collectionDetailsMap[name])
625+
.filter((info): info is CollectionInfo => !!info);
626+
627+
const result = await generateMongoQuery(
628+
prompt,
629+
accountId,
630+
connectedDbInfo ?? undefined,
631+
collectionCtx,
632+
intermediateContext?.data,
633+
selectedCollectionInfos // Pass full info objects
634+
);
623635
setQueryResult(result);
624636
setIntermediateContext(null); // Clear context after use
625637

@@ -1058,7 +1070,7 @@ const QueryGeneratorPage: React.FC<QueryGeneratorPageProps> = ({ name, email, on
10581070
if (isPromptUnchanged) return 'Query Generated';
10591071
if (selectedCollections.length > 0) {
10601072
if (selectedCollections.length === 1) return `Generate Query for ${selectedCollections[0]} collection`;
1061-
return `Generate Query for ${selectedCollections.length} collections`;
1073+
return `Generate Query across ${selectedCollections.length} collections`;
10621074
}
10631075
return 'Generate Query';
10641076
}, [isLoading, selectedCollections, isPromptUnchanged]);

frontend/services/geminiService.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ export const generateMongoQuery = async (
1818
userInput: string,
1919
accountId: string,
2020
dbInfo?: DbInfo,
21-
collectionContext?: CollectionInfo,
21+
collectionContext?: CollectionInfo, // Kept for backward compatibility/single select
2222
intermediateContext?: any,
23+
selectedCollections: CollectionInfo[] = [],
2324
): Promise<QueryResultData> => {
2425
// --- DEVELOPMENT MOCK ---
2526
if (!USE_MSAL_AUTH) {
@@ -66,7 +67,11 @@ export const generateMongoQuery = async (
6667
user_input: userInput,
6768
account_id: accountId,
6869
db_context: dbInfo, // Send DB context to the backend for more accurate queries
69-
collection_context: collectionContext, // Optional: send collection context if available
70+
// Send mapping of selected collections as context
71+
collection_context: selectedCollections.length > 0 ? selectedCollections.map(col => ({
72+
name: col.name || "",
73+
sampleDocument: col.sampleDocument
74+
})) : (collectionContext ? [collectionContext] : []), // Fallback to single context as array
7075
intermediate_context: intermediateContext, // Optional: send data from a previous query
7176
}),
7277
});

0 commit comments

Comments
 (0)