Skip to content

Commit 5598fa0

Browse files
committed
feat: expose the vector rerank method in the CypherEngine API
1 parent 17148d4 commit 5598fa0

2 files changed

Lines changed: 124 additions & 0 deletions

File tree

crates/lance-graph-python/src/graph.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,50 @@ impl CypherEngine {
10561056
}
10571057
}
10581058

1059+
/// Execute Cypher query with vector reranking
1060+
///
1061+
/// Convenience method combining graph traversal and vector similarity ranking.
1062+
/// See CypherQuery.execute_with_vector_rerank for detailed documentation.
1063+
///
1064+
/// Parameters
1065+
/// ----------
1066+
/// query : str
1067+
/// Cypher query string
1068+
/// vector_search : VectorSearch
1069+
/// Vector search configuration
1070+
///
1071+
/// Returns
1072+
/// -------
1073+
/// pyarrow.Table
1074+
/// Results sorted by vector similarity
1075+
fn execute_with_vector_rerank(
1076+
&self,
1077+
py: Python,
1078+
query: &str,
1079+
vector_search: &VectorSearch,
1080+
) -> PyResult<PyObject> {
1081+
// Parse query and execute with cached catalog/context
1082+
let cypher_query = RustCypherQuery::new(query)
1083+
.map_err(graph_error_to_pyerr)?
1084+
.with_config(self.config.clone());
1085+
1086+
let catalog = self.catalog.clone();
1087+
let context = self.context.as_ref().clone();
1088+
let vs = vector_search.inner.clone();
1089+
1090+
// Execute query to get candidates, then apply vector reranking
1091+
let result_batch = RT
1092+
.block_on(Some(py), async move {
1093+
let candidates = cypher_query
1094+
.execute_with_catalog_and_context(catalog, context)
1095+
.await?;
1096+
vs.search(&candidates).await
1097+
})?
1098+
.map_err(graph_error_to_pyerr)?;
1099+
1100+
record_batch_to_python_table(py, &result_batch)
1101+
}
1102+
10591103
fn __repr__(&self) -> String {
10601104
format!(
10611105
"CypherEngine(nodes={}, relationships={})",

python/python/tests/test_vector_search.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,83 @@ def test_vector_search_different_query_vectors(vector_env):
317317
.search(table)
318318
)
319319
assert results3.to_pydict()["name"][0] == "Doc4"
320+
321+
322+
def test_cypher_engine_execute_with_vector_rerank(vector_env):
323+
"""Test CypherEngine.execute_with_vector_rerank basic functionality."""
324+
from lance_graph import CypherEngine
325+
326+
config, datasets, _ = vector_env
327+
engine = CypherEngine(config, datasets)
328+
329+
results = engine.execute_with_vector_rerank(
330+
"MATCH (d:Document) WHERE d.category = 'tech' RETURN d.id, d.name, d.embedding",
331+
VectorSearch("d.embedding")
332+
.query_vector([1.0, 0.0, 0.0])
333+
.metric(DistanceMetric.L2)
334+
.top_k(2),
335+
)
336+
337+
data = results.to_pydict()
338+
assert len(data["d.name"]) == 2
339+
assert data["d.name"][0] == "Doc1"
340+
341+
342+
def test_cypher_engine_vs_cypher_query_vector_rerank_equivalence(vector_env):
343+
"""Test that CypherEngine produces same results as CypherQuery for vector rerank."""
344+
from lance_graph import CypherEngine
345+
346+
config, datasets, _ = vector_env
347+
348+
query_text = "MATCH (d:Document) WHERE d.category = 'tech' RETURN d.id, d.name, d.embedding"
349+
vector_search = (
350+
VectorSearch("d.embedding")
351+
.query_vector([1.0, 0.0, 0.0])
352+
.metric(DistanceMetric.L2)
353+
.top_k(2)
354+
)
355+
356+
# Execute with CypherQuery
357+
query = CypherQuery(query_text).with_config(config)
358+
result_query = query.execute_with_vector_rerank(datasets, vector_search)
359+
360+
# Execute with CypherEngine
361+
engine = CypherEngine(config, datasets)
362+
result_engine = engine.execute_with_vector_rerank(query_text, vector_search)
363+
364+
# Results should be identical
365+
assert result_query.to_pydict() == result_engine.to_pydict()
366+
367+
368+
def test_cypher_engine_vector_rerank_multiple_queries(vector_env):
369+
"""Test that CypherEngine efficiently handles multiple vector rerank queries."""
370+
from lance_graph import CypherEngine
371+
372+
config, datasets, _ = vector_env
373+
engine = CypherEngine(config, datasets)
374+
375+
# Execute multiple different queries using the same cached engine
376+
results1 = engine.execute_with_vector_rerank(
377+
"MATCH (d:Document) RETURN d.id, d.name, d.embedding",
378+
VectorSearch("d.embedding")
379+
.query_vector([1.0, 0.0, 0.0])
380+
.metric(DistanceMetric.L2)
381+
.top_k(2),
382+
)
383+
384+
results2 = engine.execute_with_vector_rerank(
385+
"MATCH (d:Document) WHERE d.category = 'science' RETURN d.id, d.name, d.embedding",
386+
VectorSearch("d.embedding")
387+
.query_vector([0.0, 1.0, 0.0])
388+
.metric(DistanceMetric.Cosine)
389+
.top_k(1),
390+
)
391+
392+
data1 = results1.to_pydict()
393+
data2 = results2.to_pydict()
394+
395+
assert len(data1["d.name"]) == 2
396+
assert data1["d.name"][0] == "Doc1"
397+
398+
assert len(data2["d.name"]) == 1
399+
assert data2["d.name"][0] == "Doc3"

0 commit comments

Comments
 (0)