Skip to content

Commit d843b4b

Browse files
authored
feat: support the SUM function in the datafusion planner (#36)
* feat: support the SUM function in the datafusion planner * doc: update test comments
1 parent d2056b6 commit d843b4b

2 files changed

Lines changed: 147 additions & 0 deletions

File tree

rust/lance-graph/src/datafusion_planner.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use datafusion::logical_expr::{
2222
col, lit, BinaryExpr, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator,
2323
};
2424
use datafusion_functions_aggregate::count::count;
25+
use datafusion_functions_aggregate::sum::sum;
2526
use std::collections::{HashMap, HashSet};
2627
use std::sync::Arc;
2728

@@ -1541,6 +1542,16 @@ impl DataFusionPlanner {
15411542
lit(0)
15421543
}
15431544
}
1545+
"sum" => {
1546+
if args.len() == 1 {
1547+
let arg_expr = Self::to_df_value_expr(&args[0]);
1548+
// Use DataFusion's sum helper function
1549+
sum(arg_expr)
1550+
} else {
1551+
// Invalid argument count - return placeholder
1552+
lit(0)
1553+
}
1554+
}
15441555
_ => {
15451556
// Unsupported function - return placeholder for now
15461557
lit(0)

rust/lance-graph/tests/test_datafusion_pipeline.rs

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,3 +2610,139 @@ async fn test_count_property_without_alias_has_descriptive_name() {
26102610
result.schema()
26112611
);
26122612
}
2613+
2614+
#[tokio::test]
2615+
async fn test_sum_property() {
2616+
let person_batch = create_person_dataset();
2617+
let config = GraphConfig::builder()
2618+
.with_node_label("Person", "id")
2619+
.build()
2620+
.unwrap();
2621+
2622+
let query = CypherQuery::new("MATCH (p:Person) RETURN sum(p.age) AS total_age")
2623+
.unwrap()
2624+
.with_config(config);
2625+
2626+
let mut datasets = HashMap::new();
2627+
datasets.insert("Person".to_string(), person_batch);
2628+
2629+
let result = query.execute_datafusion(datasets).await.unwrap();
2630+
2631+
assert_eq!(result.num_rows(), 1);
2632+
let sum_col = result
2633+
.column_by_name("total_age")
2634+
.unwrap()
2635+
.as_any()
2636+
.downcast_ref::<Int64Array>()
2637+
.unwrap();
2638+
// Sum of ages: 25 + 35 + 30 + 40 + 28 = 158
2639+
assert_eq!(sum_col.value(0), 158);
2640+
}
2641+
2642+
#[tokio::test]
2643+
async fn test_sum_with_filter() {
2644+
let person_batch = create_person_dataset();
2645+
let config = GraphConfig::builder()
2646+
.with_node_label("Person", "id")
2647+
.build()
2648+
.unwrap();
2649+
2650+
let query =
2651+
CypherQuery::new("MATCH (p:Person) WHERE p.age >= 30 RETURN sum(p.age) AS total_age")
2652+
.unwrap()
2653+
.with_config(config);
2654+
2655+
let mut datasets = HashMap::new();
2656+
datasets.insert("Person".to_string(), person_batch);
2657+
2658+
let result = query.execute_datafusion(datasets).await.unwrap();
2659+
2660+
assert_eq!(result.num_rows(), 1);
2661+
let sum_col = result
2662+
.column_by_name("total_age")
2663+
.unwrap()
2664+
.as_any()
2665+
.downcast_ref::<Int64Array>()
2666+
.unwrap();
2667+
// Sum of ages >= 30: 35 + 30 + 40 = 105
2668+
assert_eq!(sum_col.value(0), 105);
2669+
}
2670+
2671+
#[tokio::test]
2672+
async fn test_sum_with_grouping() {
2673+
let person_batch = create_person_dataset();
2674+
let config = GraphConfig::builder()
2675+
.with_node_label("Person", "id")
2676+
.build()
2677+
.unwrap();
2678+
2679+
let query =
2680+
CypherQuery::new("MATCH (p:Person) RETURN p.city, sum(p.age) AS total_age ORDER BY p.city")
2681+
.unwrap()
2682+
.with_config(config);
2683+
2684+
let mut datasets = HashMap::new();
2685+
datasets.insert("Person".to_string(), person_batch);
2686+
2687+
let result = query.execute_datafusion(datasets).await.unwrap();
2688+
2689+
// Should have 5 groups: NULL, Chicago, New York, San Francisco, Seattle
2690+
assert_eq!(result.num_rows(), 5);
2691+
2692+
let city_col = result
2693+
.column_by_name("p.city")
2694+
.unwrap()
2695+
.as_any()
2696+
.downcast_ref::<StringArray>()
2697+
.unwrap();
2698+
2699+
let sum_col = result
2700+
.column_by_name("total_age")
2701+
.unwrap()
2702+
.as_any()
2703+
.downcast_ref::<Int64Array>()
2704+
.unwrap();
2705+
2706+
// Verify grouping results (ordered by city, NULL comes first)
2707+
assert!(city_col.is_null(0)); // David: 40 (NULL city)
2708+
assert_eq!(sum_col.value(0), 40);
2709+
2710+
assert_eq!(city_col.value(1), "Chicago"); // Charlie: 30
2711+
assert_eq!(sum_col.value(1), 30);
2712+
2713+
assert_eq!(city_col.value(2), "New York"); // Alice: 25
2714+
assert_eq!(sum_col.value(2), 25);
2715+
2716+
assert_eq!(city_col.value(3), "San Francisco"); // Bob: 35
2717+
assert_eq!(sum_col.value(3), 35);
2718+
2719+
assert_eq!(city_col.value(4), "Seattle"); // Eve: 28
2720+
assert_eq!(sum_col.value(4), 28);
2721+
}
2722+
2723+
#[tokio::test]
2724+
async fn test_sum_without_alias_has_descriptive_name() {
2725+
let person_batch = create_person_dataset();
2726+
let config = GraphConfig::builder()
2727+
.with_node_label("Person", "id")
2728+
.build()
2729+
.unwrap();
2730+
2731+
let query = CypherQuery::new("MATCH (p:Person) RETURN sum(p.age)")
2732+
.unwrap()
2733+
.with_config(config);
2734+
2735+
let mut datasets = HashMap::new();
2736+
datasets.insert("Person".to_string(), person_batch);
2737+
2738+
let result = query.execute_datafusion(datasets).await.unwrap();
2739+
2740+
assert_eq!(result.num_rows(), 1);
2741+
// Should have column named "sum(p.age)" not "expr"
2742+
let sum_col = result.column_by_name("sum(p.age)");
2743+
assert!(
2744+
sum_col.is_some(),
2745+
"Expected column named 'sum(p.age)' but schema is: {:?}",
2746+
result.schema()
2747+
);
2748+
}

0 commit comments

Comments
 (0)