Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions datafusion/catalog/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use async_trait::async_trait;
use datafusion_common::{Constraints, Statistics, not_impl_err};
use datafusion_common::{Result, internal_err};
use datafusion_expr::Expr;
use datafusion_expr::statistics::StatisticsRequest;

use datafusion_expr::dml::InsertOp;
use datafusion_expr::{
Expand Down Expand Up @@ -406,6 +407,7 @@ pub struct ScanArgs<'a> {
filters: Option<&'a [Expr]>,
projection: Option<&'a [usize]>,
limit: Option<usize>,
statistics_requests: &'a [StatisticsRequest],
}

impl<'a> ScanArgs<'a> {
Expand Down Expand Up @@ -467,6 +469,27 @@ impl<'a> ScanArgs<'a> {
pub fn limit(&self) -> Option<usize> {
self.limit
}

/// Set the statistics the caller would like the provider to answer for
/// this scan, if it can do so cheaply.
///
/// Providers read these via [`Self::statistics_requests()`]; anything a
/// provider cannot answer cheaply it simply ignores. DataFusion's own
/// `TableProvider`s ignore this field — it exists so a request can be
/// threaded from a custom optimizer rule (which annotates
/// `TableScan::statistics_requests`) through to a custom provider.
pub fn with_statistics_requests(
mut self,
statistics_requests: &'a [StatisticsRequest],
) -> Self {
self.statistics_requests = statistics_requests;
self
}

/// Get the statistics requests for the scan. Empty if none were set.
pub fn statistics_requests(&self) -> &'a [StatisticsRequest] {
self.statistics_requests
}
}

/// Result of a table scan operation from [`TableProvider::scan_with_args`].
Expand Down
4 changes: 3 additions & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ impl DefaultPhysicalPlanner {
filters,
fetch,
projected_schema,
statistics_requests,
..
} = scan;

Expand All @@ -657,7 +658,8 @@ impl DefaultPhysicalPlanner {
let opts = ScanArgs::default()
.with_projection(projection.as_deref())
.with_filters(Some(&filters_vec))
.with_limit(*fetch);
.with_limit(*fetch)
.with_statistics_requests(statistics_requests);
let res = source.scan_with_args(session_state, opts).await?;
Arc::clone(res.plan())
} else {
Expand Down
4 changes: 4 additions & 0 deletions datafusion/core/tests/user_defined/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ mod relation_planner;

/// Tests for insert operations
mod insert_operation;

/// Tests for `StatisticsRequest`s flowing from a custom optimizer rule
/// through the physical planner into a custom `TableProvider`.
mod statistics_requests;
214 changes: 214 additions & 0 deletions datafusion/core/tests/user_defined/statistics_requests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! End-to-end test that a *custom* optimizer rule can annotate a
//! `TableScan` with `StatisticsRequest`s and have them reach a *custom*
//! `TableProvider`'s `scan_with_args`.
//!
//! DataFusion ships no rule that populates `TableScan::statistics_requests`
//! and no provider that consumes `ScanArgs::statistics_requests`. This test
//! plays both roles, demonstrating that the request-side hooks are
//! sufficient to build the whole feature outside of DataFusion.

use std::sync::{Arc, Mutex};

use arrow::array::{Int64Array, RecordBatch};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion::catalog::{ScanArgs, ScanResult, Session, TableProvider};
use datafusion::common::tree_node::Transformed;
use datafusion::common::{Column, Result};
use datafusion::datasource::TableType;
use datafusion::datasource::memory::MemorySourceConfig;
use datafusion::execution::context::SessionContext;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::logical_expr::statistics::StatisticsRequest;
use datafusion::logical_expr::{Expr, LogicalPlan};
use datafusion::optimizer::{ApplyOrder, OptimizerConfig, OptimizerRule};
use datafusion::physical_plan::ExecutionPlan;

/// A custom optimizer rule that annotates every `TableScan` with a
/// `RowCount` request plus a `Min` request for each of its columns.
///
/// This stands in for whatever request-derivation logic an external
/// implementer would write (e.g. Min/Max for sort keys, DistinctCount for
/// join keys). Here it is intentionally trivial and deterministic.
#[derive(Debug)]
struct RequestColumnStatistics;

impl OptimizerRule for RequestColumnStatistics {
fn name(&self) -> &str {
"test_request_column_statistics"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}

fn supports_rewrite(&self) -> bool {
true
}

fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let LogicalPlan::TableScan(scan) = plan else {
return Ok(Transformed::no(plan));
};
// Idempotent: the optimizer runs rules to a fixpoint, so only
// annotate a scan we have not already touched.
if !scan.statistics_requests.is_empty() {
return Ok(Transformed::no(LogicalPlan::TableScan(scan)));
}
let mut requests = vec![StatisticsRequest::RowCount];
for field in scan.projected_schema.fields() {
requests.push(StatisticsRequest::Min(Column::new_unqualified(
field.name(),
)));
}
Ok(Transformed::yes(LogicalPlan::TableScan(
scan.with_statistics_requests(requests),
)))
}
}

/// A `TableProvider` that records the `statistics_requests` it was asked
/// for, so the test can assert what reached it.
#[derive(Debug)]
struct RecordingTable {
schema: SchemaRef,
batch: RecordBatch,
last_requests: Arc<Mutex<Vec<StatisticsRequest>>>,
}

#[async_trait]
impl TableProvider for RecordingTable {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}

fn table_type(&self) -> TableType {
TableType::Base
}

async fn scan(
&self,
_state: &dyn Session,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(MemorySourceConfig::try_new_exec(
&[vec![self.batch.clone()]],
Arc::clone(&self.schema),
projection.cloned(),
)?)
}

async fn scan_with_args<'a>(
&self,
state: &dyn Session,
args: ScanArgs<'a>,
) -> Result<ScanResult> {
// Record what reached us, then delegate to `scan`.
*self.last_requests.lock().unwrap() = args.statistics_requests().to_vec();
let plan = self
.scan(
state,
args.projection().map(|p| p.to_vec()).as_ref(),
args.filters().unwrap_or(&[]),
args.limit(),
)
.await?;
Ok(ScanResult::new(plan))
}
}

fn make_table() -> (Arc<RecordingTable>, Arc<Mutex<Vec<StatisticsRequest>>>) {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Int64, false),
]));
let batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![1, 2, 3])),
Arc::new(Int64Array::from(vec![10, 20, 30])),
],
)
.unwrap();
let last_requests = Arc::new(Mutex::new(Vec::new()));
let provider = Arc::new(RecordingTable {
schema,
batch,
last_requests: Arc::clone(&last_requests),
});
(provider, last_requests)
}

#[tokio::test]
async fn custom_rule_requests_reach_custom_provider() -> Result<()> {
let (provider, last_requests) = make_table();

let state = SessionStateBuilder::new()
.with_default_features()
.with_optimizer_rule(Arc::new(RequestColumnStatistics))
.build();
let ctx = SessionContext::new_with_state(state);
ctx.register_table("t", provider)?;

ctx.sql("SELECT a, b FROM t").await?.collect().await?;

let got = last_requests.lock().unwrap().clone();
assert_eq!(
got.len(),
3,
"expected RowCount + Min(a) + Min(b), got {got:?}"
);
assert!(
got.contains(&StatisticsRequest::RowCount),
"expected RowCount, got {got:?}"
);
assert!(
got.contains(&StatisticsRequest::Min(Column::new_unqualified("a"))),
"expected Min(a), got {got:?}"
);
assert!(
got.contains(&StatisticsRequest::Min(Column::new_unqualified("b"))),
"expected Min(b), got {got:?}"
);
Ok(())
}

#[tokio::test]
async fn no_requests_without_a_rule() -> Result<()> {
// Without a rule populating `TableScan::statistics_requests`, the
// provider sees an empty request list — stock DataFusion behavior.
let (provider, last_requests) = make_table();
let ctx = SessionContext::new();
ctx.register_table("t", provider)?;

ctx.sql("SELECT a, b FROM t").await?.collect().await?;

assert!(
last_requests.lock().unwrap().is_empty(),
"expected no requests without a custom rule"
);
Ok(())
}
111 changes: 111 additions & 0 deletions datafusion/expr-common/src/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1694,3 +1694,114 @@ mod tests {
all_ops.into_iter().collect()
}
}

// ---------------------------------------------------------------------------
// Query-aware statistics request / response.
//
// A small extension to the existing `Statistics` model: instead of "give me
// everything you have for every column", a caller can ask for a specific list
// of stats by name. Providers that have something cheap to offer (parquet
// thrift footers, an external catalog, cached metadata) answer the entries
// they can; everything else is reported `Absent`. Callers (e.g. custom
// optimizer rules) decide what to do with the gaps.
//
// These types are intentionally just a vocabulary — DataFusion itself does
// not (yet) populate or consume them. They exist so that the request can be
// threaded from a `TableScan` (see `TableScan::statistics_requests`) through
// `ScanArgs::statistics_requests` to a `TableProvider`, which is enough for
// the whole feature to be implemented outside of DataFusion.
// ---------------------------------------------------------------------------

use datafusion_common::Column;
use datafusion_common::stats::Precision;

/// A statistic a caller would like a provider to supply, if it can do so
/// cheaply.
///
/// Each variant maps onto a field of [`datafusion_common::Statistics`] /
/// [`datafusion_common::ColumnStatistics`], so a provider that already
/// populates one can answer the other trivially. The companion
/// [`StatisticsValue`] is paired 1:1 with the request in the response.
/// Whether a value is exact or estimated is encoded in the returned
/// [`Precision`] wrapper, not in the request kind itself.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum StatisticsRequest {
/// Smallest non-null value of `column`.
Min(Column),
/// Largest non-null value of `column`.
Max(Column),
/// Number of NULLs in `column`.
NullCount(Column),
/// Number of distinct values in `column` (exact or estimated).
DistinctCount(Column),
/// Sum of values in `column` (numerics, widened per
/// `ColumnStatistics::sum_value`).
Sum(Column),
/// Encoded/output byte size of `column`.
ByteSize(Column),
/// Number of rows in the container (table / file).
RowCount,
/// Total byte size of the container's output.
TotalByteSize,
}

/// Response value paired 1:1 with an inbound [`StatisticsRequest`].
///
/// `Min` / `Max` answers carry the column's natural type; `RowCount` /
/// `NullCount` / `DistinctCount` answers carry a `ScalarValue::UInt64`. A
/// provider that cannot (or will not) answer a request returns
/// [`StatisticsValue::Absent`] — the caller decides whether to fall back to
/// another mechanism.
#[derive(Debug, Clone)]
pub enum StatisticsValue {
/// A single scalar value, exact or estimated.
Scalar(Precision<ScalarValue>),
/// Provider can't (or won't) answer this request.
Absent,
}

impl StatisticsValue {
/// Convenience: an `Exact` scalar response.
pub fn exact(value: ScalarValue) -> Self {
Self::Scalar(Precision::Exact(value))
}

/// Convenience: an `Inexact` scalar response.
pub fn inexact(value: ScalarValue) -> Self {
Self::Scalar(Precision::Inexact(value))
}
}

/// Sparse map of statistics answers, keyed by request. Only entries the
/// provider actually answered are present, so memory scales with what was
/// asked rather than with table width. Consumers infer `Absent` from a
/// missing key.
pub type SatisfiedStatistics =
std::collections::HashMap<StatisticsRequest, StatisticsValue>;

#[cfg(test)]
mod stats_request_tests {
use super::*;
use std::collections::HashMap;

#[test]
fn statistics_request_is_hashable_keyable() {
// Two equal `StatisticsRequest`s hash equal and round-trip through a
// `HashMap`, so they can key a `SatisfiedStatistics` map.
let r1 = StatisticsRequest::Min(Column::new_unqualified("c"));
let r2 = StatisticsRequest::Min(Column::new_unqualified("c"));
assert_eq!(r1, r2);

let mut map: SatisfiedStatistics = HashMap::new();
map.insert(
r1.clone(),
StatisticsValue::exact(ScalarValue::Int64(Some(7))),
);
match map.get(&r2) {
Some(StatisticsValue::Scalar(Precision::Exact(ScalarValue::Int64(
Some(7),
)))) => {}
other => panic!("unexpected lookup: {other:?}"),
}
}
}
Loading
Loading