Skip to content

Commit 48c6871

Browse files
Merge branch 'main' into feat/add-count-filtering-to-ElasticSearchDocumentStore
2 parents f514842 + 32bad32 commit 48c6871

9 files changed

Lines changed: 123 additions & 20 deletions

File tree

integrations/chroma/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@
1111
## Contributing
1212

1313
Refer to the general [Contribution Guidelines](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CONTRIBUTING.md).
14+
15+
To run integration tests locally, you need a Chroma server running.
16+
Start one with: `docker run -p 8000:8000 chromadb/chroma:latest`.

integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import chromadb
99
from chromadb.api.models.AsyncCollection import AsyncCollection
1010
from chromadb.api.types import GetResult, Metadata, OneOrMany, QueryResult
11+
from chromadb.config import Settings
1112
from haystack import default_from_dict, default_to_dict, logging
1213
from haystack.dataclasses import Document
1314
from haystack.document_stores.errors import DocumentStoreError
@@ -40,6 +41,7 @@ def __init__(
4041
port: Optional[int] = None,
4142
distance_function: Literal["l2", "cosine", "ip"] = "l2",
4243
metadata: Optional[dict] = None,
44+
client_settings: Optional[dict[str, Any]] = None,
4345
**embedding_function_params: Any,
4446
):
4547
"""
@@ -67,6 +69,11 @@ def __init__(
6769
:param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client
6870
method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the
6971
`distance_function` parameter above.
72+
:param client_settings: a dictionary of Chroma Settings configuration options passed to
73+
`chromadb.config.Settings`. These settings configure the underlying Chroma client behavior.
74+
For available options, see [Chroma's config.py](https://github.com/chroma-core/chroma/blob/main/chromadb/config.py).
75+
**Note**: specifying these settings may interfere with standard client initialization parameters.
76+
This option is intended for advanced customization.
7077
:param embedding_function_params: additional parameters to pass to the embedding function.
7178
"""
7279

@@ -84,6 +91,7 @@ def __init__(
8491
self._embedding_function_params = embedding_function_params
8592
self._distance_function = distance_function
8693
self._metadata = metadata
94+
self._client_settings = client_settings
8795

8896
self._persist_path = persist_path
8997
self._host = host
@@ -102,18 +110,29 @@ def _ensure_initialized(self):
102110
"You cannot specify both options."
103111
)
104112
raise ValueError(error_message)
113+
114+
# Use dict to conditionally pass settings because Chroma doesn't accept settings=None
115+
client_kwargs: dict[str, Any] = {}
116+
if self._client_settings:
117+
try:
118+
client_kwargs["settings"] = Settings(**self._client_settings)
119+
except ValueError as e:
120+
msg = f"Invalid client_settings ({self._client_settings}): {e}"
121+
raise ValueError(msg) from e
122+
105123
if self._host and self._port is not None:
106124
# Remote connection via HTTP client
107125
client = chromadb.HttpClient(
108126
host=self._host,
109127
port=self._port,
128+
**client_kwargs,
110129
)
111130
elif self._persist_path is None:
112131
# In-memory storage
113-
client = chromadb.Client()
132+
client = chromadb.Client(**client_kwargs)
114133
else:
115134
# Local persistent storage
116-
client = chromadb.PersistentClient(path=self._persist_path)
135+
client = chromadb.PersistentClient(path=self._persist_path, **client_kwargs)
117136

118137
self._client = client # store client for potential future use
119138

@@ -148,9 +167,19 @@ async def _ensure_initialized_async(self):
148167
)
149168
raise ValueError(error_message)
150169

170+
# Use dict to conditionally pass settings because Chroma doesn't accept settings=None
171+
client_kwargs: dict[str, Any] = {}
172+
if self._client_settings:
173+
try:
174+
client_kwargs["settings"] = Settings(**self._client_settings)
175+
except ValueError as e:
176+
msg = f"Invalid client_settings ({self._client_settings}): {e}"
177+
raise ValueError(msg) from e
178+
151179
client = await chromadb.AsyncHttpClient(
152180
host=self._host,
153181
port=self._port,
182+
**client_kwargs,
154183
)
155184

156185
self._async_client = client # store client for potential future use
@@ -862,6 +891,7 @@ def to_dict(self) -> dict[str, Any]:
862891
host=self._host,
863892
port=self._port,
864893
distance_function=self._distance_function,
894+
client_settings=self._client_settings,
865895
**self._embedding_function_params,
866896
)
867897

integrations/chroma/tests/test_document_store.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from unittest import mock
1010

1111
import pytest
12+
from chromadb.api.shared_system_client import SharedSystemClient
1213
from haystack.dataclasses import ByteStream, Document
1314
from haystack.testing.document_store import (
1415
TEST_EMBEDDING_1,
@@ -20,6 +21,19 @@
2021
from haystack_integrations.document_stores.chroma import ChromaDocumentStore
2122

2223

24+
@pytest.fixture
25+
def clear_chroma_system_cache():
26+
"""
27+
Chroma's in-memory client uses a singleton pattern with an internal cache.
28+
Once a client is created with certain settings, Chroma rejects creating another
29+
with different settings in the same process. This fixture clears the cache
30+
before and after tests that use custom client settings.
31+
"""
32+
SharedSystemClient.clear_system_cache()
33+
yield
34+
SharedSystemClient.clear_system_cache()
35+
36+
2337
class TestDocumentStore(CountDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest):
2438
"""
2539
Common test cases will be provided by `DocumentStoreBaseTests` but
@@ -75,6 +89,10 @@ def test_init_http_connection(self):
7589
assert store._host == "localhost"
7690
assert store._port == 8000
7791

92+
def test_init_with_client_settings(self):
93+
store = ChromaDocumentStore(client_settings={"anonymized_telemetry": False})
94+
assert store._client_settings == {"anonymized_telemetry": False}
95+
7896
def test_invalid_initialization_both_host_and_persist_path(self):
7997
"""
8098
Test that providing both host and persist_path raises an error.
@@ -83,9 +101,33 @@ def test_invalid_initialization_both_host_and_persist_path(self):
83101
store = ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost")
84102
store._ensure_initialized()
85103

104+
def test_client_settings_applied(self, clear_chroma_system_cache):
105+
"""
106+
Chroma's in-memory client uses a singleton pattern with an internal cache.
107+
Once a client is created with certain settings, Chroma rejects creating another
108+
with different settings in the same process. We clear the cache before and after
109+
this test to avoid conflicts with other tests that use default settings.
110+
"""
111+
store = ChromaDocumentStore(client_settings={"anonymized_telemetry": False})
112+
store._ensure_initialized()
113+
assert store._client.get_settings().anonymized_telemetry is False
114+
115+
def test_invalid_client_settings(self, clear_chroma_system_cache):
116+
store = ChromaDocumentStore(
117+
client_settings={
118+
"invalid_setting_name": "some_value",
119+
"another_fake_setting": 123,
120+
}
121+
)
122+
with pytest.raises(ValueError, match="Invalid client_settings"):
123+
store._ensure_initialized()
124+
86125
def test_to_dict(self, request):
87126
ds = ChromaDocumentStore(
88-
collection_name=request.node.name, embedding_function="HuggingFaceEmbeddingFunction", api_key="1234567890"
127+
collection_name=request.node.name,
128+
embedding_function="HuggingFaceEmbeddingFunction",
129+
api_key="1234567890",
130+
client_settings={"anonymized_telemetry": False},
89131
)
90132
ds_dict = ds.to_dict()
91133
assert ds_dict == {
@@ -98,6 +140,7 @@ def test_to_dict(self, request):
98140
"port": None,
99141
"api_key": "1234567890",
100142
"distance_function": "l2",
143+
"client_settings": {"anonymized_telemetry": False},
101144
},
102145
}
103146

@@ -114,13 +157,15 @@ def test_from_dict(self):
114157
"port": None,
115158
"api_key": "1234567890",
116159
"distance_function": "l2",
160+
"client_settings": {"anonymized_telemetry": False},
117161
},
118162
}
119163

120164
ds = ChromaDocumentStore.from_dict(ds_dict)
121165
assert ds._collection_name == collection_name
122166
assert ds._embedding_function == function_name
123167
assert ds._embedding_function_params == {"api_key": "1234567890"}
168+
assert ds._client_settings == {"anonymized_telemetry": False}
124169

125170
def test_same_collection_name_reinitialization(self):
126171
ChromaDocumentStore("test_1")

integrations/chroma/tests/test_document_store_async.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
sys.platform == "win32",
1919
reason="We do not run the Chroma server on Windows and async is only supported with HTTP connections",
2020
)
21+
@pytest.mark.integration
2122
@pytest.mark.asyncio
2223
class TestDocumentStoreAsync:
2324
@pytest.fixture
@@ -96,7 +97,29 @@ async def test_comparison_equal_async(self, document_store, filterable_docs):
9697
)
9798
self.assert_documents_are_equal(result, [d for d in filterable_docs if d.meta.get("number") == 100])
9899

99-
@pytest.mark.integration
100+
async def test_client_settings_applied_async(self):
101+
store = ChromaDocumentStore(
102+
host="localhost",
103+
port=8000,
104+
client_settings={"anonymized_telemetry": False},
105+
collection_name=f"{uuid.uuid1()}-async-settings",
106+
)
107+
await store._ensure_initialized_async()
108+
assert store._async_client.get_settings().anonymized_telemetry is False
109+
110+
async def test_invalid_client_settings_async(self):
111+
store = ChromaDocumentStore(
112+
host="localhost",
113+
port=8000,
114+
client_settings={
115+
"invalid_setting_name": "some_value",
116+
"another_fake_setting": 123,
117+
},
118+
collection_name=f"{uuid.uuid1()}-async-invalid",
119+
)
120+
with pytest.raises(ValueError, match="Invalid client_settings"):
121+
await store._ensure_initialized_async()
122+
100123
async def test_search_async(self):
101124
document_store = ChromaDocumentStore(host="localhost", port=8000, collection_name="my_custom_collection")
102125

integrations/chroma/tests/test_retriever.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def test_to_dict(self, request):
4141
"port": None,
4242
"api_key": "1234567890",
4343
"distance_function": "l2",
44+
"client_settings": None,
4445
},
4546
},
4647
},
@@ -131,6 +132,7 @@ def test_to_dict(self, request):
131132
"port": None,
132133
"api_key": "1234567890",
133134
"distance_function": "l2",
135+
"client_settings": None,
134136
},
135137
},
136138
},

integrations/github/tests/test_issue_commenter_tool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_to_dict_with_extra_params(self, monkeypatch):
7676
raise_on_failure=False,
7777
retry_attempts=3,
7878
outputs_to_string={"handler": message_handler},
79-
inputs_from_state={"repository": "repo"},
79+
inputs_from_state={"repository": "url"},
8080
outputs_to_state={"documents": {"source": "success", "handler": message_handler}},
8181
)
8282
tool_dict = tool.to_dict()
@@ -91,7 +91,7 @@ def test_to_dict_with_extra_params(self, monkeypatch):
9191
tool_dict["data"]["outputs_to_string"]["handler"]
9292
== "haystack_integrations.tools.github.utils.message_handler"
9393
)
94-
assert tool_dict["data"]["inputs_from_state"] == {"repository": "repo"}
94+
assert tool_dict["data"]["inputs_from_state"] == {"repository": "url"}
9595
assert tool_dict["data"]["outputs_to_state"]["documents"]["source"] == "success"
9696
assert (
9797
tool_dict["data"]["outputs_to_state"]["documents"]["handler"]
@@ -110,7 +110,7 @@ def test_from_dict_with_extra_params(self, monkeypatch):
110110
"raise_on_failure": False,
111111
"retry_attempts": 3,
112112
"outputs_to_string": {"handler": "haystack_integrations.tools.github.utils.message_handler"},
113-
"inputs_from_state": {"repository": "repo"},
113+
"inputs_from_state": {"repository": "url"},
114114
"outputs_to_state": {
115115
"documents": {
116116
"source": "success",
@@ -127,6 +127,6 @@ def test_from_dict_with_extra_params(self, monkeypatch):
127127
assert tool.raise_on_failure is False
128128
assert tool.retry_attempts == 3
129129
assert tool.outputs_to_string["handler"] == message_handler
130-
assert tool.inputs_from_state == {"repository": "repo"}
130+
assert tool.inputs_from_state == {"repository": "url"}
131131
assert tool.outputs_to_state["documents"]["source"] == "success"
132132
assert tool.outputs_to_state["documents"]["handler"] == message_handler

integrations/github/tests/test_issue_viewer_tool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_to_dict_with_extra_params(self, monkeypatch):
7272
raise_on_failure=False,
7373
retry_attempts=3,
7474
outputs_to_string={"handler": message_handler},
75-
inputs_from_state={"repository": "repo"},
75+
inputs_from_state={"repository": "url"},
7676
outputs_to_state={"documents": {"source": "documents", "handler": message_handler}},
7777
)
7878
tool_dict = tool.to_dict()
@@ -87,7 +87,7 @@ def test_to_dict_with_extra_params(self, monkeypatch):
8787
tool_dict["data"]["outputs_to_string"]["handler"]
8888
== "haystack_integrations.tools.github.utils.message_handler"
8989
)
90-
assert tool_dict["data"]["inputs_from_state"] == {"repository": "repo"}
90+
assert tool_dict["data"]["inputs_from_state"] == {"repository": "url"}
9191
assert tool_dict["data"]["outputs_to_state"]["documents"]["source"] == "documents"
9292
assert (
9393
tool_dict["data"]["outputs_to_state"]["documents"]["handler"]
@@ -106,7 +106,7 @@ def test_from_dict_with_extra_params(self, monkeypatch):
106106
"raise_on_failure": False,
107107
"retry_attempts": 3,
108108
"outputs_to_string": {"handler": "haystack_integrations.tools.github.utils.message_handler"},
109-
"inputs_from_state": {"repository": "repo"},
109+
"inputs_from_state": {"repository": "url"},
110110
"outputs_to_state": {
111111
"documents": {
112112
"source": "documents",
@@ -123,6 +123,6 @@ def test_from_dict_with_extra_params(self, monkeypatch):
123123
assert tool.raise_on_failure is False
124124
assert tool.retry_attempts == 3
125125
assert tool.outputs_to_string["handler"] == message_handler
126-
assert tool.inputs_from_state == {"repository": "repo"}
126+
assert tool.inputs_from_state == {"repository": "url"}
127127
assert tool.outputs_to_state["documents"]["source"] == "documents"
128128
assert tool.outputs_to_state["documents"]["handler"] == message_handler

integrations/github/tests/test_pr_creator_tool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_to_dict_with_extra_params(self, monkeypatch):
7373
github_token=Secret.from_env_var("GITHUB_TOKEN"),
7474
raise_on_failure=False,
7575
outputs_to_string={"handler": message_handler},
76-
inputs_from_state={"repository": "repo"},
76+
inputs_from_state={"repository": "issue_url"},
7777
outputs_to_state={"documents": {"source": "result", "handler": message_handler}},
7878
)
7979
tool_dict = tool.to_dict()
@@ -91,7 +91,7 @@ def test_to_dict_with_extra_params(self, monkeypatch):
9191
tool_dict["data"]["outputs_to_string"]["handler"]
9292
== "haystack_integrations.tools.github.utils.message_handler"
9393
)
94-
assert tool_dict["data"]["inputs_from_state"] == {"repository": "repo"}
94+
assert tool_dict["data"]["inputs_from_state"] == {"repository": "issue_url"}
9595
assert tool_dict["data"]["outputs_to_state"]["documents"]["source"] == "result"
9696
assert (
9797
tool_dict["data"]["outputs_to_state"]["documents"]["handler"]
@@ -109,7 +109,7 @@ def test_from_dict_with_extra_params(self, monkeypatch):
109109
"github_token": {"env_vars": ["GITHUB_TOKEN"], "strict": True, "type": "env_var"},
110110
"raise_on_failure": False,
111111
"outputs_to_string": {"handler": "haystack_integrations.tools.github.utils.message_handler"},
112-
"inputs_from_state": {"repository": "repo"},
112+
"inputs_from_state": {"repository": "issue_url"},
113113
"outputs_to_state": {
114114
"documents": {
115115
"source": "result",
@@ -125,6 +125,6 @@ def test_from_dict_with_extra_params(self, monkeypatch):
125125
assert tool.github_token == Secret.from_env_var("GITHUB_TOKEN")
126126
assert tool.raise_on_failure is False
127127
assert tool.outputs_to_string["handler"] == message_handler
128-
assert tool.inputs_from_state == {"repository": "repo"}
128+
assert tool.inputs_from_state == {"repository": "issue_url"}
129129
assert tool.outputs_to_state["documents"]["source"] == "result"
130130
assert tool.outputs_to_state["documents"]["handler"] == message_handler

integrations/github/tests/test_repo_forker_tool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_to_dict_with_extra_params(self, monkeypatch):
7171
github_token=Secret.from_env_var("GITHUB_TOKEN"),
7272
raise_on_failure=False,
7373
outputs_to_string={"source": "repo", "handler": message_handler},
74-
inputs_from_state={"repository": "repo"},
74+
inputs_from_state={"repository": "url"},
7575
outputs_to_state={"documents": {"source": "repo", "handler": message_handler}},
7676
)
7777
tool_dict = tool.to_dict()
@@ -89,7 +89,7 @@ def test_to_dict_with_extra_params(self, monkeypatch):
8989
tool_dict["data"]["outputs_to_string"]["handler"]
9090
== "haystack_integrations.tools.github.utils.message_handler"
9191
)
92-
assert tool_dict["data"]["inputs_from_state"] == {"repository": "repo"}
92+
assert tool_dict["data"]["inputs_from_state"] == {"repository": "url"}
9393
assert tool_dict["data"]["outputs_to_state"]["documents"]["source"] == "repo"
9494
assert (
9595
tool_dict["data"]["outputs_to_state"]["documents"]["handler"]
@@ -107,7 +107,7 @@ def test_from_dict_with_extra_params(self, monkeypatch):
107107
"github_token": {"env_vars": ["GITHUB_TOKEN"], "strict": True, "type": "env_var"},
108108
"raise_on_failure": False,
109109
"outputs_to_string": {"handler": "haystack_integrations.tools.github.utils.message_handler"},
110-
"inputs_from_state": {"repository": "repo"},
110+
"inputs_from_state": {"repository": "url"},
111111
"outputs_to_state": {
112112
"documents": {
113113
"source": "repo",
@@ -123,6 +123,6 @@ def test_from_dict_with_extra_params(self, monkeypatch):
123123
assert tool.github_token == Secret.from_env_var("GITHUB_TOKEN")
124124
assert tool.raise_on_failure is False
125125
assert tool.outputs_to_string["handler"] == message_handler
126-
assert tool.inputs_from_state == {"repository": "repo"}
126+
assert tool.inputs_from_state == {"repository": "url"}
127127
assert tool.outputs_to_state["documents"]["source"] == "repo"
128128
assert tool.outputs_to_state["documents"]["handler"] == message_handler

0 commit comments

Comments
 (0)