Skip to content

Commit 9232748

Browse files
authored
new HybridRetriever class (#102)
* rewrite HybridRetriever class * fix types * fix HybridRetriever class inheritance issues * fix lint * multithread, as in Helia's code
1 parent ebe6180 commit 9232748

1 file changed

Lines changed: 194 additions & 34 deletions

File tree

src/retrievers/csv_chroma.py

Lines changed: 194 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,61 @@
1+
import asyncio
12
from pathlib import Path
3+
from typing import Annotated, Any, Coroutine, TypedDict
24

35
import chromadb.config
46
from langchain.chains.query_constructor.schema import AttributeInfo
5-
from langchain.retrievers import EnsembleRetriever
7+
from langchain.retrievers import EnsembleRetriever, MultiQueryRetriever
68
from langchain.retrievers.merger_retriever import MergerRetriever
79
from langchain.retrievers.self_query.base import SelfQueryRetriever
810
from langchain_chroma.vectorstores import Chroma
911
from langchain_community.document_loaders.csv_loader import CSVLoader
1012
from langchain_community.retrievers import BM25Retriever
13+
from langchain_core.documents import Document
1114
from langchain_core.embeddings import Embeddings
1215
from langchain_core.language_models.chat_models import BaseChatModel
13-
from langchain_core.retrievers import BaseRetriever
16+
from langchain_core.prompts.prompt import PromptTemplate
1417
from nltk.tokenize import word_tokenize
18+
from pydantic import AfterValidator, Field
19+
from pydantic.json_schema import SkipJsonSchema
1520

1621
chroma_settings = chromadb.config.Settings(anonymized_telemetry=False)
1722

23+
multi_query_prompt = PromptTemplate(
24+
input_variables=["question"],
25+
template="""You are a biomedical question expansion engine for information retrieval over the Reactome biological pathway database.
26+
27+
Given a single user question, generate **exactly 4** alternate standalone questions. These should be:
28+
29+
- Semantically related to the original question.
30+
- Lexically diverse to improve retrieval via vector search and RAG-fusion.
31+
- Biologically enriched with inferred or associated details.
32+
33+
Your goal is to improve recall of relevant documents by expanding the original query using:
34+
- Synonymous gene/protein names (e.g., EGFR, ErbB1, HER1)
35+
- Pathway or process-level context (e.g., signal transduction, apoptosis)
36+
- Known diseases, phenotypes, or biological functions
37+
- Cellular localization (e.g., nucleus, cytoplasm, membrane)
38+
- Upstream/downstream molecular interactions
39+
40+
Rules:
41+
- Each question must be **fully standalone** (no "this"/"it").
42+
- Do not change the core intent—preserve the user's informational goal.
43+
- Use appropriate biological terminology and Reactome-relevant concepts.
44+
- Vary the **phrasing**, **focus**, or **biological angle** of each question.
45+
- If the input is ambiguous, infer a biologically meaningful interpretation.
46+
47+
Output:
48+
Return only the 4 alternative questions separated by newlines.
49+
Do not include any explanations or metadata.
50+
51+
Original Question: {question}""",
52+
)
53+
54+
55+
ExcludedField = SkipJsonSchema[
56+
Annotated[Any, Field(default=None, exclude=True), AfterValidator(lambda x: None)]
57+
]
58+
1859

1960
def list_chroma_subdirectories(directory: Path) -> list[str]:
2061
subdirectories = list(
@@ -31,40 +72,159 @@ def create_bm25_chroma_ensemble_retriever(
3172
descriptions_info: dict[str, str],
3273
field_info: dict[str, list[AttributeInfo]],
3374
) -> MergerRetriever:
34-
retriever_list: list[BaseRetriever] = []
35-
for subdirectory in list_chroma_subdirectories(embeddings_directory):
36-
# set up BM25 retriever
37-
csv_file_name = subdirectory + ".csv"
38-
reactome_csvs_dir: Path = embeddings_directory / "csv_files"
39-
loader = CSVLoader(file_path=reactome_csvs_dir / csv_file_name)
40-
data = loader.load()
41-
bm25_retriever = BM25Retriever.from_documents(
42-
data,
43-
preprocess_func=lambda text: word_tokenize(
44-
text.casefold(), language="english"
45-
),
46-
)
47-
bm25_retriever.k = 10
75+
return HybridRetriever.from_subdirectory(
76+
llm,
77+
embedding,
78+
embeddings_directory,
79+
descriptions_info=descriptions_info,
80+
field_info=field_info,
81+
include_original=True,
82+
)
4883

49-
# set up vectorstore SelfQuery retriever
50-
vectordb = Chroma(
51-
persist_directory=str(embeddings_directory / subdirectory),
52-
embedding_function=embedding,
53-
client_settings=chroma_settings,
54-
)
5584

56-
selfq_retriever = SelfQueryRetriever.from_llm(
57-
llm=llm,
58-
vectorstore=vectordb,
59-
document_contents=descriptions_info[subdirectory],
60-
metadata_field_info=field_info[subdirectory],
61-
search_kwargs={"k": 10},
62-
)
63-
rrf_retriever = EnsembleRetriever(
64-
retrievers=[bm25_retriever, selfq_retriever], weights=[0.2, 0.8]
85+
class RetrieverDict(TypedDict):
86+
bm25: BM25Retriever
87+
vector: SelfQueryRetriever
88+
89+
90+
class HybridRetriever(MultiQueryRetriever):
91+
retriever: ExcludedField = None
92+
_retrievers: dict[str, RetrieverDict]
93+
94+
@classmethod
95+
def from_subdirectory(
96+
cls,
97+
llm: BaseChatModel,
98+
embedding: Embeddings,
99+
embeddings_directory: Path,
100+
*,
101+
descriptions_info: dict[str, str],
102+
field_info: dict[str, list[AttributeInfo]],
103+
include_original=False,
104+
):
105+
_retrievers: dict[str, RetrieverDict] = {}
106+
for subdirectory in list_chroma_subdirectories(embeddings_directory):
107+
# set up BM25 retriever
108+
csv_file_name = subdirectory + ".csv"
109+
reactome_csvs_dir: Path = embeddings_directory / "csv_files"
110+
loader = CSVLoader(file_path=reactome_csvs_dir / csv_file_name)
111+
data = loader.load()
112+
bm25_retriever = BM25Retriever.from_documents(
113+
data,
114+
preprocess_func=lambda text: word_tokenize(
115+
text.casefold(), language="english"
116+
),
117+
)
118+
bm25_retriever.k = 10
119+
120+
# set up vectorstore SelfQuery retriever
121+
vectordb = Chroma(
122+
persist_directory=str(embeddings_directory / subdirectory),
123+
embedding_function=embedding,
124+
client_settings=chroma_settings,
125+
)
126+
127+
selfq_retriever = SelfQueryRetriever.from_llm(
128+
llm=llm,
129+
vectorstore=vectordb,
130+
document_contents=descriptions_info[subdirectory],
131+
metadata_field_info=field_info[subdirectory],
132+
search_kwargs={"k": 10},
133+
)
134+
135+
_retrievers[subdirectory] = {
136+
"bm25": bm25_retriever,
137+
"vector": selfq_retriever,
138+
}
139+
llm_chain = MultiQueryRetriever.from_llm(
140+
bm25_retriever, llm, multi_query_prompt, None, include_original
141+
).llm_chain
142+
hybrid_retriever = cls(
143+
llm_chain=llm_chain,
144+
include_original=include_original,
145+
_retrievers={},
65146
)
66-
retriever_list.append(rrf_retriever)
147+
hybrid_retriever._retrievers = _retrievers
148+
return hybrid_retriever
149+
150+
def weighted_reciprocal_rank(
151+
self, doc_lists: list[list[Document]]
152+
) -> list[Document]:
153+
return EnsembleRetriever(
154+
retrievers=[], weights=[1 / len(doc_lists)] * len(doc_lists)
155+
).weighted_reciprocal_rank(doc_lists)
67156

68-
reactome_retriever = MergerRetriever(retrievers=retriever_list)
157+
def retrieve_documents(self, queries: list[str], run_manager) -> list[Document]:
158+
subdirectory_docs: list[Document] = []
159+
for subdirectory, retrievers in self._retrievers.items():
160+
bm25_retriever = retrievers["bm25"]
161+
vector_retriever = retrievers["vector"]
162+
doc_lists: list[list[Document]] = []
163+
for i, query in enumerate(queries):
164+
bm25_docs = bm25_retriever.invoke(
165+
query,
166+
config={
167+
"callbacks": run_manager.get_child(
168+
tag=f"{subdirectory}-bm25-{i}"
169+
)
170+
},
171+
)
172+
vector_docs = vector_retriever.invoke(
173+
query,
174+
config={
175+
"callbacks": run_manager.get_child(
176+
tag=f"{subdirectory}-vector-{i}"
177+
)
178+
},
179+
)
180+
doc_lists.append(bm25_docs + vector_docs)
181+
subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists))
182+
return subdirectory_docs
69183

70-
return reactome_retriever
184+
async def aretrieve_documents(
185+
self, queries: list[str], run_manager
186+
) -> list[Document]:
187+
subdirectory_results: dict[
188+
str,
189+
list[
190+
tuple[
191+
Coroutine[Any, Any, list[Document]],
192+
Coroutine[Any, Any, list[Document]],
193+
]
194+
],
195+
] = {}
196+
for subdirectory, retrievers in self._retrievers.items():
197+
bm25_retriever = retrievers["bm25"]
198+
vector_retriever = retrievers["vector"]
199+
subdirectory_results[subdirectory] = []
200+
for i, query in enumerate(queries):
201+
bm25_results = asyncio.to_thread(
202+
bm25_retriever.invoke,
203+
query,
204+
config={
205+
"callbacks": run_manager.get_child(
206+
tag=f"{subdirectory}-bm25-{i}"
207+
)
208+
},
209+
)
210+
vector_results = asyncio.to_thread(
211+
vector_retriever.invoke,
212+
query,
213+
config={
214+
"callbacks": run_manager.get_child(
215+
tag=f"{subdirectory}-vector-{i}"
216+
)
217+
},
218+
)
219+
subdirectory_results[subdirectory].extend(
220+
(bm25_results, vector_results)
221+
)
222+
subdirectory_docs: list[Document] = []
223+
for subdir_results in subdirectory_results.values():
224+
results_iter = iter(await asyncio.gather(*subdir_results))
225+
doc_lists: list[list[Document]] = [
226+
bm25_results + vector_results
227+
for bm25_results, vector_results in zip(results_iter, results_iter)
228+
]
229+
subdirectory_docs.extend(self.weighted_reciprocal_rank(doc_lists))
230+
return subdirectory_docs

0 commit comments

Comments
 (0)