Skip to content

Commit 47655fd

Browse files
proto: serialize dynamic filters on Sort, Aggregate, HashJoin plan nodes (#22011)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #20418 (Looks like this was accidentally closed early) - Informs: #21207 (comment) ## Rationale for this change `SortExec`, `AggregateExec`, and `HashJoinExec` do not serialize their dynamic filters, so plans lose dynamic filtering when they are serialized and sent across network boundaries. ## What changes are included in this PR? This change adds `with_dynamic_filter_expr()` and `dynamic_filter_expr()` to `SortExec`, `AggregateExec`, and `HashJoinExec`. ``` pub fn with_dynamic_filter_expr( mut self, filter: Arc<DynamicFilterPhysicalExpr>, ) -> Result<Self> pub fn dynamic_filter_expr(&self) -> Option<&Arc<DynamicFilterPhysicalExpr>> { ``` This are used as getters and setters for the `proto` crate to get and set dynamic filters. ## Are these changes tested? Yes. See `datafusion/datafusion/proto/tests/cases/roundtrip_physical_plan.rs`. There are also tests for the plan nodes in the `physical-plan` crate. ## Are there any user-facing changes? `SortExec`, `AggregateExec`, and `HashJoinExec` now roundtrip serialize dynamic filter expressions. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 5cf8eef commit 47655fd

11 files changed

Lines changed: 755 additions & 81 deletions

File tree

datafusion/core/tests/physical_optimizer/filter_pushdown.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2835,7 +2835,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_is_used() {
28352835

28362836
// Verify that a dynamic filter was created
28372837
let dynamic_filter = hash_join
2838-
.dynamic_filter_for_test()
2838+
.dynamic_filter_expr()
28392839
.expect("Dynamic filter should be created");
28402840

28412841
// Verify that is_used() returns the expected value based on probe side support.

datafusion/physical-expr-common/src/physical_expr.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
7373
/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html
7474
/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html
7575
pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
76-
/// Get the data type of this expression, given the schema of the input
76+
/// Get the data type of this expression, given the schema of the input.
77+
/// Returns an error if the data type cannot be determined, ex. if the
78+
/// schema is missing a required field.
7779
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
7880
Ok(self.return_field(input_schema)?.data_type().to_owned())
7981
}

datafusion/physical-expr/src/expressions/dynamic_filters.rs

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ pub struct DynamicFilterPhysicalExpr {
8686
/// **Warning:** exposed publicly solely so that proto (de)serialization in
8787
/// `datafusion-proto` can read and rebuild this state. Do not treat this type
8888
/// or its layout as a stable API.
89-
#[derive(Clone)]
89+
#[derive(Clone, Debug)]
9090
pub struct Inner {
9191
/// A unique identifier for the expression.
9292
pub expression_id: u64,
@@ -100,24 +100,6 @@ pub struct Inner {
100100
pub is_complete: bool,
101101
}
102102

103-
// TODO: Include expression_id in Debug output.
104-
//
105-
// See https://github.com/apache/datafusion/issues/20418. Currently, plan nodes
106-
// like `HashJoinExec`, `AggregateExec`, `SortExec` do not serialize their
107-
// dynamic filter. They auto-create one on decode with a fresh `expression_id`,
108-
// so a round-trip Debug comparison would diverge purely on the id even when
109-
// the rest of the state is preserved. Excluding it from Debug keeps those
110-
// roundtrip equality assertions meaningful until that work lands.
111-
impl std::fmt::Debug for Inner {
112-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113-
f.debug_struct("Inner")
114-
.field("generation", &self.generation)
115-
.field("expr", &self.expr)
116-
.field("is_complete", &self.is_complete)
117-
.finish()
118-
}
119-
}
120-
121103
impl Inner {
122104
fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
123105
Self {

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 189 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ use arrow_schema::FieldRef;
4747
use datafusion_common::stats::Precision;
4848
use datafusion_common::tree_node::TreeNodeRecursion;
4949
use datafusion_common::{
50-
Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, not_impl_err,
50+
Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err,
51+
internal_err, not_impl_err,
5152
};
5253
use datafusion_execution::TaskContext;
5354
use datafusion_expr::{Accumulator, Aggregate};
@@ -893,6 +894,47 @@ impl AggregateExec {
893894
&self.filter_expr
894895
}
895896

897+
/// Returns the dynamic filter expression for this aggregate, if set.
898+
pub fn dynamic_filter_expr(&self) -> Option<&Arc<DynamicFilterPhysicalExpr>> {
899+
self.dynamic_filter.as_ref().map(|df| &df.filter)
900+
}
901+
902+
/// Replace the dynamic filter expression. This method errors if the aggregate does not
903+
/// support dynamic filtering or if the filter expression is incompatible with this
904+
/// [`AggregateExec`].
905+
pub fn with_dynamic_filter_expr(
906+
mut self,
907+
filter: Arc<DynamicFilterPhysicalExpr>,
908+
) -> Result<Self> {
909+
// If there is no dynamic filter state initialized via `try_new`, then
910+
// we can safely assume that the aggregate does not support dynamic filtering.
911+
let Some(dyn_filter) = self.dynamic_filter.as_ref() else {
912+
return internal_err!("Aggregate does not support dynamic filtering");
913+
};
914+
915+
// Validate that the filter is compatible with the aggregation columns.
916+
let cols = self.cols_for_dynamic_filter(&dyn_filter.supported_accumulators_info);
917+
if cols.len() != filter.children().len() {
918+
return internal_err!(
919+
"Dynamic filter expression is incompatible with aggregate due to mismatched number of columns"
920+
);
921+
}
922+
for (col, child) in cols.iter().zip(filter.children()) {
923+
if !col.eq(child) {
924+
return internal_err!(
925+
"Dynamic filter expression is incompatible with aggregate due to mismatched column references {col} != {child}"
926+
);
927+
}
928+
}
929+
930+
// Overwrite our filter
931+
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
932+
filter,
933+
supported_accumulators_info: dyn_filter.supported_accumulators_info.clone(),
934+
}));
935+
Ok(self)
936+
}
937+
896938
/// Input plan
897939
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
898940
&self.input
@@ -1285,6 +1327,28 @@ impl AggregateExec {
12851327
}
12861328
}
12871329

1330+
// Collect column references for the dynamic filter expression from the supported accumulators.
1331+
fn cols_for_dynamic_filter(
1332+
&self,
1333+
supported_accumulators_info: &[PerAccumulatorDynFilter],
1334+
) -> Vec<Arc<dyn PhysicalExpr>> {
1335+
let all_cols: Vec<Arc<dyn PhysicalExpr>> = supported_accumulators_info
1336+
.iter()
1337+
.filter_map(|info| {
1338+
// This should always be true due to how the supported accumulators
1339+
// are constructed. See `init_dynamic_filter` for more details.
1340+
if let [arg] = &self.aggr_expr[info.aggr_index].expressions().as_slice()
1341+
&& arg.is::<Column>()
1342+
{
1343+
return Some(Arc::clone(arg));
1344+
}
1345+
None
1346+
})
1347+
.collect();
1348+
debug_assert!(all_cols.len() == supported_accumulators_info.len());
1349+
all_cols
1350+
}
1351+
12881352
/// Calculate scaled byte size based on row count ratio.
12891353
/// Returns `Precision::Absent` if input statistics are insufficient.
12901354
/// Returns `Precision::Inexact` with the scaled value otherwise.
@@ -2200,6 +2264,7 @@ mod tests {
22002264
use crate::coalesce_partitions::CoalescePartitionsExec;
22012265
use crate::common;
22022266
use crate::common::collect;
2267+
use crate::empty::EmptyExec;
22032268
use crate::execution_plan::Boundedness;
22042269
use crate::expressions::col;
22052270
use crate::metrics::MetricValue;
@@ -2225,6 +2290,7 @@ mod tests {
22252290
use datafusion_functions_aggregate::count::count_udaf;
22262291
use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
22272292
use datafusion_functions_aggregate::median::median_udaf;
2293+
use datafusion_functions_aggregate::min_max::min_udaf;
22282294
use datafusion_functions_aggregate::sum::sum_udaf;
22292295
use datafusion_physical_expr::Partitioning;
22302296
use datafusion_physical_expr::PhysicalSortExpr;
@@ -3846,13 +3912,10 @@ mod tests {
38463912
// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
38473913
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
38483914
Arc::new(
3849-
AggregateExprBuilder::new(
3850-
datafusion_functions_aggregate::min_max::min_udaf(),
3851-
vec![col("b", &schema)?],
3852-
)
3853-
.schema(Arc::clone(&schema))
3854-
.alias("MIN(b)")
3855-
.build()?,
3915+
AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
3916+
.schema(Arc::clone(&schema))
3917+
.alias("MIN(b)")
3918+
.build()?,
38563919
),
38573920
Arc::new(
38583921
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
@@ -3991,13 +4054,10 @@ mod tests {
39914054
// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
39924055
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
39934056
Arc::new(
3994-
AggregateExprBuilder::new(
3995-
datafusion_functions_aggregate::min_max::min_udaf(),
3996-
vec![col("b", &schema)?],
3997-
)
3998-
.schema(Arc::clone(&schema))
3999-
.alias("MIN(b)")
4000-
.build()?,
4057+
AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
4058+
.schema(Arc::clone(&schema))
4059+
.alias("MIN(b)")
4060+
.build()?,
40014061
),
40024062
Arc::new(
40034063
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
@@ -4945,4 +5005,118 @@ mod tests {
49455005

49465006
Ok(())
49475007
}
5008+
5009+
/// Test that [`AggregateExec::with_dynamic_filter_expr`] overrides the existing dynamic filter
5010+
#[test]
5011+
fn test_with_dynamic_filter() -> Result<()> {
5012+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
5013+
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
5014+
5015+
// Partial min aggregate supports dynamic filtering
5016+
let agg = AggregateExec::try_new(
5017+
AggregateMode::Partial,
5018+
PhysicalGroupBy::new_single(vec![]),
5019+
vec![Arc::new(
5020+
AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
5021+
.schema(Arc::clone(&schema))
5022+
.alias("min_a")
5023+
.build()?,
5024+
)],
5025+
vec![None],
5026+
child,
5027+
Arc::clone(&schema),
5028+
)?;
5029+
5030+
// Assertion 1: A filter with the same children can override the existing
5031+
// dynamic filter.
5032+
let new_df = Arc::new(DynamicFilterPhysicalExpr::new(
5033+
vec![col("a", &schema)?],
5034+
lit(false),
5035+
));
5036+
let agg = agg.with_dynamic_filter_expr(Arc::clone(&new_df))?;
5037+
5038+
// The aggregate's filter should now resolve to the new inner expression.
5039+
let swapped = agg
5040+
.dynamic_filter_expr()
5041+
.expect("should still have dynamic filter")
5042+
.current()?;
5043+
assert_eq!(format!("{swapped}"), format!("{}", lit(false)));
5044+
5045+
// Assertion 2: A filter that has been through `PhysicalExpr::with_new_children`
5046+
// should still be accepted when the new children are equivalent to the originals.
5047+
let new_df_as_pexpr: Arc<dyn PhysicalExpr> =
5048+
Arc::<DynamicFilterPhysicalExpr>::clone(&new_df);
5049+
let remapped_pexpr =
5050+
new_df_as_pexpr.with_new_children(vec![col("a", &schema)?])?;
5051+
let Ok(remapped_df) = (remapped_pexpr as Arc<dyn std::any::Any + Send + Sync>)
5052+
.downcast::<DynamicFilterPhysicalExpr>()
5053+
else {
5054+
panic!("should be DynamicFilterPhysicalExpr after with_new_children");
5055+
};
5056+
// Hard to assert this because the filter is identical. No error means
5057+
// the filter was accepted. That's a good enough assertion for now.
5058+
let _agg = agg.with_dynamic_filter_expr(remapped_df)?;
5059+
Ok(())
5060+
}
5061+
5062+
/// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the aggregate does not support dynamic filtering
5063+
#[test]
5064+
fn test_with_dynamic_filter_error_unsupported() -> Result<()> {
5065+
let schema = Arc::new(Schema::new(vec![
5066+
Field::new("a", DataType::Int64, false),
5067+
Field::new("b", DataType::Int64, false),
5068+
]));
5069+
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
5070+
5071+
// Final mode with a group-by does not support dynamic filters.
5072+
let agg = AggregateExec::try_new(
5073+
AggregateMode::Final,
5074+
PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]),
5075+
vec![Arc::new(
5076+
AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
5077+
.schema(Arc::clone(&schema))
5078+
.alias("sum_b")
5079+
.build()?,
5080+
)],
5081+
vec![None],
5082+
child,
5083+
Arc::clone(&schema),
5084+
)?;
5085+
assert!(agg.dynamic_filter_expr().is_none());
5086+
5087+
let df = Arc::new(DynamicFilterPhysicalExpr::new(
5088+
vec![col("a", &schema)?],
5089+
lit(true),
5090+
));
5091+
assert!(agg.with_dynamic_filter_expr(df).is_err());
5092+
Ok(())
5093+
}
5094+
5095+
/// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the column is not in the schema
5096+
#[test]
5097+
fn test_with_dynamic_filter_error_column_mismatch() -> Result<()> {
5098+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
5099+
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
5100+
5101+
let agg = AggregateExec::try_new(
5102+
AggregateMode::Partial,
5103+
PhysicalGroupBy::new_single(vec![]),
5104+
vec![Arc::new(
5105+
AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
5106+
.schema(Arc::clone(&schema))
5107+
.alias("min_a")
5108+
.build()?,
5109+
)],
5110+
vec![None],
5111+
child,
5112+
Arc::clone(&schema),
5113+
)?;
5114+
5115+
let df = Arc::new(DynamicFilterPhysicalExpr::new(
5116+
vec![Arc::new(Column::new("bad", 99)) as _],
5117+
lit(true),
5118+
));
5119+
assert!(agg.with_dynamic_filter_expr(df).is_err());
5120+
Ok(())
5121+
}
49485122
}

0 commit comments

Comments
 (0)