Skip to content

Commit 17b5653

Browse files
committed
Add ClosureStatisticsProvider for test injection and cardinality feedback
1 parent d3c3bfe commit 17b5653

1 file changed

Lines changed: 143 additions & 2 deletions

File tree

  • datafusion/physical-plan/src/operator_statistics

datafusion/physical-plan/src/operator_statistics/mod.rs

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,8 +713,10 @@ impl StatisticsProvider for AggregateStatisticsProvider {
713713

714714
let num_rows = Precision::Inexact(estimate);
715715

716-
// TODO: once #20184 lands, pass enhanced child_stats to partition_statistics
717-
// so column-level stats (NDV, min/max) propagate through the registry walk.
716+
// TODO: column-level stats (NDV, min/max) enriched by the registry walk
717+
// are lost here because partition_statistics(None) re-fetches raw child
718+
// stats internally. Once #20184 lands, pass enhanced child_stats so the
719+
// operator's built-in column mapping uses them instead.
718720
let mut base = Arc::unwrap_or_clone(plan.partition_statistics(None)?);
719721
rescale_byte_size(&mut base, num_rows);
720722

@@ -961,6 +963,67 @@ impl StatisticsProvider for UnionStatisticsProvider {
961963
}
962964
}
963965

966+
type ProviderFn = dyn Fn(&dyn ExecutionPlan, &[ExtendedStatistics]) -> Result<StatisticsResult>
967+
+ Send
968+
+ Sync;
969+
970+
/// A [`StatisticsProvider`] backed by a user-supplied closure.
971+
///
972+
/// Useful for injecting custom statistics in tests or for cardinality feedback
973+
/// pipelines where real runtime statistics need to override plan estimates.
974+
/// The closure receives the current plan node and its children's enhanced
975+
/// statistics, returning a [`StatisticsResult`].
976+
///
977+
/// To distinguish between multiple nodes of the same type (e.g., two
978+
/// `FilterExec` nodes), match on structural properties like the input schema's
979+
/// column names, number of columns, or child row counts.
980+
///
981+
/// # Example
982+
///
983+
/// ```rust,ignore (requires crate-internal imports)
984+
/// let provider = ClosureStatisticsProvider::new(|plan, child_stats| {
985+
/// if plan.downcast_ref::<FilterExec>().is_some() {
986+
/// Ok(StatisticsResult::Computed(ExtendedStatistics::from(Statistics {
987+
/// num_rows: Precision::Inexact(42),
988+
/// ..Statistics::new_unknown(plan.schema().as_ref())
989+
/// })))
990+
/// } else {
991+
/// Ok(StatisticsResult::Delegate)
992+
/// }
993+
/// });
994+
/// ```
995+
pub struct ClosureStatisticsProvider {
996+
f: Box<ProviderFn>,
997+
}
998+
999+
impl ClosureStatisticsProvider {
1000+
/// Create a new provider from a closure.
1001+
pub fn new(
1002+
f: impl Fn(&dyn ExecutionPlan, &[ExtendedStatistics]) -> Result<StatisticsResult>
1003+
+ Send
1004+
+ Sync
1005+
+ 'static,
1006+
) -> Self {
1007+
Self { f: Box::new(f) }
1008+
}
1009+
}
1010+
1011+
impl Debug for ClosureStatisticsProvider {
1012+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1013+
write!(f, "ClosureStatisticsProvider")
1014+
}
1015+
}
1016+
1017+
impl StatisticsProvider for ClosureStatisticsProvider {
1018+
fn compute_statistics(
1019+
&self,
1020+
plan: &dyn ExecutionPlan,
1021+
child_stats: &[ExtendedStatistics],
1022+
) -> Result<StatisticsResult> {
1023+
(self.f)(plan, child_stats)
1024+
}
1025+
}
1026+
9641027
#[cfg(test)]
9651028
mod tests {
9661029
use super::*;
@@ -2153,4 +2216,82 @@ mod tests {
21532216
assert_eq!(stats.base.num_rows, Precision::Absent);
21542217
Ok(())
21552218
}
2219+
2220+
// =========================================================================
2221+
// ClosureStatisticsProvider tests
2222+
// =========================================================================
2223+
2224+
#[test]
2225+
fn test_closure_provider_basic() -> Result<()> {
2226+
// Override all FilterExec stats with a fixed row count
2227+
let provider = ClosureStatisticsProvider::new(|plan, _child_stats| {
2228+
if plan.downcast_ref::<FilterExec>().is_some() {
2229+
Ok(StatisticsResult::Computed(ExtendedStatistics::from(
2230+
Statistics {
2231+
num_rows: Precision::Inexact(42),
2232+
total_byte_size: Precision::Absent,
2233+
column_statistics: vec![],
2234+
},
2235+
)))
2236+
} else {
2237+
Ok(StatisticsResult::Delegate)
2238+
}
2239+
});
2240+
2241+
let registry = StatisticsRegistry::with_providers(vec![
2242+
Arc::new(provider),
2243+
Arc::new(DefaultStatisticsProvider),
2244+
]);
2245+
2246+
let source = make_source(1000);
2247+
let filter: Arc<dyn ExecutionPlan> =
2248+
Arc::new(FilterExec::try_new(lit(true), source)?);
2249+
let stats = registry.compute(filter.as_ref())?;
2250+
assert_eq!(stats.base.num_rows, Precision::Inexact(42));
2251+
Ok(())
2252+
}
2253+
2254+
#[test]
2255+
fn test_closure_provider_distinguishes_nodes_by_child_stats() -> Result<()> {
2256+
// Two FilterExec nodes with different input sizes.
2257+
// The closure uses the child row count as a proxy to distinguish them,
2258+
// which mirrors the cardinality feedback use case where you match a
2259+
// runtime-observed count to the right node in the plan tree.
2260+
let provider = ClosureStatisticsProvider::new(|plan, child_stats| {
2261+
if plan.downcast_ref::<FilterExec>().is_none() {
2262+
return Ok(StatisticsResult::Delegate);
2263+
}
2264+
match child_stats[0].base.num_rows.get_value().copied() {
2265+
Some(500) => Ok(StatisticsResult::Computed(ExtendedStatistics::from(
2266+
Statistics {
2267+
num_rows: Precision::Inexact(100),
2268+
total_byte_size: Precision::Absent,
2269+
column_statistics: vec![],
2270+
},
2271+
))),
2272+
Some(200) => Ok(StatisticsResult::Computed(ExtendedStatistics::from(
2273+
Statistics {
2274+
num_rows: Precision::Inexact(50),
2275+
total_byte_size: Precision::Absent,
2276+
column_statistics: vec![],
2277+
},
2278+
))),
2279+
_ => Ok(StatisticsResult::Delegate),
2280+
}
2281+
});
2282+
2283+
let registry = StatisticsRegistry::with_providers(vec![Arc::new(provider)]);
2284+
2285+
let filter_a: Arc<dyn ExecutionPlan> =
2286+
Arc::new(FilterExec::try_new(lit(true), make_source(500))?);
2287+
let filter_b: Arc<dyn ExecutionPlan> =
2288+
Arc::new(FilterExec::try_new(lit(true), make_source(200))?);
2289+
2290+
let stats_a = registry.compute(filter_a.as_ref())?;
2291+
let stats_b = registry.compute(filter_b.as_ref())?;
2292+
2293+
assert_eq!(stats_a.base.num_rows, Precision::Inexact(100));
2294+
assert_eq!(stats_b.base.num_rows, Precision::Inexact(50));
2295+
Ok(())
2296+
}
21562297
}

0 commit comments

Comments
 (0)