diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs index 4218f76fa135a..850f9d187780b 100644 --- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs @@ -23,16 +23,23 @@ use arrow::array::Int32Array; use arrow::array::{Int64Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; +use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::memory::MemTable; use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::datasource::physical_plan::ParquetSource; use datafusion::datasource::source::DataSourceExec; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::Result; use datafusion_common::assert_batches_eq; use datafusion_common::cast::as_int64_array; use datafusion_common::config::ConfigOptions; +use datafusion_common::stats::Precision; +use datafusion_common::{ColumnStatistics, Result, Statistics}; +use datafusion_datasource::file_scan_config::FileScanConfigBuilder; use datafusion_execution::TaskContext; +use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::Operator; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::{self, cast}; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics; @@ -402,3 +409,147 @@ async fn utf8_grouping_min_max_limit_fallbacks() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_count_distinct_optimization() -> Result<()> { + struct TestCase { + name: &'static str, + distinct_count: Precision, + use_column_expr: bool, + expect_optimized: bool, + expected_value: Option, + } + + let cases = vec![ + TestCase { + name: "exact statistics", + distinct_count: Precision::Exact(42), + use_column_expr: true, + expect_optimized: true, + expected_value: Some(42), + }, + TestCase { + name: "absent statistics", + distinct_count: Precision::Absent, + use_column_expr: true, + expect_optimized: false, + expected_value: None, + }, + TestCase { + name: "inexact statistics", + distinct_count: Precision::Inexact(42), + use_column_expr: true, + expect_optimized: false, + expected_value: None, + }, + TestCase { + name: "non-column expression with exact statistics", + distinct_count: Precision::Exact(42), + use_column_expr: false, + expect_optimized: false, + expected_value: None, + }, + ]; + + for case in cases { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let statistics = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + distinct_count: case.distinct_count, + null_count: Precision::Exact(10), + ..Default::default() + }, + ColumnStatistics::default(), + ], + }; + + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + Arc::new(ParquetSource::new(Arc::clone(&schema))), + ) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_statistics(statistics) + .build(); + + let source: Arc = DataSourceExec::from_data_source(config); + let schema = source.schema(); + + let (agg_args, alias): (Vec>, _) = + if case.use_column_expr { + (vec![expressions::col("a", &schema)?], "COUNT(DISTINCT a)") + } else { + ( + vec![expressions::binary( + expressions::col("a", &schema)?, + Operator::Plus, + expressions::col("b", &schema)?, + &schema, + )?], + "COUNT(DISTINCT a + b)", + ) + }; + + let count_distinct_expr = AggregateExprBuilder::new(count_udaf(), agg_args) + .schema(Arc::clone(&schema)) + .alias(alias) + .distinct() + .build()?; + + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::default(), + vec![Arc::new(count_distinct_expr.clone())], + vec![None], + source, + Arc::clone(&schema), + )?; + + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + vec![Arc::new(count_distinct_expr)], + vec![None], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + if case.expect_optimized { + assert!( + optimized.as_any().is::(), + "'{}': expected ProjectionExec", + case.name + ); + + if let Some(expected_val) = case.expected_value { + let task_ctx = Arc::new(TaskContext::default()); + let result = common::collect(optimized.execute(0, task_ctx)?).await?; + assert_eq!(result.len(), 1, "'{}': expected 1 batch", case.name); + assert_eq!( + as_int64_array(result[0].column(0)).unwrap().values(), + &[expected_val], + "'{}': unexpected value", + case.name + ); + } + } else { + assert!( + optimized.as_any().is::(), + "'{}': expected AggregateExec (not optimized)", + case.name + ); + } + } + + Ok(()) +} diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 376cf39745903..ebe3c60a4ddde 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -365,31 +365,40 @@ impl AggregateUDFImpl for Count { } fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + let [expr] = statistics_args.exprs else { + return None; + }; + let col_stats = &statistics_args.statistics.column_statistics; + if statistics_args.is_distinct { + // Only column references can be resolved from statistics; + // expressions like casts or literals are not supported. + let col_expr = expr.as_any().downcast_ref::()?; + if let Precision::Exact(dc) = col_stats[col_expr.index()].distinct_count { + let dc = i64::try_from(dc).ok()?; + return Some(ScalarValue::Int64(Some(dc))); + } return None; } - if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows - && statistics_args.exprs.len() == 1 - { - // TODO optimize with exprs other than Column - if let Some(col_expr) = statistics_args.exprs[0] - .as_any() - .downcast_ref::() - { - let current_val = &statistics_args.statistics.column_statistics - [col_expr.index()] - .null_count; - if let &Precision::Exact(val) = current_val { - return Some(ScalarValue::Int64(Some((num_rows - val) as i64))); - } - } else if let Some(lit_expr) = statistics_args.exprs[0] - .as_any() - .downcast_ref::() - && lit_expr.value() == &COUNT_STAR_EXPANSION - { - return Some(ScalarValue::Int64(Some(num_rows as i64))); + + let Precision::Exact(num_rows) = statistics_args.statistics.num_rows else { + return None; + }; + + // TODO optimize with exprs other than Column + if let Some(col_expr) = expr.as_any().downcast_ref::() { + if let Precision::Exact(val) = col_stats[col_expr.index()].null_count { + let count = i64::try_from(num_rows - val).ok()?; + return Some(ScalarValue::Int64(Some(count))); } + } else if let Some(lit_expr) = + expr.as_any().downcast_ref::() + && lit_expr.value() == &COUNT_STAR_EXPANSION + { + let num_rows = i64::try_from(num_rows).ok()?; + return Some(ScalarValue::Int64(Some(num_rows))); } + None }