Skip to content

Commit 4d07710

Browse files
Ryzhtusanakin87
andauthored
feat: add async support for Chroma + improve typing (#1697)
* Async support for chroma DocumentStore and Retrievers * Added test coverage (work in progress) * progress and refactorings * more refactoring * typing * fix * skip async tests on windows * fix --------- Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
1 parent cfd5254 commit 4d07710

9 files changed

Lines changed: 882 additions & 451 deletions

File tree

.github/workflows/chroma.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,9 @@ jobs:
4949
- name: Install Hatch
5050
run: pip install --upgrade hatch
5151

52-
# TODO: Once this integration is properly typed, use hatch run test:types
53-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
5452
- name: Lint
5553
if: matrix.python-version == '3.9' && runner.os == 'Linux'
56-
run: hatch run fmt-check && hatch run lint:typing
54+
run: hatch run fmt-check && hatch run test:types
5755

5856
- name: Generate docs
5957
if: matrix.python-version == '3.9' && runner.os == 'Linux'

integrations/chroma/pyproject.toml

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,6 @@ cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
6868

6969
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
7070

71-
# TODO: remove lint environment once this integration is properly typed
72-
# test environment should be used instead
73-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
74-
[tool.hatch.envs.lint]
75-
installer = "uv"
76-
detached = true
77-
dependencies = [
78-
"pip",
79-
"black>=23.1.0",
80-
"mypy>=1.0.0",
81-
"ruff>=0.0.243",
82-
"numpy", # we need the stubs from the main package
83-
]
84-
85-
[tool.hatch.envs.lint.scripts]
86-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
87-
8871
[tool.hatch.metadata]
8972
allow-direct-references = true
9073

@@ -176,14 +159,8 @@ exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
176159

177160
[tool.pytest.ini_options]
178161
minversion = "6.0"
179-
markers = ["unit: unit tests", "integration: integration tests"]
162+
markers = ["integration: integration tests"]
180163

181164
[[tool.mypy.overrides]]
182-
module = [
183-
"chromadb.*",
184-
"haystack.*",
185-
"haystack_integrations.*",
186-
"pytest.*",
187-
"numpy.*",
188-
]
165+
module = ["haystack_integrations.*"]
189166
ignore_missing_imports = true

integrations/chroma/src/haystack_integrations/components/retrievers/chroma/retriever.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,33 @@ def run(
9090
top_k = top_k or self.top_k
9191
return {"documents": self.document_store.search([query], top_k, filters)[0]}
9292

93+
@component.output_types(documents=List[Document])
94+
async def run_async(
95+
self,
96+
query: str,
97+
filters: Optional[Dict[str, Any]] = None,
98+
top_k: Optional[int] = None,
99+
):
100+
"""
101+
Asynchronously run the retriever on the given input data.
102+
103+
Asynchronous methods are only supported for HTTP connections.
104+
105+
:param query: The input data for the retriever. In this case, a plain-text query.
106+
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
107+
the `filter_policy` chosen at retriever initialization. See init method docstring for more
108+
details.
109+
:param top_k: The maximum number of documents to retrieve.
110+
If not specified, the default value from the constructor is used.
111+
:returns: A dictionary with the following keys:
112+
- `documents`: List of documents returned by the search engine.
113+
114+
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
115+
"""
116+
filters = apply_filter_policy(self.filter_policy, self.filters, filters)
117+
top_k = top_k or self.top_k
118+
return {"documents": await self.document_store.search_async([query], top_k, filters)[0]}
119+
93120
@classmethod
94121
def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryTextRetriever":
95122
"""
@@ -126,11 +153,31 @@ def to_dict(self) -> Dict[str, Any]:
126153

127154

128155
@component
129-
class ChromaEmbeddingRetriever(ChromaQueryTextRetriever):
156+
class ChromaEmbeddingRetriever:
130157
"""
131158
A component for retrieving documents from a [Chroma database](https://docs.trychroma.com/) using embeddings.
132159
"""
133160

161+
def __init__(
162+
self,
163+
document_store: ChromaDocumentStore,
164+
filters: Optional[Dict[str, Any]] = None,
165+
top_k: int = 10,
166+
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
167+
):
168+
"""
169+
:param document_store: an instance of `ChromaDocumentStore`.
170+
:param filters: filters to narrow down the search space.
171+
:param top_k: the maximum number of documents to retrieve.
172+
:param filter_policy: Policy to determine how filters are applied.
173+
"""
174+
self.filters = filters or {}
175+
self.top_k = top_k
176+
self.document_store = document_store
177+
self.filter_policy = (
178+
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
179+
)
180+
134181
@component.output_types(documents=List[Document])
135182
def run(
136183
self,
@@ -157,3 +204,66 @@ def run(
157204

158205
query_embeddings = [query_embedding]
159206
return {"documents": self.document_store.search_embeddings(query_embeddings, top_k, filters)[0]}
207+
208+
@component.output_types(documents=List[Document])
209+
async def run_async(
210+
self,
211+
query_embedding: List[float],
212+
filters: Optional[Dict[str, Any]] = None,
213+
top_k: Optional[int] = None,
214+
):
215+
"""
216+
Asynchronously run the retriever on the given input data.
217+
218+
Asynchronous methods are only supported for HTTP connections.
219+
220+
:param query_embedding: the query embeddings.
221+
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
222+
the `filter_policy` chosen at retriever initialization. See init method docstring for more
223+
details.
224+
:param top_k: the maximum number of documents to retrieve.
225+
If not specified, the default value from the constructor is used.
226+
227+
:returns: a dictionary with the following keys:
228+
- `documents`: List of documents returned by the search engine.
229+
"""
230+
filters = apply_filter_policy(self.filter_policy, self.filters, filters)
231+
232+
top_k = top_k or self.top_k
233+
234+
query_embeddings = [query_embedding]
235+
return {"documents": await self.document_store.search_embeddings_async(query_embeddings, top_k, filters)[0]}
236+
237+
@classmethod
238+
def from_dict(cls, data: Dict[str, Any]) -> "ChromaEmbeddingRetriever":
239+
"""
240+
Deserializes the component from a dictionary.
241+
242+
:param data:
243+
Dictionary to deserialize from.
244+
:returns:
245+
Deserialized component.
246+
"""
247+
document_store = ChromaDocumentStore.from_dict(data["init_parameters"]["document_store"])
248+
data["init_parameters"]["document_store"] = document_store
249+
# Pipelines serialized with old versions of the component might not
250+
# have the filter_policy field.
251+
if filter_policy := data["init_parameters"].get("filter_policy"):
252+
data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(filter_policy)
253+
254+
return default_from_dict(cls, data)
255+
256+
def to_dict(self) -> Dict[str, Any]:
257+
"""
258+
Serializes the component to a dictionary.
259+
260+
:returns:
261+
Dictionary with serialized data.
262+
"""
263+
return default_to_dict(
264+
self,
265+
filters=self.filters,
266+
top_k=self.top_k,
267+
filter_policy=self.filter_policy.value,
268+
document_store=self.document_store.to_dict(),
269+
)

0 commit comments

Comments
 (0)