Skip to content

Commit e4c211a

Browse files
committed
fixes to rag retreival for policy prompts.
1 parent 510a2d3 commit e4c211a

File tree

2 files changed

+64
-39
lines changed

2 files changed

+64
-39
lines changed

main_chat/chat_route.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,7 @@ def _check_if_needs_new_data(
167167
"Return ONLY valid JSON with keys: needs_new_data (boolean) and reason (brief string explaining your decision)."
168168
)
169169

170-
user_prompt = (
171-
"Conversation History:\n" + (history_context if history_context else "(No previous conversation)") + "\n\n"
172-
"Cached Data:\n" + cache_summary + "\n\n"
173-
"Current Question: " + question + "\n\n"
174-
"Analyze if this question can be answered from the conversation history and/or cached data above, or if it needs new data retrieval.\n"
175-
"Return JSON only."
176-
)
170+
user_prompt = "Conversation History:\n" + (history_context if history_context else "(No previous conversation)") + "\n\n" "Cached Data:\n" + cache_summary + "\n\n" "Current Question: " + question + "\n\n" "Analyze if this question can be answered from the conversation history and/or cached data above, or if it needs new data retrieval.\n" "Return JSON only."
177171

178172
default_result = {"needs_new_data": True, "reason": "Error analyzing question, defaulting to new data"}
179173

@@ -266,11 +260,11 @@ def _route_question(question: str) -> Dict[str, Any]:
266260
" - If question mentions a SPECIFIC policy by name (e.g., 'Anti-Displacement Plan', 'Slow Streets', 'Imagine Boston 2030'):\n"
267261
" * Set policy_sources to that specific document (e.g., ['Boston Anti-Displacement Plan Analysis.txt'])\n"
268262
" * ALWAYS also add relevant transcript_tags for community perspective\n"
269-
" - If question is GENERAL about policy/planning but doesn't name a specific document:\n"
270-
" * Examples: 'What are current policy issues?', 'What policies affect housing?', 'What is being planned for the neighborhood?'\n"
263+
" - If question is GENERAL about policy/policies/planning/housing (doesn't name a specific document):\n"
264+
" * Examples: 'What are current policy issues?', 'What policies affect housing?', 'What is being planned for the neighborhood?', 'What does the city say about displacement?'\n"
271265
" * Set policy_sources to null (will search ALL policy documents)\n"
272-
" * Add relevant transcript_tags\n"
273-
" * Set k to at least 10 to get diverse policy coverage\n\n"
266+
" * Set folder_categories to ['policies'] to search the policy folder\n"
267+
" * Set k to at least 15 to get diverse policy coverage\n\n"
274268
"RULE 6: COMBINED DATA + CONTEXT → 'hybrid' mode\n"
275269
" - Questions that explicitly ask for BOTH numbers/data AND context/explanation\n"
276270
" - Examples: 'How many homicides and what concerns come up?', 'Show trends and how policies address them'\n"
@@ -304,7 +298,7 @@ def _route_question(question: str) -> Dict[str, Any]:
304298
"Question:\n" + question + "\n\n"
305299
"Policy sources include: 'Boston Anti-Displacement Plan Analysis.txt', 'Boston Slow Streets Plan Analysis.txt', 'Imagine Boston 2030 Analysis.txt'.\n"
306300
"Transcript tags include: safety, violence, youth, media, community, displacement, government, structural racism.\n"
307-
"Folder categories (for client uploads): newsletters, policy, transcripts.\n"
301+
"Folder categories (for client uploads): newsletters, policies, transcripts.\n"
308302
"Output JSON only."
309303
)
310304

@@ -584,39 +578,54 @@ def _run_rag(question: str, plan: Dict[str, Any], conversation_history: Optional
584578
k = int(plan.get("k", 5))
585579
tags = plan.get("transcript_tags")
586580
sources = plan.get("policy_sources")
581+
folders = plan.get("folder_categories")
587582

588583
combined_chunks: List[str] = []
589584
combined_meta: List[Dict[str, Any]] = []
590585

591586
# Increase k for better source diversity
592-
# Retrieve more chunks to get more diverse sources in the citations
593-
# Use at least 20 chunks to ensure we get multiple unique sources
594-
retrieval_k = max(k * 3, 20) # At least 20 chunks for source diversity
587+
retrieval_k = max(k * 3, 20)
595588

596-
# transcripts
589+
# ========================================================================
590+
# TRANSCRIPTS - retrieve with tags
591+
# ========================================================================
597592
try:
598593
t_res = rag_retrieval.retrieve_transcripts(question, tags=tags, k=retrieval_k)
599594
t_chunks = t_res.get("chunks", [])
600-
print(f" 📝 Transcripts: {len(t_chunks)} chunks found")
595+
print(f" 📄 Transcripts: {len(t_chunks)} chunks found")
601596
combined_chunks.extend(t_chunks)
602597
combined_meta.extend(t_res.get("metadata", []))
603598
except Exception as e:
604599
print(f" ⚠ Transcript retrieval error: {e}")
605600

606-
# policies
601+
# ========================================================================
602+
# POLICIES - retrieve from CLIENT_UPLOAD with folder_category filter
603+
# ========================================================================
607604
try:
608605
if sources:
609-
print(f" 🔍 Policy sources requested: {sources}")
610-
# When specific policy sources are requested
606+
# Specific policy documents requested by name
607+
print(f" 📚 Policy sources requested: {sources}")
611608
for src in sources:
612-
print(f" 🔍 Querying policy source: {src}")
609+
print(f" 📚 Querying policy source: {src}")
613610
p_res = rag_retrieval.retrieve_policies(question, k=retrieval_k, source=src)
614611
p_chunks = p_res.get("chunks", [])
615612
print(f" Found {len(p_chunks)} chunks from {src}")
616613
combined_chunks.extend(p_chunks)
617614
combined_meta.extend(p_res.get("metadata", []))
615+
elif folders:
616+
# Folder categories specified (e.g., ["policies", "newsletters"])
617+
print(f" 📁 Folder categories requested: {folders}")
618+
folders_list = folders if isinstance(folders, list) else [folders]
619+
620+
for folder in folders_list:
621+
p_res = rag_retrieval.retrieve(question, k=retrieval_k, doc_type="client_upload", folder_category=folder)
622+
p_chunks = p_res.get("chunks", [])
623+
print(f" Found {len(p_chunks)} chunks from folder: {folder}")
624+
combined_chunks.extend(p_chunks)
625+
combined_meta.extend(p_res.get("metadata", []))
618626
else:
619-
print(" 🔍 No specific policy sources, searching all policies")
627+
# No specific sources or folders - search all policies by default
628+
print(" 📚 No specific policy sources, searching all policies")
620629
p_res = rag_retrieval.retrieve_policies(question, k=retrieval_k)
621630
p_chunks = p_res.get("chunks", [])
622631
print(f" Policies: {len(p_chunks)} chunks found")

main_chat/rag_pipeline/rag_retrieval.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def load_vectordb():
3535
return vectordb
3636

3737

38-
def retrieve(query, k=5, doc_type=None, tags=None, source=None, min_score=None, vectordb=None):
38+
def retrieve(query, k=5, doc_type=None, tags=None, source=None, folder_category=None, min_score=None, vectordb=None):
3939
"""
4040
Universal retrieval with flexible metadata filtering.
41-
[... rest of function unchanged ...]
41+
Added 'folder_category' parameter to filter by Google Drive subfolder.
4242
"""
4343
# Defensive clamp: Chroma requires k >= 1
4444
try:
@@ -52,8 +52,9 @@ def retrieve(query, k=5, doc_type=None, tags=None, source=None, min_score=None,
5252
vectordb = load_vectordb()
5353

5454
# Build filter dictionary
55-
filter_dict = None
55+
filter_conditions = []
5656

57+
# Doc type filter
5758
doc_filter = None
5859
if isinstance(doc_type, (list, tuple)):
5960
doc_types = [dt for dt in doc_type if dt]
@@ -64,17 +65,29 @@ def retrieve(query, k=5, doc_type=None, tags=None, source=None, min_score=None,
6465
elif doc_type:
6566
doc_filter = {"doc_type": doc_type}
6667

67-
if doc_filter and source:
68-
filter_dict = {
69-
"$and": [
70-
doc_filter,
71-
{"source": source},
72-
]
73-
}
74-
elif doc_filter:
75-
filter_dict = doc_filter
76-
elif source:
77-
filter_dict = {"source": source}
68+
if doc_filter:
69+
filter_conditions.append(doc_filter)
70+
71+
# Source filter
72+
if source:
73+
filter_conditions.append({"source": source})
74+
75+
# Folder category filter (NEW)
76+
if folder_category:
77+
if isinstance(folder_category, (list, tuple)):
78+
if len(folder_category) == 1:
79+
filter_conditions.append({"folder_category": folder_category[0]})
80+
elif len(folder_category) > 1:
81+
filter_conditions.append({"$or": [{"folder_category": f} for f in folder_category]})
82+
else:
83+
filter_conditions.append({"folder_category": folder_category})
84+
85+
# Combine all conditions
86+
filter_dict = None
87+
if len(filter_conditions) == 1:
88+
filter_dict = filter_conditions[0]
89+
elif len(filter_conditions) > 1:
90+
filter_dict = {"$and": filter_conditions}
7891

7992
if min_score is not None:
8093
results_with_scores = vectordb.similarity_search_with_score(query, k=k * 3 if tags else k, filter=filter_dict if filter_dict else None)
@@ -114,12 +127,15 @@ def retrieve(query, k=5, doc_type=None, tags=None, source=None, min_score=None,
114127

115128
def retrieve_transcripts(query, tags=None, k=5):
116129
"""Convenience function for transcript-only search."""
117-
return retrieve(query, k=k, doc_type="transcript", tags=tags)
130+
return retrieve(query, k=k, doc_type="transcripts", tags=tags)
118131

119132

120133
def retrieve_policies(query, k=5, source=None):
121-
"""Convenience function for policy-only search."""
122-
return retrieve(query, k=k, doc_type="policy", source=source)
134+
"""
135+
Convenience function for policy-only search.
136+
Searches CLIENT_UPLOAD documents in the 'policies' folder category.
137+
"""
138+
return retrieve(query, k=k, doc_type="client_upload", folder_category="policies", source=source) # Changed from "policy" to "policies"
123139

124140

125141
def format_results(result_dict):

0 commit comments

Comments
 (0)