Skip to content

Commit 6f74e51

Browse files
authored
feat: support MIN and MAX in the datafusion planner (#66)
1 parent 1bbc849 commit 6f74e51

2 files changed

Lines changed: 244 additions & 0 deletions

File tree

rust/lance-graph/src/datafusion_planner/expression.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use crate::ast::{BooleanExpression, PropertyValue, ValueExpression};
99
use datafusion::logical_expr::{col, lit, BinaryExpr, Expr, Operator};
1010
use datafusion_functions_aggregate::average::avg;
1111
use datafusion_functions_aggregate::count::count;
12+
use datafusion_functions_aggregate::min_max::max;
13+
use datafusion_functions_aggregate::min_max::min;
1214
use datafusion_functions_aggregate::sum::sum;
1315

1416
/// Convert BooleanExpression to DataFusion Expr
@@ -131,6 +133,22 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr {
131133
lit(0)
132134
}
133135
}
136+
"min" => {
137+
if args.len() == 1 {
138+
let arg_expr = to_df_value_expr(&args[0]);
139+
min(arg_expr)
140+
} else {
141+
lit(0)
142+
}
143+
}
144+
"max" => {
145+
if args.len() == 1 {
146+
let arg_expr = to_df_value_expr(&args[0]);
147+
max(arg_expr)
148+
} else {
149+
lit(0)
150+
}
151+
}
134152
_ => {
135153
// Unsupported function - return placeholder for now
136154
lit(0)
@@ -555,6 +573,44 @@ mod tests {
555573
assert!(s.contains("p__amount"), "Should contain column reference");
556574
}
557575

576+
#[test]
577+
fn test_value_expr_function_min() {
578+
let expr = ValueExpression::Function {
579+
name: "MIN".into(),
580+
args: vec![ValueExpression::Property(PropertyRef {
581+
variable: "p".into(),
582+
property: "amount".into(),
583+
})],
584+
};
585+
586+
let df_expr = to_df_value_expr(&expr);
587+
let s = format!("{:?}", df_expr);
588+
assert!(
589+
s.contains("min") || s.contains("Min"),
590+
"Should be MIN function"
591+
);
592+
assert!(s.contains("p__amount"), "Should contain column reference");
593+
}
594+
595+
#[test]
596+
fn test_value_expr_function_max() {
597+
let expr = ValueExpression::Function {
598+
name: "MAX".into(),
599+
args: vec![ValueExpression::Property(PropertyRef {
600+
variable: "p".into(),
601+
property: "amount".into(),
602+
})],
603+
};
604+
605+
let df_expr = to_df_value_expr(&expr);
606+
let s = format!("{:?}", df_expr);
607+
assert!(
608+
s.contains("max") || s.contains("Max"),
609+
"Should be MAX function"
610+
);
611+
assert!(s.contains("p__amount"), "Should contain column reference");
612+
}
613+
558614
// ========================================================================
559615
// Unit tests for contains_aggregate()
560616
// ========================================================================
@@ -588,6 +644,38 @@ mod tests {
588644
);
589645
}
590646

647+
#[test]
648+
fn test_contains_aggregate_min() {
649+
let expr = ValueExpression::Function {
650+
name: "MIN".into(),
651+
args: vec![ValueExpression::Property(PropertyRef {
652+
variable: "p".into(),
653+
property: "value".into(),
654+
})],
655+
};
656+
657+
assert!(
658+
contains_aggregate(&expr),
659+
"MIN should be detected as aggregate"
660+
);
661+
}
662+
663+
#[test]
664+
fn test_contains_aggregate_max() {
665+
let expr = ValueExpression::Function {
666+
name: "MAX".into(),
667+
args: vec![ValueExpression::Property(PropertyRef {
668+
variable: "p".into(),
669+
property: "value".into(),
670+
})],
671+
};
672+
673+
assert!(
674+
contains_aggregate(&expr),
675+
"MAX should be detected as aggregate"
676+
);
677+
}
678+
591679
#[test]
592680
fn test_contains_aggregate_property() {
593681
let expr = ValueExpression::Property(PropertyRef {

rust/lance-graph/tests/test_datafusion_pipeline.rs

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3123,6 +3123,162 @@ async fn test_avg_without_alias_has_descriptive_name() {
31233123
);
31243124
}
31253125

3126+
#[tokio::test]
3127+
async fn test_min_property() {
3128+
let person_batch = create_person_dataset();
3129+
let config = GraphConfig::builder()
3130+
.with_node_label("Person", "id")
3131+
.build()
3132+
.unwrap();
3133+
3134+
let query = CypherQuery::new("MATCH (p:Person) RETURN min(p.age) AS min_age")
3135+
.unwrap()
3136+
.with_config(config);
3137+
3138+
let mut datasets = HashMap::new();
3139+
datasets.insert("Person".to_string(), person_batch);
3140+
3141+
let result = query
3142+
.execute(datasets, Some(ExecutionStrategy::DataFusion))
3143+
.await
3144+
.unwrap();
3145+
3146+
assert_eq!(result.num_rows(), 1);
3147+
3148+
let min_col = result
3149+
.column_by_name("min_age")
3150+
.unwrap()
3151+
.as_any()
3152+
.downcast_ref::<Int64Array>()
3153+
.unwrap();
3154+
3155+
// Ages: 25, 35, 30, 40, 28 => min = 25
3156+
assert_eq!(min_col.value(0), 25);
3157+
}
3158+
3159+
#[tokio::test]
3160+
async fn test_max_property() {
3161+
let person_batch = create_person_dataset();
3162+
let config = GraphConfig::builder()
3163+
.with_node_label("Person", "id")
3164+
.build()
3165+
.unwrap();
3166+
3167+
let query = CypherQuery::new("MATCH (p:Person) RETURN max(p.age) AS max_age")
3168+
.unwrap()
3169+
.with_config(config);
3170+
3171+
let mut datasets = HashMap::new();
3172+
datasets.insert("Person".to_string(), person_batch);
3173+
3174+
let result = query
3175+
.execute(datasets, Some(ExecutionStrategy::DataFusion))
3176+
.await
3177+
.unwrap();
3178+
3179+
assert_eq!(result.num_rows(), 1);
3180+
3181+
let max_col = result
3182+
.column_by_name("max_age")
3183+
.unwrap()
3184+
.as_any()
3185+
.downcast_ref::<Int64Array>()
3186+
.unwrap();
3187+
3188+
// Ages: 25, 35, 30, 40, 28 => max = 40
3189+
assert_eq!(max_col.value(0), 40);
3190+
}
3191+
3192+
#[tokio::test]
3193+
async fn test_min_max_with_grouping() {
3194+
let person_batch = create_person_dataset();
3195+
let config = GraphConfig::builder()
3196+
.with_node_label("Person", "id")
3197+
.build()
3198+
.unwrap();
3199+
3200+
// One person per city in this dataset (including NULL), so min(age) == that person's age
3201+
let query_min =
3202+
CypherQuery::new("MATCH (p:Person) RETURN p.city, min(p.age) AS min_age ORDER BY p.city")
3203+
.unwrap()
3204+
.with_config(config.clone());
3205+
3206+
let query_max =
3207+
CypherQuery::new("MATCH (p:Person) RETURN p.city, max(p.age) AS max_age ORDER BY p.city")
3208+
.unwrap()
3209+
.with_config(config);
3210+
3211+
let mut datasets = HashMap::new();
3212+
datasets.insert("Person".to_string(), person_batch);
3213+
3214+
let result_min = query_min
3215+
.execute(datasets.clone(), Some(ExecutionStrategy::DataFusion))
3216+
.await
3217+
.unwrap();
3218+
3219+
let result_max = query_max
3220+
.execute(datasets, Some(ExecutionStrategy::DataFusion))
3221+
.await
3222+
.unwrap();
3223+
3224+
assert_eq!(result_min.num_rows(), 5);
3225+
assert_eq!(result_max.num_rows(), 5);
3226+
3227+
let city_col_min = result_min
3228+
.column_by_name("p.city")
3229+
.unwrap()
3230+
.as_any()
3231+
.downcast_ref::<StringArray>()
3232+
.unwrap();
3233+
3234+
let min_col_min = result_min
3235+
.column_by_name("min_age")
3236+
.unwrap()
3237+
.as_any()
3238+
.downcast_ref::<Int64Array>()
3239+
.unwrap();
3240+
3241+
let city_col_max = result_max
3242+
.column_by_name("p.city")
3243+
.unwrap()
3244+
.as_any()
3245+
.downcast_ref::<StringArray>()
3246+
.unwrap();
3247+
3248+
let min_col_max = result_max
3249+
.column_by_name("max_age")
3250+
.unwrap()
3251+
.as_any()
3252+
.downcast_ref::<Int64Array>()
3253+
.unwrap();
3254+
3255+
// ORDER BY p.city, NULL comes first per your other tests
3256+
assert!(city_col_min.is_null(0)); // David city NULL
3257+
assert!(city_col_max.is_null(0));
3258+
assert_eq!(min_col_min.value(0), 40);
3259+
assert_eq!(min_col_max.value(0), 40);
3260+
3261+
assert_eq!(city_col_min.value(1), "Chicago"); // Charlie
3262+
assert_eq!(city_col_max.value(1), "Chicago");
3263+
assert_eq!(min_col_min.value(1), 30);
3264+
assert_eq!(min_col_max.value(1), 30);
3265+
3266+
assert_eq!(city_col_min.value(2), "New York"); // Alice
3267+
assert_eq!(city_col_max.value(2), "New York");
3268+
assert_eq!(min_col_min.value(2), 25);
3269+
assert_eq!(min_col_max.value(2), 25);
3270+
3271+
assert_eq!(city_col_min.value(3), "San Francisco"); // Bob
3272+
assert_eq!(city_col_max.value(3), "San Francisco");
3273+
assert_eq!(min_col_min.value(3), 35);
3274+
assert_eq!(min_col_max.value(3), 35);
3275+
3276+
assert_eq!(city_col_min.value(4), "Seattle"); // Eve
3277+
assert_eq!(city_col_max.value(4), "Seattle");
3278+
assert_eq!(min_col_min.value(4), 28);
3279+
assert_eq!(min_col_max.value(4), 28);
3280+
}
3281+
31263282
// ============================================================================
31273283
// Disconnected Pattern (Join) Tests
31283284
// ============================================================================

0 commit comments

Comments
 (0)