Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1351,10 +1351,13 @@ def _check_and_build_embedding_retrieval_query(
# cosine_similarity and inner_product are modified from the result of the operator
if vector_function == "cosine_similarity":
score_definition = f"1 - (embedding <=> {query_embedding_for_postgres}) AS score"
order_by_definition = f"embedding <=> {query_embedding_for_postgres}"
elif vector_function == "inner_product":
score_definition = f"(embedding <#> {query_embedding_for_postgres}) * -1 AS score"
order_by_definition = f"embedding <#> {query_embedding_for_postgres}"
elif vector_function == "l2_distance":
score_definition = f"embedding <-> {query_embedding_for_postgres} AS score"
order_by_definition = f"embedding <-> {query_embedding_for_postgres}"

sql_select = SQL("SELECT *, {score} FROM {schema_name}.{table_name}").format(
schema_name=Identifier(self.schema_name),
Expand All @@ -1367,13 +1370,9 @@ def _check_and_build_embedding_retrieval_query(
if filters:
sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters)

# we always want to return the most similar documents first
# so when using l2_distance, the sort order must be ASC
sort_order = "ASC" if vector_function == "l2_distance" else "DESC"

sql_sort = SQL(" ORDER BY score {sort_order} LIMIT {top_k}").format(
sql_sort = SQL(" ORDER BY {order_by} ASC LIMIT {top_k}").format(
top_k=SQLLiteral(top_k),
sort_order=SQL(sort_order),
order_by=SQL(order_by_definition),
)

sql_query = sql_select + sql_where_clause + sql_sort
Expand Down
41 changes: 41 additions & 0 deletions integrations/pgvector/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,47 @@ def test_check_and_build_embedding_retrieval_query_rejects_invalid_vector_functi
)


@pytest.mark.parametrize(
("vector_function", "score_sql", "order_by_sql"),
[
(
"cosine_similarity",
"1 - (embedding <=> '[0.1,0.2]') AS score",
"embedding <=> '[0.1,0.2]'",
),
(
"inner_product",
"(embedding <#> '[0.1,0.2]') * -1 AS score",
"embedding <#> '[0.1,0.2]'",
),
(
"l2_distance",
"embedding <-> '[0.1,0.2]' AS score",
"embedding <-> '[0.1,0.2]'",
),
],
)
def test_check_and_build_embedding_retrieval_query_orders_by_distance_operator(
mock_store, vector_function, score_sql, order_by_sql
):
mock_store.embedding_dimension = 2

sql_query, params = mock_store._check_and_build_embedding_retrieval_query(
query_embedding=[0.1, 0.2],
vector_function=vector_function,
top_k=5,
)

rendered = repr(sql_query)
assert score_sql in rendered
assert "SQL(' ORDER BY ')" in rendered
assert f'SQL("{order_by_sql}")' in rendered
assert "SQL(' ASC LIMIT ')" in rendered
assert "Literal(5)" in rendered
assert "ORDER BY score" not in rendered
assert params == ()


@pytest.mark.parametrize(
"bad_field",
[
Expand Down
Loading