Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
score_threshold: Optional[float] = None,
group_by: Optional[str] = None,
group_size: Optional[int] = None,
):
) -> None:
"""
Create a QdrantEmbeddingRetriever component.

Expand Down Expand Up @@ -136,7 +136,7 @@ def run(
score_threshold: Optional[float] = None,
group_by: Optional[str] = None,
group_size: Optional[int] = None,
):
) -> Dict[str, List[Document]]:
"""
Run the Embedding Retriever on the given input data.

Expand Down Expand Up @@ -180,7 +180,7 @@ async def run_async(
score_threshold: Optional[float] = None,
group_by: Optional[str] = None,
group_size: Optional[int] = None,
):
) -> Dict[str, List[Document]]:
"""
Asynchronously run the Embedding Retriever on the given input data.

Expand Down Expand Up @@ -252,7 +252,7 @@ def __init__(
score_threshold: Optional[float] = None,
group_by: Optional[str] = None,
group_size: Optional[int] = None,
):
) -> None:
"""
Create a QdrantSparseEmbeddingRetriever component.

Expand Down Expand Up @@ -342,7 +342,7 @@ def run(
score_threshold: Optional[float] = None,
group_by: Optional[str] = None,
group_size: Optional[int] = None,
):
) -> Dict[str, List[Document]]:
"""
Run the Sparse Embedding Retriever on the given input data.

Expand Down Expand Up @@ -391,7 +391,7 @@ async def run_async(
score_threshold: Optional[float] = None,
group_by: Optional[str] = None,
group_size: Optional[int] = None,
):
) -> Dict[str, List[Document]]:
"""
Asynchronously run the Sparse Embedding Retriever on the given input data.

Expand Down Expand Up @@ -473,7 +473,7 @@ def __init__(
score_threshold: Optional[float] = None,
group_by: Optional[str] = None,
group_size: Optional[int] = None,
):
) -> None:
"""
Create a QdrantHybridRetriever component.

Expand Down Expand Up @@ -557,7 +557,7 @@ def run(
score_threshold: Optional[float] = None,
group_by: Optional[str] = None,
group_size: Optional[int] = None,
):
) -> Dict[str, List[Document]]:
"""
Run the Sparse Embedding Retriever on the given input data.

Expand Down Expand Up @@ -606,7 +606,7 @@ async def run_async(
score_threshold: Optional[float] = None,
group_by: Optional[str] = None,
group_size: Optional[int] = None,
):
) -> Dict[str, List[Document]]:
"""
Asynchronously run the Sparse Embedding Retriever on the given input data.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
from copy import deepcopy
from itertools import islice
from typing import Any, AsyncGenerator, ClassVar, Dict, Generator, List, Optional, Set, Union
from typing import Any, AsyncGenerator, ClassVar, Dict, Generator, List, Optional, Set, Tuple, Union

import numpy as np
import qdrant_client
Expand All @@ -18,6 +19,7 @@
from .converters import (
DENSE_VECTORS_NAME,
SPARSE_VECTORS_NAME,
QdrantPoint,
convert_haystack_documents_to_qdrant_points,
convert_id,
convert_qdrant_point_to_haystack_document,
Expand All @@ -34,7 +36,7 @@ class QdrantStoreError(DocumentStoreError):
FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]]


def get_batches_from_generator(iterable, n):
def get_batches_from_generator(iterable: List, n: int) -> Generator:
"""
Batch elements of an iterable into fixed-length chunks or blocks.
"""
Expand Down Expand Up @@ -127,10 +129,10 @@ def __init__(
write_batch_size: int = 100,
scroll_size: int = 10_000,
payload_fields_to_index: Optional[List[dict]] = None,
):
) -> None:
"""
:param location:
If `memory` - use in-memory Qdrant instance.
If `":memory:"` - use in-memory Qdrant instance.
If `str` - use it as a URL parameter.
If `None` - use default values for host and port.
:param url:
Expand Down Expand Up @@ -164,7 +166,7 @@ def __init__(
Dimension of the embeddings.
:param on_disk:
Whether to store the collection on disk.
:param use_sparse_embedding:
:param use_sparse_embeddings:
If set to `True`, enables support for sparse embeddings.
:param sparse_idf:
If set to `True`, computes the Inverse Document Frequency (IDF) when using sparse embeddings.
Expand Down Expand Up @@ -232,7 +234,6 @@ def __init__(
self.path = path
self.force_disable_check_same_thread = force_disable_check_same_thread
self.metadata = metadata or {}
self.api_key = api_key

# Store the Qdrant collection specific attributes
self.shard_number = shard_number
Expand All @@ -258,9 +259,10 @@ def __init__(
self.write_batch_size = write_batch_size
self.scroll_size = scroll_size

def _initialize_client(self):
def _initialize_client(self) -> None:
if self._client is None:
client_params = self._prepare_client_params()
# This step adds the api-key and User-Agent to metadata
self._client = qdrant_client.QdrantClient(**client_params)
# Make sure the collection is properly set up
self._set_up_collection(
Expand All @@ -274,7 +276,7 @@ def _initialize_client(self):
self.payload_fields_to_index,
)

async def _initialize_async_client(self):
async def _initialize_async_client(self) -> None:
"""
Returns the asynchronous Qdrant client, initializing it if necessary.
"""
Expand Down Expand Up @@ -628,8 +630,6 @@ def get_documents_by_id(

:param ids:
A list of document IDs to retrieve.
:param index:
The name of the index to retrieve documents from.
:returns:
A list of documents.
"""
Expand Down Expand Up @@ -661,8 +661,6 @@ async def get_documents_by_id_async(

:param ids:
A list of document IDs to retrieve.
:param index:
The name of the index to retrieve documents from.
:returns:
A list of documents.
"""
Expand Down Expand Up @@ -1210,7 +1208,7 @@ def get_distance(self, similarity: str) -> rest.Distance:
)
raise QdrantStoreError(msg) from ke

def _create_payload_index(self, collection_name: str, payload_fields_to_index: Optional[List[dict]] = None):
def _create_payload_index(self, collection_name: str, payload_fields_to_index: Optional[List[dict]] = None) -> None:
"""
Create payload index for the collection if payload_fields_to_index is provided
See: https://qdrant.tech/documentation/concepts/indexing/#payload-index
Expand All @@ -1229,7 +1227,7 @@ def _create_payload_index(self, collection_name: str, payload_fields_to_index: O

async def _create_payload_index_async(
self, collection_name: str, payload_fields_to_index: Optional[List[dict]] = None
):
) -> None:
"""
Asynchronously create payload index for the collection if payload_fields_to_index is provided
See: https://qdrant.tech/documentation/concepts/indexing/#payload-index
Expand Down Expand Up @@ -1257,7 +1255,7 @@ def _set_up_collection(
sparse_idf: bool,
on_disk: bool = False,
payload_fields_to_index: Optional[List[dict]] = None,
):
) -> None:
"""
Sets up the Qdrant collection with the specified parameters.
:param collection_name:
Expand Down Expand Up @@ -1313,7 +1311,7 @@ async def _set_up_collection_async(
sparse_idf: bool,
on_disk: bool = False,
payload_fields_to_index: Optional[List[dict]] = None,
):
) -> None:
"""
Asynchronously sets up the Qdrant collection with the specified parameters.
:param collection_name:
Expand Down Expand Up @@ -1367,7 +1365,7 @@ def recreate_collection(
on_disk: Optional[bool] = None,
use_sparse_embeddings: Optional[bool] = None,
sparse_idf: bool = False,
):
) -> None:
"""
Recreates the Qdrant collection with the specified parameters.

Expand Down Expand Up @@ -1410,7 +1408,7 @@ async def recreate_collection_async(
on_disk: Optional[bool] = None,
use_sparse_embeddings: Optional[bool] = None,
sparse_idf: bool = False,
):
) -> None:
"""
Asynchronously recreates the Qdrant collection with the specified parameters.

Expand Down Expand Up @@ -1449,7 +1447,7 @@ def _handle_duplicate_documents(
self,
documents: List[Document],
policy: DuplicatePolicy = None,
):
) -> List[Document]:
"""
Checks whether any of the passed documents is already existing in the chosen index and returns a list of
documents that are not in the index yet.
Expand All @@ -1476,7 +1474,7 @@ async def _handle_duplicate_documents_async(
self,
documents: List[Document],
policy: DuplicatePolicy = None,
):
) -> List[Document]:
"""
Asynchronously checks whether any of the passed documents is already existing
in the chosen index and returns a list of
Expand Down Expand Up @@ -1521,7 +1519,7 @@ def _drop_duplicate_documents(self, documents: List[Document]) -> List[Document]

return _documents

def _prepare_collection_params(self):
def _prepare_collection_params(self) -> Dict[str, Any]:
"""
Prepares the common parameters for collection creation.
"""
Expand All @@ -1537,26 +1535,31 @@ def _prepare_collection_params(self):
"init_from": self.init_from,
}

def _prepare_client_params(self):
def _prepare_client_params(self) -> Dict[str, Any]:
"""
Prepares the common parameters for client initialization.

"""
return {
"location": self.location,
"url": self.url,
"port": self.port,
"grpc_port": self.grpc_port,
"prefer_grpc": self.prefer_grpc,
"https": self.https,
"api_key": self.api_key.resolve_value() if self.api_key else None,
"prefix": self.prefix,
"timeout": self.timeout,
"host": self.host,
"path": self.path,
"metadata": self.metadata,
"force_disable_check_same_thread": self.force_disable_check_same_thread,
}
# NOTE: We need to use deepcopy here to avoid modifying the original class attributes.
# For example, the resolved api key is added to metadata by the QdrantClient class when using a hosted
# Qdrant service, which means running to_dict() exposes the api key.
return deepcopy(
{
"location": self.location,
"url": self.url,
"port": self.port,
"grpc_port": self.grpc_port,
"prefer_grpc": self.prefer_grpc,
"https": self.https,
"api_key": self.api_key.resolve_value() if self.api_key else None,
"prefix": self.prefix,
"timeout": self.timeout,
"host": self.host,
"path": self.path,
"metadata": self.metadata,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be enough to only deepcopy metadata? Or maybe something like: "metadata": {**self.metadata}?

All other params are primitives and they should not be affected when the QDrantClient changes them. Or am I missing something?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's fair we could just target the metadata field in this case

"force_disable_check_same_thread": self.force_disable_check_same_thread,
}
)

def _prepare_collection_config(
self,
Expand All @@ -1565,7 +1568,7 @@ def _prepare_collection_config(
on_disk: Optional[bool] = None,
use_sparse_embeddings: Optional[bool] = None,
sparse_idf: bool = False,
):
) -> Tuple[Dict[str, rest.VectorParams], Optional[Dict[str, rest.SparseVectorParams]]]:
"""
Prepares the configuration for creating or recreating a Qdrant collection.

Expand Down Expand Up @@ -1595,9 +1598,12 @@ def _prepare_collection_config(

return vectors_config, sparse_vectors_config

def _validate_filters(self, filters: Optional[Union[Dict[str, Any], rest.Filter]] = None):
def _validate_filters(self, filters: Optional[Union[Dict[str, Any], rest.Filter]] = None) -> None:
"""
Validates the filters provided for querying.

:param filters: Filters to validate. Can be a dictionary or an instance of `qdrant_client.http.models.Filter`.
:raises ValueError: If the filters are not in the correct format or syntax.
"""
if filters and not isinstance(filters, dict) and not isinstance(filters, rest.Filter):
msg = "Filter must be a dictionary or an instance of `qdrant_client.http.models.Filter`"
Expand All @@ -1607,7 +1613,7 @@ def _validate_filters(self, filters: Optional[Union[Dict[str, Any], rest.Filter]
msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details."
raise ValueError(msg)

def _process_query_point_results(self, results, scale_score: bool = False):
def _process_query_point_results(self, results: List[QdrantPoint], scale_score: bool = False) -> List[Document]:
"""
Processes query results from Qdrant.
"""
Expand All @@ -1627,7 +1633,7 @@ def _process_query_point_results(self, results, scale_score: bool = False):

return documents

def _process_group_results(self, groups):
def _process_group_results(self, groups: List[rest.PointGroup]) -> List[Document]:
"""
Processes grouped query results from Qdrant.

Expand All @@ -1647,7 +1653,7 @@ def _validate_collection_compatibility(
collection_info,
distance,
embedding_dim: int,
):
) -> None:
"""
Validates that an existing collection is compatible with the current configuration.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ def convert_filters_to_qdrant(


def build_filters_for_repeated_operators(
must_clauses,
should_clauses,
must_not_clauses,
qdrant_filter,
must_clauses: List,
should_clauses: List,
must_not_clauses: List,
qdrant_filter: List[models.Filter],
) -> List[models.Filter]:
"""
Flattens the nested lists of clauses by creating separate Filters for each clause of a logical operator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
logger.setLevel(python_logging.INFO)


def migrate_to_sparse_embeddings_support(old_document_store: QdrantDocumentStore, new_index: str):
def migrate_to_sparse_embeddings_support(old_document_store: QdrantDocumentStore, new_index: str) -> None:
"""
Utility function to migrate an existing `QdrantDocumentStore` to a new one with support for sparse embeddings.

Expand Down
Loading