Skip to content

Commit c7a4650

Browse files
authored
feat: Unify the execute and explain APIs (#54)
* unify the execute api with an execution strategy * unify explain api and update tests * update style
1 parent 2e0bcae commit c7a4650

11 files changed

Lines changed: 598 additions & 270 deletions

python/python/tests/test_explain.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Tests for explain_datafusion API."""
1+
"""Tests for explain API."""
22

33
import pyarrow as pa
44
import pytest
@@ -25,7 +25,7 @@ def test_explain_simple_query(person_data):
2525
"""Test explain output contains all expected sections."""
2626
config, people = person_data
2727
query = CypherQuery("MATCH (p:Person) RETURN p.name, p.age").with_config(config)
28-
plan = query.explain_datafusion({"Person": people})
28+
plan = query.explain({"Person": people})
2929

3030
# Verify the plan is a non-empty string
3131
assert isinstance(plan, str)
@@ -48,7 +48,7 @@ def test_explain_with_clauses(person_data):
4848
query = CypherQuery(
4949
"MATCH (p:Person) WHERE p.age > 30 RETURN p.name ORDER BY p.age LIMIT 2"
5050
).with_config(config)
51-
plan = query.explain_datafusion({"Person": people})
51+
plan = query.explain({"Person": people})
5252

5353
assert isinstance(plan, str)
5454
assert "WHERE p.age > 30" in plan
@@ -63,11 +63,11 @@ def test_explain_error_handling(person_data):
6363
# Missing config
6464
query_no_config = CypherQuery("MATCH (p:Person) RETURN p.name")
6565
with pytest.raises(ValueError, match="Graph configuration is required"):
66-
query_no_config.explain_datafusion({"Person": people})
66+
query_no_config.explain({"Person": people})
6767

6868
# Missing datasets
6969
query_with_config = CypherQuery("MATCH (p:Person) RETURN p.name").with_config(
7070
config
7171
)
7272
with pytest.raises(ValueError, match="No input datasets provided"):
73-
query_with_config.explain_datafusion({})
73+
query_with_config.explain({})

python/python/tests/test_graph.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,41 +62,38 @@ def graph_env(tmp_path):
6262
return config, datasets, people_table
6363

6464

65-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
66-
def test_basic_node_selection(graph_env, execute_method):
65+
def test_basic_node_selection(graph_env):
6766
config, datasets, _ = graph_env
6867
query = CypherQuery("MATCH (p:Person) RETURN p.name, p.age").with_config(config)
69-
result = getattr(query, execute_method)({"Person": datasets["Person"]})
68+
result = query.execute({"Person": datasets["Person"]})
7069
data = result.to_pydict()
7170

7271
assert set(data.keys()) == {"p.name", "p.age"}
7372
assert len(data["p.name"]) == 4
7473
assert "Alice" in set(data["p.name"])
7574

7675

77-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
78-
def test_filtered_query(graph_env, execute_method):
76+
def test_filtered_query(graph_env):
7977
config, datasets, _ = graph_env
8078
query = CypherQuery(
8179
"MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age"
8280
).with_config(config)
83-
result = getattr(query, execute_method)({"Person": datasets["Person"]})
81+
result = query.execute({"Person": datasets["Person"]})
8482
data = result.to_pydict()
8583

8684
assert len(data["p.name"]) == 2
8785
assert set(data["p.name"]) == {"Bob", "David"}
8886
assert all(age > 30 for age in data["p.age"])
8987

9088

91-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
92-
def test_relationship_query(graph_env, execute_method):
89+
def test_relationship_query(graph_env):
9390
config, datasets, _ = graph_env
9491
query = CypherQuery(
9592
"MATCH (p:Person)-[:WORKS_FOR]->(c:Company) "
9693
"RETURN p.person_id AS person_id, p.name AS name, c.company_id AS company_id"
9794
).with_config(config)
9895

99-
result = getattr(query, execute_method)(
96+
result = query.execute(
10097
{
10198
"Person": datasets["Person"],
10299
"Company": datasets["Company"],
@@ -109,8 +106,7 @@ def test_relationship_query(graph_env, execute_method):
109106
assert data["company_id"] == [101, 101, 102, 103]
110107

111108

112-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
113-
def test_friendship_direct_and_network(graph_env, execute_method):
109+
def test_friendship_direct_and_network(graph_env):
114110
config, datasets, _ = graph_env
115111
# Direct friends of Alice (person_id = 1)
116112
query_direct = CypherQuery(
@@ -119,7 +115,7 @@ def test_friendship_direct_and_network(graph_env, execute_method):
119115
"RETURN b.person_id AS friend_id"
120116
).with_config(config)
121117

122-
result_direct = getattr(query_direct, execute_method)(
118+
result_direct = query_direct.execute(
123119
{
124120
"Person": datasets["Person"],
125121
"FRIEND_OF": datasets["FRIEND_OF"],
@@ -134,7 +130,7 @@ def test_friendship_direct_and_network(graph_env, execute_method):
134130
"RETURN f.person_id AS person1_id, t.person_id AS person2_id"
135131
).with_config(config)
136132

137-
result_edges = getattr(query_edges, execute_method)(
133+
result_edges = query_edges.execute(
138134
{
139135
"Person": datasets["Person"],
140136
"FRIEND_OF": datasets["FRIEND_OF"],
@@ -145,16 +141,15 @@ def test_friendship_direct_and_network(graph_env, execute_method):
145141
assert got == {(1, 2), (1, 3), (2, 4), (3, 4)}
146142

147143

148-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
149-
def test_two_hop_friends_of_friends(graph_env, execute_method):
144+
def test_two_hop_friends_of_friends(graph_env):
150145
config, datasets, _ = graph_env
151146
query = CypherQuery(
152147
"MATCH (a:Person)-[:FRIEND_OF]->(b:Person)-[:FRIEND_OF]->(c:Person) "
153148
"WHERE a.person_id = 1 "
154149
"RETURN a.person_id AS a_id, b.person_id AS b_id, c.person_id AS c_id"
155150
).with_config(config)
156151

157-
result = getattr(query, execute_method)(
152+
result = query.execute(
158153
{
159154
"Person": datasets["Person"],
160155
"FRIEND_OF": datasets["FRIEND_OF"],
@@ -164,29 +159,31 @@ def test_two_hop_friends_of_friends(graph_env, execute_method):
164159
assert set(data["c_id"]) == {4}
165160

166161

167-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
168-
def test_variable_length_path(graph_env, execute_method):
162+
def test_variable_length_path(graph_env):
169163
config, datasets, _ = graph_env
170164
query = CypherQuery(
171165
"MATCH (p1:Person)-[:FRIEND_OF*1..2]-(p2:Person) "
172166
"RETURN p1.person_id AS p1, p2.person_id AS p2"
173167
).with_config(config)
174-
_ = getattr(query, execute_method)(
168+
169+
result = query.execute(
175170
{
176171
"Person": datasets["Person"],
177172
"FRIEND_OF": datasets["FRIEND_OF"],
178173
}
179174
)
175+
data = result.to_pydict()
176+
got = set(zip(data["p1"], data["p2"]))
177+
assert got == {(1, 2), (1, 3), (2, 4), (3, 4), (1, 4)}
180178

181179

182-
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
183-
def test_distinct_clause(graph_env, execute_method):
180+
def test_distinct_clause(graph_env):
184181
config, datasets, _ = graph_env
185182
query = CypherQuery(
186183
"MATCH (p:Person)-[:WORKS_FOR]->(c:Company) RETURN DISTINCT c.company_name"
187184
).with_config(config)
188185

189-
result = getattr(query, execute_method)(
186+
result = query.execute(
190187
{
191188
"Person": datasets["Person"],
192189
"Company": datasets["Company"],

python/src/graph.rs

Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ use arrow::ffi_stream::ArrowArrayStreamReader;
2222
use arrow_array::{RecordBatch, RecordBatchReader};
2323
use arrow_schema::Schema;
2424
use lance_graph::{
25-
CypherQuery as RustCypherQuery, GraphConfig as RustGraphConfig, GraphError as RustGraphError,
25+
ExecutionStrategy as RustExecutionStrategy, CypherQuery as RustCypherQuery,
26+
GraphConfig as RustGraphConfig, GraphError as RustGraphError,
2627
};
2728
use pyo3::{
2829
exceptions::{PyNotImplementedError, PyRuntimeError, PyValueError},
@@ -34,6 +35,28 @@ use serde_json::Value as JsonValue;
3435

3536
use crate::RT;
3637

38+
/// Execution strategy for Cypher queries
39+
#[pyclass(name = "ExecutionStrategy", module = "lance.graph")]
40+
#[derive(Clone, Copy)]
41+
pub enum ExecutionStrategy {
42+
/// Use DataFusion query planner (default, full feature support)
43+
DataFusion,
44+
/// Use simple single-table executor (legacy, limited features)
45+
Simple,
46+
/// Use Lance native executor (not yet implemented)
47+
LanceNative,
48+
}
49+
50+
impl From<ExecutionStrategy> for RustExecutionStrategy {
51+
fn from(strategy: ExecutionStrategy) -> Self {
52+
match strategy {
53+
ExecutionStrategy::DataFusion => RustExecutionStrategy::DataFusion,
54+
ExecutionStrategy::Simple => RustExecutionStrategy::Simple,
55+
ExecutionStrategy::LanceNative => RustExecutionStrategy::LanceNative,
56+
}
57+
}
58+
}
59+
3760
/// Convert GraphError to PyErr
3861
fn graph_error_to_pyerr(err: RustGraphError) -> PyErr {
3962
match &err {
@@ -267,6 +290,8 @@ impl CypherQuery {
267290
/// ----------
268291
/// datasets : dict
269292
/// Dictionary mapping table names to Lance datasets
293+
/// strategy : ExecutionStrategy, optional
294+
/// Execution strategy to use (defaults to DataFusion)
270295
///
271296
/// Returns
272297
/// -------
@@ -277,56 +302,40 @@ impl CypherQuery {
277302
/// ------
278303
/// RuntimeError
279304
/// If query execution fails
280-
fn execute(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<PyObject> {
281-
// Convert datasets to Arrow batches while holding the GIL - same as before
282-
let arrow_datasets = python_datasets_to_batches(datasets)?;
283-
284-
// Clone the inner query for use in the async block
285-
let inner_query = self.inner.clone();
286-
287-
// Use RT.block_on with Some(py) like the scanner to_pyarrow method
288-
let result_batch = RT
289-
.block_on(Some(py), inner_query.execute(arrow_datasets))?
290-
.map_err(graph_error_to_pyerr)?;
291-
292-
record_batch_to_python_table(py, &result_batch)
293-
}
294-
295-
/// Execute query using the DataFusion planner with in-memory datasets
296-
///
297-
/// Parameters
298-
/// ----------
299-
/// datasets : dict
300-
/// Dictionary mapping table names to in-memory tables (pyarrow.Table, LanceDataset, etc.)
301-
/// Keys should match node labels and relationship types in the graph config.
302305
///
303-
/// Returns
304-
/// -------
305-
/// pyarrow.Table
306-
/// Query results as Arrow table
306+
/// Examples
307+
/// --------
308+
/// >>> # Default strategy (DataFusion)
309+
/// >>> result = query.execute(datasets)
307310
///
308-
/// Raises
309-
/// ------
310-
/// ValueError
311-
/// If the query is invalid or datasets are missing
312-
/// RuntimeError
313-
/// If query execution fails
314-
fn execute_datafusion(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<PyObject> {
315-
// Convert datasets to Arrow RecordBatch map
311+
/// >>> # Explicit strategy
312+
/// >>> from lance.graph import ExecutionStrategy
313+
/// >>> result = query.execute(datasets, strategy=ExecutionStrategy.Simple)
314+
#[pyo3(signature = (datasets, strategy=None))]
315+
fn execute(
316+
&self,
317+
py: Python,
318+
datasets: &Bound<'_, PyDict>,
319+
strategy: Option<ExecutionStrategy>,
320+
) -> PyResult<PyObject> {
321+
// Convert datasets to Arrow batches while holding the GIL
316322
let arrow_datasets = python_datasets_to_batches(datasets)?;
317323

318-
// Clone for async move
324+
// Convert Python strategy to Rust strategy
325+
let rust_strategy = strategy.map(|s| s.into());
326+
327+
// Clone the inner query for use in the async block
319328
let inner_query = self.inner.clone();
320329

321-
// Execute via runtime
330+
// Use RT.block_on with Some(py) like the scanner to_pyarrow method
322331
let result_batch = RT
323-
.block_on(Some(py), inner_query.execute_datafusion(arrow_datasets))?
332+
.block_on(Some(py), inner_query.execute(arrow_datasets, rust_strategy))?
324333
.map_err(graph_error_to_pyerr)?;
325334

326335
record_batch_to_python_table(py, &result_batch)
327336
}
328337

329-
/// Explain query uusing the DataFusion planner with in-memory datasets
338+
/// Explain query using the DataFusion planner with in-memory datasets
330339
///
331340
/// Parameters
332341
/// ----------
@@ -345,7 +354,7 @@ impl CypherQuery {
345354
/// If the query is invalid or datasets are missing
346355
/// RuntimeError
347356
/// If query explain fails
348-
fn explain_datafusion(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
357+
fn explain(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
349358
// Convert datasets to Arrow RecordBatch map
350359
let arrow_datasets = python_datasets_to_batches(datasets)?;
351360

@@ -354,7 +363,7 @@ impl CypherQuery {
354363

355364
// Execute via runtime
356365
let plan = RT
357-
.block_on(Some(py), inner_query.explain_datafusion(arrow_datasets))?
366+
.block_on(Some(py), inner_query.explain(arrow_datasets))?
358367
.map_err(graph_error_to_pyerr)?;
359368

360369
Ok(plan)
@@ -562,6 +571,7 @@ fn record_batch_to_python_table(
562571
pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
563572
let graph_module = PyModule::new(py, "graph")?;
564573

574+
graph_module.add_class::<ExecutionStrategy>()?;
565575
graph_module.add_class::<GraphConfig>()?;
566576
graph_module.add_class::<GraphConfigBuilder>()?;
567577
graph_module.add_class::<CypherQuery>()?;

rust/lance-graph/benches/graph_execution.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use arrow_schema::{DataType, Field, Schema as ArrowSchema};
2222
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
2323
use futures::TryStreamExt;
2424
use lance::dataset::{Dataset, WriteMode, WriteParams};
25-
use lance_graph::{CypherQuery, GraphConfig};
25+
use lance_graph::{CypherQuery, ExecutionStrategy, GraphConfig};
2626
use tempfile::TempDir;
2727

2828
fn create_people_batch() -> RecordBatch {
@@ -71,7 +71,11 @@ fn execute_cypher_query(
7171
q: &CypherQuery,
7272
datasets: HashMap<String, RecordBatch>,
7373
) -> RecordBatch {
74-
rt.block_on(async move { q.execute(datasets).await.unwrap() })
74+
rt.block_on(async move {
75+
q.execute(datasets, Some(ExecutionStrategy::Simple))
76+
.await
77+
.unwrap()
78+
})
7579
}
7680

7781
fn make_people_batch(n: usize) -> RecordBatch {

rust/lance-graph/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ pub const MAX_VARIABLE_LENGTH_HOPS: u32 = 20;
5353

5454
pub use config::{GraphConfig, NodeMapping, RelationshipMapping};
5555
pub use error::{GraphError, Result};
56-
pub use query::CypherQuery;
56+
pub use query::{CypherQuery, ExecutionStrategy};

0 commit comments

Comments
 (0)