diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs index 7e8fad3252..2ca637105f 100644 --- a/ballista/core/src/lib.rs +++ b/ballista/core/src/lib.rs @@ -50,6 +50,8 @@ pub mod extension; #[cfg(feature = "build-binary")] /// Object store configuration and utilities for distributed file access. pub mod object_store; +/// Hash-partition (bucketing) metadata for distributed query optimization. +pub mod partitioning; /// Query planning utilities for distributed execution. pub mod planner; /// Runtime registry for codec and function registration. diff --git a/ballista/core/src/partitioning.rs b/ballista/core/src/partitioning.rs new file mode 100644 index 0000000000..7408a0b805 --- /dev/null +++ b/ballista/core/src/partitioning.rs @@ -0,0 +1,644 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Hash-partition (bucketing) metadata for distributed query optimization. +//! +//! Pinot-style colocated joins, sub-partitioning, and small-side broadcast +//! all need to know whether a table's data is already hash-bucketed by some +//! column. This module defines the contract by which a [`TableProvider`] can +//! declare that bucketing, plus a thin wrapper that lets users attach the +//! metadata to any existing provider without writing a custom one. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::{Constraints, Statistics}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::context::TaskContext; +use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_expr::expressions::Column; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, + PlanProperties, SendableRecordBatchStream, +}; +use futures::StreamExt; + +/// Hash function used to bucket a table's rows on disk. +/// +/// The optimizer uses this to verify that two co-located inputs were bucketed +/// by the same function before eliding a shuffle. Crucially, it does **not** +/// have to match DataFusion's internal `RepartitionExec` hasher — the +/// declaration is a promise about the on-disk layout, and the colocated-join +/// rule trusts it. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum HashFn { + /// Murmur3 (32-bit), as used by Spark/Hive bucketing. + Murmur3, + /// xxHash64. + XxHash64, +} + +impl fmt::Display for HashFn { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + HashFn::Murmur3 => write!(f, "murmur3"), + HashFn::XxHash64 => write!(f, "xxhash64"), + } + } +} + +/// Declares that a table is hash-bucketed across `num_buckets` partitions on +/// the columns named in `keys`, using `hash_fn`. +/// +/// Two tables are *co-located* for a join when they share the same `keys` +/// (matching join keys positionally), the same `hash_fn`, and the same +/// `num_buckets` — bucket *k* of one matches bucket *k* of the other. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct HashDistribution { + /// Column names used as partition keys. + pub keys: Vec, + /// Hash function used to assign rows to buckets. + pub hash_fn: HashFn, + /// Number of buckets. + pub num_buckets: usize, +} + +impl HashDistribution { + /// Construct a new hash distribution descriptor. + pub fn new(keys: Vec, hash_fn: HashFn, num_buckets: usize) -> Self { + Self { + keys, + hash_fn, + num_buckets, + } + } +} + +/// Optional trait that a [`TableProvider`] can implement to advertise its +/// hash-partition layout to the Ballista physical optimizer. +/// +/// The optimizer downcasts each scan's underlying provider to this trait via +/// [`std::any::Any`]. If the downcast succeeds and the returned distribution +/// matches a join's required hash partitioning, the planner can elide the +/// shuffle. +pub trait BallistaPartitionMetadata: Send + Sync { + /// Returns the hash distribution of this table, or `None` if the table is + /// not bucketed. + fn hash_distribution(&self) -> Option; +} + +/// Convenience wrapper that attaches a [`HashDistribution`] to any existing +/// [`TableProvider`]. +/// +/// All `TableProvider` methods delegate to the inner provider; the only +/// extras are (a) the [`BallistaPartitionMetadata`] impl and (b) an +/// adapter that re-advertises the scan's `output_partitioning()` as +/// [`Partitioning::Hash`] so downstream optimizer rules can see it. +/// +/// The caller is responsible for ensuring that the inner provider produces +/// exactly `distribution.num_buckets` partitions, with file group *k* +/// containing bucket *k*. The standard Spark/Hive convention of +/// `part-NNNNN-…` filenames satisfies this when files are sorted by name. +#[derive(Debug)] +pub struct PartitionedTableProvider { + inner: Arc, + distribution: HashDistribution, +} + +impl PartitionedTableProvider { + /// Wrap an existing provider and declare its hash distribution. + pub fn new( + inner: Arc, + distribution: HashDistribution, + ) -> Self { + Self { + inner, + distribution, + } + } + + /// Returns the wrapped provider. + pub fn inner(&self) -> &Arc { + &self.inner + } +} + +impl BallistaPartitionMetadata for PartitionedTableProvider { + fn hash_distribution(&self) -> Option { + Some(self.distribution.clone()) + } +} + +#[async_trait] +impl TableProvider for PartitionedTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.inner.schema() + } + + fn constraints(&self) -> Option<&Constraints> { + self.inner.constraints() + } + + fn table_type(&self) -> TableType { + self.inner.table_type() + } + + fn get_table_definition(&self) -> Option<&str> { + self.inner.get_table_definition() + } + + fn get_column_default( + &self, + column: &str, + ) -> Option<&datafusion::logical_expr::Expr> { + self.inner.get_column_default(column) + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let inner_plan = self.inner.scan(state, projection, filters, limit).await?; + HashDistributedScanExec::try_new(inner_plan, self.distribution.clone()) + .map(|exec| Arc::new(exec) as Arc) + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + self.inner.supports_filters_pushdown(filters) + } + + fn statistics(&self) -> Option { + self.inner.statistics() + } +} + +/// A passthrough [`ExecutionPlan`] adapter that overrides `output_partitioning` +/// to advertise a known [`HashDistribution`]. +/// +/// The inner plan is executed unchanged — this node only edits its +/// [`PlanProperties`] so optimizer rules looking at `output_partitioning()` +/// see `Partitioning::Hash(keys, num_buckets)`. +/// +/// Resolving the hash key column names to physical [`Column`] expressions +/// requires the inner schema, which is why construction can fail if a +/// declared key is not present after projection. +#[derive(Debug)] +pub struct HashDistributedScanExec { + inner: Arc, + distribution: HashDistribution, + properties: Arc, +} + +impl HashDistributedScanExec { + /// Wrap an inner scan plan and re-advertise its partitioning. + /// + /// Returns an error if any column in `distribution.keys` is missing from + /// the inner plan's output schema (e.g., projected away). + pub fn try_new( + inner: Arc, + distribution: HashDistribution, + ) -> Result { + let schema = inner.schema(); + let key_exprs = resolve_key_columns(&schema, &distribution.keys)?; + + let inner_props = inner.properties(); + let eq_properties = EquivalenceProperties::new(schema.clone()); + + let properties = Arc::new(PlanProperties::new( + eq_properties, + Partitioning::Hash(key_exprs, distribution.num_buckets), + inner_props.emission_type, + inner_props.boundedness, + )); + + Ok(Self { + inner, + distribution, + properties, + }) + } + + /// The hash distribution this adapter declares. + pub fn distribution(&self) -> &HashDistribution { + &self.distribution + } +} + +impl BallistaPartitionMetadata for HashDistributedScanExec { + fn hash_distribution(&self) -> Option { + Some(self.distribution.clone()) + } +} + +impl DisplayAs for HashDistributedScanExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "HashDistributedScanExec: keys=[{}], hash_fn={}, buckets={}", + self.distribution.keys.join(","), + self.distribution.hash_fn, + self.distribution.num_buckets, + ) + } + DisplayFormatType::TreeRender => { + write!( + f, + "buckets={} hash_fn={}", + self.distribution.num_buckets, self.distribution.hash_fn, + ) + } + } + } +} + +impl ExecutionPlan for HashDistributedScanExec { + fn name(&self) -> &str { + "HashDistributedScanExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let [child] = <[_; 1]>::try_from(children).map_err(|_| { + DataFusionError::Plan( + "HashDistributedScanExec requires exactly one child".to_string(), + ) + })?; + Ok(Arc::new(Self::try_new(child, self.distribution.clone())?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + self.inner.execute(partition, context) + } + + fn partition_statistics(&self, partition: Option) -> Result { + self.inner.partition_statistics(partition) + } +} + +fn resolve_key_columns( + schema: &SchemaRef, + keys: &[String], +) -> Result>> { + keys.iter() + .map(|name| { + Column::new_with_schema(name, schema) + .map(|col| Arc::new(col) as Arc) + }) + .collect::>>() + .map_err(|e| { + DataFusionError::Plan(format!( + "hash distribution key not found in scan output: {e}" + )) + }) +} + +/// Per-partition concat that re-buckets an already-bucketed plan into a +/// smaller (divisor) bucket count without a network shuffle. +/// +/// When two co-located inputs were bucketed by the same key and same hash +/// function but different bucket counts (e.g., 16 vs 8), the larger side can +/// be locally coalesced so its partitioning matches the smaller side. This is +/// safe because, for any row key `k`, `(hash(k) % 16) % 8 == hash(k) % 8` — +/// so input partitions `[i, i+8]` of the 16-bucket side both belong in +/// output partition `i` of the 8-bucket projection. +/// +/// The exec is a pure remapping: it does not move data across executors, it +/// just chains record batches from the relevant input partitions. +#[derive(Debug)] +pub struct BucketSubPartitionExec { + inner: Arc, + output_distribution: HashDistribution, + properties: Arc, +} + +impl BucketSubPartitionExec { + /// Wrap `inner` to emit `output_distribution.num_buckets` partitions. + /// + /// `output_distribution.num_buckets` must evenly divide + /// `inner.output_partitioning().partition_count()`. + pub fn try_new( + inner: Arc, + output_distribution: HashDistribution, + ) -> Result { + let inner_count = inner.properties().output_partitioning().partition_count(); + let out_count = output_distribution.num_buckets; + if out_count == 0 || inner_count == 0 || !inner_count.is_multiple_of(out_count) { + return Err(DataFusionError::Plan(format!( + "BucketSubPartitionExec requires output buckets to divide input \ + partitions; got input={inner_count}, output={out_count}", + ))); + } + + let schema = inner.schema(); + let key_exprs = resolve_key_columns(&schema, &output_distribution.keys)?; + let inner_props = inner.properties(); + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::Hash(key_exprs, out_count), + inner_props.emission_type, + inner_props.boundedness, + )); + Ok(Self { + inner, + output_distribution, + properties, + }) + } + + /// The hash distribution this exec advertises. + pub fn output_distribution(&self) -> &HashDistribution { + &self.output_distribution + } + + /// Number of input partitions per output partition. + pub fn coalesce_factor(&self) -> usize { + self.inner.properties().output_partitioning().partition_count() + / self.output_distribution.num_buckets + } +} + +impl BallistaPartitionMetadata for BucketSubPartitionExec { + fn hash_distribution(&self) -> Option { + Some(self.output_distribution.clone()) + } +} + +impl DisplayAs for BucketSubPartitionExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "BucketSubPartitionExec: out_buckets={}, factor={}", + self.output_distribution.num_buckets, + self.coalesce_factor(), + ) + } + DisplayFormatType::TreeRender => { + write!( + f, + "out_buckets={} factor={}", + self.output_distribution.num_buckets, + self.coalesce_factor(), + ) + } + } + } +} + +impl ExecutionPlan for BucketSubPartitionExec { + fn name(&self) -> &str { + "BucketSubPartitionExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let [child] = <[_; 1]>::try_from(children).map_err(|_| { + DataFusionError::Plan( + "BucketSubPartitionExec requires exactly one child".to_string(), + ) + })?; + Ok(Arc::new(Self::try_new(child, self.output_distribution.clone())?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let inner_count = + self.inner.properties().output_partitioning().partition_count(); + let stride = self.output_distribution.num_buckets; + let mut input_streams = Vec::with_capacity(self.coalesce_factor()); + let mut idx = partition; + while idx < inner_count { + input_streams.push(self.inner.execute(idx, Arc::clone(&context))?); + idx += stride; + } + let chained = futures::stream::iter(input_streams).flatten(); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + chained, + ))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::array::{Int32Array, StringArray}; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::arrow::record_batch::RecordBatch; + use datafusion::datasource::MemTable; + use datafusion::prelude::SessionContext; + + fn sample_table() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(StringArray::from(vec!["a", "b", "c", "d"])), + ], + ) + .unwrap(); + // Two batches → two partitions, mimicking a 2-bucket layout. + Arc::new( + MemTable::try_new(schema, vec![vec![batch.clone()], vec![batch]]).unwrap(), + ) + } + + #[tokio::test] + async fn provider_advertises_hash_distribution() { + let inner = sample_table(); + let dist = HashDistribution::new(vec!["id".into()], HashFn::Murmur3, 2); + let provider = PartitionedTableProvider::new(inner, dist.clone()); + + assert_eq!(provider.hash_distribution(), Some(dist.clone())); + + let ctx = SessionContext::new(); + let plan = provider + .scan(&ctx.state(), None, &[], None) + .await + .expect("scan should succeed"); + + match plan.properties().output_partitioning() { + Partitioning::Hash(keys, n) => { + assert_eq!(*n, 2); + assert_eq!(keys.len(), 1); + let col = keys[0] + .as_any() + .downcast_ref::() + .expect("expected Column expression"); + assert_eq!(col.name(), "id"); + } + other => panic!("expected Hash partitioning, got {other:?}"), + } + + let any_provider: &dyn Any = &provider; + assert!(any_provider.downcast_ref::().is_some()); + + let scan_meta = plan + .as_any() + .downcast_ref::() + .expect("scan should be wrapped in HashDistributedScanExec"); + assert_eq!(scan_meta.hash_distribution().as_ref(), Some(&dist)); + } + + #[tokio::test] + async fn missing_key_column_is_an_error() { + let inner = sample_table(); + let dist = HashDistribution::new(vec!["nope".into()], HashFn::XxHash64, 4); + let provider = PartitionedTableProvider::new(inner, dist); + let ctx = SessionContext::new(); + let err = provider + .scan(&ctx.state(), None, &[], None) + .await + .expect_err("scan should fail when key column missing"); + assert!( + err.to_string().contains("hash distribution key not found"), + "unexpected error: {err}" + ); + } + + #[tokio::test] + async fn projection_removing_key_is_an_error() { + let inner = sample_table(); + let dist = HashDistribution::new(vec!["id".into()], HashFn::Murmur3, 2); + let provider = PartitionedTableProvider::new(inner, dist); + let ctx = SessionContext::new(); + // Project away `id`; only `name` remains. + let err = provider + .scan(&ctx.state(), Some(&vec![1]), &[], None) + .await + .expect_err("scan should fail when key projected away"); + assert!(err.to_string().contains("hash distribution key not found")); + } + + async fn multi_partition_table( + partitions: usize, + rows_per_partition: usize, + ) -> Arc { + let schema = Arc::new(Schema::new(vec![Field::new( + "id", + DataType::Int32, + false, + )])); + let parts: Vec> = (0..partitions) + .map(|p| { + let arr = Arc::new(Int32Array::from_iter_values( + (0..rows_per_partition).map(|r| (p * 1000 + r) as i32), + )); + vec![RecordBatch::try_new(schema.clone(), vec![arr]).unwrap()] + }) + .collect(); + let provider = Arc::new(MemTable::try_new(schema, parts).unwrap()); + let ctx = SessionContext::new(); + provider.scan(&ctx.state(), None, &[], None).await.unwrap() + } + + #[tokio::test] + async fn sub_partition_chains_input_partitions() { + use datafusion::execution::context::TaskContext; + use futures::TryStreamExt; + + // 6 input partitions, chain into 3 outputs (factor=2). + // Output 0 reads inputs [0, 3]; output 1 reads [1, 4]; output 2 reads [2, 5]. + let inner = multi_partition_table(6, 2).await; + let dist = HashDistribution::new(vec!["id".into()], HashFn::Murmur3, 3); + let exec = BucketSubPartitionExec::try_new(inner, dist).unwrap(); + + match exec.properties().output_partitioning() { + Partitioning::Hash(_, n) => assert_eq!(*n, 3), + other => panic!("expected Hash, got {other:?}"), + } + assert_eq!(exec.coalesce_factor(), 2); + + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, ctx).unwrap(); + let batches: Vec<_> = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 2, "output 0 should chain inputs 0 and 3"); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 4); // 2 + 2 + } + + #[tokio::test] + async fn sub_partition_rejects_non_divisor() { + let inner = multi_partition_table(5, 1).await; + let dist = HashDistribution::new(vec!["id".into()], HashFn::Murmur3, 3); + let err = + BucketSubPartitionExec::try_new(inner, dist).expect_err("should reject"); + assert!( + err.to_string().contains("divide"), + "unexpected error: {err}" + ); + } +} diff --git a/ballista/scheduler/src/physical_optimizer/broadcast_small_side.rs b/ballista/scheduler/src/physical_optimizer/broadcast_small_side.rs new file mode 100644 index 0000000000..46125565cb --- /dev/null +++ b/ballista/scheduler/src/physical_optimizer/broadcast_small_side.rs @@ -0,0 +1,277 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Broadcast small-side optimizer rule. +//! +//! Converts a [`HashJoinExec`] in [`PartitionMode::Partitioned`] mode to +//! [`PartitionMode::CollectLeft`] (broadcast) when one side's total byte +//! size is below the configured threshold. Avoids the network shuffle on +//! the larger side when the smaller side fits. +//! +//! Restricted to `JoinType::Inner` because CollectLeft has known +//! correctness issues for Left/Full outer joins in Ballista +//! ([issue #1055](https://github.com/apache/datafusion-ballista/issues/1055)). +//! `should_swap_join_order` from the existing JoinSelection rule chooses +//! which side to put on the build side after the conversion. + +use std::sync::Arc; + +use ballista_core::config::BallistaConfig; +use datafusion::common::JoinType; +use datafusion::common::config::ConfigOptions; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion::error::Result; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; + +use crate::physical_optimizer::join_selection::should_swap_join_order; + +/// Optimizer rule that promotes small-side broadcasts. +#[derive(Debug)] +pub struct BroadcastSmallSideRule { + threshold_bytes: usize, +} + +impl BroadcastSmallSideRule { + /// Build the rule with an explicit threshold (in bytes). + pub fn with_threshold(threshold_bytes: usize) -> Self { + Self { threshold_bytes } + } + + /// Build the rule from a [`SessionConfig`]'s broadcast-threshold setting. + /// Returns a rule with `threshold = 0` (no-op) if the session config has + /// no [`BallistaConfig`] extension — Ballista features should only fire + /// when the user has explicitly opted into Ballista. + pub fn from_session_config(config: &datafusion::prelude::SessionConfig) -> Self { + let threshold = config + .options() + .extensions + .get::() + .map(|c| c.broadcast_join_threshold_bytes()) + .unwrap_or(0); + Self::with_threshold(threshold) + } +} + +impl PhysicalOptimizerRule for BroadcastSmallSideRule { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + if self.threshold_bytes == 0 { + return Ok(plan); + } + let threshold = self.threshold_bytes; + plan.transform_up(|p| try_broadcast_small_side(p, threshold)) + .data() + } + + fn name(&self) -> &str { + "BroadcastSmallSideRule" + } + + fn schema_check(&self) -> bool { + true + } +} + +fn try_broadcast_small_side( + plan: Arc, + threshold_bytes: usize, +) -> Result>> { + let Some(join) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + if !matches!(join.partition_mode(), PartitionMode::Partitioned) { + return Ok(Transformed::no(plan)); + } + if !matches!(join.join_type(), JoinType::Inner) { + return Ok(Transformed::no(plan)); + } + if join.null_aware { + // CollectLeft needs global probe-side state for null-aware semantics. + return Ok(Transformed::no(plan)); + } + + let left_bytes = side_byte_size(join.left()); + let right_bytes = side_byte_size(join.right()); + + let left_fits = left_bytes.is_some_and(|b| b > 0 && b < threshold_bytes); + let right_fits = right_bytes.is_some_and(|b| b > 0 && b < threshold_bytes); + + let new_plan: Arc = match (left_fits, right_fits) { + (false, false) => return Ok(Transformed::no(plan)), + (true, true) | (true, false) => { + // Left fits → keep it on the build side. + Arc::new( + join.builder() + .with_partition_mode(PartitionMode::CollectLeft) + .build()?, + ) + } + (false, true) => { + // Right fits → swap so the small side becomes the build side. + if !join.join_type().supports_swap() + || should_swap_join_order(&**join.left(), &**join.right())? + { + join.swap_inputs(PartitionMode::CollectLeft)? + } else { + Arc::new( + join.builder() + .with_partition_mode(PartitionMode::CollectLeft) + .build()?, + ) + } + } + }; + Ok(Transformed::yes(new_plan)) +} + +fn side_byte_size(plan: &Arc) -> Option { + plan.partition_statistics(None) + .ok() + .and_then(|s| s.total_byte_size.get_value().copied()) +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::stats::Precision; + use datafusion::common::{ColumnStatistics, Statistics}; + use datafusion::physical_expr::PhysicalExprRef; + use datafusion::physical_expr::expressions::Column; + use datafusion::physical_plan::joins::HashJoinExecBuilder; + use datafusion::physical_plan::repartition::RepartitionExec; + use datafusion::physical_plan::test::exec::StatisticsExec; + use datafusion::physical_plan::Partitioning; + + fn schema() -> Arc { + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])) + } + + fn stats_exec(byte_size: usize) -> Arc { + Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(byte_size / 4), + total_byte_size: Precision::Inexact(byte_size), + column_statistics: vec![ColumnStatistics::new_unknown()], + }, + (*schema()).clone(), + )) + } + + fn col(name: &str, idx: usize) -> PhysicalExprRef { + Arc::new(Column::new(name, idx)) + } + + fn build_join( + left: Arc, + right: Arc, + join_type: JoinType, + ) -> Arc { + let on = vec![(col("id", 0), col("id", 0))]; + let (lk, rk): (Vec<_>, Vec<_>) = on.iter().cloned().unzip(); + let left = Arc::new( + RepartitionExec::try_new(left, Partitioning::Hash(lk, 4)).unwrap(), + ); + let right = Arc::new( + RepartitionExec::try_new(right, Partitioning::Hash(rk, 4)).unwrap(), + ); + Arc::new( + HashJoinExecBuilder::new(left, right, on, join_type) + .with_partition_mode(PartitionMode::Partitioned) + .build() + .unwrap(), + ) + } + + fn join_partition_mode(plan: &Arc) -> PartitionMode { + // After a swap, HashJoinExec may sit under a ProjectionExec that + // restores the original column order. Walk the tree to find it. + let mut cur: Arc = Arc::clone(plan); + loop { + if let Some(j) = cur.as_any().downcast_ref::() { + return *j.partition_mode(); + } + let children = cur.children(); + assert_eq!( + children.len(), + 1, + "expected HashJoinExec along single-child path" + ); + cur = Arc::clone(children[0]); + } + } + + #[test] + fn broadcasts_when_left_below_threshold() { + let small = stats_exec(1024); + let big = stats_exec(100 * 1024 * 1024); + let join = build_join(small, big, JoinType::Inner); + let optimized = BroadcastSmallSideRule::with_threshold(1024 * 1024) + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!(join_partition_mode(&optimized), PartitionMode::CollectLeft); + } + + #[test] + fn broadcasts_and_swaps_when_right_below_threshold() { + let big = stats_exec(100 * 1024 * 1024); + let small = stats_exec(1024); + let join = build_join(big, small, JoinType::Inner); + let optimized = BroadcastSmallSideRule::with_threshold(1024 * 1024) + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!(join_partition_mode(&optimized), PartitionMode::CollectLeft); + } + + #[test] + fn skips_when_neither_side_fits() { + let big1 = stats_exec(50 * 1024 * 1024); + let big2 = stats_exec(100 * 1024 * 1024); + let join = build_join(big1, big2, JoinType::Inner); + let optimized = BroadcastSmallSideRule::with_threshold(1024 * 1024) + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!(join_partition_mode(&optimized), PartitionMode::Partitioned); + } + + #[test] + fn skips_outer_join() { + let small = stats_exec(1024); + let big = stats_exec(100 * 1024 * 1024); + let join = build_join(small, big, JoinType::Left); + let optimized = BroadcastSmallSideRule::with_threshold(1024 * 1024) + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!(join_partition_mode(&optimized), PartitionMode::Partitioned); + } + + #[test] + fn disabled_when_threshold_is_zero() { + let small = stats_exec(1024); + let big = stats_exec(100 * 1024 * 1024); + let join = build_join(small, big, JoinType::Inner); + let optimized = BroadcastSmallSideRule::with_threshold(0) + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!(join_partition_mode(&optimized), PartitionMode::Partitioned); + } +} diff --git a/ballista/scheduler/src/physical_optimizer/colocated_join.rs b/ballista/scheduler/src/physical_optimizer/colocated_join.rs new file mode 100644 index 0000000000..e543d41fb9 --- /dev/null +++ b/ballista/scheduler/src/physical_optimizer/colocated_join.rs @@ -0,0 +1,454 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Pinot-style colocated join rule. +//! +//! When both inputs of a [`HashJoinExec`] are already hash-partitioned by the +//! join keys on the same number of buckets with the same hash function, the +//! [`RepartitionExec`] that DataFusion's `EnforceDistribution` rule would +//! otherwise insert is redundant. This rule strips it, eliminating one shuffle +//! per pair of co-located inputs. +//! +//! The rule is conservative: it only removes an existing `RepartitionExec` +//! when both sides' source plans expose matching [`BallistaPartitionMetadata`] +//! and the join keys align positionally with the declared distribution keys. + +use std::sync::Arc; + +use ballista_core::partitioning::{ + BallistaPartitionMetadata, BucketSubPartitionExec, HashDistribution, +}; +use datafusion::common::config::ConfigOptions; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion::error::Result; +use datafusion::physical_expr::PhysicalExprRef; +use datafusion::physical_expr::expressions::Column; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::joins::HashJoinExec; +use datafusion::physical_plan::repartition::RepartitionExec; + +/// Optimizer rule that elides redundant [`RepartitionExec`] above the inputs +/// of a [`HashJoinExec`] when both inputs are co-located. +#[derive(Debug, Default)] +pub struct ColocatedJoinRule {} + +impl ColocatedJoinRule { + /// Construct the rule. + pub fn new() -> Self { + Self::default() + } +} + +impl PhysicalOptimizerRule for ColocatedJoinRule { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_up(try_elide_join_repartitions).data() + } + + fn name(&self) -> &str { + "ColocatedJoinRule" + } + + fn schema_check(&self) -> bool { + true + } +} + +fn try_elide_join_repartitions( + plan: Arc, +) -> Result>> { + let Some(join) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + + let on = join.on(); + let (left_keys, right_keys): (Vec<_>, Vec<_>) = + on.iter().cloned().unzip(); + + let left_dist = source_distribution(join.left()); + let right_dist = source_distribution(join.right()); + let (Some(left_dist), Some(right_dist)) = (left_dist, right_dist) else { + return Ok(Transformed::no(plan)); + }; + + if !same_keying(&left_dist, &right_dist, &left_keys, &right_keys) { + return Ok(Transformed::no(plan)); + } + + let target_buckets = match colocation_target(&left_dist, &right_dist) { + Some(n) => n, + None => return Ok(Transformed::no(plan)), + }; + + let new_left = rewrite_side(join.left(), &left_dist, target_buckets)?; + let new_right = rewrite_side(join.right(), &right_dist, target_buckets)?; + if Arc::ptr_eq(&new_left, join.left()) && Arc::ptr_eq(&new_right, join.right()) { + return Ok(Transformed::no(plan)); + } + + let new_join = Arc::clone(&plan).with_new_children(vec![new_left, new_right])?; + Ok(Transformed::yes(new_join)) +} + +/// Decide the bucket count to colocate at: either the matching count (full +/// colocation) or the smaller of the two when one divides the other +/// (sub-partition colocation). Returns `None` when no rewrite is possible. +fn colocation_target(left: &HashDistribution, right: &HashDistribution) -> Option { + if left.num_buckets == right.num_buckets { + return Some(left.num_buckets); + } + let (small, large) = if left.num_buckets < right.num_buckets { + (left.num_buckets, right.num_buckets) + } else { + (right.num_buckets, left.num_buckets) + }; + if small > 0 && large % small == 0 { + Some(small) + } else { + None + } +} + +/// Walk down through a single [`RepartitionExec`] (if present) to find the +/// [`BallistaPartitionMetadata`] declared by the underlying source. +fn source_distribution(plan: &Arc) -> Option { + let candidate = plan + .as_any() + .downcast_ref::() + .map(|r| r.input()) + .unwrap_or(plan); + metadata_of(candidate) +} + +fn metadata_of(plan: &Arc) -> Option { + let any = plan.as_any(); + if let Some(meta) = + any.downcast_ref::() + { + return meta.hash_distribution(); + } + if let Some(meta) = any.downcast_ref::() { + return meta.hash_distribution(); + } + None +} + +fn same_keying( + left: &HashDistribution, + right: &HashDistribution, + left_keys: &[PhysicalExprRef], + right_keys: &[PhysicalExprRef], +) -> bool { + left.hash_fn == right.hash_fn + && keys_match_columns(left_keys, &left.keys) + && keys_match_columns(right_keys, &right.keys) +} + +fn keys_match_columns(exprs: &[PhysicalExprRef], names: &[String]) -> bool { + if exprs.len() != names.len() { + return false; + } + exprs.iter().zip(names.iter()).all(|(expr, name)| { + expr.as_any() + .downcast_ref::() + .is_some_and(|c| c.name() == name) + }) +} + +/// Rebuild one side of the join so its output partitioning is `Hash(keys, +/// target_buckets)` without going through a network shuffle. +/// +/// Walks past any single intervening `RepartitionExec` to find the source +/// plan that exposes a `BallistaPartitionMetadata`. If the source already has +/// `target_buckets` partitions, returns it directly (eliding the repartition). +/// Otherwise wraps it in a [`BucketSubPartitionExec`] that locally chains the +/// relevant input partitions. +fn rewrite_side( + plan: &Arc, + source_dist: &HashDistribution, + target_buckets: usize, +) -> Result> { + let source = repartition_input(plan) + .map(Arc::clone) + .unwrap_or_else(|| Arc::clone(plan)); + + if source_dist.num_buckets == target_buckets { + return Ok(source); + } + + let target_dist = HashDistribution::new( + source_dist.keys.clone(), + source_dist.hash_fn, + target_buckets, + ); + Ok(Arc::new(BucketSubPartitionExec::try_new(source, target_dist)?)) +} + +/// Returns the input of `plan` if `plan` is a `RepartitionExec`, else `None`. +fn repartition_input(plan: &Arc) -> Option<&Arc> { + plan.as_any() + .downcast_ref::() + .map(|r| r.input()) +} + +#[cfg(test)] +mod tests { + use super::*; + use ballista_core::partitioning::{ + HashDistributedScanExec, HashDistribution, HashFn, + }; + use datafusion::arrow::array::Int32Array; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::arrow::record_batch::RecordBatch; + use datafusion::common::JoinType; + use datafusion::datasource::{MemTable, TableProvider}; + use datafusion::physical_plan::Partitioning; + use datafusion::physical_plan::displayable; + use datafusion::physical_plan::joins::{HashJoinExecBuilder, PartitionMode}; + use datafusion::prelude::SessionContext; + + fn build_scan( + keys: &[&str], + hash_fn: HashFn, + buckets: usize, + col_names: &[&str], + ) -> Arc { + let fields: Vec<_> = col_names + .iter() + .map(|n| Field::new(*n, DataType::Int32, false)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + let partitions: Vec> = (0..buckets) + .map(|b| { + let arrays: Vec<_> = (0..col_names.len()) + .map(|_| { + Arc::new(Int32Array::from(vec![b as i32, b as i32 + 1])) as _ + }) + .collect(); + vec![RecordBatch::try_new(schema.clone(), arrays).unwrap()] + }) + .collect(); + let provider = Arc::new(MemTable::try_new(schema, partitions).unwrap()); + let dist = HashDistribution::new( + keys.iter().map(|s| s.to_string()).collect(), + hash_fn, + buckets, + ); + let rt = tokio::runtime::Builder::new_current_thread().build().unwrap(); + let ctx = SessionContext::new(); + let inner = rt + .block_on(provider.scan(&ctx.state(), None, &[], None)) + .unwrap(); + Arc::new(HashDistributedScanExec::try_new(inner, dist).unwrap()) + } + + fn col(name: &str, idx: usize) -> PhysicalExprRef { + Arc::new(Column::new(name, idx)) + } + + fn join_with_repartitions( + left: Arc, + right: Arc, + on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + buckets: usize, + ) -> Arc { + let (left_keys, right_keys): (Vec<_>, Vec<_>) = on.iter().cloned().unzip(); + let left_repart = Arc::new( + RepartitionExec::try_new(left, Partitioning::Hash(left_keys, buckets)) + .unwrap(), + ); + let right_repart = Arc::new( + RepartitionExec::try_new(right, Partitioning::Hash(right_keys, buckets)) + .unwrap(), + ); + Arc::new( + HashJoinExecBuilder::new(left_repart, right_repart, on, JoinType::Inner) + .with_partition_mode(PartitionMode::Partitioned) + .build() + .unwrap(), + ) + } + + fn count_repartitions(plan: &Arc) -> usize { + let mut count = 0; + let _ = plan.clone().transform_up(|p| { + if p.as_any().downcast_ref::().is_some() { + count += 1; + } + Ok(Transformed::no(p)) + }); + count + } + + fn count_sub_partitions(plan: &Arc) -> usize { + let mut count = 0; + let _ = plan.clone().transform_up(|p| { + if p.as_any() + .downcast_ref::() + .is_some() + { + count += 1; + } + Ok(Transformed::no(p)) + }); + count + } + + #[test] + fn elides_repartition_when_inputs_are_colocated() { + let left = build_scan(&["id"], HashFn::Murmur3, 4, &["id", "lval"]); + let right = build_scan(&["id"], HashFn::Murmur3, 4, &["id", "rval"]); + let join = join_with_repartitions( + left, + right, + vec![(col("id", 0), col("id", 0))], + 4, + ); + assert_eq!(count_repartitions(&join), 2, "setup precondition"); + + let optimized = ColocatedJoinRule::new() + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!( + count_repartitions(&optimized), + 0, + "expected both repartitions to be elided\n{}", + displayable(optimized.as_ref()).indent(false), + ); + } + + #[test] + fn elides_when_left_and_right_use_different_column_names() { + // Left scan bucketed by `id`, right scan bucketed by `other`. Each + // side's declared bucketing key still matches its join key, so the + // join is colocated despite the cross-side name difference. + let left = build_scan(&["id"], HashFn::Murmur3, 4, &["id", "lval"]); + let right = build_scan(&["other"], HashFn::Murmur3, 4, &["other", "rval"]); + let join = join_with_repartitions( + left, + right, + vec![(col("id", 0), col("other", 0))], + 4, + ); + let optimized = ColocatedJoinRule::new() + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!(count_repartitions(&optimized), 0); + } + + #[test] + fn preserves_repartition_when_bucket_count_mismatch() { + // 4 vs 5: not divisible. (For divisor case see + // `sub_partitions_when_bucket_count_divides`.) + let left = build_scan(&["id"], HashFn::Murmur3, 4, &["id", "lval"]); + let right = build_scan(&["id"], HashFn::Murmur3, 5, &["id", "rval"]); + let join = join_with_repartitions( + left, + right, + vec![(col("id", 0), col("id", 0))], + 4, + ); + let optimized = ColocatedJoinRule::new() + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!(count_repartitions(&optimized), 2); + } + + #[test] + fn sub_partitions_when_bucket_count_divides() { + // Left bucketed 8 ways, right 4 ways → divisor relationship. + // Larger side is locally coalesced; both repartitions are removed. + let left = build_scan(&["id"], HashFn::Murmur3, 8, &["id", "lval"]); + let right = build_scan(&["id"], HashFn::Murmur3, 4, &["id", "rval"]); + let join = join_with_repartitions( + left, + right, + vec![(col("id", 0), col("id", 0))], + 4, + ); + let optimized = ColocatedJoinRule::new() + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!( + count_repartitions(&optimized), + 0, + "expected both repartitions removed via sub-partitioning\n{}", + displayable(optimized.as_ref()).indent(false), + ); + assert_eq!( + count_sub_partitions(&optimized), + 1, + "expected one BucketSubPartitionExec on the larger side\n{}", + displayable(optimized.as_ref()).indent(false), + ); + } + + #[test] + fn preserves_repartition_when_bucket_count_indivisible() { + // 6 vs 4: neither divides the other → not colocated. + let left = build_scan(&["id"], HashFn::Murmur3, 6, &["id", "lval"]); + let right = build_scan(&["id"], HashFn::Murmur3, 4, &["id", "rval"]); + let join = join_with_repartitions( + left, + right, + vec![(col("id", 0), col("id", 0))], + 4, + ); + let optimized = ColocatedJoinRule::new() + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!(count_repartitions(&optimized), 2); + assert_eq!(count_sub_partitions(&optimized), 0); + } + + #[test] + fn preserves_repartition_when_hash_fn_mismatch() { + let left = build_scan(&["id"], HashFn::Murmur3, 4, &["id", "lval"]); + let right = build_scan(&["id"], HashFn::XxHash64, 4, &["id", "rval"]); + let join = join_with_repartitions( + left, + right, + vec![(col("id", 0), col("id", 0))], + 4, + ); + let optimized = ColocatedJoinRule::new() + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!(count_repartitions(&optimized), 2); + } + + #[test] + fn preserves_repartition_when_join_key_misaligned() { + // Left scan declares bucketing on "lval", but join is on "id". + let left = build_scan(&["lval"], HashFn::Murmur3, 4, &["id", "lval"]); + let right = build_scan(&["id"], HashFn::Murmur3, 4, &["id", "rval"]); + let join = join_with_repartitions( + left, + right, + vec![(col("id", 0), col("id", 0))], + 4, + ); + let optimized = ColocatedJoinRule::new() + .optimize(join, &ConfigOptions::default()) + .unwrap(); + assert_eq!(count_repartitions(&optimized), 2); + } +} diff --git a/ballista/scheduler/src/physical_optimizer/mod.rs b/ballista/scheduler/src/physical_optimizer/mod.rs index 9bd816678e..69a036f809 100644 --- a/ballista/scheduler/src/physical_optimizer/mod.rs +++ b/ballista/scheduler/src/physical_optimizer/mod.rs @@ -15,4 +15,6 @@ // specific language governing permissions and limitations // under the License. +pub mod broadcast_small_side; +pub mod colocated_join; pub mod join_selection; diff --git a/ballista/scheduler/src/state/aqe/mod.rs b/ballista/scheduler/src/state/aqe/mod.rs index 470a6e9eed..f83c15c3a7 100644 --- a/ballista/scheduler/src/state/aqe/mod.rs +++ b/ballista/scheduler/src/state/aqe/mod.rs @@ -42,14 +42,6 @@ use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use std::vec; -// TODO: the AQE planner runs DataFusion's DefaultPhysicalPlanner with a -// list of PhysicalOptimizerRules and never goes through -// DefaultDistributedPlanner::plan_query_stages_internal, so neither -// maybe_promote_to_broadcast nor the HashJoinExec(CollectLeft) shuffle -// lowering fire here. Joins that would broadcast under the default -// planner stay on the Partitioned shuffle path. Move the lowering into -// an AQE optimizer rule in a follow-up PR. - mod adapter; mod execution_plan; pub mod optimizer_rule; diff --git a/ballista/scheduler/src/state/aqe/planner.rs b/ballista/scheduler/src/state/aqe/planner.rs index 66c330e296..01a4eebbb2 100644 --- a/ballista/scheduler/src/state/aqe/planner.rs +++ b/ballista/scheduler/src/state/aqe/planner.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +use crate::physical_optimizer::broadcast_small_side::BroadcastSmallSideRule; +use crate::physical_optimizer::colocated_join::ColocatedJoinRule; use crate::state::aqe::adapter::BallistaAdapter; use crate::state::aqe::execution_plan::{AdaptiveDatafusionExec, ExchangeExec}; use crate::state::aqe::optimizer_rule::{ @@ -122,7 +124,7 @@ impl AdaptivePlanner { Self::try_new_with_optimizers( plan, session_config, - Self::default_optimizers(), + Self::default_optimizers(session_config), job_name, ) } @@ -419,8 +421,11 @@ impl AdaptivePlanner { /// /// # Returns /// A vector of default physical optimizer rules. - fn default_optimizers() -> Vec { + fn default_optimizers( + session_config: &SessionConfig, + ) -> Vec { let mut physical_optimizers = PhysicalOptimizer::new().rules; + physical_optimizers.extend(Self::join_optimization_rules(session_config)); physical_optimizers.push(Arc::new(PropagateEmptyExecRule::default())); // `DistributedExchangeRule` should be the last plan mutator rule in the chain physical_optimizers.push(Arc::new(DistributedExchangeRule::default())); @@ -448,6 +453,21 @@ impl AdaptivePlanner { .with_config(session_config.clone()) .build() } + /// Returns the colocated/broadcast optimizer rules that should run before + /// `DistributedExchangeRule`. + fn join_optimization_rules( + session_config: &SessionConfig, + ) -> Vec { + vec![ + // Strip redundant RepartitionExec above colocated joins or fold + // bucket-count divisor mismatches into local sub-partitioning. + Arc::new(ColocatedJoinRule::new()), + // Convert Partitioned joins to broadcast when one side is small + // enough — runs after ColocatedJoinRule so colocation wins when + // both apply. + Arc::new(BroadcastSmallSideRule::from_session_config(session_config)), + ] + } /// Recursively finds runnable exchanges in the execution plan. /// /// # Arguments diff --git a/ballista/scheduler/src/state/aqe/test/colocated_join_e2e.rs b/ballista/scheduler/src/state/aqe/test/colocated_join_e2e.rs new file mode 100644 index 0000000000..3ac39e9477 --- /dev/null +++ b/ballista/scheduler/src/state/aqe/test/colocated_join_e2e.rs @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! End-to-end verification for the Pinot-style colocated-join optimizer. +//! +//! These tests exercise the full path: SQL → DataFusion physical plan → +//! AdaptivePlanner (which runs `ColocatedJoinRule` and +//! `BroadcastSmallSideRule` in sequence with the rest of the rule chain). +//! +//! The colocation case relies on `PartitionedTableProvider` to advertise +//! per-table hash bucketing; the unbucketed case asserts the plan is +//! unchanged so we know the rule is silent on tables without metadata. + +use crate::assert_plan; +use crate::state::aqe::planner::AdaptivePlanner; +use ballista_core::extension::SessionConfigExt; +use ballista_core::partitioning::{ + HashDistribution, HashFn, PartitionedTableProvider, +}; +use datafusion::arrow::array::{Int32Array, RecordBatch}; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::catalog::TableProvider; +use datafusion::datasource::MemTable; +use datafusion::execution::SessionStateBuilder; +use datafusion::prelude::{SessionConfig, SessionContext}; +use std::sync::Arc; + +fn ctx() -> SessionContext { + let config = SessionConfig::new_with_ballista() + .with_target_partitions(4) + .with_round_robin_repartition(false) + // Disable broadcast so these tests assert only the colocation / + // sub-partition behavior — the default threshold (10 MB) would + // promote our tiny inputs to CollectLeft and overshadow the rule + // under test. + .with_ballista_broadcast_join_threshold_bytes(0) + // Upstream defaults to sort-merge join (issue #1648); opt back into + // hash join so ColocatedJoinRule (which only matches HashJoinExec) + // can fire on the planned join. + .set_bool("datafusion.optimizer.prefer_hash_join", true); + let state = SessionStateBuilder::new() + .with_config(config) + .with_default_features() + .build(); + SessionContext::new_with_state(state) +} + +fn build_bucketed_table( + schema: Arc, + buckets: usize, + keys: &[&str], +) -> Arc { + let parts: Vec> = (0..buckets) + .map(|b| { + let arrays: Vec<_> = schema + .fields() + .iter() + .map(|_| { + Arc::new(Int32Array::from(vec![b as i32, b as i32 + 100])) as _ + }) + .collect(); + vec![RecordBatch::try_new(schema.clone(), arrays).unwrap()] + }) + .collect(); + let inner: Arc = + Arc::new(MemTable::try_new(schema, parts).unwrap()); + let dist = HashDistribution::new( + keys.iter().map(|s| s.to_string()).collect(), + HashFn::Murmur3, + buckets, + ); + Arc::new(PartitionedTableProvider::new(inner, dist)) +} + +fn ab_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("v", DataType::Int32, false), + ])) +} + +#[tokio::test] +async fn colocated_join_emits_no_exchange() -> datafusion::error::Result<()> { + let ctx = ctx(); + ctx.register_table("a", build_bucketed_table(ab_schema(), 4, &["id"]))?; + ctx.register_table("b", build_bucketed_table(ab_schema(), 4, &["id"]))?; + + let plan = ctx + .sql("select a.id, a.v, b.v from a inner join b on a.id = b.id") + .await? + .create_physical_plan() + .await?; + let planner = + AdaptivePlanner::try_new(ctx.state().config(), plan, "colo_inner".to_string())?; + + // Both inputs already satisfy the join's hash distribution, so the + // optimizer should leave the plan exchange-free. + assert_plan!(planner.current_plan(), @ r" + AdaptiveDatafusionExec: is_final=false, plan_id=0, stage_id=pending, stage_resolved=false + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)], projection=[id@0, v@1, v@3] + HashDistributedScanExec: keys=[id], hash_fn=murmur3, buckets=4 + DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1] + HashDistributedScanExec: keys=[id], hash_fn=murmur3, buckets=4 + DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1] + "); + Ok(()) +} + +#[tokio::test] +async fn divisor_join_inserts_sub_partition() -> datafusion::error::Result<()> { + // Divisor case (8/4=2): the larger side wraps in BucketSubPartitionExec + // and the join still avoids a network shuffle. + let ctx = ctx(); + ctx.register_table("a", build_bucketed_table(ab_schema(), 8, &["id"]))?; + ctx.register_table("b", build_bucketed_table(ab_schema(), 4, &["id"]))?; + + let plan = ctx + .sql("select a.id, a.v, b.v from a inner join b on a.id = b.id") + .await? + .create_physical_plan() + .await?; + let planner = + AdaptivePlanner::try_new(ctx.state().config(), plan, "colo_div".to_string())?; + + assert_plan!(planner.current_plan(), @ r" + AdaptiveDatafusionExec: is_final=false, plan_id=0, stage_id=pending, stage_resolved=false + ProjectionExec: expr=[id@1 as id, v@2 as v, v@0 as v] + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)], projection=[v@1, id@2, v@3] + HashDistributedScanExec: keys=[id], hash_fn=murmur3, buckets=4 + DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1] + BucketSubPartitionExec: out_buckets=4, factor=2 + HashDistributedScanExec: keys=[id], hash_fn=murmur3, buckets=8 + DataSourceExec: partitions=8, partition_sizes=[1, 1, 1, 1, 1, 1, 1, 1] + "); + Ok(()) +} + +#[tokio::test] +async fn unbucketed_join_keeps_exchange() -> datafusion::error::Result<()> { + // No PartitionedTableProvider here — the optimizer should be silent and + // the standard ExchangeExec stage boundaries should remain in place. + let ctx = ctx(); + let schema = ab_schema(); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(Int32Array::from(vec![10, 20, 30, 40])), + ], + ) + .unwrap(); + let provider: Arc = Arc::new( + MemTable::try_new(schema, vec![vec![batch.clone()], vec![batch]]).unwrap(), + ); + ctx.register_table("a", provider.clone())?; + ctx.register_table("b", provider)?; + + let plan = ctx + .sql("select a.id, a.v, b.v from a inner join b on a.id = b.id") + .await? + .create_physical_plan() + .await?; + let planner = + AdaptivePlanner::try_new(ctx.state().config(), plan, "no_colo".to_string())?; + + assert_plan!(planner.current_plan(), @ r" + AdaptiveDatafusionExec: is_final=false, plan_id=2, stage_id=pending, stage_resolved=false + HashJoinExec: mode=Partitioned, join_type=Inner, on=[(id@0, id@0)], projection=[id@0, v@1, v@3] + ExchangeExec: partitioning=Hash([id@0], 4), plan_id=0, stage_id=pending, stage_resolved=false + DataSourceExec: partitions=2, partition_sizes=[1, 1] + ExchangeExec: partitioning=Hash([id@0], 4), plan_id=1, stage_id=pending, stage_resolved=false + DataSourceExec: partitions=2, partition_sizes=[1, 1] + "); + Ok(()) +} diff --git a/ballista/scheduler/src/state/aqe/test/mod.rs b/ballista/scheduler/src/state/aqe/test/mod.rs index a60cd038ab..097f076d67 100644 --- a/ballista/scheduler/src/state/aqe/test/mod.rs +++ b/ballista/scheduler/src/state/aqe/test/mod.rs @@ -17,6 +17,8 @@ /// Test if stages can be added or removed mod alter_stages; +/// End-to-end verification of the Pinot-style colocated-join optimizer. +mod colocated_join_e2e; /// Tests if plan is going to be split to stages correctly mod plan_to_stages;