Skip to content

Commit 2a15a9f

Browse files
authored
perf(datafusion): push down list_length expression (#8600)
1 parent 1e689a1 commit 2a15a9f

7 files changed

Lines changed: 339 additions & 10 deletions

File tree

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ datafusion-datasource = { version = "54", default-features = false }
144144
datafusion-execution = { version = "54" }
145145
datafusion-expr = { version = "54" }
146146
datafusion-functions = { version = "54" }
147+
datafusion-functions-nested = { version = "54" }
147148
datafusion-physical-expr = { version = "54" }
148149
datafusion-physical-expr-adapter = { version = "54" }
149150
datafusion-physical-expr-common = { version = "54" }

vortex-datafusion/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ datafusion-datasource = { workspace = true, default-features = false }
2424
datafusion-execution = { workspace = true }
2525
datafusion-expr = { workspace = true }
2626
datafusion-functions = { workspace = true }
27+
datafusion-functions-nested = { workspace = true }
2728
datafusion-physical-expr = { workspace = true }
2829
datafusion-physical-expr-adapter = { workspace = true }
2930
datafusion-physical-expr-common = { workspace = true }

vortex-datafusion/src/convert/exprs.rs

Lines changed: 163 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ use arrow_schema::DataType;
77
use arrow_schema::Field;
88
use arrow_schema::Schema;
99
use datafusion_common::Result as DFResult;
10+
use datafusion_common::ScalarValue;
1011
use datafusion_common::exec_datafusion_err;
1112
use datafusion_common::tree_node::TreeNode;
1213
use datafusion_common::tree_node::TreeNodeRecursion;
1314
use datafusion_expr::Operator as DFOperator;
1415
use datafusion_functions::core::getfield::GetFieldFunc;
1516
use datafusion_functions::string::octet_length::OctetLengthFunc;
17+
use datafusion_functions_nested::length::ArrayLength;
1618
use datafusion_physical_expr::PhysicalExpr;
1719
use datafusion_physical_expr::ScalarFunctionExpr;
1820
use datafusion_physical_expr::projection::ProjectionExpr;
@@ -32,6 +34,7 @@ use vortex::expr::get_item;
3234
use vortex::expr::is_not_null;
3335
use vortex::expr::is_null;
3436
use vortex::expr::list_contains;
37+
use vortex::expr::list_length;
3538
use vortex::expr::lit;
3639
use vortex::expr::nested_case_when;
3740
use vortex::expr::not;
@@ -155,6 +158,32 @@ impl DefaultExpressionConvertor {
155158
Ok(cast(byte_length(input), return_dtype))
156159
}
157160

161+
/// Attempts to convert DataFusion's `array_length` function (aliased as `list_length`) to
162+
/// Vortex `list_length`.
163+
///
164+
/// Supports the single-argument form `array_length(arr)` and the equivalent two-argument
165+
/// form with an explicit first dimension `array_length(arr, 1)`.
166+
fn try_convert_array_length(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
167+
let Some(input) = array_length_input(scalar_fn) else {
168+
return Err(exec_datafusion_err!(
169+
"array_length pushdown supports only the one-argument form or an explicit first \
170+
dimension"
171+
));
172+
};
173+
174+
let input = self.convert(input.as_ref())?;
175+
let return_dtype = self
176+
.session
177+
.arrow()
178+
.from_arrow_field(&Field::new(
179+
"",
180+
scalar_fn.return_type().clone(),
181+
scalar_fn.nullable(),
182+
))
183+
.map_err(|e| exec_datafusion_err!("Failed to convert return type to dtype: {e}"))?;
184+
Ok(cast(list_length(input), return_dtype))
185+
}
186+
158187
/// Attempts to convert a DataFusion ScalarFunctionExpr to a Vortex expression.
159188
fn try_convert_scalar_function(&self, scalar_fn: &ScalarFunctionExpr) -> DFResult<Expression> {
160189
if let Some(octet_length_fn) =
@@ -163,6 +192,12 @@ impl DefaultExpressionConvertor {
163192
return self.try_convert_octet_length(octet_length_fn);
164193
}
165194

195+
if let Some(array_length_fn) =
196+
ScalarFunctionExpr::try_downcast_func::<ArrayLength>(scalar_fn)
197+
{
198+
return self.try_convert_array_length(array_length_fn);
199+
}
200+
166201
if let Some(get_field_fn) = ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn)
167202
{
168203
// DataFusion's GetFieldFunc flattens nested field access into a single call
@@ -511,6 +546,7 @@ fn is_convertible_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
511546
|| expr.downcast_ref::<ScalarFunctionExpr>().is_some_and(|sf| {
512547
ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(sf).is_some()
513548
|| ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(sf).is_some()
549+
|| ScalarFunctionExpr::try_downcast_func::<ArrayLength>(sf).is_some()
514550
})
515551
}
516552

@@ -572,14 +608,20 @@ fn supported_data_types(dt: &DataType) -> bool {
572608
}
573609

574610
/// Checks if a scalar function can be pushed down.
575-
/// Currently GetFieldFunc and OctetLengthFunc are supported.
611+
/// Currently GetFieldFunc, OctetLengthFunc, and ArrayLength are supported.
576612
fn can_scalar_fn_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
577613
if ScalarFunctionExpr::try_downcast_func::<GetFieldFunc>(scalar_fn).is_some() {
578614
return true;
579615
}
580616

581-
ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(scalar_fn)
617+
if ScalarFunctionExpr::try_downcast_func::<OctetLengthFunc>(scalar_fn)
582618
.is_some_and(|octet_length| can_octet_length_be_pushed_down(octet_length, schema))
619+
{
620+
return true;
621+
}
622+
623+
ScalarFunctionExpr::try_downcast_func::<ArrayLength>(scalar_fn)
624+
.is_some_and(|array_length| can_array_length_be_pushed_down(array_length, schema))
583625
}
584626

585627
fn can_octet_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
@@ -598,6 +640,42 @@ fn can_octet_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Sche
598640
}) && can_be_pushed_down_impl(input, schema)
599641
}
600642

643+
fn can_array_length_be_pushed_down(scalar_fn: &ScalarFunctionExpr, schema: &Schema) -> bool {
644+
let Some(input) = array_length_input(scalar_fn) else {
645+
return false;
646+
};
647+
648+
// The argument must resolve to a list type. We gate on the resolved data type rather than
649+
// `can_be_pushed_down_impl`, since list columns are intentionally rejected there. We still
650+
// require the argument to be a convertible expression (e.g. a column or struct field access).
651+
input.data_type(schema).as_ref().is_ok_and(|data_type| {
652+
matches!(
653+
data_type,
654+
DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _)
655+
)
656+
}) && is_convertible_expr(input)
657+
}
658+
659+
/// Returns the list argument of an `array_length` call if the call is a form we can rewrite to
660+
/// `list_length`: either the single-argument form `array_length(arr)`, or the two-argument form
661+
/// with an explicit first dimension `array_length(arr, 1)`, which is equivalent. Higher
662+
/// dimensions recurse into nested lists and are not supported.
663+
fn array_length_input(scalar_fn: &ScalarFunctionExpr) -> Option<&Arc<dyn PhysicalExpr>> {
664+
match scalar_fn.args() {
665+
[input] => Some(input),
666+
[input, dimension] if is_dimension_one(dimension) => Some(input),
667+
_ => None,
668+
}
669+
}
670+
671+
/// Returns true if `expr` is an `Int64` literal equal to 1. DataFusion coerces the `array_length`
672+
/// dimension argument to `Int64`, so that is the only form we need to recognize; any other literal
673+
/// simply isn't pushed down.
674+
fn is_dimension_one(expr: &Arc<dyn PhysicalExpr>) -> bool {
675+
expr.downcast_ref::<df_expr::Literal>()
676+
.is_some_and(|literal| matches!(literal.value(), ScalarValue::Int64(Some(1))))
677+
}
678+
601679
#[cfg(test)]
602680
mod tests {
603681
use std::sync::Arc;
@@ -633,7 +711,7 @@ mod tests {
633711
true,
634712
),
635713
Field::new(
636-
"unsupported_list",
714+
"tags",
637715
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
638716
true,
639717
),
@@ -652,6 +730,21 @@ mod tests {
652730
)
653731
}
654732

733+
fn array_length_expr(
734+
args: Vec<Arc<dyn PhysicalExpr>>,
735+
schema: &Schema,
736+
) -> Arc<dyn PhysicalExpr> {
737+
Arc::new(
738+
ScalarFunctionExpr::try_new(
739+
Arc::new(ScalarUDF::from(ArrayLength::new())),
740+
args,
741+
schema,
742+
Arc::new(ConfigOptions::new()),
743+
)
744+
.unwrap(),
745+
)
746+
}
747+
655748
#[test]
656749
fn test_make_vortex_predicate_empty() {
657750
let expr_convertor = DefaultExpressionConvertor::default();
@@ -798,6 +891,23 @@ mod tests {
798891
");
799892
}
800893

894+
#[rstest]
895+
fn test_expr_from_df_array_length(test_schema: Schema) {
896+
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
897+
let array_length = array_length_expr(vec![expr], &test_schema);
898+
899+
let result = DefaultExpressionConvertor::default()
900+
.convert(array_length.as_ref())
901+
.unwrap();
902+
903+
assert_snapshot!(result.display_tree().to_string(), @r"
904+
vortex.cast(u64?)
905+
└── input: vortex.list.length()
906+
└── input: vortex.get_item(tags)
907+
└── input: vortex.root()
908+
");
909+
}
910+
801911
#[rstest]
802912
// Supported types
803913
#[case::null(DataType::Null, true)]
@@ -861,8 +971,7 @@ mod tests {
861971

862972
#[rstest]
863973
fn test_can_be_pushed_down_column_unsupported_type(test_schema: Schema) {
864-
let col_expr =
865-
Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
974+
let col_expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
866975

867976
assert!(!can_be_pushed_down_impl(&col_expr, &test_schema));
868977
}
@@ -919,7 +1028,7 @@ mod tests {
9191028

9201029
#[rstest]
9211030
fn test_can_be_pushed_down_binary_unsupported_operand(test_schema: Schema) {
922-
let left = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
1031+
let left = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
9231032
let right =
9241033
Arc::new(df_expr::Literal::new(ScalarValue::Int32(Some(42)))) as Arc<dyn PhysicalExpr>;
9251034
let binary_expr = Arc::new(df_expr::BinaryExpr::new(left, DFOperator::Eq, right))
@@ -942,7 +1051,7 @@ mod tests {
9421051

9431052
#[rstest]
9441053
fn test_can_be_pushed_down_like_unsupported_operand(test_schema: Schema) {
945-
let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
1054+
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
9461055
let pattern = Arc::new(df_expr::Literal::new(ScalarValue::Utf8(Some(
9471056
"test%".to_string(),
9481057
)))) as Arc<dyn PhysicalExpr>;
@@ -962,7 +1071,7 @@ mod tests {
9621071

9631072
#[rstest]
9641073
fn test_can_be_pushed_down_octet_length_unsupported_operand(test_schema: Schema) {
965-
let expr = Arc::new(df_expr::Column::new("unsupported_list", 5)) as Arc<dyn PhysicalExpr>;
1074+
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
9661075
let octet_length = Arc::new(ScalarFunctionExpr::new(
9671076
"octet_length",
9681077
Arc::new(ScalarUDF::from(OctetLengthFunc::new())),
@@ -974,6 +1083,52 @@ mod tests {
9741083
assert!(!can_be_pushed_down_impl(&octet_length, &test_schema));
9751084
}
9761085

1086+
#[rstest]
1087+
fn test_can_be_pushed_down_array_length_supported(test_schema: Schema) {
1088+
let expr = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
1089+
let array_length = array_length_expr(vec![expr], &test_schema);
1090+
1091+
assert!(can_be_pushed_down_impl(&array_length, &test_schema));
1092+
}
1093+
1094+
#[rstest]
1095+
fn test_can_be_pushed_down_array_length_unsupported_operand(test_schema: Schema) {
1096+
// `array_length` over a non-list column cannot be pushed down.
1097+
let expr = Arc::new(df_expr::Column::new("name", 1)) as Arc<dyn PhysicalExpr>;
1098+
let array_length = Arc::new(ScalarFunctionExpr::new(
1099+
"array_length",
1100+
Arc::new(ScalarUDF::from(ArrayLength::new())),
1101+
vec![expr],
1102+
Arc::new(Field::new("array_length", DataType::UInt64, true)),
1103+
Arc::new(ConfigOptions::new()),
1104+
)) as Arc<dyn PhysicalExpr>;
1105+
1106+
assert!(!can_be_pushed_down_impl(&array_length, &test_schema));
1107+
}
1108+
1109+
#[rstest]
1110+
fn test_can_be_pushed_down_array_length_dimension_one_supported(test_schema: Schema) {
1111+
// `array_length(arr, 1)` is the first-dimension length, equivalent to `list_length`.
1112+
let list = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
1113+
let dimension =
1114+
Arc::new(df_expr::Literal::new(ScalarValue::Int64(Some(1)))) as Arc<dyn PhysicalExpr>;
1115+
let array_length = array_length_expr(vec![list, dimension], &test_schema);
1116+
1117+
assert!(can_be_pushed_down_impl(&array_length, &test_schema));
1118+
}
1119+
1120+
#[rstest]
1121+
fn test_can_be_pushed_down_array_length_higher_dimension_not_supported(test_schema: Schema) {
1122+
// Dimensions other than 1 recurse into nested lists, which `list_length` does not model,
1123+
// so they must not be pushed down.
1124+
let list = Arc::new(df_expr::Column::new("tags", 5)) as Arc<dyn PhysicalExpr>;
1125+
let dimension =
1126+
Arc::new(df_expr::Literal::new(ScalarValue::Int64(Some(2)))) as Arc<dyn PhysicalExpr>;
1127+
let array_length = array_length_expr(vec![list, dimension], &test_schema);
1128+
1129+
assert!(!can_be_pushed_down_impl(&array_length, &test_schema));
1130+
}
1131+
9771132
// https://github.com/vortex-data/vortex/issues/6211
9781133
#[tokio::test]
9791134
async fn test_cast_int_to_string() -> anyhow::Result<()> {

vortex-sqllogictest/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ anyhow = { workspace = true }
1818
async-trait = { workspace = true }
1919
bigdecimal = { workspace = true }
2020
datafusion = { workspace = true }
21+
datafusion-functions-nested = { workspace = true }
2122
datafusion-sqllogictest = { workspace = true }
2223
indicatif = { workspace = true }
2324
regex = { workspace = true }

vortex-sqllogictest/bin/sqllogictests-runner.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,12 @@ fn drive_datafusion(path: &Path, work_dir: &Path, mode: Mode) -> anyhow::Result<
7373
Arc::new(DefaultTableFactory::new()),
7474
)
7575
.with_file_formats(vec![factory]);
76-
let session =
77-
SessionContext::new_with_state(session_state_builder.build()).enable_url_table();
76+
// The workspace builds `datafusion` without the `nested_expressions` feature, so array
77+
// functions (e.g. `make_array`, `array_length`) are not registered by default. Register
78+
// them explicitly so SLT files can construct and query list columns.
79+
let mut session_state = session_state_builder.build();
80+
datafusion_functions_nested::register_all(&mut session_state)?;
81+
let session = SessionContext::new_with_state(session_state).enable_url_table();
7882

7983
let mut runner = Runner::new(|| async {
8084
Ok(PathNormalizing::new(

0 commit comments

Comments
 (0)