Skip to content

Commit 20ed9aa

Browse files
committed
Add a cost model for the physical layer
1 parent 4cb23c0 commit 20ed9aa

7 files changed

Lines changed: 1347 additions & 0 deletions

File tree

src/distributed_planner/distributed_config.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ extensions_options! {
6565
/// budget will still be admitted (otherwise we would livelock), so the actual peak per
6666
/// connection is `worker_connection_buffer_budget_bytes + max_message_size`.
6767
pub worker_connection_buffer_budget_bytes: usize, default = 64 * 1024 * 1024
68+
/// Distributed DataFusion relies on row count estimation in order to infer how many workers
69+
/// should be used in serving the query. Some plans might not implement any kind of row count
70+
/// estimation, and this parameter sets the default estimated row count for those plans.
71+
pub default_estimated_row_count: Option<usize>, default = Some(0)
6872
/// Collection of [TaskEstimator]s that will be applied to leaf nodes in order to
6973
/// estimate how many tasks should be spawned for the [Stage] containing the leaf node.
7074
pub(crate) __private_task_estimator: CombinedTaskEstimator, default = CombinedTaskEstimator::default()

src/distributed_planner/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod partial_reduce_below_network_shuffles;
77
mod prepare_network_boundaries;
88
mod push_fetch_into_network_coalesce;
99
mod session_state_builder_ext;
10+
mod statistics;
1011
mod task_estimator;
1112

1213
pub use distributed_config::DistributedConfig;

src/distributed_planner/statistics/compute_per_node.rs

Lines changed: 1009 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
use crate::DistributedConfig;
2+
use crate::distributed_planner::statistics::compute_per_node::calculate_compute_complexity;
3+
use crate::distributed_planner::statistics::plan_statistics::plan_statistics;
4+
use datafusion::common::Result;
5+
use datafusion::physical_plan::{ExecutionPlan, Statistics};
6+
use std::sync::Arc;
7+
8+
pub(crate) fn calculate_cost(
9+
plan: &Arc<dyn ExecutionPlan>,
10+
cfg: &DistributedConfig,
11+
) -> Result<usize> {
12+
f(plan, cfg).map(|(cost, _stats)| cost)
13+
}
14+
15+
fn f(plan: &Arc<dyn ExecutionPlan>, d_cfg: &DistributedConfig) -> Result<(usize, Arc<Statistics>)> {
16+
let children = plan.children();
17+
let mut child_stats = Vec::with_capacity(children.len());
18+
let mut acc_cost = 0;
19+
for child in children {
20+
let (cost, child_stat) = f(child, d_cfg)?;
21+
acc_cost += cost;
22+
child_stats.push(child_stat);
23+
}
24+
25+
let stats = plan_statistics(plan, &child_stats, d_cfg)?;
26+
let complexity = calculate_compute_complexity(plan);
27+
acc_cost += complexity.cost(&stats, &child_stats).unwrap_or(0);
28+
Ok((acc_cost, stats))
29+
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
use datafusion::arrow::datatypes::{DataType, IntervalUnit};
2+
3+
/// Default data size estimate for variable-width columns when no statistics are available.
4+
///
5+
/// Reference: Trino's PlanNodeStatsEstimate.java:40
6+
/// https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L40
7+
const DEFAULT_DATA_SIZE_PER_COLUMN: usize = 50;
8+
9+
/// This function returns the amount of bytes each row is estimated to occupy.
10+
///
11+
/// The estimation follows Trino's approach for calculating output size per row:
12+
/// - For fixed-width (primitive) types: uses the type's fixed byte width
13+
/// - For variable-width types: uses a default estimate plus offset overhead
14+
/// - Accounts for validity bitmap overhead (1 bit per value, rounded to 1 byte per row)
15+
///
16+
/// DataFusion has `Statistics::calculate_total_byte_size()` which uses `DataType::primitive_width()`,
17+
/// but it returns `Precision::Absent` (unknown) when encountering any non-primitive type:
18+
/// https://github.com/apache/datafusion/blob/branch-52/datafusion/common/src/stats.rs#L326-L347
19+
///
20+
/// For distributed query planning, we need estimates even for variable-width types to make
21+
/// cost-based decisions about data shuffling and task count assignation. This implementation
22+
/// provides estimates for all types following Trino's cost model.
23+
///
24+
/// Reference: Trino's PlanNodeStatsEstimate.getOutputSizeForSymbol()
25+
/// https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L89-L114
26+
pub(super) fn default_bytes_for_datatype(data_type: &DataType) -> usize {
27+
// 1 byte for validity bitmap per row (Arrow uses 1 bit, but we round up for estimation).
28+
// Trino calls this the "is null" boolean array.
29+
// Reference: PlanNodeStatsEstimate.java:98-99
30+
// https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L98-L99
31+
const VALIDITY_OVERHEAD: usize = 1;
32+
33+
// Handle non-primitive types.
34+
// NOTE: The cases below are Arrow-specific adaptations. Trino only distinguishes between
35+
// FixedWidthType and variable-width types, using Integer.BYTES (4) for offsets.
36+
// Reference: PlanNodeStatsEstimate.java:108-109
37+
// https://github.com/trinodb/trino/blob/458/core/trino-main/src/main/java/io/trino/cost/PlanNodeStatsEstimate.java#L108-L109
38+
match data_type {
39+
// Primitive types from data_type.primitive_width()
40+
DataType::Int8 => VALIDITY_OVERHEAD + 1,
41+
DataType::Int16 => VALIDITY_OVERHEAD + 2,
42+
DataType::Int32 => VALIDITY_OVERHEAD + 4,
43+
DataType::Int64 => VALIDITY_OVERHEAD + 8,
44+
DataType::UInt8 => VALIDITY_OVERHEAD + 1,
45+
DataType::UInt16 => VALIDITY_OVERHEAD + 2,
46+
DataType::UInt32 => VALIDITY_OVERHEAD + 4,
47+
DataType::UInt64 => VALIDITY_OVERHEAD + 8,
48+
DataType::Float16 => VALIDITY_OVERHEAD + 2,
49+
DataType::Float32 => VALIDITY_OVERHEAD + 4,
50+
DataType::Float64 => VALIDITY_OVERHEAD + 8,
51+
DataType::Timestamp(_, _) => VALIDITY_OVERHEAD + 8,
52+
DataType::Date32 => VALIDITY_OVERHEAD + 4,
53+
DataType::Date64 => VALIDITY_OVERHEAD + 8,
54+
DataType::Time32(_) => VALIDITY_OVERHEAD + 4,
55+
DataType::Time64(_) => VALIDITY_OVERHEAD + 8,
56+
DataType::Duration(_) => VALIDITY_OVERHEAD + 8,
57+
DataType::Interval(IntervalUnit::YearMonth) => VALIDITY_OVERHEAD + 4,
58+
DataType::Interval(IntervalUnit::DayTime) => VALIDITY_OVERHEAD + 8,
59+
DataType::Interval(IntervalUnit::MonthDayNano) => VALIDITY_OVERHEAD + 16,
60+
DataType::Decimal32(_, _) => VALIDITY_OVERHEAD + 4,
61+
DataType::Decimal64(_, _) => VALIDITY_OVERHEAD + 8,
62+
DataType::Decimal128(_, _) => VALIDITY_OVERHEAD + 16,
63+
DataType::Decimal256(_, _) => VALIDITY_OVERHEAD + 32,
64+
// Null type has no data (Arrow-specific)
65+
DataType::Null => 0,
66+
67+
// Boolean is stored as bits (1/8 byte per value), but we round up (Arrow-specific)
68+
DataType::Boolean => VALIDITY_OVERHEAD + 1,
69+
70+
// Fixed-size binary: just the fixed size + validity (Arrow-specific)
71+
DataType::FixedSizeBinary(size) => VALIDITY_OVERHEAD + (*size as usize),
72+
73+
// Fixed-size list: fixed count * element size (Arrow-specific)
74+
DataType::FixedSizeList(field, size) => {
75+
VALIDITY_OVERHEAD + (*size as usize) * default_bytes_for_datatype(field.data_type())
76+
}
77+
78+
// Struct: sum of all child field sizes (Arrow-specific)
79+
// Trino would treat ROW types as variable-width
80+
DataType::Struct(fields) => fields
81+
.iter()
82+
.map(|f| default_bytes_for_datatype(f.data_type()))
83+
.sum(),
84+
85+
// Dictionary-encoded: just the key indices, values are shared across rows (Arrow-specific)
86+
// Trino doesn't have dictionary encoding at the type level
87+
DataType::Dictionary(key_type, _value_type) => default_bytes_for_datatype(key_type),
88+
89+
// Union: type_id (1 byte) + max child size (Arrow-specific)
90+
DataType::Union(fields, _) => {
91+
let max_child_size = fields
92+
.iter()
93+
.map(|(_, f)| default_bytes_for_datatype(f.data_type()))
94+
.max()
95+
.unwrap_or(0);
96+
1 + max_child_size
97+
}
98+
99+
// Run-end encoded: estimate as if it were the value type (Arrow-specific)
100+
// Actual compression depends on data distribution
101+
DataType::RunEndEncoded(_, values) => default_bytes_for_datatype(values.data_type()),
102+
103+
// Variable-width string/binary types.
104+
// Offset size follows Trino's Integer.BYTES (4 bytes).
105+
// Reference: PlanNodeStatsEstimate.java:109
106+
DataType::Utf8 | DataType::Binary => {
107+
VALIDITY_OVERHEAD + size_of::<i32>() + DEFAULT_DATA_SIZE_PER_COLUMN
108+
}
109+
// Large variants use i64 offsets (Arrow-specific, Trino doesn't have large variants)
110+
DataType::LargeUtf8 | DataType::LargeBinary => {
111+
VALIDITY_OVERHEAD + size_of::<i64>() + DEFAULT_DATA_SIZE_PER_COLUMN
112+
}
113+
// View types use 16-byte inline representation (Arrow-specific)
114+
// Reference: https://arrow.apache.org/docs/format/Columnar.html#variable-size-binary-view-layout
115+
DataType::Utf8View | DataType::BinaryView => VALIDITY_OVERHEAD + 16,
116+
117+
// List types (Arrow-specific adaptation)
118+
// Spark assumes 1 element average for collections (SPARK-18853). Trino treats them
119+
// as flat variable-width with 50-byte default. We follow Spark's 1-element assumption
120+
// to avoid massive overestimation (e.g. Map<Int,String> was 605 bytes with 10 elements).
121+
DataType::List(field) => {
122+
VALIDITY_OVERHEAD + size_of::<i32>() + default_bytes_for_datatype(field.data_type())
123+
}
124+
DataType::LargeList(field) => {
125+
VALIDITY_OVERHEAD + size_of::<i64>() + default_bytes_for_datatype(field.data_type())
126+
}
127+
DataType::ListView(field) | DataType::LargeListView(field) => {
128+
VALIDITY_OVERHEAD + 8 + default_bytes_for_datatype(field.data_type())
129+
}
130+
131+
// Map type: stored as List<Struct<key, value>> (Arrow-specific)
132+
// Uses same 1-element assumption as List types (following Spark).
133+
DataType::Map(field, _) => {
134+
VALIDITY_OVERHEAD + size_of::<i32>() + default_bytes_for_datatype(field.data_type())
135+
} // Fallback for any other types - use Trino's default
136+
}
137+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
mod compute_per_node;
2+
mod cost;
3+
mod default_bytes_for_datatype;
4+
mod plan_statistics;
5+
6+
#[allow(unused)] // will be used in a follow-up PR.
7+
pub(crate) use cost::calculate_cost;
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
use crate::DistributedConfig;
2+
use crate::distributed_planner::statistics::default_bytes_for_datatype::default_bytes_for_datatype;
3+
use datafusion::common::stats::Precision;
4+
use datafusion::common::{Statistics, not_impl_err, plan_err};
5+
use datafusion::config::ConfigOptions;
6+
use datafusion::error::Result;
7+
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
8+
use datafusion::physical_plan::execution_plan::CardinalityEffect;
9+
use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
10+
use delegate::delegate;
11+
use itertools::Itertools;
12+
use std::fmt::Formatter;
13+
use std::sync::Arc;
14+
15+
/// Uses upstream DataFusion stats system with some small overrides.
16+
pub(super) fn plan_statistics(
17+
node: &Arc<dyn ExecutionPlan>,
18+
children_stats: &[Arc<Statistics>],
19+
opts: &DistributedConfig,
20+
) -> Result<Arc<Statistics>> {
21+
let mut stats = partition_statistics_with_children_override(node, None, children_stats)?;
22+
23+
// If rows are absent, be conservative and assume that all the rows from all the children
24+
// are going to be returned.
25+
if matches!(stats.num_rows, Precision::Absent) {
26+
let num_rows = children_stats
27+
.iter()
28+
.flat_map(|v| v.num_rows.get_value())
29+
.sum1::<usize>();
30+
let num_rows = if let Some(num_rows) = num_rows {
31+
num_rows
32+
} else if let Some(default) = opts.default_estimated_row_count {
33+
default
34+
} else {
35+
return plan_err!(
36+
"{} does not provide row stats, and none of its children [{}] provides a row count",
37+
node.name(),
38+
node.children()
39+
.iter()
40+
.map(|v| v.name())
41+
.collect::<Vec<_>>()
42+
.join(", ")
43+
);
44+
};
45+
stats.num_rows = Precision::Inexact(num_rows)
46+
}
47+
48+
let schema = node.schema();
49+
50+
for (i, col_stats) in &mut stats.column_statistics.iter_mut().enumerate() {
51+
let rows = stats.num_rows.get_value().unwrap_or(&0);
52+
53+
// If some of the NDVs are not present in one of the column-level stats, assume the
54+
// worst and use the same as the input number of rows.
55+
if matches!(col_stats.distinct_count, Precision::Absent) {
56+
col_stats.distinct_count = Precision::Inexact(*rows);
57+
}
58+
59+
// If the per-column byte size stats are not present, estimate the byte size based on the
60+
// data type and the row count.
61+
let Some(dt) = schema.fields.get(i).map(|v| v.data_type()) else {
62+
return plan_err!("Field with index {i} not present in schema: {schema:?}");
63+
};
64+
65+
// If it turns out that we do not have `byte_size` stats, but we do have an estimated number
66+
// of rows, do a best-effort in trying to infer the byte size for each column.
67+
if matches!(col_stats.byte_size, Precision::Absent) {
68+
col_stats.byte_size = Precision::Inexact(default_bytes_for_datatype(dt) * rows)
69+
}
70+
}
71+
72+
// If bytes are absent, let's just infer them based on the schema and the
73+
// number of rows.
74+
if matches!(stats.total_byte_size, Precision::Absent) {
75+
let mut total_byte_size = 0;
76+
for col_stats in &stats.column_statistics {
77+
total_byte_size += col_stats.byte_size.get_value().unwrap_or(&0);
78+
}
79+
stats.total_byte_size = Precision::Inexact(total_byte_size);
80+
}
81+
82+
Ok(Arc::new(stats))
83+
}
84+
85+
// FIXME: because of limitations the the statistics API on DataFusion, we need to resource to
86+
// this sketchy way of overriding child statistics, as we cannot just provide our own.
87+
// If we don't do this:
88+
// 1. we cannot tell nodes to compute statistics based on the ones we provide.
89+
// 2. we recompute statistics unnecessarily across the plan
90+
// This is tracked by https://github.com/apache/datafusion/issues/20184 upstream, and until
91+
// that one is solved, we need to resource to this wrapper.
92+
fn partition_statistics_with_children_override(
93+
node: &Arc<dyn ExecutionPlan>,
94+
partition: Option<usize>,
95+
child_stats: &[Arc<Statistics>],
96+
) -> Result<Statistics> {
97+
// DataFusion stats system is not very mature yet. This override layer brings in changes
98+
// that might not have already been released or informed overrides.
99+
let statistics_wrapped_children = child_stats
100+
.iter()
101+
.zip(node.children())
102+
.map(|(stats, child)| StatisticsWrapper {
103+
inner: Arc::clone(child),
104+
stats: Arc::clone(stats),
105+
})
106+
.map(|v| Arc::new(v) as _)
107+
.collect();
108+
109+
let stats = Arc::clone(node)
110+
.with_new_children(statistics_wrapped_children)?
111+
.partition_statistics(partition)?;
112+
113+
Ok(stats.as_ref().clone())
114+
}
115+
116+
#[derive(Debug)]
117+
struct StatisticsWrapper {
118+
stats: Arc<Statistics>,
119+
inner: Arc<dyn ExecutionPlan>,
120+
}
121+
122+
impl DisplayAs for StatisticsWrapper {
123+
delegate! {
124+
to self.inner {
125+
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result;
126+
}
127+
}
128+
}
129+
130+
impl ExecutionPlan for StatisticsWrapper {
131+
fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
132+
if partition.is_some() {
133+
return plan_err!("StatisticsWrapper not prepared for partition-specific stats");
134+
}
135+
Ok(Arc::clone(&self.stats))
136+
}
137+
138+
delegate! {
139+
to self.inner {
140+
fn name(&self) -> &str;
141+
fn properties(&self) -> &Arc<PlanProperties>;
142+
fn maintains_input_order(&self) -> Vec<bool>;
143+
fn benefits_from_input_partitioning(&self) -> Vec<bool>;
144+
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>>;
145+
fn repartitioned(&self, _target_partitions: usize, _config: &ConfigOptions) -> Result<Option<Arc<dyn ExecutionPlan>>>;
146+
fn execute(&self, partition: usize, context: Arc<TaskContext>) -> Result<SendableRecordBatchStream>;
147+
fn supports_limit_pushdown(&self) -> bool;
148+
fn with_fetch(&self, _limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>>;
149+
fn fetch(&self) -> Option<usize>;
150+
fn cardinality_effect(&self) -> CardinalityEffect;
151+
}
152+
}
153+
154+
fn with_new_children(
155+
self: Arc<Self>,
156+
_: Vec<Arc<dyn ExecutionPlan>>,
157+
) -> Result<Arc<dyn ExecutionPlan>> {
158+
not_impl_err!("with_new_children not implemented")
159+
}
160+
}

0 commit comments

Comments
 (0)