Skip to content

Commit bab4aa7

Browse files
committed
fixed transcript retreival bugs.
1 parent 28c8a09 commit bab4aa7

File tree

2 files changed

+52
-38
lines changed

2 files changed

+52
-38
lines changed

main_chat/chat_route.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def _route_question(question: str) -> Dict[str, Any]:
277277
"3. If mode is 'rag' or 'hybrid':\n"
278278
" - transcript_tags: array of 0-2 strings OR null (valid tags: safety, violence, youth, media, community, displacement, government, structural racism)\n"
279279
" - policy_sources: array of strings OR null (valid: 'Boston Anti-Displacement Plan Analysis.txt', 'Boston Slow Streets Plan Analysis.txt', 'Imagine Boston 2030 Analysis.txt')\n"
280-
" - folder_categories: array of strings OR null (valid: newsletters, policy, transcripts)\n"
280+
" - folder_categories: array of strings OR null (valid: newsletters, policies, transcripts)\n"
281281
" - k: integer between 3 and 10 (default 5, minimum 5 for event queries)\n"
282282
"4. For crime questions using 'hybrid' mode (Rule 1b or 1c): transcript_tags MUST include at least one of: 'safety' or 'violence'\n"
283283
" For crime questions using 'sql' mode (Rule 1a): transcript_tags, policy_sources, and folder_categories MUST be null\n"
@@ -583,7 +583,6 @@ def _run_rag(question: str, plan: Dict[str, Any], conversation_history: Optional
583583
combined_chunks: List[str] = []
584584
combined_meta: List[Dict[str, Any]] = []
585585

586-
# Increase k for better source diversity
587586
retrieval_k = max(k * 3, 20)
588587

589588
# ========================================================================
@@ -599,7 +598,7 @@ def _run_rag(question: str, plan: Dict[str, Any], conversation_history: Optional
599598
print(f" ⚠ Transcript retrieval error: {e}")
600599

601600
# ========================================================================
602-
# POLICIES - retrieve from CLIENT_UPLOAD with folder_category filter
601+
# POLICIES - retrieve from client_upload with folder_category filter
603602
# ========================================================================
604603
try:
605604
if sources:
@@ -618,11 +617,17 @@ def _run_rag(question: str, plan: Dict[str, Any], conversation_history: Optional
618617
folders_list = folders if isinstance(folders, list) else [folders]
619618

620619
for folder in folders_list:
621-
p_res = rag_retrieval.retrieve(question, k=retrieval_k, doc_type="client_upload", folder_category=folder)
622-
p_chunks = p_res.get("chunks", [])
623-
print(f" Found {len(p_chunks)} chunks from folder: {folder}")
624-
combined_chunks.extend(p_chunks)
625-
combined_meta.extend(p_res.get("metadata", []))
620+
# Map folder to correct doc_type
621+
if folder == "transcripts":
622+
# Transcripts have doc_type="transcript" (already retrieved above)
623+
continue
624+
else:
625+
# Other folders (policies, newsletters) are client_upload
626+
p_res = rag_retrieval.retrieve(question, k=retrieval_k, doc_type="client_upload", folder_category=folder)
627+
p_chunks = p_res.get("chunks", [])
628+
print(f" Found {len(p_chunks)} chunks from folder: {folder}")
629+
combined_chunks.extend(p_chunks)
630+
combined_meta.extend(p_res.get("metadata", []))
626631
else:
627632
# No specific sources or folders - search all policies by default
628633
print(" 📚 No specific policy sources, searching all policies")

main_chat/rag_pipeline/rag_retrieval.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def load_vectordb():
3838
def retrieve(query, k=5, doc_type=None, tags=None, source=None, folder_category=None, min_score=None, vectordb=None):
3939
"""
4040
Universal retrieval with flexible metadata filtering.
41-
Added 'folder_category' parameter to filter by Google Drive subfolder.
4241
"""
4342
# Defensive clamp: Chroma requires k >= 1
4443
try:
@@ -52,45 +51,45 @@ def retrieve(query, k=5, doc_type=None, tags=None, source=None, folder_category=
5251
vectordb = load_vectordb()
5352

5453
# Build filter dictionary
55-
filter_conditions = []
54+
filter_dict = None
55+
56+
# Build list of filter conditions
57+
filters = []
5658

5759
# Doc type filter
58-
doc_filter = None
59-
if isinstance(doc_type, (list, tuple)):
60-
doc_types = [dt for dt in doc_type if dt]
61-
if len(doc_types) == 1:
62-
doc_filter = {"doc_type": doc_types[0]}
63-
elif len(doc_types) > 1:
64-
doc_filter = {"$or": [{"doc_type": dt} for dt in doc_types]}
65-
elif doc_type:
66-
doc_filter = {"doc_type": doc_type}
67-
68-
if doc_filter:
69-
filter_conditions.append(doc_filter)
60+
if doc_type:
61+
if isinstance(doc_type, (list, tuple)):
62+
doc_types = [dt for dt in doc_type if dt]
63+
if len(doc_types) == 1:
64+
filters.append({"doc_type": doc_types[0]})
65+
elif len(doc_types) > 1:
66+
filters.append({"$or": [{"doc_type": dt} for dt in doc_types]})
67+
else:
68+
filters.append({"doc_type": doc_type})
7069

7170
# Source filter
7271
if source:
73-
filter_conditions.append({"source": source})
72+
filters.append({"source": source})
7473

75-
# Folder category filter (NEW)
74+
# Folder category filter
7675
if folder_category:
7776
if isinstance(folder_category, (list, tuple)):
7877
if len(folder_category) == 1:
79-
filter_conditions.append({"folder_category": folder_category[0]})
78+
filters.append({"folder_category": folder_category[0]})
8079
elif len(folder_category) > 1:
81-
filter_conditions.append({"$or": [{"folder_category": f} for f in folder_category]})
80+
filters.append({"$or": [{"folder_category": f} for f in folder_category]})
8281
else:
83-
filter_conditions.append({"folder_category": folder_category})
82+
filters.append({"folder_category": folder_category})
8483

85-
# Combine all conditions
86-
filter_dict = None
87-
if len(filter_conditions) == 1:
88-
filter_dict = filter_conditions[0]
89-
elif len(filter_conditions) > 1:
90-
filter_dict = {"$and": filter_conditions}
84+
# Combine filters
85+
if len(filters) == 1:
86+
filter_dict = filters[0]
87+
elif len(filters) > 1:
88+
filter_dict = {"$and": filters}
9189

90+
# Retrieve with or without min_score
9291
if min_score is not None:
93-
results_with_scores = vectordb.similarity_search_with_score(query, k=k * 3 if tags else k, filter=filter_dict if filter_dict else None)
92+
results_with_scores = vectordb.similarity_search_with_score(query, k=k * 3 if tags else k, filter=filter_dict)
9493

9594
if tags:
9695
filtered_results = []
@@ -106,9 +105,14 @@ def retrieve(query, k=5, doc_type=None, tags=None, source=None, folder_category=
106105

107106
filtered_results = [(doc, score) for doc, score in results_with_scores if score <= min_score]
108107

109-
return {"chunks": [doc.page_content for doc, _ in filtered_results[:k]], "metadata": [doc.metadata for doc, _ in filtered_results[:k]], "scores": [score for _, score in filtered_results[:k]], "query": query}
108+
return {
109+
"chunks": [doc.page_content for doc, _ in filtered_results[:k]],
110+
"metadata": [doc.metadata for doc, _ in filtered_results[:k]],
111+
"scores": [score for _, score in filtered_results[:k]],
112+
"query": query,
113+
}
110114
else:
111-
results = vectordb.similarity_search(query, k=k * 3 if tags else k, filter=filter_dict if filter_dict else None)
115+
results = vectordb.similarity_search(query, k=k * 3 if tags else k, filter=filter_dict)
112116

113117
if tags:
114118
filtered_results = []
@@ -122,12 +126,17 @@ def retrieve(query, k=5, doc_type=None, tags=None, source=None, folder_category=
122126
if filtered_results:
123127
results = filtered_results
124128

125-
return {"chunks": [doc.page_content for doc in results[:k]], "metadata": [doc.metadata for doc in results[:k]], "scores": None, "query": query}
129+
return {
130+
"chunks": [doc.page_content for doc in results[:k]],
131+
"metadata": [doc.metadata for doc in results[:k]],
132+
"scores": None,
133+
"query": query,
134+
}
126135

127136

128137
def retrieve_transcripts(query, tags=None, k=5):
129138
"""Convenience function for transcript-only search."""
130-
return retrieve(query, k=k, doc_type="transcripts", tags=tags)
139+
return retrieve(query, k=k, doc_type="transcript", tags=tags)
131140

132141

133142
def retrieve_policies(query, k=5, source=None):

0 commit comments

Comments
 (0)