|
11 | 11 | from chromadb.config import Settings |
12 | 12 | from haystack import default_from_dict, default_to_dict, logging |
13 | 13 | from haystack.dataclasses import Document |
14 | | -from haystack.document_stores.errors import DocumentStoreError |
| 14 | +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError |
15 | 15 | from haystack.document_stores.types import DuplicatePolicy |
16 | 16 | from haystack.utils.misc import _normalize_metadata_field_name |
17 | 17 |
|
@@ -573,62 +573,123 @@ def _convert_document_to_chroma(doc: Document) -> dict[str, Any] | None: |
573 | 573 | def write_documents( |
574 | 574 | self, |
575 | 575 | documents: list[Document], |
576 | | - policy: DuplicatePolicy = DuplicatePolicy.FAIL, |
| 576 | + policy: DuplicatePolicy = DuplicatePolicy.NONE, |
577 | 577 | ) -> int: |
578 | 578 | """ |
579 | | - Writes (or overwrites) documents into the store. |
| 579 | + Writes documents into the store. |
580 | 580 |
|
581 | 581 | :param documents: |
582 | 582 | A list of documents to write into the document store. |
583 | 583 | :param policy: |
584 | | - Not supported at the moment. |
| 584 | + How to handle documents whose `id` already exists in the store: |
| 585 | + - `NONE` (default): treated as `FAIL`. |
| 586 | + - `OVERWRITE`: replace the existing document. |
| 587 | + - `SKIP`: keep the existing document and skip the new one. |
| 588 | + - `FAIL`: raise `DuplicateDocumentError`. |
585 | 589 |
|
586 | 590 | :raises ValueError: |
587 | 591 | When input is not valid. |
| 592 | + :raises DuplicateDocumentError: |
| 593 | + When `policy` is `FAIL` (or `NONE`) and any document `id` already exists. |
588 | 594 |
|
589 | 595 | :returns: |
590 | | - The number of documents written |
| 596 | + The number of documents written. |
591 | 597 | """ |
592 | 598 | self._ensure_initialized() |
593 | 599 | assert self._collection is not None |
594 | 600 |
|
595 | | - for doc in documents: |
596 | | - data = ChromaDocumentStore._convert_document_to_chroma(doc) |
597 | | - if data is not None: |
598 | | - self._collection.add(**data) |
| 601 | + if policy == DuplicatePolicy.NONE: |
| 602 | + policy = DuplicatePolicy.FAIL |
599 | 603 |
|
600 | | - return len(documents) |
| 604 | + chroma_payloads: list[dict[str, Any]] = [ |
| 605 | + p for p in (ChromaDocumentStore._convert_document_to_chroma(doc) for doc in documents) if p is not None |
| 606 | + ] |
| 607 | + if not chroma_payloads: |
| 608 | + return 0 |
| 609 | + |
| 610 | + if policy in (DuplicatePolicy.FAIL, DuplicatePolicy.SKIP): |
| 611 | + existing_ids = set(self._collection.get(ids=[p["ids"][0] for p in chroma_payloads])["ids"]) |
| 612 | + payloads_to_write = self._apply_duplicate_policy(chroma_payloads, existing_ids, policy) |
| 613 | + else: |
| 614 | + payloads_to_write = chroma_payloads |
| 615 | + |
| 616 | + for payload in payloads_to_write: |
| 617 | + if policy == DuplicatePolicy.OVERWRITE: |
| 618 | + self._collection.upsert(**payload) |
| 619 | + else: |
| 620 | + self._collection.add(**payload) |
| 621 | + |
| 622 | + return len(payloads_to_write) |
601 | 623 |
|
602 | 624 | async def write_documents_async( |
603 | 625 | self, |
604 | 626 | documents: list[Document], |
605 | | - policy: DuplicatePolicy = DuplicatePolicy.FAIL, |
| 627 | + policy: DuplicatePolicy = DuplicatePolicy.NONE, |
606 | 628 | ) -> int: |
607 | 629 | """ |
608 | | - Asynchronously writes (or overwrites) documents into the store. |
| 630 | + Asynchronously writes documents into the store. |
609 | 631 |
|
610 | 632 | Asynchronous methods are only supported for HTTP connections. |
611 | 633 |
|
612 | 634 | :param documents: |
613 | 635 | A list of documents to write into the document store. |
614 | 636 | :param policy: |
615 | | - Not supported at the moment. |
| 637 | + How to handle documents whose `id` already exists in the store: |
| 638 | + - `NONE` (default): treated as `FAIL`. |
| 639 | + - `OVERWRITE`: replace the existing document. |
| 640 | + - `SKIP`: keep the existing document and skip the new one. |
| 641 | + - `FAIL`: raise `DuplicateDocumentError`. |
616 | 642 |
|
617 | 643 | :raises ValueError: |
618 | 644 | When input is not valid. |
| 645 | + :raises DuplicateDocumentError: |
| 646 | + When `policy` is `FAIL` (or `NONE`) and any document `id` already exists. |
619 | 647 |
|
620 | 648 | :returns: |
621 | | - The number of documents written |
| 649 | + The number of documents written. |
622 | 650 | """ |
623 | 651 | await self._ensure_initialized_async() |
624 | 652 | assert self._async_collection is not None |
625 | 653 |
|
626 | | - for doc in documents: |
627 | | - data = ChromaDocumentStore._convert_document_to_chroma(doc) |
628 | | - if data is not None: |
629 | | - await self._async_collection.add(**data) |
| 654 | + if policy == DuplicatePolicy.NONE: |
| 655 | + policy = DuplicatePolicy.FAIL |
| 656 | + |
| 657 | + chroma_payloads: list[dict[str, Any]] = [ |
| 658 | + p for p in (ChromaDocumentStore._convert_document_to_chroma(doc) for doc in documents) if p is not None |
| 659 | + ] |
| 660 | + if not chroma_payloads: |
| 661 | + return 0 |
630 | 662 |
|
631 | | - return len(documents) |
| 663 | + if policy in (DuplicatePolicy.FAIL, DuplicatePolicy.SKIP): |
| 664 | + existing = await self._async_collection.get(ids=[p["ids"][0] for p in chroma_payloads]) |
| 665 | + existing_ids = set(existing["ids"]) |
| 666 | + payloads_to_write = self._apply_duplicate_policy(chroma_payloads, existing_ids, policy) |
| 667 | + else: |
| 668 | + payloads_to_write = chroma_payloads |
| 669 | + |
| 670 | + for payload in payloads_to_write: |
| 671 | + if policy == DuplicatePolicy.OVERWRITE: |
| 672 | + await self._async_collection.upsert(**payload) |
| 673 | + else: |
| 674 | + await self._async_collection.add(**payload) |
| 675 | + |
| 676 | + return len(payloads_to_write) |
| 677 | + |
| 678 | + @staticmethod |
| 679 | + def _apply_duplicate_policy( |
| 680 | + payloads: list[dict[str, Any]], |
| 681 | + existing_ids: set[str], |
| 682 | + policy: DuplicatePolicy, |
| 683 | + ) -> list[dict[str, Any]]: |
| 684 | + if policy == DuplicatePolicy.FAIL: |
| 685 | + duplicates = [p["ids"][0] for p in payloads if p["ids"][0] in existing_ids] |
| 686 | + if duplicates: |
| 687 | + msg = f"Documents with ids {duplicates} already exist in the document store." |
| 688 | + raise DuplicateDocumentError(msg) |
| 689 | + return payloads |
| 690 | + if policy == DuplicatePolicy.SKIP: |
| 691 | + return [p for p in payloads if p["ids"][0] not in existing_ids] |
| 692 | + return payloads |
632 | 693 |
|
633 | 694 | def delete_documents(self, document_ids: list[str]) -> None: |
634 | 695 | """ |
|
0 commit comments