Skip to content

Commit b648517

Browse files
feat(chroma): support DuplicatePolicy in write_documents and use async mixin tests (#3245)
Co-authored-by: bogdankostic <bogdankostic@web.de>
1 parent d26a08b commit b648517

3 files changed

Lines changed: 142 additions & 324 deletions

File tree

integrations/chroma/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ classifiers = [
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
2626
dependencies = [
27-
"haystack-ai>=2.26.1",
27+
"haystack-ai>=2.28.0",
2828
"chromadb>=1.5.4"
2929
]
3030

integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py

Lines changed: 80 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from chromadb.config import Settings
1212
from haystack import default_from_dict, default_to_dict, logging
1313
from haystack.dataclasses import Document
14-
from haystack.document_stores.errors import DocumentStoreError
14+
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
1515
from haystack.document_stores.types import DuplicatePolicy
1616
from haystack.utils.misc import _normalize_metadata_field_name
1717

@@ -573,62 +573,123 @@ def _convert_document_to_chroma(doc: Document) -> dict[str, Any] | None:
573573
def write_documents(
574574
self,
575575
documents: list[Document],
576-
policy: DuplicatePolicy = DuplicatePolicy.FAIL,
576+
policy: DuplicatePolicy = DuplicatePolicy.NONE,
577577
) -> int:
578578
"""
579-
Writes (or overwrites) documents into the store.
579+
Writes documents into the store.
580580
581581
:param documents:
582582
A list of documents to write into the document store.
583583
:param policy:
584-
Not supported at the moment.
584+
How to handle documents whose `id` already exists in the store:
585+
- `NONE` (default): treated as `FAIL`.
586+
- `OVERWRITE`: replace the existing document.
587+
- `SKIP`: keep the existing document and skip the new one.
588+
- `FAIL`: raise `DuplicateDocumentError`.
585589
586590
:raises ValueError:
587591
When input is not valid.
592+
:raises DuplicateDocumentError:
593+
When `policy` is `FAIL` (or `NONE`) and any document `id` already exists.
588594
589595
:returns:
590-
The number of documents written
596+
The number of documents written.
591597
"""
592598
self._ensure_initialized()
593599
assert self._collection is not None
594600

595-
for doc in documents:
596-
data = ChromaDocumentStore._convert_document_to_chroma(doc)
597-
if data is not None:
598-
self._collection.add(**data)
601+
if policy == DuplicatePolicy.NONE:
602+
policy = DuplicatePolicy.FAIL
599603

600-
return len(documents)
604+
chroma_payloads: list[dict[str, Any]] = [
605+
p for p in (ChromaDocumentStore._convert_document_to_chroma(doc) for doc in documents) if p is not None
606+
]
607+
if not chroma_payloads:
608+
return 0
609+
610+
if policy in (DuplicatePolicy.FAIL, DuplicatePolicy.SKIP):
611+
existing_ids = set(self._collection.get(ids=[p["ids"][0] for p in chroma_payloads])["ids"])
612+
payloads_to_write = self._apply_duplicate_policy(chroma_payloads, existing_ids, policy)
613+
else:
614+
payloads_to_write = chroma_payloads
615+
616+
for payload in payloads_to_write:
617+
if policy == DuplicatePolicy.OVERWRITE:
618+
self._collection.upsert(**payload)
619+
else:
620+
self._collection.add(**payload)
621+
622+
return len(payloads_to_write)
601623

602624
async def write_documents_async(
603625
self,
604626
documents: list[Document],
605-
policy: DuplicatePolicy = DuplicatePolicy.FAIL,
627+
policy: DuplicatePolicy = DuplicatePolicy.NONE,
606628
) -> int:
607629
"""
608-
Asynchronously writes (or overwrites) documents into the store.
630+
Asynchronously writes documents into the store.
609631
610632
Asynchronous methods are only supported for HTTP connections.
611633
612634
:param documents:
613635
A list of documents to write into the document store.
614636
:param policy:
615-
Not supported at the moment.
637+
How to handle documents whose `id` already exists in the store:
638+
- `NONE` (default): treated as `FAIL`.
639+
- `OVERWRITE`: replace the existing document.
640+
- `SKIP`: keep the existing document and skip the new one.
641+
- `FAIL`: raise `DuplicateDocumentError`.
616642
617643
:raises ValueError:
618644
When input is not valid.
645+
:raises DuplicateDocumentError:
646+
When `policy` is `FAIL` (or `NONE`) and any document `id` already exists.
619647
620648
:returns:
621-
The number of documents written
649+
The number of documents written.
622650
"""
623651
await self._ensure_initialized_async()
624652
assert self._async_collection is not None
625653

626-
for doc in documents:
627-
data = ChromaDocumentStore._convert_document_to_chroma(doc)
628-
if data is not None:
629-
await self._async_collection.add(**data)
654+
if policy == DuplicatePolicy.NONE:
655+
policy = DuplicatePolicy.FAIL
656+
657+
chroma_payloads: list[dict[str, Any]] = [
658+
p for p in (ChromaDocumentStore._convert_document_to_chroma(doc) for doc in documents) if p is not None
659+
]
660+
if not chroma_payloads:
661+
return 0
630662

631-
return len(documents)
663+
if policy in (DuplicatePolicy.FAIL, DuplicatePolicy.SKIP):
664+
existing = await self._async_collection.get(ids=[p["ids"][0] for p in chroma_payloads])
665+
existing_ids = set(existing["ids"])
666+
payloads_to_write = self._apply_duplicate_policy(chroma_payloads, existing_ids, policy)
667+
else:
668+
payloads_to_write = chroma_payloads
669+
670+
for payload in payloads_to_write:
671+
if policy == DuplicatePolicy.OVERWRITE:
672+
await self._async_collection.upsert(**payload)
673+
else:
674+
await self._async_collection.add(**payload)
675+
676+
return len(payloads_to_write)
677+
678+
@staticmethod
679+
def _apply_duplicate_policy(
680+
payloads: list[dict[str, Any]],
681+
existing_ids: set[str],
682+
policy: DuplicatePolicy,
683+
) -> list[dict[str, Any]]:
684+
if policy == DuplicatePolicy.FAIL:
685+
duplicates = [p["ids"][0] for p in payloads if p["ids"][0] in existing_ids]
686+
if duplicates:
687+
msg = f"Documents with ids {duplicates} already exist in the document store."
688+
raise DuplicateDocumentError(msg)
689+
return payloads
690+
if policy == DuplicatePolicy.SKIP:
691+
return [p for p in payloads if p["ids"][0] not in existing_ids]
692+
return payloads
632693

633694
def delete_documents(self, document_ids: list[str]) -> None:
634695
"""

0 commit comments

Comments
 (0)