Skip to content

Commit 1923f42

Browse files
dcfocusAllen Cheng
andauthored
feat(graph): add COLLECT aggregation function (lance-format#85)
Adds COLLECT() aggregation that collects values into an array, translating to DataFusion's array_agg function. - Add collect case in to_df_value_expr for DataFusion translation - Update contains_aggregate to recognize collect - Update semantic validation to allow COLLECT with bare variables - Add tests for COLLECT with and without grouping Co-authored-by: Allen Cheng <dacheng@Allens-MacBook-Pro.local>
1 parent 357c8aa commit 1923f42

3 files changed

Lines changed: 80 additions & 3 deletions

File tree

rust/lance-graph/src/datafusion_planner/expression.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::datafusion_planner::udf;
1010
use datafusion::functions::string::lower;
1111
use datafusion::functions::string::upper;
1212
use datafusion::logical_expr::{col, lit, BinaryExpr, Expr, Operator};
13+
use datafusion_functions_aggregate::array_agg::array_agg;
1314
use datafusion_functions_aggregate::average::avg;
1415
use datafusion_functions_aggregate::count::count;
1516
use datafusion_functions_aggregate::min_max::max;
@@ -188,6 +189,15 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr {
188189
lit(0)
189190
}
190191
}
192+
// COLLECT aggregation - collects values into an array
193+
"collect" => {
194+
if args.len() == 1 {
195+
let arg_expr = to_df_value_expr(&args[0]);
196+
array_agg(arg_expr)
197+
} else {
198+
lit(0)
199+
}
200+
}
191201
// String functions
192202
"tolower" | "lower" => {
193203
if args.len() == 1 {
@@ -307,7 +317,7 @@ pub(crate) fn contains_aggregate(expr: &ValueExpression) -> bool {
307317
// Check if this is an aggregate function
308318
let is_aggregate = matches!(
309319
name.to_lowercase().as_str(),
310-
"count" | "sum" | "avg" | "min" | "max"
320+
"count" | "sum" | "avg" | "min" | "max" | "collect"
311321
);
312322
// Also check arguments recursively
313323
is_aggregate || args.iter().any(contains_aggregate)

rust/lance-graph/src/semantic.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ impl SemanticAnalyzer {
289289
ValueExpression::Function { name, args } => {
290290
// Validate function-specific arity and signature rules
291291
match name.to_lowercase().as_str() {
292-
"count" | "sum" | "avg" | "min" | "max" => {
292+
"count" | "sum" | "avg" | "min" | "max" | "collect" => {
293293
if args.len() != 1 {
294294
return Err(GraphError::PlanError {
295295
message: format!(
@@ -302,7 +302,7 @@ impl SemanticAnalyzer {
302302
}
303303

304304
// Additional validation for SUM, AVG, MIN, MAX: they require properties, not bare variables
305-
// Only COUNT allows bare variables (COUNT(*) or COUNT(p))
305+
// Only COUNT and COLLECT allow bare variables (COUNT(*), COUNT(p), COLLECT(p))
306306
if matches!(name.to_lowercase().as_str(), "sum" | "avg" | "min" | "max") {
307307
if let Some(ValueExpression::Variable(v)) = args.first() {
308308
return Err(GraphError::PlanError {

rust/lance-graph/tests/test_datafusion_pipeline.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4244,3 +4244,70 @@ async fn test_tolower_with_integer_column_in_return() {
42444244
assert_eq!(names.value(0), "Charlie");
42454245
assert_eq!(ages.value(0), 30);
42464246
}
4247+
4248+
#[tokio::test]
4249+
async fn test_collect_property() {
4250+
// Test COLLECT aggregation - collects values into an array
4251+
let person_batch = create_person_dataset();
4252+
let config = GraphConfig::builder()
4253+
.with_node_label("Person", "id")
4254+
.build()
4255+
.unwrap();
4256+
4257+
let query = CypherQuery::new("MATCH (p:Person) RETURN collect(p.name) AS all_names")
4258+
.unwrap()
4259+
.with_config(config);
4260+
4261+
let mut datasets = HashMap::new();
4262+
datasets.insert("Person".to_string(), person_batch);
4263+
4264+
let result = query
4265+
.execute(datasets, Some(ExecutionStrategy::DataFusion))
4266+
.await
4267+
.unwrap();
4268+
4269+
// COLLECT returns a single row with an array of all values
4270+
assert_eq!(result.num_rows(), 1);
4271+
// Verify the column exists
4272+
assert!(result.column_by_name("all_names").is_some());
4273+
}
4274+
4275+
#[tokio::test]
4276+
async fn test_collect_with_grouping() {
4277+
// Test COLLECT with GROUP BY - collect names grouped by city
4278+
let person_batch = create_person_dataset();
4279+
let config = GraphConfig::builder()
4280+
.with_node_label("Person", "id")
4281+
.build()
4282+
.unwrap();
4283+
4284+
let query = CypherQuery::new(
4285+
"MATCH (p:Person) WHERE p.city IS NOT NULL RETURN p.city, collect(p.name) AS names ORDER BY p.city",
4286+
)
4287+
.unwrap()
4288+
.with_config(config);
4289+
4290+
let mut datasets = HashMap::new();
4291+
datasets.insert("Person".to_string(), person_batch);
4292+
4293+
let result = query
4294+
.execute(datasets, Some(ExecutionStrategy::DataFusion))
4295+
.await
4296+
.unwrap();
4297+
4298+
// Should have one row per city (4 cities with non-null values)
4299+
assert_eq!(result.num_rows(), 4);
4300+
4301+
let cities = result
4302+
.column_by_name("p.city")
4303+
.unwrap()
4304+
.as_any()
4305+
.downcast_ref::<StringArray>()
4306+
.unwrap();
4307+
4308+
// Cities should be ordered: Chicago, New York, San Francisco, Seattle
4309+
assert_eq!(cities.value(0), "Chicago");
4310+
assert_eq!(cities.value(1), "New York");
4311+
assert_eq!(cities.value(2), "San Francisco");
4312+
assert_eq!(cities.value(3), "Seattle");
4313+
}

0 commit comments

Comments
 (0)