Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions integrations/chroma/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,22 @@ name = "chroma-haystack"
dynamic = ["version"]
description = ''
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = "Apache-2.0"
keywords = []
authors = [{ name = "John Doe", email = "jd@example.com" }]
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.11.0", "chromadb>=1.0.2"]
dependencies = ["haystack-ai>=2.22.0", "chromadb>=1.0.2"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/chroma#readme"
Expand Down Expand Up @@ -78,7 +77,6 @@ disallow_incomplete_defs = true
allow-direct-references = true

[tool.ruff]
target-version = "py39"
line-length = 120

[tool.ruff.lint]
Expand Down Expand Up @@ -129,10 +127,6 @@ ignore = [
# Allow assertions
"S101",
]
unfixable = [
# Don't touch unused imports
"F401",
]
exclude = ["example"]

[tool.ruff.lint.isort]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Optional, Union
from typing import Any

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.document_stores.types import FilterPolicy
Expand Down Expand Up @@ -48,9 +48,9 @@ class ChromaQueryTextRetriever:
def __init__(
self,
document_store: ChromaDocumentStore,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
filter_policy: str | FilterPolicy = FilterPolicy.REPLACE,
):
"""
:param document_store: an instance of `ChromaDocumentStore`.
Expand All @@ -69,8 +69,8 @@ def __init__(
def run(
self,
query: str,
filters: Optional[dict[str, Any]] = None,
top_k: Optional[int] = None,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
) -> dict[str, Any]:
"""
Run the retriever on the given input data.
Expand All @@ -94,8 +94,8 @@ def run(
async def run_async(
self,
query: str,
filters: Optional[dict[str, Any]] = None,
top_k: Optional[int] = None,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
) -> dict[str, Any]:
"""
Asynchronously run the retriever on the given input data.
Expand Down Expand Up @@ -161,9 +161,9 @@ class ChromaEmbeddingRetriever:
def __init__(
self,
document_store: ChromaDocumentStore,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
filter_policy: str | FilterPolicy = FilterPolicy.REPLACE,
):
"""
:param document_store: an instance of `ChromaDocumentStore`.
Expand All @@ -182,8 +182,8 @@ def __init__(
def run(
self,
query_embedding: list[float],
filters: Optional[dict[str, Any]] = None,
top_k: Optional[int] = None,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
) -> dict[str, Any]:
"""
Run the retriever on the given input data.
Expand All @@ -209,8 +209,8 @@ def run(
async def run_async(
self,
query_embedding: list[float],
filters: Optional[dict[str, Any]] = None,
top_k: Optional[int] = None,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
) -> dict[str, Any]:
"""
Asynchronously run the retriever on the given input data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Sequence
from typing import Any, Literal, Optional, cast
from typing import Any, Literal, cast

import chromadb
from chromadb.api.models.AsyncCollection import AsyncCollection
Expand Down Expand Up @@ -36,12 +36,12 @@ def __init__(
self,
collection_name: str = "documents",
embedding_function: str = "default",
persist_path: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None,
persist_path: str | None = None,
host: str | None = None,
port: int | None = None,
distance_function: Literal["l2", "cosine", "ip"] = "l2",
metadata: Optional[dict] = None,
client_settings: Optional[dict[str, Any]] = None,
metadata: dict | None = None,
client_settings: dict[str, Any] | None = None,
**embedding_function_params: Any,
):
"""
Expand Down Expand Up @@ -97,8 +97,8 @@ def __init__(
self._host = host
self._port = port

self._collection: Optional[chromadb.Collection] = None
self._async_collection: Optional[AsyncCollection] = None
self._collection: chromadb.Collection | None = None
self._async_collection: AsyncCollection | None = None

def _ensure_initialized(self):
if not self._collection:
Expand Down Expand Up @@ -208,7 +208,7 @@ async def _ensure_initialized_async(self):
)

@staticmethod
def _prepare_get_kwargs(filters: Optional[dict[str, Any]] = None) -> dict[str, Any]:
def _prepare_get_kwargs(filters: dict[str, Any] | None = None) -> dict[str, Any]:
"""
Prepare kwargs for Chroma get operations.
"""
Expand All @@ -226,7 +226,7 @@ def _prepare_get_kwargs(filters: Optional[dict[str, Any]] = None) -> dict[str, A
return kwargs

@staticmethod
def _prepare_query_kwargs(filters: Optional[dict[str, Any]] = None) -> dict[str, Any]:
def _prepare_query_kwargs(filters: dict[str, Any] | None = None) -> dict[str, Any]:
"""
Prepare kwargs for Chroma query operations.
"""
Expand Down Expand Up @@ -264,7 +264,7 @@ async def count_documents_async(self) -> int:

return value

def filter_documents(self, filters: Optional[dict[str, Any]] = None) -> list[Document]:
def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Document]:
"""
Returns the documents that match the filters provided.

Expand All @@ -282,7 +282,7 @@ def filter_documents(self, filters: Optional[dict[str, Any]] = None) -> list[Doc

return self._get_result_to_documents(result)

async def filter_documents_async(self, filters: Optional[dict[str, Any]] = None) -> list[Document]:
async def filter_documents_async(self, filters: dict[str, Any] | None = None) -> list[Document]:
"""
Asynchronously returns the documents that match the filters provided.

Expand Down Expand Up @@ -353,7 +353,7 @@ def _prepare_metadata_update(
return ids_to_update, updated_metadata

@staticmethod
def _convert_document_to_chroma(doc: Document) -> Optional[dict[str, Any]]:
def _convert_document_to_chroma(doc: Document) -> dict[str, Any] | None:
"""
Converts a Haystack Document to a Chroma document.
"""
Expand Down Expand Up @@ -755,7 +755,7 @@ def search(
self,
queries: list[str],
top_k: int,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
) -> list[list[Document]]:
"""
Search the documents in the store using the provided text queries.
Expand All @@ -781,7 +781,7 @@ async def search_async(
self,
queries: list[str],
top_k: int,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
) -> list[list[Document]]:
"""
Asynchronously search the documents in the store using the provided text queries.
Expand Down Expand Up @@ -809,7 +809,7 @@ def search_embeddings(
self,
query_embeddings: list[list[float]],
top_k: int,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
) -> list[list[Document]]:
"""
Perform vector search on the stored document, pass the embeddings of the queries instead of their text.
Expand Down Expand Up @@ -837,7 +837,7 @@ async def search_embeddings_async(
self,
query_embeddings: list[list[float]],
top_k: int,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
) -> list[list[Document]]:
"""
Asynchronously perform vector search on the stored document, pass the embeddings of the queries instead of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Optional, cast
from typing import Any, cast

from chromadb.api.types import WhereDocument, validate_where, validate_where_document

Expand Down Expand Up @@ -38,8 +38,8 @@ class ChromaFilter:
"""

ids: list[str]
where: Optional[dict[str, Any]]
where_document: Optional[dict[str, Any]]
where: dict[str, Any] | None
where_document: dict[str, Any] | None


def _convert_filters(filters: dict[str, Any]) -> ChromaFilter:
Expand Down
2 changes: 1 addition & 1 deletion integrations/chroma/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def assert_documents_are_equal(self, received: list[Document], expected: list[Do
received.sort(key=operator.attrgetter("id"))
expected.sort(key=operator.attrgetter("id"))

for doc_received, doc_expected in zip(received, expected):
for doc_received, doc_expected in zip(received, expected, strict=True):
assert doc_received.content == doc_expected.content
assert doc_received.meta == doc_expected.meta

Expand Down
2 changes: 1 addition & 1 deletion integrations/chroma/tests/test_document_store_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def assert_documents_are_equal(received: list[Document], expected: list[Document
received.sort(key=operator.attrgetter("id"))
expected.sort(key=operator.attrgetter("id"))

for doc_received, doc_expected in zip(received, expected):
for doc_received, doc_expected in zip(received, expected, strict=True):
assert doc_received.content == doc_expected.content
assert doc_received.meta == doc_expected.meta

Expand Down