Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2458,9 +2458,8 @@ async fn cache_producer_test() -> Result<()> {
@r"
CacheNode
Projection: aggregate_test_100.c2, aggregate_test_100.c3, CAST(CAST(aggregate_test_100.c2 AS Int64) + CAST(aggregate_test_100.c3 AS Int64) AS Int64) AS sum
Projection: aggregate_test_100.c2, aggregate_test_100.c3
Limit: skip=0, fetch=1
TableScan: aggregate_test_100, fetch=1
Limit: skip=0, fetch=1
TableScan: aggregate_test_100 projection=[c2, c3], fetch=1
"
);
Ok(())
Expand Down
117 changes: 98 additions & 19 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,29 +329,34 @@ fn optimize_projections(
.collect()
}
LogicalPlan::Extension(extension) => {
let Some(necessary_children_indices) =
if let Some(necessary_children_indices) =
extension.node.necessary_children_exprs(indices.indices())
else {
// Requirements from parent cannot be routed down to user defined logical plan safely
return Ok(Transformed::no(plan));
};
let children = extension.node.inputs();
assert_eq_or_internal_err!(
children.len(),
necessary_children_indices.len(),
"Inconsistent length between children and necessary children indices. \
{
let children = extension.node.inputs();
assert_eq_or_internal_err!(
children.len(),
necessary_children_indices.len(),
"Inconsistent length between children and necessary children indices. \
Make sure `.necessary_children_exprs` implementation of the \
`UserDefinedLogicalNode` is consistent with actual children length \
for the node."
);
children
.into_iter()
.zip(necessary_children_indices)
.map(|(child, necessary_indices)| {
RequiredIndices::new_from_indices(necessary_indices)
.with_plan_exprs(&plan, child.schema())
})
.collect::<Result<Vec<_>>>()?
);
children
.into_iter()
.zip(necessary_children_indices)
.map(|(child, necessary_indices)| {
RequiredIndices::new_from_indices(necessary_indices)
.with_plan_exprs(&plan, child.schema())
})
.collect::<Result<Vec<_>>>()?
} else {
// Requirements from parent cannot be routed down to user defined logical plan safely
// Assume it requires all input exprs here
plan.inputs()
.into_iter()
.map(RequiredIndices::new_for_all_exprs)
.collect()
}
}
LogicalPlan::EmptyRelation(_)
| LogicalPlan::Values(_)
Expand Down Expand Up @@ -1172,6 +1177,57 @@ mod tests {
}
}

/// A user-defined node that does NOT implement `necessary_children_exprs`,
/// so the optimizer cannot determine which columns are required from its
/// children and must assume all columns are needed.
#[derive(Debug, Hash, PartialEq, Eq)]
struct OpaqueRequirementsUserDefined {
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
}

// Manual implementation needed because of `schema` field. Comparison excludes this field.
impl PartialOrd for OpaqueRequirementsUserDefined {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.input
.partial_cmp(&other.input)
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}

impl UserDefinedLogicalNodeCore for OpaqueRequirementsUserDefined {
fn name(&self) -> &str {
"OpaqueRequirementsUserDefined"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.input]
}

fn schema(&self) -> &DFSchemaRef {
&self.schema
}

fn expressions(&self) -> Vec<Expr> {
vec![]
}

fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
mut inputs: Vec<LogicalPlan>,
) -> Result<Self> {
Ok(Self {
input: Arc::new(inputs.swap_remove(0)),
schema: Arc::clone(&self.schema),
})
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "OpaqueRequirementsUserDefined")
}
}

#[test]
fn merge_two_projection() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down Expand Up @@ -2204,6 +2260,29 @@ mod tests {
Ok(())
}

#[test]
fn test_continue_processing_through_extension() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![col("a")])?
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlan::Extension(Extension {
node: Arc::new(OpaqueRequirementsUserDefined {
input: Arc::new(plan),
schema: Arc::clone(table_scan.schema()),
}),
});
let plan = optimize(plan).expect("failed to optimize plan");
assert_optimized_plan_equal!(
plan,
@r"
OpaqueRequirementsUserDefined
TableScan: test projection=[a]
"
)
}

/// tests that it removes an aggregate is never used downstream
#[test]
fn table_unused_aggregate() -> Result<()> {
Expand Down
155 changes: 153 additions & 2 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,25 @@
// under the License.

use std::any::Any;
use std::cmp::Ordering;
use std::collections::HashMap;
use std::fmt::Formatter;
use std::sync::Arc;

use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};

use datafusion_common::config::ConfigOptions;
use datafusion_common::{Result, TableReference, plan_err};
use datafusion_common::{
DFSchemaRef, Result, ScalarValue, TableReference, ToDFSchema, plan_err,
};
use datafusion_expr::expr::Cast;
use datafusion_expr::logical_plan::builder::LogicalPlanBuilder;
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::test::function_stub::sum_udaf;
use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF};
use datafusion_expr::{
AggregateUDF, Expr, Extension, LogicalPlan, ScalarUDF, SortExpr,
TableProviderFilterPushDown, TableSource, UserDefinedLogicalNodeCore, WindowUDF, col,
};
use datafusion_functions_aggregate::average::avg_udaf;
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::planner::AggregateFunctionPlanner;
Expand Down Expand Up @@ -690,6 +699,148 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {

fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}

fn optimize_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
let config = OptimizerContext::new().with_skip_failing_rules(false);
let optimizer = Optimizer::new();
optimizer.optimize(plan, &config, observe)
}

/// Extension node that does NOT implement `necessary_children_exprs`.
/// Used to test that the optimizer still processes subtrees below such nodes.
#[derive(Debug, Hash, PartialEq, Eq)]
struct OpaqueRequirementsExtension {
input: Arc<LogicalPlan>,
schema: DFSchemaRef,
}

impl PartialOrd for OpaqueRequirementsExtension {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.input
.partial_cmp(&other.input)
.filter(|cmp| *cmp != Ordering::Equal || self == other)
}
}

impl UserDefinedLogicalNodeCore for OpaqueRequirementsExtension {
fn name(&self) -> &str {
"OpaqueRequirementsExtension"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.input]
}

fn schema(&self) -> &DFSchemaRef {
&self.schema
}

fn expressions(&self) -> Vec<Expr> {
vec![]
}

fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
mut inputs: Vec<LogicalPlan>,
) -> Result<Self> {
Ok(Self {
input: Arc::new(inputs.swap_remove(0)),
schema: Arc::clone(&self.schema),
})
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "OpaqueRequirementsExtension")
}
}

struct InexactFilterTableSource {
schema: SchemaRef,
}

impl TableSource for InexactFilterTableSource {
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
self.schema.clone()
}

fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> Result<Vec<TableProviderFilterPushDown>> {
Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()])
}
}

/// Reproduction of https://github.com/apache/datafusion/issues/18816
/// Extension nodes without `necessary_children_exprs` should not prevent
/// the optimizer from pruning unnecessary columns in subtrees.
#[test]
fn extension_node_does_not_block_projection_pruning() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true),
]));

let table_source: Arc<dyn TableSource> = Arc::new(InexactFilterTableSource {
schema: Arc::clone(&schema),
});

let ts_cast = Expr::Cast(Cast::new(
Box::new(col("t.ts")),
DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())),
));
let ts_millis_1000 = Expr::Literal(
ScalarValue::TimestampMillisecond(Some(1000), Some("UTC".into())),
None,
);
let ts_millis_2000 = Expr::Literal(
ScalarValue::TimestampMillisecond(Some(2000), Some("UTC".into())),
None,
);

let plan = LogicalPlanBuilder::scan("t", table_source, None)?
.project(vec![col("t.a"), ts_cast.alias_qualified(Some("t"), "ts")])?
.filter(
col("t.ts")
.gt(ts_millis_1000)
.and(col("t.ts").lt(ts_millis_2000)),
)?
.sort(vec![
SortExpr::new(col("t.a"), true, true),
SortExpr::new(col("t.ts"), true, true),
])?
.build()?;

let df_schema = schema.to_dfschema_ref()?;
let plan = LogicalPlan::Extension(Extension {
node: Arc::new(OpaqueRequirementsExtension {
input: Arc::new(plan),
schema: df_schema,
}),
});

let optimized = optimize_plan(plan)?;
assert_snapshot!(
format!("{optimized}"),
@r#"
OpaqueRequirementsExtension
Sort: t.a ASC NULLS FIRST, t.ts ASC NULLS FIRST
Projection: t.a, CAST(t.ts AS Timestamp(ms, "UTC")) AS ts
Projection: t.a, t.ts
Filter: __common_expr_3 > TimestampMillisecond(1000, Some("UTC")) AND __common_expr_3 < TimestampMillisecond(2000, Some("UTC"))
Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_3, t.a, t.ts
TableScan: t projection=[a, ts], partial_filters=[t.ts > TimestampNanosecond(1000000000, None), t.ts < TimestampNanosecond(2000000000, None), CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))]
"#,
);

Ok(())
}

#[derive(Default)]
struct MyContextProvider {
options: ConfigOptions,
Expand Down
Loading