Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions integrations/mongodb_atlas/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ name = "mongodb-atlas-haystack"
dynamic = ["version"]
description = "An integration of MongoDB Atlas with Haystack"
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = "Apache-2.0"
keywords = []
authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }]
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
Expand All @@ -24,7 +23,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"haystack-ai>=2.11.0",
"haystack-ai>=2.22.0",
"pymongo[srv]>=4.13.0"
]

Expand Down Expand Up @@ -80,7 +79,6 @@ disallow_incomplete_defs = true


[tool.ruff]
target-version = "py39"
line-length = 120

[tool.ruff.lint]
Expand Down Expand Up @@ -129,10 +127,6 @@ ignore = [
# Allow assert statements
"S101",
]
unfixable = [
# Don't touch unused imports
"F401",
]

[tool.ruff.lint.isort]
known-first-party = ["haystack_integrations"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional, Union
from typing import Any

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
Expand Down Expand Up @@ -45,9 +45,9 @@ def __init__(
self,
*,
document_store: MongoDBAtlasDocumentStore,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
filter_policy: str | FilterPolicy = FilterPolicy.REPLACE,
):
"""
Create the MongoDBAtlasDocumentStore component.
Expand Down Expand Up @@ -110,8 +110,8 @@ def from_dict(cls, data: dict[str, Any]) -> "MongoDBAtlasEmbeddingRetriever":
def run(
self,
query_embedding: list[float],
filters: Optional[dict[str, Any]] = None,
top_k: Optional[int] = None,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
) -> dict[str, list[Document]]:
"""
Retrieve documents from the MongoDBAtlasDocumentStore, based on the provided embedding similarity.
Expand All @@ -138,8 +138,8 @@ def run(
async def run_async(
self,
query_embedding: list[float],
filters: Optional[dict[str, Any]] = None,
top_k: Optional[int] = None,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
) -> dict[str, list[Document]]:
"""
Asynchronously retrieve documents from the MongoDBAtlasDocumentStore, based on the provided embedding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Literal, Optional, Union
from typing import Any, Literal

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
Expand Down Expand Up @@ -43,9 +43,9 @@ def __init__(
self,
*,
document_store: MongoDBAtlasDocumentStore,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
filter_policy: str | FilterPolicy = FilterPolicy.REPLACE,
):
"""
:param document_store: An instance of MongoDBAtlasDocumentStore.
Expand Down Expand Up @@ -103,12 +103,12 @@ def from_dict(cls, data: dict[str, Any]) -> "MongoDBAtlasFullTextRetriever":
@component.output_types(documents=list[Document])
def run(
self,
query: Union[str, list[str]],
fuzzy: Optional[dict[str, int]] = None,
match_criteria: Optional[Literal["any", "all"]] = None,
score: Optional[dict[str, dict]] = None,
synonyms: Optional[str] = None,
filters: Optional[dict[str, Any]] = None,
query: str | list[str],
fuzzy: dict[str, int] | None = None,
match_criteria: Literal["any", "all"] | None = None,
score: dict[str, dict] | None = None,
synonyms: str | None = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
) -> dict[str, list[Document]]:
"""
Expand Down Expand Up @@ -153,12 +153,12 @@ def run(
@component.output_types(documents=list[Document])
async def run_async(
self,
query: Union[str, list[str]],
fuzzy: Optional[dict[str, int]] = None,
match_criteria: Optional[Literal["any", "all"]] = None,
score: Optional[dict[str, dict]] = None,
synonyms: Optional[str] = None,
filters: Optional[dict[str, Any]] = None,
query: str | list[str],
fuzzy: dict[str, int] | None = None,
match_criteria: Literal["any", "all"] | None = None,
score: dict[str, dict] | None = None,
synonyms: str | None = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
) -> dict[str, list[Document]]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import re
from typing import Any, Literal, Optional, Union
from typing import Any, Literal

from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses.document import Document
Expand Down Expand Up @@ -105,10 +105,10 @@ def __init__(
self.full_text_search_index = full_text_search_index
self.embedding_field = embedding_field
self.content_field = content_field
self._connection: Optional[MongoClient] = None
self._connection_async: Optional[AsyncMongoClient] = None
self._collection: Optional[Collection] = None
self._collection_async: Optional[AsyncCollection] = None
self._connection: MongoClient | None = None
self._connection_async: AsyncMongoClient | None = None
self._collection: Collection | None = None
self._collection_async: AsyncCollection | None = None

def __del__(self) -> None:
"""
Expand All @@ -118,7 +118,7 @@ def __del__(self) -> None:
self._connection.close()

@property
def connection(self) -> Union[AsyncMongoClient, MongoClient]:
def connection(self) -> AsyncMongoClient | MongoClient:
if self._connection:
return self._connection
if self._connection_async:
Expand All @@ -127,7 +127,7 @@ def connection(self) -> Union[AsyncMongoClient, MongoClient]:
raise DocumentStoreError(msg)

@property
def collection(self) -> Union[AsyncCollection, Collection]:
def collection(self) -> AsyncCollection | Collection:
if self._collection:
return self._collection
if self._collection_async:
Expand Down Expand Up @@ -278,7 +278,7 @@ async def count_documents_async(self) -> int:
assert self._collection_async is not None
return await self._collection_async.count_documents({})

def filter_documents(self, filters: Optional[dict[str, Any]] = None) -> list[Document]:
def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Document]:
"""
Returns the documents that match the filters provided.

Expand All @@ -294,7 +294,7 @@ def filter_documents(self, filters: Optional[dict[str, Any]] = None) -> list[Doc
documents = list(self._collection.find(filters))
return [self._mongo_doc_to_haystack_doc(doc) for doc in documents]

async def filter_documents_async(self, filters: Optional[dict[str, Any]] = None) -> list[Document]:
async def filter_documents_async(self, filters: dict[str, Any] | None = None) -> list[Document]:
"""
Asynchronously returns the documents that match the filters provided.

Expand Down Expand Up @@ -332,7 +332,7 @@ def write_documents(self, documents: list[Document], policy: DuplicatePolicy = D
policy = DuplicatePolicy.FAIL

mongo_documents = [self._haystack_doc_to_mongo_doc(doc) for doc in documents]
operations: list[Union[UpdateOne, InsertOne, ReplaceOne]]
operations: list[UpdateOne | InsertOne | ReplaceOne]
written_docs = len(documents)

if policy == DuplicatePolicy.SKIP:
Expand Down Expand Up @@ -377,7 +377,7 @@ async def write_documents_async(

mongo_documents = [self._haystack_doc_to_mongo_doc(doc) for doc in documents]

operations: list[Union[UpdateOne, InsertOne, ReplaceOne]]
operations: list[UpdateOne | InsertOne | ReplaceOne]
written_docs = len(documents)

if policy == DuplicatePolicy.SKIP:
Expand Down Expand Up @@ -636,7 +636,7 @@ async def delete_all_documents_async(self, *, recreate_collection: bool = False)
def _embedding_retrieval(
self,
query_embedding: list[float],
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
) -> list[Document]:
"""
Expand Down Expand Up @@ -686,7 +686,7 @@ def _embedding_retrieval(
return documents

async def _embedding_retrieval_async(
self, query_embedding: list[float], filters: Optional[dict[str, Any]] = None, top_k: int = 10
self, query_embedding: list[float], filters: dict[str, Any] | None = None, top_k: int = 10
) -> list[Document]:
"""
Asynchronously find the documents that are most similar to the provided `query_embedding` by using a vector
Expand Down Expand Up @@ -738,12 +738,12 @@ async def _embedding_retrieval_async(

def _fulltext_retrieval(
self,
query: Union[str, list[str]],
fuzzy: Optional[dict[str, int]] = None,
match_criteria: Optional[Literal["any", "all"]] = None,
score: Optional[dict[str, dict]] = None,
synonyms: Optional[str] = None,
filters: Optional[dict[str, Any]] = None,
query: str | list[str],
fuzzy: dict[str, int] | None = None,
match_criteria: Literal["any", "all"] | None = None,
score: dict[str, dict] | None = None,
synonyms: str | None = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
) -> list[Document]:
"""
Expand Down Expand Up @@ -831,12 +831,12 @@ def _fulltext_retrieval(

async def _fulltext_retrieval_async(
self,
query: Union[str, list[str]],
fuzzy: Optional[dict[str, int]] = None,
match_criteria: Optional[Literal["any", "all"]] = None,
score: Optional[dict[str, dict]] = None,
synonyms: Optional[str] = None,
filters: Optional[dict[str, Any]] = None,
query: str | list[str],
fuzzy: dict[str, int] | None = None,
match_criteria: Literal["any", "all"] | None = None,
score: dict[str, dict] | None = None,
synonyms: str | None = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
) -> list[Document]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import os
from time import sleep
from typing import Union
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -157,7 +156,7 @@ def test_synonyms_retrieval(self, document_store: MongoDBAtlasDocumentStore):
assert results[0].score >= results[1].score

@pytest.mark.parametrize("query", ["", []])
def test_empty_query_raises_value_error(self, query: Union[str, list], document_store: MongoDBAtlasDocumentStore):
def test_empty_query_raises_value_error(self, query: str | list, document_store: MongoDBAtlasDocumentStore):
with pytest.raises(ValueError):
document_store._fulltext_retrieval(query=query)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import os
from time import sleep
from typing import Union
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
Expand Down Expand Up @@ -116,7 +115,7 @@ async def test_synonyms_retrieval_async(self, document_store: MongoDBAtlasDocume

@pytest.mark.parametrize("query", ["", []])
async def test_empty_query_raises_value_error_async(
self, query: Union[str, list], document_store: MongoDBAtlasDocumentStore
self, query: str | list, document_store: MongoDBAtlasDocumentStore
):
with pytest.raises(ValueError):
await document_store._fulltext_retrieval_async(query=query)
Expand Down