Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any

Expand Down Expand Up @@ -125,11 +126,39 @@ def run(self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
return {"documents": docs}

@component.output_types(documents=list[Document])
async def run_async(
self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None
) -> dict[str, list[Document]]:
Comment on lines +129 to +132
"""
Retrieve documents using multiple queries concurrently.

Uses each component's `run_async` method if available, otherwise falls back to running `run`
in a thread executor. Queries are processed concurrently using asyncio.gather.

:param queries: List of text queries to process.
:param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
:returns:
A dictionary containing:
- `documents`: List of retrieved documents sorted by relevance score.
"""
retriever_kwargs = retriever_kwargs or {}

if not self._is_warmed_up:
self.warm_up()

results = await asyncio.gather(*[self._run_one_async(q, retriever_kwargs) for q in queries])
docs: list[Document] = [doc for result in results if result for doc in result]
docs = _deduplicate_documents(docs)
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
return {"documents": docs}
Comment on lines +145 to +154

def _run_on_thread(self, query: str, retriever_kwargs: dict[str, Any] | None = None) -> list[Document] | None:
"""
Process a single query on a separate thread.

:param query: The text query to process.
:param retriever_kwargs: Arguments to pass to the retriever's run method.
:returns:
List of retrieved documents or None if no results.
"""
Expand All @@ -140,6 +169,35 @@ def _run_on_thread(self, query: str, retriever_kwargs: dict[str, Any] | None = N
return result["documents"]
return None

async def _run_one_async(self, query: str, retriever_kwargs: dict[str, Any]) -> list[Document] | None:
"""
Process a single query asynchronously.

:param query: The text query to process.
:param retriever_kwargs: Arguments to pass to the retriever's run method.
:returns:
List of retrieved documents or None if no results.
"""
loop = asyncio.get_running_loop()

if hasattr(self.query_embedder, "run_async") and callable(self.query_embedder.run_async):
embedding_result = await self.query_embedder.run_async(text=query)
else:
embedding_result = await loop.run_in_executor(None, lambda: self.query_embedder.run(text=query))

query_embedding = embedding_result["embedding"]

if hasattr(self.retriever, "run_async") and callable(self.retriever.run_async):
result = await self.retriever.run_async(query_embedding=query_embedding, **retriever_kwargs)
else:
result = await loop.run_in_executor(
None, lambda: self.retriever.run(query_embedding=query_embedding, **retriever_kwargs)
)

if result and "documents" in result:
return result["documents"]
return None

def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand Down
48 changes: 48 additions & 0 deletions haystack/components/retrievers/multi_query_text_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any

Expand Down Expand Up @@ -105,6 +106,33 @@ def run(self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
return {"documents": docs}

@component.output_types(documents=list[Document])
async def run_async(
self, queries: list[str], retriever_kwargs: dict[str, Any] | None = None
) -> dict[str, list[Document]]:
Comment on lines +109 to +112
"""
Retrieve documents using multiple queries concurrently.

Uses the retriever's `run_async` method if available, otherwise falls back to running `run`
in a thread executor. Queries are processed concurrently using asyncio.gather.

:param queries: List of text queries to process.
:param retriever_kwargs: Optional dictionary of arguments to pass to the retriever's run method.
:returns:
A dictionary containing:
`documents`: List of retrieved documents sorted by relevance score.
"""
retriever_kwargs = retriever_kwargs or {}

if not self._is_warmed_up:
self.warm_up()

results = await asyncio.gather(*[self._run_one_async(q, retriever_kwargs) for q in queries])
docs: list[Document] = [doc for result in results if result for doc in result]
docs = _deduplicate_documents(docs)
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
return {"documents": docs}
Comment on lines +125 to +134

def _run_on_thread(self, query: str, retriever_kwargs: dict[str, Any] | None = None) -> list[Document] | None:
"""
Process a single query on a separate thread.
Expand All @@ -119,6 +147,26 @@ def _run_on_thread(self, query: str, retriever_kwargs: dict[str, Any] | None = N
return result["documents"]
return None

async def _run_one_async(self, query: str, retriever_kwargs: dict[str, Any]) -> list[Document] | None:
"""
Process a single query asynchronously.

:param query: The text query to process.
:param retriever_kwargs: Arguments to pass to the retriever's run method.
:returns:
List of retrieved documents or None if no results.
"""
loop = asyncio.get_running_loop()

if hasattr(self.retriever, "run_async") and callable(self.retriever.run_async):
result = await self.retriever.run_async(query=query, **retriever_kwargs)
else:
result = await loop.run_in_executor(None, lambda: self.retriever.run(query=query, **retriever_kwargs))

if result and "documents" in result:
return result["documents"]
return None

def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand Down
42 changes: 42 additions & 0 deletions haystack/components/retrievers/text_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
from typing import Any

from haystack import Document, component, default_from_dict, default_to_dict
Expand Down Expand Up @@ -104,6 +105,47 @@ def run(
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
return {"documents": docs}

@component.output_types(documents=list[Document])
async def run_async(
self, query: str, filters: dict[str, Any] | None = None, top_k: int | None = None
) -> dict[str, list[Document]]:
"""
Retrieve documents using a single query asynchronously.

Uses `run_async` on the text embedder and retriever if available, otherwise falls back to
running `run` in a thread executor.

:param query: The query to retrieve documents for.
:param filters: A dictionary of filters to apply when retrieving documents.
:param top_k: The maximum number of documents to return.
:returns:
A dictionary containing:
- `documents`: List of retrieved documents sorted by relevance score.
"""
if not self._is_warmed_up:
self.warm_up()

loop = asyncio.get_running_loop()

if hasattr(self.text_embedder, "run_async") and callable(self.text_embedder.run_async):
embedding_result = await self.text_embedder.run_async(text=query)
else:
embedding_result = await loop.run_in_executor(None, lambda: self.text_embedder.run(text=query))

if hasattr(self.retriever, "run_async") and callable(self.retriever.run_async):
result = await self.retriever.run_async(
query_embedding=embedding_result["embedding"], filters=filters, top_k=top_k
)
else:
result = await loop.run_in_executor(
None,
lambda: self.retriever.run(query_embedding=embedding_result["embedding"], filters=filters, top_k=top_k),
)
Comment on lines +128 to +143

docs: list[Document] = result["documents"]
docs.sort(key=lambda x: x.score or 0.0, reverse=True)
return {"documents": docs}

def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
enhancements:
- |
Added ``run_async`` to ``TextEmbeddingRetriever``, ``MultiQueryEmbeddingRetriever``, and
``MultiQueryTextRetriever``. These components now execute natively as coroutines in
``AsyncPipeline``, delegating to each wrapped component's ``run_async`` when available and
falling back to a thread executor otherwise.
Loading
Loading