Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2835,7 +2835,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_is_used() {

// Verify that a dynamic filter was created
let dynamic_filter = hash_join
.dynamic_filter_for_test()
.dynamic_filter_expr()
.expect("Dynamic filter should be created");

// Verify that is_used() returns the expected value based on probe side support.
Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-expr-common/src/physical_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
/// Get the data type of this expression, given the schema of the input
/// Get the data type of this expression, given the schema of the input.
/// Returns an error if the data type cannot be determined, ex. if the
/// schema is missing a required field.
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
Ok(self.return_field(input_schema)?.data_type().to_owned())
}
Expand Down
20 changes: 1 addition & 19 deletions datafusion/physical-expr/src/expressions/dynamic_filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ pub struct DynamicFilterPhysicalExpr {
/// **Warning:** exposed publicly solely so that proto (de)serialization in
/// `datafusion-proto` can read and rebuild this state. Do not treat this type
/// or its layout as a stable API.
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Inner {
/// A unique identifier for the expression.
pub expression_id: u64,
Expand All @@ -100,24 +100,6 @@ pub struct Inner {
pub is_complete: bool,
}

// TODO: Include expression_id in Debug output.
//
// See https://github.com/apache/datafusion/issues/20418. Currently, plan nodes
// like `HashJoinExec`, `AggregateExec`, `SortExec` do not serialize their
// dynamic filter. They auto-create one on decode with a fresh `expression_id`,
// so a round-trip Debug comparison would diverge purely on the id even when
// the rest of the state is preserved. Excluding it from Debug keeps those
// roundtrip equality assertions meaningful until that work lands.
impl std::fmt::Debug for Inner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Inner")
.field("generation", &self.generation)
.field("expr", &self.expr)
.field("is_complete", &self.is_complete)
.finish()
}
}

impl Inner {
fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
Self {
Expand Down
204 changes: 189 additions & 15 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ use arrow_schema::FieldRef;
use datafusion_common::stats::Precision;
use datafusion_common::tree_node::TreeNodeRecursion;
use datafusion_common::{
Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, not_impl_err,
Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err,
internal_err, not_impl_err,
};
use datafusion_execution::TaskContext;
use datafusion_expr::{Accumulator, Aggregate};
Expand Down Expand Up @@ -892,6 +893,47 @@ impl AggregateExec {
&self.filter_expr
}

/// Returns the dynamic filter expression for this aggregate, if set.
pub fn dynamic_filter_expr(&self) -> Option<&Arc<DynamicFilterPhysicalExpr>> {
self.dynamic_filter.as_ref().map(|df| &df.filter)
}

/// Replace the dynamic filter expression. This method errors if the aggregate does not
/// support dynamic filtering or if the filter expression is incompatible with this
/// [`AggregateExec`].
pub fn with_dynamic_filter_expr(
mut self,
filter: Arc<DynamicFilterPhysicalExpr>,
) -> Result<Self> {
// If there is no dynamic filter state initialized via `try_new`, then
// we can safely assume that the aggregate does not support dynamic filtering.
let Some(dyn_filter) = self.dynamic_filter.as_ref() else {
return internal_err!("Aggregate does not support dynamic filtering");
};

// Validate that the filter is compatible with the aggregation columns.
let cols = self.cols_for_dynamic_filter(&dyn_filter.supported_accumulators_info);
if cols.len() != filter.children().len() {
return internal_err!(
"Dynamic filter expression is incompatible with aggregate due to mismatched number of columns"
);
}
for (col, child) in cols.iter().zip(filter.children()) {
if !col.eq(child) {
return internal_err!(
"Dynamic filter expression is incompatible with aggregate due to mismatched column references {col} != {child}"
);
}
}

// Overwrite our filter
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
filter,
supported_accumulators_info: dyn_filter.supported_accumulators_info.clone(),
}));
Ok(self)
}

/// Input plan
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
Expand Down Expand Up @@ -1284,6 +1326,28 @@ impl AggregateExec {
}
}

// Collect column references for the dynamic filter expression from the supported accumulators.
fn cols_for_dynamic_filter(
&self,
supported_accumulators_info: &[PerAccumulatorDynFilter],
) -> Vec<Arc<dyn PhysicalExpr>> {
let all_cols: Vec<Arc<dyn PhysicalExpr>> = supported_accumulators_info
.iter()
.filter_map(|info| {
// This should always be true due to how the supported accumulators
// are constructed. See `init_dynamic_filter` for more details.
if let [arg] = &self.aggr_expr[info.aggr_index].expressions().as_slice()
&& arg.is::<Column>()
{
return Some(Arc::clone(arg));
}
None
})
.collect();
debug_assert!(all_cols.len() == supported_accumulators_info.len());
all_cols
}

/// Calculate scaled byte size based on row count ratio.
/// Returns `Precision::Absent` if input statistics are insufficient.
/// Returns `Precision::Inexact` with the scaled value otherwise.
Expand Down Expand Up @@ -2177,6 +2241,7 @@ mod tests {
use crate::coalesce_partitions::CoalescePartitionsExec;
use crate::common;
use crate::common::collect;
use crate::empty::EmptyExec;
use crate::execution_plan::Boundedness;
use crate::expressions::col;
use crate::metrics::MetricValue;
Expand All @@ -2202,6 +2267,7 @@ mod tests {
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
use datafusion_functions_aggregate::median::median_udaf;
use datafusion_functions_aggregate::min_max::min_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::Partitioning;
use datafusion_physical_expr::PhysicalSortExpr;
Expand Down Expand Up @@ -3682,13 +3748,10 @@ mod tests {
// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
Arc::new(
AggregateExprBuilder::new(
datafusion_functions_aggregate::min_max::min_udaf(),
vec![col("b", &schema)?],
)
.schema(Arc::clone(&schema))
.alias("MIN(b)")
.build()?,
AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
.schema(Arc::clone(&schema))
.alias("MIN(b)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
Expand Down Expand Up @@ -3827,13 +3890,10 @@ mod tests {
// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
Arc::new(
AggregateExprBuilder::new(
datafusion_functions_aggregate::min_max::min_udaf(),
vec![col("b", &schema)?],
)
.schema(Arc::clone(&schema))
.alias("MIN(b)")
.build()?,
AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
.schema(Arc::clone(&schema))
.alias("MIN(b)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
Expand Down Expand Up @@ -4781,4 +4841,118 @@ mod tests {

Ok(())
}

/// Test that [`AggregateExec::with_dynamic_filter_expr`] overrides the existing dynamic filter
#[test]
fn test_with_dynamic_filter() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));

// Partial min aggregate supports dynamic filtering
let agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(vec![]),
vec![Arc::new(
AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
.schema(Arc::clone(&schema))
.alias("min_a")
.build()?,
)],
vec![None],
child,
Arc::clone(&schema),
)?;

// Assertion 1: A filter with the same children can override the existing
// dynamic filter.
let new_df = Arc::new(DynamicFilterPhysicalExpr::new(
vec![col("a", &schema)?],
lit(false),
));
let agg = agg.with_dynamic_filter_expr(Arc::clone(&new_df))?;

// The aggregate's filter should now resolve to the new inner expression.
let swapped = agg
.dynamic_filter_expr()
.expect("should still have dynamic filter")
.current()?;
assert_eq!(format!("{swapped}"), format!("{}", lit(false)));

// Assertion 2: A filter that has been through `PhysicalExpr::with_new_children`
// should still be accepted when the new children are equivalent to the originals.
let new_df_as_pexpr: Arc<dyn PhysicalExpr> =
Arc::<DynamicFilterPhysicalExpr>::clone(&new_df);
let remapped_pexpr =
new_df_as_pexpr.with_new_children(vec![col("a", &schema)?])?;
let Ok(remapped_df) = (remapped_pexpr as Arc<dyn std::any::Any + Send + Sync>)
.downcast::<DynamicFilterPhysicalExpr>()
else {
panic!("should be DynamicFilterPhysicalExpr after with_new_children");
};
// Hard to assert this because the filter is identical. No error means
// the filter was accepted. That's a good enough assertion for now.
let _agg = agg.with_dynamic_filter_expr(remapped_df)?;
Ok(())
}

/// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the aggregate does not support dynamic filtering
#[test]
fn test_with_dynamic_filter_error_unsupported() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Int64, false),
]));
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));

// Final mode with a group-by does not support dynamic filters.
let agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]),
vec![Arc::new(
AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
.schema(Arc::clone(&schema))
.alias("sum_b")
.build()?,
)],
vec![None],
child,
Arc::clone(&schema),
)?;
assert!(agg.dynamic_filter_expr().is_none());

let df = Arc::new(DynamicFilterPhysicalExpr::new(
vec![col("a", &schema)?],
lit(true),
));
assert!(agg.with_dynamic_filter_expr(df).is_err());
Ok(())
}

/// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the column is not in the schema
#[test]
fn test_with_dynamic_filter_error_column_mismatch() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));

let agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::new_single(vec![]),
vec![Arc::new(
AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
.schema(Arc::clone(&schema))
.alias("min_a")
.build()?,
)],
vec![None],
child,
Arc::clone(&schema),
)?;

let df = Arc::new(DynamicFilterPhysicalExpr::new(
vec![Arc::new(Column::new("bad", 99)) as _],
lit(true),
));
assert!(agg.with_dynamic_filter_expr(df).is_err());
Ok(())
}
}
Loading
Loading