Skip to content

Commit dc029bc

Browse files
authored
feat: add datafusion pipeline python apis (#31)
* feat: add datafusion pipeline python apis * test: add tests for datafusion-related python apis
1 parent 26dfb0e commit dc029bc

3 files changed

Lines changed: 209 additions & 21 deletions

File tree

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Tests for explain_datafusion API."""
2+
3+
import pyarrow as pa
4+
import pytest
5+
from lance_graph import CypherQuery, GraphConfig
6+
7+
8+
@pytest.fixture
9+
def person_data():
10+
"""Create simple Person dataset for testing."""
11+
people_table = pa.table(
12+
{
13+
"person_id": [1, 2, 3, 4],
14+
"name": ["Alice", "Bob", "Carol", "David"],
15+
"age": [28, 34, 29, 42],
16+
}
17+
)
18+
19+
config = GraphConfig.builder().with_node_label("Person", "person_id").build()
20+
21+
return config, people_table
22+
23+
24+
def test_explain_simple_query(person_data):
25+
"""Test explain output contains all expected sections."""
26+
config, people = person_data
27+
query = CypherQuery("MATCH (p:Person) RETURN p.name, p.age").with_config(config)
28+
plan = query.explain_datafusion({"Person": people})
29+
30+
# Verify the plan is a non-empty string
31+
assert isinstance(plan, str)
32+
assert len(plan) > 0
33+
34+
# Verify it contains expected sections
35+
assert "Cypher Query:" in plan
36+
assert "MATCH (p:Person) RETURN p.name, p.age" in plan
37+
assert "graph_logical_plan" in plan
38+
assert "logical_plan" in plan
39+
assert "physical_plan" in plan
40+
41+
# Verify table format
42+
assert "+" in plan and "|" in plan
43+
44+
45+
def test_explain_with_clauses(person_data):
46+
"""Test explain output includes query clauses (WHERE, ORDER BY, LIMIT)."""
47+
config, people = person_data
48+
query = CypherQuery(
49+
"MATCH (p:Person) WHERE p.age > 30 RETURN p.name ORDER BY p.age LIMIT 2"
50+
).with_config(config)
51+
plan = query.explain_datafusion({"Person": people})
52+
53+
assert isinstance(plan, str)
54+
assert "WHERE p.age > 30" in plan
55+
assert "ORDER BY" in plan
56+
assert "LIMIT" in plan
57+
58+
59+
def test_explain_error_handling(person_data):
60+
"""Test explain error handling for missing config and datasets."""
61+
config, people = person_data
62+
63+
# Missing config
64+
query_no_config = CypherQuery("MATCH (p:Person) RETURN p.name")
65+
with pytest.raises(ValueError, match="Graph configuration is required"):
66+
query_no_config.explain_datafusion({"Person": people})
67+
68+
# Missing datasets
69+
query_with_config = CypherQuery("MATCH (p:Person) RETURN p.name").with_config(
70+
config
71+
)
72+
with pytest.raises(ValueError, match="No input datasets provided"):
73+
query_with_config.explain_datafusion({})

python/python/tests/test_graph.py

Lines changed: 68 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,37 +62,55 @@ def graph_env(tmp_path):
6262
return config, datasets, people_table
6363

6464

65-
def test_basic_node_selection(graph_env):
65+
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
66+
def test_basic_node_selection(graph_env, execute_method):
6667
config, datasets, _ = graph_env
6768
query = CypherQuery("MATCH (p:Person) RETURN p.name, p.age").with_config(config)
68-
result = query.execute({"Person": datasets["Person"]})
69+
result = getattr(query, execute_method)({"Person": datasets["Person"]})
6970
data = result.to_pydict()
70-
assert len(data["name"]) == 4
71-
assert set(data.keys()) == {"name", "age"}
72-
assert "Alice" in set(data["name"])
7371

74-
75-
def test_filtered_query(graph_env):
72+
# TODO: remove this if/else statements when the execute() also returns
73+
# Cypher dot notation
74+
if execute_method == "execute":
75+
# execute() returns unqualified names for simple queries
76+
assert set(data.keys()) == {"name", "age"}
77+
assert len(data["name"]) == 4
78+
assert "Alice" in set(data["name"])
79+
else:
80+
# execute_datafusion() returns Cypher dot notation
81+
assert set(data.keys()) == {"p.name", "p.age"}
82+
assert len(data["p.name"]) == 4
83+
assert "Alice" in set(data["p.name"])
84+
85+
86+
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
87+
def test_filtered_query(graph_env, execute_method):
7688
config, datasets, _ = graph_env
7789
query = CypherQuery(
7890
"MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age"
7991
).with_config(config)
80-
result = query.execute({"Person": datasets["Person"]})
92+
result = getattr(query, execute_method)({"Person": datasets["Person"]})
8193
data = result.to_pydict()
82-
assert len(data["name"]) == 2
83-
assert set(data["name"]) == {"Bob", "David"}
84-
assert all(age > 30 for age in data["age"])
94+
95+
if execute_method == "execute":
96+
assert len(data["name"]) == 2
97+
assert set(data["name"]) == {"Bob", "David"}
98+
assert all(age > 30 for age in data["age"])
99+
else:
100+
assert len(data["p.name"]) == 2
101+
assert set(data["p.name"]) == {"Bob", "David"}
102+
assert all(age > 30 for age in data["p.age"])
85103

86104

87-
def test_relationship_query(graph_env):
105+
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
106+
def test_relationship_query(graph_env, execute_method):
88107
config, datasets, _ = graph_env
89-
# Alias outputs to stable column names regardless of internal qualification
90108
query = CypherQuery(
91109
"MATCH (p:Person)-[:WORKS_FOR]->(c:Company) "
92110
"RETURN p.person_id AS person_id, p.name AS name, c.company_id AS company_id"
93111
).with_config(config)
94112

95-
result = query.execute(
113+
result = getattr(query, execute_method)(
96114
{
97115
"Person": datasets["Person"],
98116
"Company": datasets["Company"],
@@ -105,7 +123,8 @@ def test_relationship_query(graph_env):
105123
assert data["company_id"] == [101, 101, 102, 103]
106124

107125

108-
def test_friendship_direct_and_network(graph_env):
126+
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
127+
def test_friendship_direct_and_network(graph_env, execute_method):
109128
config, datasets, _ = graph_env
110129
# Direct friends of Alice (person_id = 1)
111130
query_direct = CypherQuery(
@@ -114,7 +133,7 @@ def test_friendship_direct_and_network(graph_env):
114133
"RETURN b.person_id AS friend_id"
115134
).with_config(config)
116135

117-
result_direct = query_direct.execute(
136+
result_direct = getattr(query_direct, execute_method)(
118137
{
119138
"Person": datasets["Person"],
120139
"FRIEND_OF": datasets["FRIEND_OF"],
@@ -129,7 +148,7 @@ def test_friendship_direct_and_network(graph_env):
129148
"RETURN f.person_id AS person1_id, t.person_id AS person2_id"
130149
).with_config(config)
131150

132-
result_edges = query_edges.execute(
151+
result_edges = getattr(query_edges, execute_method)(
133152
{
134153
"Person": datasets["Person"],
135154
"FRIEND_OF": datasets["FRIEND_OF"],
@@ -140,15 +159,16 @@ def test_friendship_direct_and_network(graph_env):
140159
assert got == {(1, 2), (1, 3), (2, 4), (3, 4)}
141160

142161

143-
def test_two_hop_friends_of_friends(graph_env):
162+
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
163+
def test_two_hop_friends_of_friends(graph_env, execute_method):
144164
config, datasets, _ = graph_env
145165
query = CypherQuery(
146166
"MATCH (a:Person)-[:FRIEND_OF]->(b:Person)-[:FRIEND_OF]->(c:Person) "
147167
"WHERE a.person_id = 1 "
148168
"RETURN a.person_id AS a_id, b.person_id AS b_id, c.person_id AS c_id"
149169
).with_config(config)
150170

151-
result = query.execute(
171+
result = getattr(query, execute_method)(
152172
{
153173
"Person": datasets["Person"],
154174
"FRIEND_OF": datasets["FRIEND_OF"],
@@ -158,15 +178,42 @@ def test_two_hop_friends_of_friends(graph_env):
158178
assert set(data["c_id"]) == {4}
159179

160180

161-
def test_variable_length_path(graph_env):
181+
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
182+
def test_variable_length_path(graph_env, execute_method):
162183
config, datasets, _ = graph_env
163184
query = CypherQuery(
164185
"MATCH (p1:Person)-[:FRIEND_OF*1..2]-(p2:Person) "
165186
"RETURN p1.person_id AS p1, p2.person_id AS p2"
166187
).with_config(config)
167-
_ = query.execute(
188+
_ = getattr(query, execute_method)(
168189
{
169190
"Person": datasets["Person"],
170191
"FRIEND_OF": datasets["FRIEND_OF"],
171192
}
172193
)
194+
195+
196+
@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
197+
def test_distinct_clause(graph_env, execute_method):
198+
config, datasets, _ = graph_env
199+
query = CypherQuery(
200+
"MATCH (p:Person)-[:WORKS_FOR]->(c:Company) RETURN DISTINCT c.company_name"
201+
).with_config(config)
202+
203+
result = getattr(query, execute_method)(
204+
{
205+
"Person": datasets["Person"],
206+
"Company": datasets["Company"],
207+
"WORKS_FOR": datasets["WORKS_FOR"],
208+
}
209+
)
210+
data = result.to_pydict()
211+
212+
if execute_method == "execute":
213+
# execute() returns qualified column names for relationship queries
214+
assert len(data["c__company_name"]) == 3
215+
assert set(data["c__company_name"]) == {"TechCorp", "DataInc", "CloudSoft"}
216+
else:
217+
# execute_datafusion() returns Cypher dot notation
218+
assert len(data["c.company_name"]) == 3
219+
assert set(data["c.company_name"]) == {"TechCorp", "DataInc", "CloudSoft"}

python/src/graph.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,74 @@ impl CypherQuery {
292292
record_batch_to_python_table(py, &result_batch)
293293
}
294294

295+
/// Execute query using the DataFusion planner with in-memory datasets
296+
///
297+
/// Parameters
298+
/// ----------
299+
/// datasets : dict
300+
/// Dictionary mapping table names to in-memory tables (pyarrow.Table, LanceDataset, etc.)
301+
/// Keys should match node labels and relationship types in the graph config.
302+
///
303+
/// Returns
304+
/// -------
305+
/// pyarrow.Table
306+
/// Query results as Arrow table
307+
///
308+
/// Raises
309+
/// ------
310+
/// ValueError
311+
/// If the query is invalid or datasets are missing
312+
/// RuntimeError
313+
/// If query execution fails
314+
fn execute_datafusion(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<PyObject> {
315+
// Convert datasets to Arrow RecordBatch map
316+
let arrow_datasets = python_datasets_to_batches(datasets)?;
317+
318+
// Clone for async move
319+
let inner_query = self.inner.clone();
320+
321+
// Execute via runtime
322+
let result_batch = RT
323+
.block_on(Some(py), inner_query.execute_datafusion(arrow_datasets))?
324+
.map_err(graph_error_to_pyerr)?;
325+
326+
record_batch_to_python_table(py, &result_batch)
327+
}
328+
329+
/// Explain query uusing the DataFusion planner with in-memory datasets
330+
///
331+
/// Parameters
332+
/// ----------
333+
/// datasets : dict
334+
/// Dictionary mapping table names to in-memory tables (pyarrow.Table, LanceDataset, etc.)
335+
/// Keys should match node labels and relationship types in the graph config.
336+
///
337+
/// Returns
338+
/// -------
339+
/// str
340+
/// Query graph logical plan, DataFusion logical plan, DataFusion physical plan as string
341+
///
342+
/// Raises
343+
/// ------
344+
/// ValueError
345+
/// If the query is invalid or datasets are missing
346+
/// RuntimeError
347+
/// If query explain fails
348+
fn explain_datafusion(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
349+
// Convert datasets to Arrow RecordBatch map
350+
let arrow_datasets = python_datasets_to_batches(datasets)?;
351+
352+
// Clone for async move
353+
let inner_query = self.inner.clone();
354+
355+
// Execute via runtime
356+
let plan = RT
357+
.block_on(Some(py), inner_query.explain_datafusion(arrow_datasets))?
358+
.map_err(graph_error_to_pyerr)?;
359+
360+
Ok(plan)
361+
}
362+
295363
/// Get variables used in the query
296364
fn variables(&self) -> Vec<String> {
297365
self.inner.variables()

0 commit comments

Comments
 (0)