Skip to content

Commit f99fc7d

Browse files
feat: add run_async to TextEmbeddingRetriever, MultiQueryEmbeddingRetriever, and MultiQueryTextRetriever (#11367)
Co-authored-by: David S. Batista <dsbatista@gmail.com>
1 parent cdeec75 commit f99fc7d

7 files changed

Lines changed: 653 additions & 0 deletions

haystack/components/retrievers/multi_query_embedding_retriever.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import asyncio
56
from concurrent.futures import ThreadPoolExecutor
67
from typing import Any
78

@@ -125,11 +126,39 @@ def run(self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None
125126
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
126127
return {"documents": docs}
127128

129+
@component.output_types(documents=list[Document])
130+
async def run_async(
131+
self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None
132+
) -> dict[str, list[Document]]:
133+
"""
134+
Retrieve documents using multiple queries concurrently.
135+
136+
Uses each component's `run_async` method if available, otherwise falls back to running `run`
137+
in a thread executor. Queries are processed concurrently using asyncio.gather.
138+
139+
:param queries: List of text queries to process.
140+
:param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
141+
:returns:
142+
A dictionary containing:
143+
- `documents`: List of retrieved documents sorted by relevance score.
144+
"""
145+
retriever_kwargs = retriever_kwargs or {}
146+
147+
if not self._is_warmed_up:
148+
self.warm_up()
149+
150+
results = await asyncio.gather(*[self._run_one_async(q, retriever_kwargs) for q in queries])
151+
docs: list[Document] = [doc for result in results if result for doc in result]
152+
docs = _deduplicate_documents(docs)
153+
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
154+
return {"documents": docs}
155+
128156
def _run_on_thread(self, query: str, retriever_kwargs: dict[str, Any] | None = None) -> list[Document] | None:
129157
"""
130158
Process a single query on a separate thread.
131159
132160
:param query: The text query to process.
161+
:param retriever_kwargs: Arguments to pass to the retriever's run method.
133162
:returns:
134163
List of retrieved documents or None if no results.
135164
"""
@@ -140,6 +169,35 @@ def _run_on_thread(self, query: str, retriever_kwargs: dict[str, Any] | None = N
140169
return result["documents"]
141170
return None
142171

172+
async def _run_one_async(self, query: str, retriever_kwargs: dict[str, Any]) -> list[Document] | None:
173+
"""
174+
Process a single query asynchronously.
175+
176+
:param query: The text query to process.
177+
:param retriever_kwargs: Arguments to pass to the retriever's run method.
178+
:returns:
179+
List of retrieved documents or None if no results.
180+
"""
181+
loop = asyncio.get_running_loop()
182+
183+
if hasattr(self.query_embedder, "run_async") and callable(self.query_embedder.run_async):
184+
embedding_result = await self.query_embedder.run_async(text=query)
185+
else:
186+
embedding_result = await loop.run_in_executor(None, lambda: self.query_embedder.run(text=query))
187+
188+
query_embedding = embedding_result["embedding"]
189+
190+
if hasattr(self.retriever, "run_async") and callable(self.retriever.run_async):
191+
result = await self.retriever.run_async(query_embedding=query_embedding, **retriever_kwargs)
192+
else:
193+
result = await loop.run_in_executor(
194+
None, lambda: self.retriever.run(query_embedding=query_embedding, **retriever_kwargs)
195+
)
196+
197+
if result and "documents" in result:
198+
return result["documents"]
199+
return None
200+
143201
def to_dict(self) -> dict[str, Any]:
144202
"""
145203
Serializes the component to a dictionary.

haystack/components/retrievers/multi_query_text_retriever.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import asyncio
56
from concurrent.futures import ThreadPoolExecutor
67
from typing import Any
78

@@ -105,6 +106,33 @@ def run(self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None
105106
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
106107
return {"documents": docs}
107108

109+
@component.output_types(documents=list[Document])
110+
async def run_async(
111+
self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None
112+
) -> dict[str, list[Document]]:
113+
"""
114+
Retrieve documents using multiple queries concurrently.
115+
116+
Uses the retriever's `run_async` method if available, otherwise falls back to running `run`
117+
in a thread executor. Queries are processed concurrently using asyncio.gather.
118+
119+
:param queries: List of text queries to process.
120+
:param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
121+
:returns:
122+
A dictionary containing:
123+
`documents`: List of retrieved documents sorted by relevance score.
124+
"""
125+
retriever_kwargs = retriever_kwargs or {}
126+
127+
if not self._is_warmed_up:
128+
self.warm_up()
129+
130+
results = await asyncio.gather(*[self._run_one_async(q, retriever_kwargs) for q in queries])
131+
docs: list[Document] = [doc for result in results if result for doc in result]
132+
docs = _deduplicate_documents(docs)
133+
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
134+
return {"documents": docs}
135+
108136
def _run_on_thread(self, query: str, retriever_kwargs: dict[str, Any] | None = None) -> list[Document] | None:
109137
"""
110138
Process a single query on a separate thread.
@@ -119,6 +147,26 @@ def _run_on_thread(self, query: str, retriever_kwargs: dict[str, Any] | None = N
119147
return result["documents"]
120148
return None
121149

150+
async def _run_one_async(self, query: str, retriever_kwargs: dict[str, Any]) -> list[Document] | None:
151+
"""
152+
Process a single query asynchronously.
153+
154+
:param query: The text query to process.
155+
:param retriever_kwargs: Arguments to pass to the retriever's run method.
156+
:returns:
157+
List of retrieved documents or None if no results.
158+
"""
159+
loop = asyncio.get_running_loop()
160+
161+
if hasattr(self.retriever, "run_async") and callable(self.retriever.run_async):
162+
result = await self.retriever.run_async(query=query, **retriever_kwargs)
163+
else:
164+
result = await loop.run_in_executor(None, lambda: self.retriever.run(query=query, **retriever_kwargs))
165+
166+
if result and "documents" in result:
167+
return result["documents"]
168+
return None
169+
122170
def to_dict(self) -> dict[str, Any]:
123171
"""
124172
Serializes the component to a dictionary.

haystack/components/retrievers/text_embedding_retriever.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import asyncio
56
from typing import Any
67

78
from haystack import Document, component, default_from_dict, default_to_dict
@@ -104,6 +105,47 @@ def run(
104105
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
105106
return {"documents": docs}
106107

108+
@component.output_types(documents=list[Document])
109+
async def run_async(
110+
self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None
111+
) -> dict[str, list[Document]]:
112+
"""
113+
Retrieve documents using a single query asynchronously.
114+
115+
Uses `run_async` on the text embedder and retriever if available, otherwise falls back to
116+
running `run` in a thread executor.
117+
118+
:param query: The query to retrieve documents for.
119+
:param filters: A dictionary of filters to apply when retrieving documents.
120+
:param top_k: The maximum number of documents to return.
121+
:returns:
122+
A dictionary containing:
123+
- `documents`: List of retrieved documents sorted by relevance score.
124+
"""
125+
if not self._is_warmed_up:
126+
self.warm_up()
127+
128+
loop = asyncio.get_running_loop()
129+
130+
if hasattr(self.text_embedder, "run_async") and callable(self.text_embedder.run_async):
131+
embedding_result = await self.text_embedder.run_async(text=query)
132+
else:
133+
embedding_result = await loop.run_in_executor(None, lambda: self.text_embedder.run(text=query))
134+
135+
if hasattr(self.retriever, "run_async") and callable(self.retriever.run_async):
136+
result = await self.retriever.run_async(
137+
query_embedding=embedding_result["embedding"], filters=filters, top_k=top_k
138+
)
139+
else:
140+
result = await loop.run_in_executor(
141+
None,
142+
lambda: self.retriever.run(query_embedding=embedding_result["embedding"], filters=filters, top_k=top_k),
143+
)
144+
145+
docs: list[Document] = result["documents"]
146+
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
147+
return {"documents": docs}
148+
107149
def to_dict(self) -> dict[str, Any]:
108150
"""
109151
Serializes the component to a dictionary.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
enhancements:
3+
- |
4+
Added ``run_async`` to ``TextEmbeddingRetriever``, ``MultiQueryEmbeddingRetriever``, and
5+
``MultiQueryTextRetriever``. These components now execute natively as coroutines in
6+
``AsyncPipeline``, delegating to each wrapped component's ``run_async`` when available and
7+
falling back to a thread executor otherwise.

0 commit comments

Comments
 (0)