Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/python/tests/test_explain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tests for explain_datafusion API."""
"""Tests for explain API."""

import pyarrow as pa
import pytest
Expand All @@ -25,7 +25,7 @@ def test_explain_simple_query(person_data):
"""Test explain output contains all expected sections."""
config, people = person_data
query = CypherQuery("MATCH (p:Person) RETURN p.name, p.age").with_config(config)
plan = query.explain_datafusion({"Person": people})
plan = query.explain({"Person": people})

# Verify the plan is a non-empty string
assert isinstance(plan, str)
Expand All @@ -48,7 +48,7 @@ def test_explain_with_clauses(person_data):
query = CypherQuery(
"MATCH (p:Person) WHERE p.age > 30 RETURN p.name ORDER BY p.age LIMIT 2"
).with_config(config)
plan = query.explain_datafusion({"Person": people})
plan = query.explain({"Person": people})

assert isinstance(plan, str)
assert "WHERE p.age > 30" in plan
Expand All @@ -63,11 +63,11 @@ def test_explain_error_handling(person_data):
# Missing config
query_no_config = CypherQuery("MATCH (p:Person) RETURN p.name")
with pytest.raises(ValueError, match="Graph configuration is required"):
query_no_config.explain_datafusion({"Person": people})
query_no_config.explain({"Person": people})

# Missing datasets
query_with_config = CypherQuery("MATCH (p:Person) RETURN p.name").with_config(
config
)
with pytest.raises(ValueError, match="No input datasets provided"):
query_with_config.explain_datafusion({})
query_with_config.explain({})
41 changes: 19 additions & 22 deletions python/python/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,41 +62,38 @@ def graph_env(tmp_path):
return config, datasets, people_table


@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
def test_basic_node_selection(graph_env, execute_method):
def test_basic_node_selection(graph_env):
config, datasets, _ = graph_env
query = CypherQuery("MATCH (p:Person) RETURN p.name, p.age").with_config(config)
result = getattr(query, execute_method)({"Person": datasets["Person"]})
result = query.execute({"Person": datasets["Person"]})
data = result.to_pydict()

assert set(data.keys()) == {"p.name", "p.age"}
assert len(data["p.name"]) == 4
assert "Alice" in set(data["p.name"])


@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
def test_filtered_query(graph_env, execute_method):
def test_filtered_query(graph_env):
config, datasets, _ = graph_env
query = CypherQuery(
"MATCH (p:Person) WHERE p.age > 30 RETURN p.name, p.age"
).with_config(config)
result = getattr(query, execute_method)({"Person": datasets["Person"]})
result = query.execute({"Person": datasets["Person"]})
data = result.to_pydict()

assert len(data["p.name"]) == 2
assert set(data["p.name"]) == {"Bob", "David"}
assert all(age > 30 for age in data["p.age"])


@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
def test_relationship_query(graph_env, execute_method):
def test_relationship_query(graph_env):
config, datasets, _ = graph_env
query = CypherQuery(
"MATCH (p:Person)-[:WORKS_FOR]->(c:Company) "
"RETURN p.person_id AS person_id, p.name AS name, c.company_id AS company_id"
).with_config(config)

result = getattr(query, execute_method)(
result = query.execute(
{
"Person": datasets["Person"],
"Company": datasets["Company"],
Expand All @@ -109,8 +106,7 @@ def test_relationship_query(graph_env, execute_method):
assert data["company_id"] == [101, 101, 102, 103]


@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
def test_friendship_direct_and_network(graph_env, execute_method):
def test_friendship_direct_and_network(graph_env):
config, datasets, _ = graph_env
# Direct friends of Alice (person_id = 1)
query_direct = CypherQuery(
Expand All @@ -119,7 +115,7 @@ def test_friendship_direct_and_network(graph_env, execute_method):
"RETURN b.person_id AS friend_id"
).with_config(config)

result_direct = getattr(query_direct, execute_method)(
result_direct = query_direct.execute(
{
"Person": datasets["Person"],
"FRIEND_OF": datasets["FRIEND_OF"],
Expand All @@ -134,7 +130,7 @@ def test_friendship_direct_and_network(graph_env, execute_method):
"RETURN f.person_id AS person1_id, t.person_id AS person2_id"
).with_config(config)

result_edges = getattr(query_edges, execute_method)(
result_edges = query_edges.execute(
{
"Person": datasets["Person"],
"FRIEND_OF": datasets["FRIEND_OF"],
Expand All @@ -145,16 +141,15 @@ def test_friendship_direct_and_network(graph_env, execute_method):
assert got == {(1, 2), (1, 3), (2, 4), (3, 4)}


@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
def test_two_hop_friends_of_friends(graph_env, execute_method):
def test_two_hop_friends_of_friends(graph_env):
config, datasets, _ = graph_env
query = CypherQuery(
"MATCH (a:Person)-[:FRIEND_OF]->(b:Person)-[:FRIEND_OF]->(c:Person) "
"WHERE a.person_id = 1 "
"RETURN a.person_id AS a_id, b.person_id AS b_id, c.person_id AS c_id"
).with_config(config)

result = getattr(query, execute_method)(
result = query.execute(
{
"Person": datasets["Person"],
"FRIEND_OF": datasets["FRIEND_OF"],
Expand All @@ -164,29 +159,31 @@ def test_two_hop_friends_of_friends(graph_env, execute_method):
assert set(data["c_id"]) == {4}


@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
def test_variable_length_path(graph_env, execute_method):
def test_variable_length_path(graph_env):
config, datasets, _ = graph_env
query = CypherQuery(
"MATCH (p1:Person)-[:FRIEND_OF*1..2]-(p2:Person) "
"RETURN p1.person_id AS p1, p2.person_id AS p2"
).with_config(config)
_ = getattr(query, execute_method)(

result = query.execute(
{
"Person": datasets["Person"],
"FRIEND_OF": datasets["FRIEND_OF"],
}
)
data = result.to_pydict()
got = set(zip(data["p1"], data["p2"]))
assert got == {(1, 2), (1, 3), (2, 4), (3, 4), (1, 4)}


@pytest.mark.parametrize("execute_method", ["execute", "execute_datafusion"])
def test_distinct_clause(graph_env, execute_method):
def test_distinct_clause(graph_env):
config, datasets, _ = graph_env
query = CypherQuery(
"MATCH (p:Person)-[:WORKS_FOR]->(c:Company) RETURN DISTINCT c.company_name"
).with_config(config)

result = getattr(query, execute_method)(
result = query.execute(
{
"Person": datasets["Person"],
"Company": datasets["Company"],
Expand Down
92 changes: 51 additions & 41 deletions python/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow_array::{RecordBatch, RecordBatchReader};
use arrow_schema::Schema;
use lance_graph::{
CypherQuery as RustCypherQuery, GraphConfig as RustGraphConfig, GraphError as RustGraphError,
ExecutionStrategy as RustExecutionStrategy, CypherQuery as RustCypherQuery,
GraphConfig as RustGraphConfig, GraphError as RustGraphError,
};
use pyo3::{
exceptions::{PyNotImplementedError, PyRuntimeError, PyValueError},
Expand All @@ -34,6 +35,28 @@ use serde_json::Value as JsonValue;

use crate::RT;

/// Execution strategy for Cypher queries
#[pyclass(name = "ExecutionStrategy", module = "lance.graph")]
#[derive(Clone, Copy)]
pub enum ExecutionStrategy {
/// Use DataFusion query planner (default, full feature support)
DataFusion,
/// Use simple single-table executor (legacy, limited features)
Simple,
/// Use Lance native executor (not yet implemented)
LanceNative,
}

impl From<ExecutionStrategy> for RustExecutionStrategy {
fn from(strategy: ExecutionStrategy) -> Self {
match strategy {
ExecutionStrategy::DataFusion => RustExecutionStrategy::DataFusion,
ExecutionStrategy::Simple => RustExecutionStrategy::Simple,
ExecutionStrategy::LanceNative => RustExecutionStrategy::LanceNative,
}
}
}

/// Convert GraphError to PyErr
fn graph_error_to_pyerr(err: RustGraphError) -> PyErr {
match &err {
Expand Down Expand Up @@ -267,6 +290,8 @@ impl CypherQuery {
/// ----------
/// datasets : dict
/// Dictionary mapping table names to Lance datasets
/// strategy : ExecutionStrategy, optional
/// Execution strategy to use (defaults to DataFusion)
///
/// Returns
/// -------
Expand All @@ -277,56 +302,40 @@ impl CypherQuery {
/// ------
/// RuntimeError
/// If query execution fails
fn execute(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<PyObject> {
// Convert datasets to Arrow batches while holding the GIL - same as before
let arrow_datasets = python_datasets_to_batches(datasets)?;

// Clone the inner query for use in the async block
let inner_query = self.inner.clone();

// Use RT.block_on with Some(py) like the scanner to_pyarrow method
let result_batch = RT
.block_on(Some(py), inner_query.execute(arrow_datasets))?
.map_err(graph_error_to_pyerr)?;

record_batch_to_python_table(py, &result_batch)
}

/// Execute query using the DataFusion planner with in-memory datasets
///
/// Parameters
/// ----------
/// datasets : dict
/// Dictionary mapping table names to in-memory tables (pyarrow.Table, LanceDataset, etc.)
/// Keys should match node labels and relationship types in the graph config.
///
/// Returns
/// -------
/// pyarrow.Table
/// Query results as Arrow table
/// Examples
/// --------
/// >>> # Default strategy (DataFusion)
/// >>> result = query.execute(datasets)
///
/// Raises
/// ------
/// ValueError
/// If the query is invalid or datasets are missing
/// RuntimeError
/// If query execution fails
fn execute_datafusion(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<PyObject> {
// Convert datasets to Arrow RecordBatch map
/// >>> # Explicit strategy
/// >>> from lance.graph import ExecutionStrategy
/// >>> result = query.execute(datasets, strategy=ExecutionStrategy.Simple)
#[pyo3(signature = (datasets, strategy=None))]
fn execute(
&self,
py: Python,
datasets: &Bound<'_, PyDict>,
strategy: Option<ExecutionStrategy>,
) -> PyResult<PyObject> {
// Convert datasets to Arrow batches while holding the GIL
let arrow_datasets = python_datasets_to_batches(datasets)?;

// Clone for async move
// Convert Python strategy to Rust strategy
let rust_strategy = strategy.map(|s| s.into());

// Clone the inner query for use in the async block
let inner_query = self.inner.clone();

// Execute via runtime
// Use RT.block_on with Some(py) like the scanner to_pyarrow method
let result_batch = RT
.block_on(Some(py), inner_query.execute_datafusion(arrow_datasets))?
.block_on(Some(py), inner_query.execute(arrow_datasets, rust_strategy))?
.map_err(graph_error_to_pyerr)?;

record_batch_to_python_table(py, &result_batch)
}

/// Explain query uusing the DataFusion planner with in-memory datasets
/// Explain query using the DataFusion planner with in-memory datasets
///
/// Parameters
/// ----------
Expand All @@ -345,7 +354,7 @@ impl CypherQuery {
/// If the query is invalid or datasets are missing
/// RuntimeError
/// If query explain fails
fn explain_datafusion(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
fn explain(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
// Convert datasets to Arrow RecordBatch map
let arrow_datasets = python_datasets_to_batches(datasets)?;

Expand All @@ -354,7 +363,7 @@ impl CypherQuery {

// Execute via runtime
let plan = RT
.block_on(Some(py), inner_query.explain_datafusion(arrow_datasets))?
.block_on(Some(py), inner_query.explain(arrow_datasets))?
.map_err(graph_error_to_pyerr)?;

Ok(plan)
Expand Down Expand Up @@ -562,6 +571,7 @@ fn record_batch_to_python_table(
pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
let graph_module = PyModule::new(py, "graph")?;

graph_module.add_class::<ExecutionStrategy>()?;
graph_module.add_class::<GraphConfig>()?;
graph_module.add_class::<GraphConfigBuilder>()?;
graph_module.add_class::<CypherQuery>()?;
Expand Down
8 changes: 6 additions & 2 deletions rust/lance-graph/benches/graph_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use futures::TryStreamExt;
use lance::dataset::{Dataset, WriteMode, WriteParams};
use lance_graph::{CypherQuery, GraphConfig};
use lance_graph::{CypherQuery, ExecutionStrategy, GraphConfig};
use tempfile::TempDir;

fn create_people_batch() -> RecordBatch {
Expand Down Expand Up @@ -71,7 +71,11 @@ fn execute_cypher_query(
q: &CypherQuery,
datasets: HashMap<String, RecordBatch>,
) -> RecordBatch {
rt.block_on(async move { q.execute(datasets).await.unwrap() })
rt.block_on(async move {
q.execute(datasets, Some(ExecutionStrategy::Simple))
.await
.unwrap()
})
}

fn make_people_batch(n: usize) -> RecordBatch {
Expand Down
2 changes: 1 addition & 1 deletion rust/lance-graph/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ pub const MAX_VARIABLE_LENGTH_HOPS: u32 = 20;

pub use config::{GraphConfig, NodeMapping, RelationshipMapping};
pub use error::{GraphError, Result};
pub use query::CypherQuery;
pub use query::{CypherQuery, ExecutionStrategy};
Loading
Loading