diff --git a/rust/lance-graph/src/datafusion_planner/expression.rs b/rust/lance-graph/src/datafusion_planner/expression.rs index 0468720e..cfe4745b 100644 --- a/rust/lance-graph/src/datafusion_planner/expression.rs +++ b/rust/lance-graph/src/datafusion_planner/expression.rs @@ -9,6 +9,8 @@ use crate::ast::{BooleanExpression, PropertyValue, ValueExpression}; use datafusion::logical_expr::{col, lit, BinaryExpr, Expr, Operator}; use datafusion_functions_aggregate::average::avg; use datafusion_functions_aggregate::count::count; +use datafusion_functions_aggregate::min_max::max; +use datafusion_functions_aggregate::min_max::min; use datafusion_functions_aggregate::sum::sum; /// Convert BooleanExpression to DataFusion Expr @@ -131,6 +133,22 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { lit(0) } } + "min" => { + if args.len() == 1 { + let arg_expr = to_df_value_expr(&args[0]); + min(arg_expr) + } else { + lit(0) + } + } + "max" => { + if args.len() == 1 { + let arg_expr = to_df_value_expr(&args[0]); + max(arg_expr) + } else { + lit(0) + } + } _ => { // Unsupported function - return placeholder for now lit(0) @@ -555,6 +573,44 @@ mod tests { assert!(s.contains("p__amount"), "Should contain column reference"); } + #[test] + fn test_value_expr_function_min() { + let expr = ValueExpression::Function { + name: "MIN".into(), + args: vec![ValueExpression::Property(PropertyRef { + variable: "p".into(), + property: "amount".into(), + })], + }; + + let df_expr = to_df_value_expr(&expr); + let s = format!("{:?}", df_expr); + assert!( + s.contains("min") || s.contains("Min"), + "Should be MIN function" + ); + assert!(s.contains("p__amount"), "Should contain column reference"); + } + + #[test] + fn test_value_expr_function_max() { + let expr = ValueExpression::Function { + name: "MAX".into(), + args: vec![ValueExpression::Property(PropertyRef { + variable: "p".into(), + property: "amount".into(), + })], + }; + + let df_expr = to_df_value_expr(&expr); + let s = format!("{:?}", df_expr); + assert!( + s.contains("max") || s.contains("Max"), + "Should be MAX function" + ); + assert!(s.contains("p__amount"), "Should contain column reference"); + } + // ======================================================================== // Unit tests for contains_aggregate() // ======================================================================== @@ -588,6 +644,38 @@ mod tests { ); } + #[test] + fn test_contains_aggregate_min() { + let expr = ValueExpression::Function { + name: "MIN".into(), + args: vec![ValueExpression::Property(PropertyRef { + variable: "p".into(), + property: "value".into(), + })], + }; + + assert!( + contains_aggregate(&expr), + "MIN should be detected as aggregate" + ); + } + + #[test] + fn test_contains_aggregate_max() { + let expr = ValueExpression::Function { + name: "MAX".into(), + args: vec![ValueExpression::Property(PropertyRef { + variable: "p".into(), + property: "value".into(), + })], + }; + + assert!( + contains_aggregate(&expr), + "MAX should be detected as aggregate" + ); + } + #[test] fn test_contains_aggregate_property() { let expr = ValueExpression::Property(PropertyRef { diff --git a/rust/lance-graph/tests/test_datafusion_pipeline.rs b/rust/lance-graph/tests/test_datafusion_pipeline.rs index 1f6669f8..1d244ea8 100644 --- a/rust/lance-graph/tests/test_datafusion_pipeline.rs +++ b/rust/lance-graph/tests/test_datafusion_pipeline.rs @@ -3123,6 +3123,162 @@ async fn test_avg_without_alias_has_descriptive_name() { ); } +#[tokio::test] +async fn test_min_property() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN min(p.age) AS min_age") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + assert_eq!(result.num_rows(), 1); + + let min_col = result + .column_by_name("min_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // Ages: 25, 35, 30, 40, 28 => min = 25 + assert_eq!(min_col.value(0), 25); +} + +#[tokio::test] +async fn test_max_property() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN max(p.age) AS max_age") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + assert_eq!(result.num_rows(), 1); + + let max_col = result + .column_by_name("max_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // Ages: 25, 35, 30, 40, 28 => max = 40 + assert_eq!(max_col.value(0), 40); +} + +#[tokio::test] +async fn test_min_max_with_grouping() { + let person_batch = create_person_dataset(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + + // One person per city in this dataset (including NULL), so min(age) == that person's age + let query_min = + CypherQuery::new("MATCH (p:Person) RETURN p.city, min(p.age) AS min_age ORDER BY p.city") + .unwrap() + .with_config(config.clone()); + + let query_max = + CypherQuery::new("MATCH (p:Person) RETURN p.city, max(p.age) AS max_age ORDER BY p.city") + .unwrap() + .with_config(config); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result_min = query_min + .execute(datasets.clone(), Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + let result_max = query_max + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + assert_eq!(result_min.num_rows(), 5); + assert_eq!(result_max.num_rows(), 5); + + let city_col_min = result_min + .column_by_name("p.city") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let min_col_min = result_min + .column_by_name("min_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let city_col_max = result_max + .column_by_name("p.city") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + let min_col_max = result_max + .column_by_name("max_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + // ORDER BY p.city, NULL comes first per your other tests + assert!(city_col_min.is_null(0)); // David city NULL + assert!(city_col_max.is_null(0)); + assert_eq!(min_col_min.value(0), 40); + assert_eq!(min_col_max.value(0), 40); + + assert_eq!(city_col_min.value(1), "Chicago"); // Charlie + assert_eq!(city_col_max.value(1), "Chicago"); + assert_eq!(min_col_min.value(1), 30); + assert_eq!(min_col_max.value(1), 30); + + assert_eq!(city_col_min.value(2), "New York"); // Alice + assert_eq!(city_col_max.value(2), "New York"); + assert_eq!(min_col_min.value(2), 25); + assert_eq!(min_col_max.value(2), 25); + + assert_eq!(city_col_min.value(3), "San Francisco"); // Bob + assert_eq!(city_col_max.value(3), "San Francisco"); + assert_eq!(min_col_min.value(3), 35); + assert_eq!(min_col_max.value(3), 35); + + assert_eq!(city_col_min.value(4), "Seattle"); // Eve + assert_eq!(city_col_max.value(4), "Seattle"); + assert_eq!(min_col_min.value(4), 28); + assert_eq!(min_col_max.value(4), 28); +} + // ============================================================================ // Disconnected Pattern (Join) Tests // ============================================================================