diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 1c661744e0867..6e22aeab61089 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -2458,9 +2458,8 @@ async fn cache_producer_test() -> Result<()> { @r" CacheNode Projection: aggregate_test_100.c2, aggregate_test_100.c3, CAST(CAST(aggregate_test_100.c2 AS Int64) + CAST(aggregate_test_100.c3 AS Int64) AS Int64) AS sum - Projection: aggregate_test_100.c2, aggregate_test_100.c3 - Limit: skip=0, fetch=1 - TableScan: aggregate_test_100, fetch=1 + Limit: skip=0, fetch=1 + TableScan: aggregate_test_100 projection=[c2, c3], fetch=1 " ); Ok(()) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 93df300bb50b4..26fbb71af6c09 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -329,29 +329,34 @@ fn optimize_projections( .collect() } LogicalPlan::Extension(extension) => { - let Some(necessary_children_indices) = + if let Some(necessary_children_indices) = extension.node.necessary_children_exprs(indices.indices()) - else { - // Requirements from parent cannot be routed down to user defined logical plan safely - return Ok(Transformed::no(plan)); - }; - let children = extension.node.inputs(); - assert_eq_or_internal_err!( - children.len(), - necessary_children_indices.len(), - "Inconsistent length between children and necessary children indices. \ + { + let children = extension.node.inputs(); + assert_eq_or_internal_err!( + children.len(), + necessary_children_indices.len(), + "Inconsistent length between children and necessary children indices. \ Make sure `.necessary_children_exprs` implementation of the \ `UserDefinedLogicalNode` is consistent with actual children length \ for the node." - ); - children - .into_iter() - .zip(necessary_children_indices) - .map(|(child, necessary_indices)| { - RequiredIndices::new_from_indices(necessary_indices) - .with_plan_exprs(&plan, child.schema()) - }) - .collect::>>()? + ); + children + .into_iter() + .zip(necessary_children_indices) + .map(|(child, necessary_indices)| { + RequiredIndices::new_from_indices(necessary_indices) + .with_plan_exprs(&plan, child.schema()) + }) + .collect::>>()? + } else { + // Requirements from parent cannot be routed down to user defined logical plan safely + // Assume it requires all input exprs here + plan.inputs() + .into_iter() + .map(RequiredIndices::new_for_all_exprs) + .collect() + } } LogicalPlan::EmptyRelation(_) | LogicalPlan::Values(_) @@ -1172,6 +1177,57 @@ mod tests { } } + /// A user-defined node that does NOT implement `necessary_children_exprs`, + /// so the optimizer cannot determine which columns are required from its + /// children and must assume all columns are needed. + #[derive(Debug, Hash, PartialEq, Eq)] + struct OpaqueRequirementsUserDefined { + input: Arc, + schema: DFSchemaRef, + } + + // Manual implementation needed because of `schema` field. Comparison excludes this field. + impl PartialOrd for OpaqueRequirementsUserDefined { + fn partial_cmp(&self, other: &Self) -> Option { + self.input + .partial_cmp(&other.input) + .filter(|cmp| *cmp != Ordering::Equal || self == other) + } + } + + impl UserDefinedLogicalNodeCore for OpaqueRequirementsUserDefined { + fn name(&self) -> &str { + "OpaqueRequirementsUserDefined" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + mut inputs: Vec, + ) -> Result { + Ok(Self { + input: Arc::new(inputs.swap_remove(0)), + schema: Arc::clone(&self.schema), + }) + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "OpaqueRequirementsUserDefined") + } + } + #[test] fn merge_two_projection() -> Result<()> { let table_scan = test_table_scan()?; @@ -2204,6 +2260,29 @@ mod tests { Ok(()) } + #[test] + fn test_continue_processing_through_extension() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![col("a")])? + .project(vec![col("a")])? + .build()?; + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(OpaqueRequirementsUserDefined { + input: Arc::new(plan), + schema: Arc::clone(table_scan.schema()), + }), + }); + let plan = optimize(plan).expect("failed to optimize plan"); + assert_optimized_plan_equal!( + plan, + @r" + OpaqueRequirementsUserDefined + TableScan: test projection=[a] + " + ) + } + /// tests that it removes an aggregate is never used downstream #[test] fn table_unused_aggregate() -> Result<()> { diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index fd4991c24413f..ed4b0c0410c2b 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -16,16 +16,25 @@ // under the License. use std::any::Any; +use std::cmp::Ordering; use std::collections::HashMap; +use std::fmt::Formatter; use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{Result, TableReference, plan_err}; +use datafusion_common::{ + DFSchemaRef, Result, ScalarValue, TableReference, ToDFSchema, plan_err, +}; +use datafusion_expr::expr::Cast; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::test::function_stub::sum_udaf; -use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion_expr::{ + AggregateUDF, Expr, Extension, LogicalPlan, ScalarUDF, SortExpr, + TableProviderFilterPushDown, TableSource, UserDefinedLogicalNodeCore, WindowUDF, col, +}; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::planner::AggregateFunctionPlanner; @@ -690,6 +699,148 @@ fn test_sql(sql: &str) -> Result { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} +fn optimize_plan(plan: LogicalPlan) -> Result { + let config = OptimizerContext::new().with_skip_failing_rules(false); + let optimizer = Optimizer::new(); + optimizer.optimize(plan, &config, observe) +} + +/// Extension node that does NOT implement `necessary_children_exprs`. +/// Used to test that the optimizer still processes subtrees below such nodes. +#[derive(Debug, Hash, PartialEq, Eq)] +struct OpaqueRequirementsExtension { + input: Arc, + schema: DFSchemaRef, +} + +impl PartialOrd for OpaqueRequirementsExtension { + fn partial_cmp(&self, other: &Self) -> Option { + self.input + .partial_cmp(&other.input) + .filter(|cmp| *cmp != Ordering::Equal || self == other) + } +} + +impl UserDefinedLogicalNodeCore for OpaqueRequirementsExtension { + fn name(&self) -> &str { + "OpaqueRequirementsExtension" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + mut inputs: Vec, + ) -> Result { + Ok(Self { + input: Arc::new(inputs.swap_remove(0)), + schema: Arc::clone(&self.schema), + }) + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "OpaqueRequirementsExtension") + } +} + +struct InexactFilterTableSource { + schema: SchemaRef, +} + +impl TableSource for InexactFilterTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) + } +} + +/// Reproduction of https://github.com/apache/datafusion/issues/18816 +/// Extension nodes without `necessary_children_exprs` should not prevent +/// the optimizer from pruning unnecessary columns in subtrees. +#[test] +fn extension_node_does_not_block_projection_pruning() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + + let table_source: Arc = Arc::new(InexactFilterTableSource { + schema: Arc::clone(&schema), + }); + + let ts_cast = Expr::Cast(Cast::new( + Box::new(col("t.ts")), + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + )); + let ts_millis_1000 = Expr::Literal( + ScalarValue::TimestampMillisecond(Some(1000), Some("UTC".into())), + None, + ); + let ts_millis_2000 = Expr::Literal( + ScalarValue::TimestampMillisecond(Some(2000), Some("UTC".into())), + None, + ); + + let plan = LogicalPlanBuilder::scan("t", table_source, None)? + .project(vec![col("t.a"), ts_cast.alias_qualified(Some("t"), "ts")])? + .filter( + col("t.ts") + .gt(ts_millis_1000) + .and(col("t.ts").lt(ts_millis_2000)), + )? + .sort(vec![ + SortExpr::new(col("t.a"), true, true), + SortExpr::new(col("t.ts"), true, true), + ])? + .build()?; + + let df_schema = schema.to_dfschema_ref()?; + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(OpaqueRequirementsExtension { + input: Arc::new(plan), + schema: df_schema, + }), + }); + + let optimized = optimize_plan(plan)?; + assert_snapshot!( + format!("{optimized}"), + @r#" + OpaqueRequirementsExtension + Sort: t.a ASC NULLS FIRST, t.ts ASC NULLS FIRST + Projection: t.a, CAST(t.ts AS Timestamp(ms, "UTC")) AS ts + Projection: t.a, t.ts + Filter: __common_expr_3 > TimestampMillisecond(1000, Some("UTC")) AND __common_expr_3 < TimestampMillisecond(2000, Some("UTC")) + Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_3, t.a, t.ts + TableScan: t projection=[a, ts], partial_filters=[t.ts > TimestampNanosecond(1000000000, None), t.ts < TimestampNanosecond(2000000000, None), CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))] + "#, + ); + + Ok(()) +} + #[derive(Default)] struct MyContextProvider { options: ConfigOptions,