Skip to content

Commit f9e0b2a

Browse files
authored
fix(pgvector): order retrieval by distance operator
1 parent 92810ad commit f9e0b2a

2 files changed

Lines changed: 49 additions & 7 deletions

File tree

integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,10 +1351,13 @@ def _check_and_build_embedding_retrieval_query(
13511351
# cosine_similarity and inner_product are modified from the result of the operator
13521352
if vector_function == "cosine_similarity":
13531353
score_definition = f"1 - (embedding <=> {query_embedding_for_postgres}) AS score"
1354+
order_by_definition = f"embedding <=> {query_embedding_for_postgres}"
13541355
elif vector_function == "inner_product":
13551356
score_definition = f"(embedding <#> {query_embedding_for_postgres}) * -1 AS score"
1357+
order_by_definition = f"embedding <#> {query_embedding_for_postgres}"
13561358
elif vector_function == "l2_distance":
13571359
score_definition = f"embedding <-> {query_embedding_for_postgres} AS score"
1360+
order_by_definition = f"embedding <-> {query_embedding_for_postgres}"
13581361

13591362
sql_select = SQL("SELECT *, {score} FROM {schema_name}.{table_name}").format(
13601363
schema_name=Identifier(self.schema_name),
@@ -1367,13 +1370,9 @@ def _check_and_build_embedding_retrieval_query(
13671370
if filters:
13681371
sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters)
13691372

1370-
# we always want to return the most similar documents first
1371-
# so when using l2_distance, the sort order must be ASC
1372-
sort_order = "ASC" if vector_function == "l2_distance" else "DESC"
1373-
1374-
sql_sort = SQL(" ORDER BY score {sort_order} LIMIT {top_k}").format(
1373+
sql_sort = SQL(" ORDER BY {order_by} ASC LIMIT {top_k}").format(
13751374
top_k=SQLLiteral(top_k),
1376-
sort_order=SQL(sort_order),
1375+
order_by=SQL(order_by_definition),
13771376
)
13781377

13791378
sql_query = sql_select + sql_where_clause + sql_sort

integrations/pgvector/tests/test_document_store.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,16 @@
2525
)
2626
from haystack.utils import Secret
2727
from psycopg import Connection, Cursor, Error
28-
from psycopg.sql import SQL
28+
from psycopg.adapt import Transformer
29+
from psycopg.sql import SQL, Composed
2930

3031
from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore
3132

3233

34+
def _render(composed: Composed) -> str:
35+
return composed.as_string(Transformer())
36+
37+
3338
@pytest.mark.integration
3439
class TestDocumentStore(
3540
CountDocumentsTest,
@@ -250,6 +255,44 @@ def test_check_and_build_embedding_retrieval_query_rejects_invalid_vector_functi
250255
)
251256

252257

258+
@pytest.mark.parametrize(
259+
("vector_function", "score_sql", "order_by_sql"),
260+
[
261+
(
262+
"cosine_similarity",
263+
"1 - (embedding <=> '[0.1,0.2]') AS score",
264+
"embedding <=> '[0.1,0.2]'",
265+
),
266+
(
267+
"inner_product",
268+
"(embedding <#> '[0.1,0.2]') * -1 AS score",
269+
"embedding <#> '[0.1,0.2]'",
270+
),
271+
(
272+
"l2_distance",
273+
"embedding <-> '[0.1,0.2]' AS score",
274+
"embedding <-> '[0.1,0.2]'",
275+
),
276+
],
277+
)
278+
def test_check_and_build_embedding_retrieval_query_orders_by_distance_operator(
279+
mock_store, vector_function, score_sql, order_by_sql
280+
):
281+
mock_store.embedding_dimension = 2
282+
283+
sql_query, params = mock_store._check_and_build_embedding_retrieval_query(
284+
query_embedding=[0.1, 0.2],
285+
vector_function=vector_function,
286+
top_k=5,
287+
)
288+
289+
rendered = _render(sql_query)
290+
assert score_sql in rendered
291+
assert f"ORDER BY {order_by_sql} ASC LIMIT 5" in rendered
292+
assert "ORDER BY score" not in rendered
293+
assert params == ()
294+
295+
253296
@pytest.mark.parametrize(
254297
"bad_field",
255298
[

0 commit comments

Comments
 (0)