Skip to content

Commit f4e24a5

Browse files
authored
fix: skips projection pruning for whole subtree (#20545)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #18816 . ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> In `UserDefinedLogicalNodeCore`, the default implementation of `necessary_children_exprs ` returns `None`, which signals to the optimizer that it cannot determine which columns are required from the child. The optimizer takes a conservative approach and skips projection pruning for that node, leading to complex and redundant plans in the subtree. However, it would make more sense to assume all columns are required and let the optimizer proceed, rather than giving up on the entire subtree entirely. ## What changes are included in this PR? ```rust LogicalPlan::Extension(extension) => { if let Some(necessary_children_indices) = extension.node.necessary_children_exprs(indices.indices()) { ... } else { // Requirements from parent cannot be routed down to user defined logical plan safely // Assume it requires all input exprs here plan.inputs() .into_iter() .map(RequiredIndices::new_for_all_exprs) .collect() } } ``` instead of https://github.com/apache/datafusion/blob/b6d46a63824f003117297848d8d83b659ac2e759/datafusion/optimizer/src/optimize_projections/mod.rs#L331-L337 <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes. In addition to unit tests, I've also added a complete end-to-end integration test that reproduces the full scenario in the issue. This might seem redundant, bloated, or even unnecessary. Please let me know if I should remove these tests. An existing test is modified, but I think the newer behavior is expected. <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? Yes. But I think the new implementation is the expected behavior. <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 944bac2 commit f4e24a5

File tree

3 files changed

+253
-24
lines changed

3 files changed

+253
-24
lines changed

datafusion/core/tests/dataframe/mod.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2458,9 +2458,8 @@ async fn cache_producer_test() -> Result<()> {
24582458
@r"
24592459
CacheNode
24602460
Projection: aggregate_test_100.c2, aggregate_test_100.c3, CAST(CAST(aggregate_test_100.c2 AS Int64) + CAST(aggregate_test_100.c3 AS Int64) AS Int64) AS sum
2461-
Projection: aggregate_test_100.c2, aggregate_test_100.c3
2462-
Limit: skip=0, fetch=1
2463-
TableScan: aggregate_test_100, fetch=1
2461+
Limit: skip=0, fetch=1
2462+
TableScan: aggregate_test_100 projection=[c2, c3], fetch=1
24642463
"
24652464
);
24662465
Ok(())

datafusion/optimizer/src/optimize_projections/mod.rs

Lines changed: 98 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -329,29 +329,34 @@ fn optimize_projections(
329329
.collect()
330330
}
331331
LogicalPlan::Extension(extension) => {
332-
let Some(necessary_children_indices) =
332+
if let Some(necessary_children_indices) =
333333
extension.node.necessary_children_exprs(indices.indices())
334-
else {
335-
// Requirements from parent cannot be routed down to user defined logical plan safely
336-
return Ok(Transformed::no(plan));
337-
};
338-
let children = extension.node.inputs();
339-
assert_eq_or_internal_err!(
340-
children.len(),
341-
necessary_children_indices.len(),
342-
"Inconsistent length between children and necessary children indices. \
334+
{
335+
let children = extension.node.inputs();
336+
assert_eq_or_internal_err!(
337+
children.len(),
338+
necessary_children_indices.len(),
339+
"Inconsistent length between children and necessary children indices. \
343340
Make sure `.necessary_children_exprs` implementation of the \
344341
`UserDefinedLogicalNode` is consistent with actual children length \
345342
for the node."
346-
);
347-
children
348-
.into_iter()
349-
.zip(necessary_children_indices)
350-
.map(|(child, necessary_indices)| {
351-
RequiredIndices::new_from_indices(necessary_indices)
352-
.with_plan_exprs(&plan, child.schema())
353-
})
354-
.collect::<Result<Vec<_>>>()?
343+
);
344+
children
345+
.into_iter()
346+
.zip(necessary_children_indices)
347+
.map(|(child, necessary_indices)| {
348+
RequiredIndices::new_from_indices(necessary_indices)
349+
.with_plan_exprs(&plan, child.schema())
350+
})
351+
.collect::<Result<Vec<_>>>()?
352+
} else {
353+
// Requirements from parent cannot be routed down to user defined logical plan safely
354+
// Assume it requires all input exprs here
355+
plan.inputs()
356+
.into_iter()
357+
.map(RequiredIndices::new_for_all_exprs)
358+
.collect()
359+
}
355360
}
356361
LogicalPlan::EmptyRelation(_)
357362
| LogicalPlan::Values(_)
@@ -1172,6 +1177,57 @@ mod tests {
11721177
}
11731178
}
11741179

1180+
/// A user-defined node that does NOT implement `necessary_children_exprs`,
1181+
/// so the optimizer cannot determine which columns are required from its
1182+
/// children and must assume all columns are needed.
1183+
#[derive(Debug, Hash, PartialEq, Eq)]
1184+
struct OpaqueRequirementsUserDefined {
1185+
input: Arc<LogicalPlan>,
1186+
schema: DFSchemaRef,
1187+
}
1188+
1189+
// Manual implementation needed because of `schema` field. Comparison excludes this field.
1190+
impl PartialOrd for OpaqueRequirementsUserDefined {
1191+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1192+
self.input
1193+
.partial_cmp(&other.input)
1194+
.filter(|cmp| *cmp != Ordering::Equal || self == other)
1195+
}
1196+
}
1197+
1198+
impl UserDefinedLogicalNodeCore for OpaqueRequirementsUserDefined {
1199+
fn name(&self) -> &str {
1200+
"OpaqueRequirementsUserDefined"
1201+
}
1202+
1203+
fn inputs(&self) -> Vec<&LogicalPlan> {
1204+
vec![&self.input]
1205+
}
1206+
1207+
fn schema(&self) -> &DFSchemaRef {
1208+
&self.schema
1209+
}
1210+
1211+
fn expressions(&self) -> Vec<Expr> {
1212+
vec![]
1213+
}
1214+
1215+
fn with_exprs_and_inputs(
1216+
&self,
1217+
_exprs: Vec<Expr>,
1218+
mut inputs: Vec<LogicalPlan>,
1219+
) -> Result<Self> {
1220+
Ok(Self {
1221+
input: Arc::new(inputs.swap_remove(0)),
1222+
schema: Arc::clone(&self.schema),
1223+
})
1224+
}
1225+
1226+
fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
1227+
write!(f, "OpaqueRequirementsUserDefined")
1228+
}
1229+
}
1230+
11751231
#[test]
11761232
fn merge_two_projection() -> Result<()> {
11771233
let table_scan = test_table_scan()?;
@@ -2204,6 +2260,29 @@ mod tests {
22042260
Ok(())
22052261
}
22062262

2263+
#[test]
2264+
fn test_continue_processing_through_extension() -> Result<()> {
2265+
let table_scan = test_table_scan()?;
2266+
let plan = LogicalPlanBuilder::from(table_scan.clone())
2267+
.project(vec![col("a")])?
2268+
.project(vec![col("a")])?
2269+
.build()?;
2270+
let plan = LogicalPlan::Extension(Extension {
2271+
node: Arc::new(OpaqueRequirementsUserDefined {
2272+
input: Arc::new(plan),
2273+
schema: Arc::clone(table_scan.schema()),
2274+
}),
2275+
});
2276+
let plan = optimize(plan).expect("failed to optimize plan");
2277+
assert_optimized_plan_equal!(
2278+
plan,
2279+
@r"
2280+
OpaqueRequirementsUserDefined
2281+
TableScan: test projection=[a]
2282+
"
2283+
)
2284+
}
2285+
22072286
/// tests that it removes an aggregate is never used downstream
22082287
#[test]
22092288
fn table_unused_aggregate() -> Result<()> {

datafusion/optimizer/tests/optimizer_integration.rs

Lines changed: 153 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,25 @@
1616
// under the License.
1717

1818
use std::any::Any;
19+
use std::cmp::Ordering;
1920
use std::collections::HashMap;
21+
use std::fmt::Formatter;
2022
use std::sync::Arc;
2123

2224
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
2325

2426
use datafusion_common::config::ConfigOptions;
25-
use datafusion_common::{Result, TableReference, plan_err};
27+
use datafusion_common::{
28+
DFSchemaRef, Result, ScalarValue, TableReference, ToDFSchema, plan_err,
29+
};
30+
use datafusion_expr::expr::Cast;
31+
use datafusion_expr::logical_plan::builder::LogicalPlanBuilder;
2632
use datafusion_expr::planner::ExprPlanner;
2733
use datafusion_expr::test::function_stub::sum_udaf;
28-
use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF};
34+
use datafusion_expr::{
35+
AggregateUDF, Expr, Extension, LogicalPlan, ScalarUDF, SortExpr,
36+
TableProviderFilterPushDown, TableSource, UserDefinedLogicalNodeCore, WindowUDF, col,
37+
};
2938
use datafusion_functions_aggregate::average::avg_udaf;
3039
use datafusion_functions_aggregate::count::count_udaf;
3140
use datafusion_functions_aggregate::planner::AggregateFunctionPlanner;
@@ -690,6 +699,148 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
690699

691700
fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
692701

702+
fn optimize_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
703+
let config = OptimizerContext::new().with_skip_failing_rules(false);
704+
let optimizer = Optimizer::new();
705+
optimizer.optimize(plan, &config, observe)
706+
}
707+
708+
/// Extension node that does NOT implement `necessary_children_exprs`.
709+
/// Used to test that the optimizer still processes subtrees below such nodes.
710+
#[derive(Debug, Hash, PartialEq, Eq)]
711+
struct OpaqueRequirementsExtension {
712+
input: Arc<LogicalPlan>,
713+
schema: DFSchemaRef,
714+
}
715+
716+
impl PartialOrd for OpaqueRequirementsExtension {
717+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
718+
self.input
719+
.partial_cmp(&other.input)
720+
.filter(|cmp| *cmp != Ordering::Equal || self == other)
721+
}
722+
}
723+
724+
impl UserDefinedLogicalNodeCore for OpaqueRequirementsExtension {
725+
fn name(&self) -> &str {
726+
"OpaqueRequirementsExtension"
727+
}
728+
729+
fn inputs(&self) -> Vec<&LogicalPlan> {
730+
vec![&self.input]
731+
}
732+
733+
fn schema(&self) -> &DFSchemaRef {
734+
&self.schema
735+
}
736+
737+
fn expressions(&self) -> Vec<Expr> {
738+
vec![]
739+
}
740+
741+
fn with_exprs_and_inputs(
742+
&self,
743+
_exprs: Vec<Expr>,
744+
mut inputs: Vec<LogicalPlan>,
745+
) -> Result<Self> {
746+
Ok(Self {
747+
input: Arc::new(inputs.swap_remove(0)),
748+
schema: Arc::clone(&self.schema),
749+
})
750+
}
751+
752+
fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
753+
write!(f, "OpaqueRequirementsExtension")
754+
}
755+
}
756+
757+
struct InexactFilterTableSource {
758+
schema: SchemaRef,
759+
}
760+
761+
impl TableSource for InexactFilterTableSource {
762+
fn as_any(&self) -> &dyn Any {
763+
self
764+
}
765+
766+
fn schema(&self) -> SchemaRef {
767+
self.schema.clone()
768+
}
769+
770+
fn supports_filters_pushdown(
771+
&self,
772+
filters: &[&Expr],
773+
) -> Result<Vec<TableProviderFilterPushDown>> {
774+
Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()])
775+
}
776+
}
777+
778+
/// Reproduction of https://github.com/apache/datafusion/issues/18816
779+
/// Extension nodes without `necessary_children_exprs` should not prevent
780+
/// the optimizer from pruning unnecessary columns in subtrees.
781+
#[test]
782+
fn extension_node_does_not_block_projection_pruning() -> Result<()> {
783+
let schema = Arc::new(Schema::new(vec![
784+
Field::new("a", DataType::Int32, true),
785+
Field::new("b", DataType::Int32, true),
786+
Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
787+
]));
788+
789+
let table_source: Arc<dyn TableSource> = Arc::new(InexactFilterTableSource {
790+
schema: Arc::clone(&schema),
791+
});
792+
793+
let ts_cast = Expr::Cast(Cast::new(
794+
Box::new(col("t.ts")),
795+
DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())),
796+
));
797+
let ts_millis_1000 = Expr::Literal(
798+
ScalarValue::TimestampMillisecond(Some(1000), Some("UTC".into())),
799+
None,
800+
);
801+
let ts_millis_2000 = Expr::Literal(
802+
ScalarValue::TimestampMillisecond(Some(2000), Some("UTC".into())),
803+
None,
804+
);
805+
806+
let plan = LogicalPlanBuilder::scan("t", table_source, None)?
807+
.project(vec![col("t.a"), ts_cast.alias_qualified(Some("t"), "ts")])?
808+
.filter(
809+
col("t.ts")
810+
.gt(ts_millis_1000)
811+
.and(col("t.ts").lt(ts_millis_2000)),
812+
)?
813+
.sort(vec![
814+
SortExpr::new(col("t.a"), true, true),
815+
SortExpr::new(col("t.ts"), true, true),
816+
])?
817+
.build()?;
818+
819+
let df_schema = schema.to_dfschema_ref()?;
820+
let plan = LogicalPlan::Extension(Extension {
821+
node: Arc::new(OpaqueRequirementsExtension {
822+
input: Arc::new(plan),
823+
schema: df_schema,
824+
}),
825+
});
826+
827+
let optimized = optimize_plan(plan)?;
828+
assert_snapshot!(
829+
format!("{optimized}"),
830+
@r#"
831+
OpaqueRequirementsExtension
832+
Sort: t.a ASC NULLS FIRST, t.ts ASC NULLS FIRST
833+
Projection: t.a, CAST(t.ts AS Timestamp(ms, "UTC")) AS ts
834+
Projection: t.a, t.ts
835+
Filter: __common_expr_3 > TimestampMillisecond(1000, Some("UTC")) AND __common_expr_3 < TimestampMillisecond(2000, Some("UTC"))
836+
Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_3, t.a, t.ts
837+
TableScan: t projection=[a, ts], partial_filters=[t.ts > TimestampNanosecond(1000000000, None), t.ts < TimestampNanosecond(2000000000, None), CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))]
838+
"#,
839+
);
840+
841+
Ok(())
842+
}
843+
693844
#[derive(Default)]
694845
struct MyContextProvider {
695846
options: ConfigOptions,

0 commit comments

Comments
 (0)