Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions integrations/pgvector/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,22 @@ name = "pgvector-haystack"
dynamic = ["version"]
description = "An integration of pgvector (vector search extension for Postgres) with Haystack"
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = "Apache-2.0"
keywords = []
authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }]
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.11.0", "pgvector>=0.3.0", "psycopg[binary]"]
dependencies = ["haystack-ai>=2.22.0", "pgvector>=0.3.0", "psycopg[binary]"]

[project.urls]
Source = "https://github.com/deepset-ai/haystack-core-integrations"
Expand Down Expand Up @@ -83,7 +82,6 @@ ignore_missing_imports = true


[tool.ruff]
target-version = "py39"
line-length = 120

[tool.ruff.lint]
Expand Down Expand Up @@ -134,10 +132,6 @@ ignore = [
# ignore assertions
"S101",
]
unfixable = [
# Don't touch unused imports
"F401",
]

[tool.ruff.lint.isort]
known-first-party = ["haystack_integrations"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Literal, Optional, Union
from typing import Any, Literal

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
Expand Down Expand Up @@ -62,10 +62,10 @@ def __init__(
self,
*,
document_store: PgvectorDocumentStore,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] | None = None,
filter_policy: str | FilterPolicy = FilterPolicy.REPLACE,
):
"""
:param document_store: An instance of `PgvectorDocumentStore`.
Expand Down Expand Up @@ -137,9 +137,9 @@ def from_dict(cls, data: dict[str, Any]) -> "PgvectorEmbeddingRetriever":
def run(
self,
query_embedding: list[float],
filters: Optional[dict[str, Any]] = None,
top_k: Optional[int] = None,
vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] | None = None,
) -> dict[str, list[Document]]:
"""
Retrieve documents from the `PgvectorDocumentStore`, based on their embeddings.
Expand Down Expand Up @@ -170,9 +170,9 @@ def run(
async def run_async(
self,
query_embedding: list[float],
filters: Optional[dict[str, Any]] = None,
top_k: Optional[int] = None,
vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] | None = None,
) -> dict[str, list[Document]]:
"""
Asynchronously retrieve documents from the `PgvectorDocumentStore`, based on their embeddings.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional, Union
from typing import Any

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document
Expand Down Expand Up @@ -52,9 +52,9 @@ def __init__(
self,
*,
document_store: PgvectorDocumentStore,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
filter_policy: str | FilterPolicy = FilterPolicy.REPLACE,
):
"""
:param document_store: An instance of `PgvectorDocumentStore`.
Expand Down Expand Up @@ -111,8 +111,8 @@ def from_dict(cls, data: dict[str, Any]) -> "PgvectorKeywordRetriever":
def run(
self,
query: str,
filters: Optional[dict[str, Any]] = None,
top_k: Optional[int] = None,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
) -> dict[str, list[Document]]:
"""
Retrieve documents from the `PgvectorDocumentStore`, based on keywords.
Expand Down Expand Up @@ -141,8 +141,8 @@ def run(
async def run_async(
self,
query: str,
filters: Optional[dict[str, Any]] = None,
top_k: Optional[int] = None,
filters: dict[str, Any] | None = None,
top_k: int | None = None,
) -> dict[str, list[Document]]:
"""
Asynchronously retrieve documents from the `PgvectorDocumentStore`, based on keywords.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Literal, Optional, Union, overload
from typing import Any, Literal, overload

from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses.document import Document
Expand Down Expand Up @@ -92,9 +92,9 @@ def __init__(
recreate_table: bool = False,
search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor",
hnsw_recreate_index_if_exists: bool = False,
hnsw_index_creation_kwargs: Optional[dict[str, int]] = None,
hnsw_index_creation_kwargs: dict[str, int] | None = None,
hnsw_index_name: str = "haystack_hnsw_index",
hnsw_ef_search: Optional[int] = None,
hnsw_ef_search: int | None = None,
keyword_index_name: str = "haystack_keyword_index",
):
"""
Expand Down Expand Up @@ -175,12 +175,12 @@ def __init__(
self.keyword_index_name = keyword_index_name
self.language = language

self._connection: Optional[Connection] = None
self._async_connection: Optional[AsyncConnection] = None
self._cursor: Optional[Cursor] = None
self._async_cursor: Optional[AsyncCursor] = None
self._dict_cursor: Optional[Cursor[DictRow]] = None
self._async_dict_cursor: Optional[AsyncCursor[DictRow]] = None
self._connection: Connection | None = None
self._async_connection: AsyncConnection | None = None
self._cursor: Cursor | None = None
self._async_cursor: AsyncCursor | None = None
self._dict_cursor: Cursor[DictRow] | None = None
self._async_dict_cursor: AsyncCursor[DictRow] | None = None
self._table_initialized = False

def to_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -250,21 +250,21 @@ async def _connection_is_valid_async(connection):

@overload
def _execute_sql(
self, cursor: Cursor, sql_query: Composed, params: Optional[tuple] = None, error_msg: str = ""
self, cursor: Cursor, sql_query: Composed, params: tuple | None = None, error_msg: str = ""
) -> Cursor: ...

@overload
def _execute_sql(
self, cursor: Cursor[DictRow], sql_query: Composed, params: Optional[tuple] = None, error_msg: str = ""
self, cursor: Cursor[DictRow], sql_query: Composed, params: tuple | None = None, error_msg: str = ""
) -> Cursor[DictRow]: ...

def _execute_sql(
self,
cursor: Union[Cursor, Cursor[DictRow]],
cursor: Cursor | Cursor[DictRow],
sql_query: Composed,
params: Optional[tuple] = None,
params: tuple | None = None,
error_msg: str = "",
) -> Union[Cursor, Cursor[DictRow]]:
) -> Cursor | Cursor[DictRow]:
"""
Internal method to execute SQL statements and handle exceptions.

Expand Down Expand Up @@ -299,21 +299,21 @@ def _execute_sql(

@overload
async def _execute_sql_async(
self, cursor: AsyncCursor, sql_query: Composed, params: Optional[tuple] = None, error_msg: str = ""
self, cursor: AsyncCursor, sql_query: Composed, params: tuple | None = None, error_msg: str = ""
) -> AsyncCursor: ...

@overload
async def _execute_sql_async(
self, cursor: AsyncCursor[DictRow], sql_query: Composed, params: Optional[tuple] = None, error_msg: str = ""
self, cursor: AsyncCursor[DictRow], sql_query: Composed, params: tuple | None = None, error_msg: str = ""
) -> AsyncCursor[DictRow]: ...

async def _execute_sql_async(
self,
cursor: Union[AsyncCursor, AsyncCursor[DictRow]],
cursor: AsyncCursor | AsyncCursor[DictRow],
sql_query: Composed,
params: Optional[tuple] = None,
params: tuple | None = None,
error_msg: str = "",
) -> Union[AsyncCursor, AsyncCursor[DictRow]]:
) -> AsyncCursor | AsyncCursor[DictRow]:
"""
Internal method to asynchronously execute SQL statements and handle exceptions.

Expand Down Expand Up @@ -759,7 +759,7 @@ async def count_documents_async(self) -> int:
return result[0]
return 0

def filter_documents(self, filters: Optional[dict[str, Any]] = None) -> list[Document]:
def filter_documents(self, filters: dict[str, Any] | None = None) -> list[Document]:
"""
Returns the documents that match the filters provided.

Expand Down Expand Up @@ -796,7 +796,7 @@ def filter_documents(self, filters: Optional[dict[str, Any]] = None) -> list[Doc
docs = _from_pg_to_haystack_documents(records)
return docs

async def filter_documents_async(self, filters: Optional[dict[str, Any]] = None) -> list[Document]:
async def filter_documents_async(self, filters: dict[str, Any] | None = None) -> list[Document]:
"""
Asynchronously returns the documents that match the filters provided.

Expand Down Expand Up @@ -1223,7 +1223,7 @@ async def update_by_filter_async(self, filters: dict[str, Any], meta: dict[str,
raise DocumentStoreError(msg) from e

def _build_keyword_retrieval_query(
self, query: str, top_k: int, filters: Optional[dict[str, Any]] = None
self, query: str, top_k: int, filters: dict[str, Any] | None = None
) -> tuple[Composed, tuple]:
"""
Builds the SQL query and the where parameters for keyword retrieval.
Expand All @@ -1236,7 +1236,7 @@ def _build_keyword_retrieval_query(
)

where_params = ()
sql_where_clause: Union[Composed, SQL] = SQL("")
sql_where_clause: Composed | SQL = SQL("")
if filters:
sql_where_clause, where_params = _convert_filters_to_where_clause_and_params(
filters=filters, operator="AND"
Expand All @@ -1252,7 +1252,7 @@ def _keyword_retrieval(
self,
query: str,
*,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
) -> list[Document]:
"""
Expand Down Expand Up @@ -1287,7 +1287,7 @@ async def _keyword_retrieval_async(
self,
query: str,
*,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
) -> list[Document]:
"""
Expand Down Expand Up @@ -1315,9 +1315,9 @@ async def _keyword_retrieval_async(
def _check_and_build_embedding_retrieval_query(
self,
query_embedding: list[float],
vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]],
vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] | None,
top_k: int,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
) -> tuple[Composed, tuple]:
"""
Performs checks and builds the SQL query and the where parameters for embedding retrieval.
Expand Down Expand Up @@ -1357,7 +1357,7 @@ def _check_and_build_embedding_retrieval_query(
score=SQL(score_definition),
)

sql_where_clause: Union[Composed, SQL] = SQL("")
sql_where_clause: Composed | SQL = SQL("")
params = ()
if filters:
sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters)
Expand All @@ -1379,9 +1379,9 @@ def _embedding_retrieval(
self,
query_embedding: list[float],
*,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None,
vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] | None = None,
) -> list[Document]:
"""
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
Expand Down Expand Up @@ -1413,9 +1413,9 @@ async def _embedding_retrieval_async(
self,
query_embedding: list[float],
*,
filters: Optional[dict[str, Any]] = None,
filters: dict[str, Any] | None = None,
top_k: int = 10,
vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None,
vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] | None = None,
) -> list[Document]:
"""
Asynchronously retrieves documents that are most similar to the query embedding using a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime
from itertools import chain
from typing import Any, Literal, Optional
from typing import Any, Literal

from haystack.errors import FilterError
from psycopg.sql import SQL, Composed
Expand All @@ -21,7 +21,7 @@
NO_VALUE = "no_value"


def _validate_filters(filters: Optional[dict[str, Any]] = None) -> None:
def _validate_filters(filters: dict[str, Any] | None = None) -> None:
"""
Validates the filters provided.
"""
Expand Down
2 changes: 1 addition & 1 deletion integrations/pgvector/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_halfvec_hnsw_write_documents(document_store_w_halfvec_hnsw_index: Pgvec
retrieved_docs = document_store_w_halfvec_hnsw_index.filter_documents()
retrieved_docs.sort(key=lambda x: x.id)

for original_doc, retrieved_doc in zip(documents, retrieved_docs):
for original_doc, retrieved_doc in zip(documents, retrieved_docs, strict=True):
assert original_doc.id == retrieved_doc.id
assert original_doc.content == retrieved_doc.content
assert len(original_doc.embedding) == len(retrieved_doc.embedding)
Expand Down
2 changes: 1 addition & 1 deletion integrations/pgvector/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def assert_documents_are_equal(self, received: list[Document], expected: list[Do
assert len(received) == len(expected)
received.sort(key=lambda x: x.id)
expected.sort(key=lambda x: x.id)
for received_doc, expected_doc in zip(received, expected):
for received_doc, expected_doc in zip(received, expected, strict=True):
# we first compare the embeddings approximately
if received_doc.embedding is None:
assert expected_doc.embedding is None
Expand Down