diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 334d5b9912..d3ea5724fd 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -67,6 +67,10 @@ pub const BALLISTA_CLIENT_IO_RETRY_WAIT_TIME_MS: &str = "ballista.client.io_retry_wait_time_ms"; /// Enables adaptive query planning pub const BALLISTA_ADAPTIVE_PLANNER_ENABLED: &str = "ballista.planner.adaptive.enabled"; + +/// Setting key for [`BallistaConfig::aqe_limit_early_stop_enabled`]. +pub const BALLISTA_AQE_LIMIT_EARLY_STOP_ENABLED: &str = + "ballista.aqe.limit_early_stop.enabled"; /// Configuration key for enabling sort-based shuffle. pub const BALLISTA_SHUFFLE_SORT_BASED_ENABLED: &str = "ballista.shuffle.sort_based.enabled"; @@ -138,6 +142,13 @@ static CONFIG_ENTRIES: LazyLock> = LazyLock::new(|| "Enables Adaptive Query Planning (EXPERIMENTAL)".to_string(), DataType::Boolean, Some(false.to_string())), + ConfigEntry::new(BALLISTA_AQE_LIMIT_EARLY_STOP_ENABLED.to_string(), + "When AQE is enabled, cancels remaining tasks of stages feeding an \ + eligible bare GlobalLimitExec once the rows already written exceed \ + the limit. Eligibility excludes OFFSET, sorted inputs, and limits \ + inside a non-pass-through subtree.".to_string(), + DataType::Boolean, + Some(true.to_string())), ConfigEntry::new(BALLISTA_SHUFFLE_SORT_BASED_ENABLED.to_string(), "Enable sort-based shuffle which writes consolidated files with index".to_string(), DataType::Boolean, @@ -352,6 +363,16 @@ impl BallistaConfig { self.get_bool_setting(BALLISTA_ADAPTIVE_PLANNER_ENABLED) } + /// Is AQE early-stop on global LIMIT enabled. + /// + /// Only takes effect when [`Self::adaptive_query_planner_enabled`] is + /// also true. When on, the scheduler tracks rows produced by stages + /// that feed an eligible `GlobalLimitExec` and cancels remaining + /// tasks once the limit is satisfied. + pub fn aqe_limit_early_stop_enabled(&self) -> bool { + self.get_bool_setting(BALLISTA_AQE_LIMIT_EARLY_STOP_ENABLED) + } + /// Returns whether sort-based shuffle is enabled. /// /// When enabled, shuffle writes produce a single consolidated file per input diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs index 08040b4e7c..02550d8f2e 100644 --- a/ballista/core/src/extension.rs +++ b/ballista/core/src/extension.rs @@ -220,6 +220,10 @@ pub trait SessionConfigExt { /// Is adaptive query planner enabled fn ballista_adaptive_query_planner_enabled(&self) -> bool; + /// Is AQE early-stop on global LIMIT enabled. Only takes effect when + /// [`Self::ballista_adaptive_query_planner_enabled`] is also true. + fn ballista_aqe_limit_early_stop_enabled(&self) -> bool; + /// Set user defined metadata keys in Ballista gRPC requests fn with_ballista_grpc_metadata(self, metadata: HashMap) -> Self; @@ -537,6 +541,16 @@ impl SessionConfigExt for SessionConfig { .unwrap_or_else(|| BallistaConfig::default().adaptive_query_planner_enabled()) } + fn ballista_aqe_limit_early_stop_enabled(&self) -> bool { + self.options() + .extensions + .get::() + .map(|c| c.aqe_limit_early_stop_enabled()) + .unwrap_or_else(|| { + BallistaConfig::default().aqe_limit_early_stop_enabled() + }) + } + fn with_ballista_grpc_metadata(self, metadata: HashMap) -> Self { let extension = BallistaGrpcMetadataInterceptor::new(metadata); self.with_extension(Arc::new(extension)) diff --git a/ballista/scheduler/src/scheduler_server/event.rs b/ballista/scheduler/src/scheduler_server/event.rs index c6d11fb1bf..9cf2a046d2 100644 --- a/ballista/scheduler/src/scheduler_server/event.rs +++ b/ballista/scheduler/src/scheduler_server/event.rs @@ -96,6 +96,15 @@ pub enum QueryStageSchedulerEvent { ExecutorLost(String, Option), /// Request to cancel specific running tasks. CancelTasks(Vec), + /// AQE early-stop trigger fired: a tracked job's accumulated row + /// count has crossed its LIMIT threshold. The handler short-stops the + /// tagged producer stages (synthesizing successful completion for any + /// remaining tasks so the consumer stage can run with the partial + /// shuffle output) and cancels the now-irrelevant in-flight tasks. + EarlyStopCancel { + /// Unique job identifier. + job_id: String, + }, } impl Debug for QueryStageSchedulerEvent { @@ -165,6 +174,9 @@ impl Debug for QueryStageSchedulerEvent { QueryStageSchedulerEvent::CancelTasks(status) => { write!(f, "CancelTasks : status:[{status:?}].") } + QueryStageSchedulerEvent::EarlyStopCancel { job_id } => { + write!(f, "EarlyStopCancel : job_id={job_id}.") + } } } } diff --git a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs index 85b0924e28..96c7a8df1b 100644 --- a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs +++ b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs @@ -342,6 +342,39 @@ impl warn!("Fail to cancel running tasks due to {e:?}"); } } + QueryStageSchedulerEvent::EarlyStopCancel { job_id } => { + info!("AQE early-stop firing for job {job_id}"); + match self.state.task_manager.early_stop_job(&job_id).await { + Ok((tasks_to_cancel, follow_up_events)) => { + if !tasks_to_cancel.is_empty() { + event_sender + .post_event( + QueryStageSchedulerEvent::CancelTasks( + tasks_to_cancel, + ), + ) + .await?; + } + for ev in follow_up_events { + event_sender.post_event(ev).await?; + } + // Producer stages are now Successful so the + // consumer (LIMIT) stage can be scheduled; + // revive offers to pick it up. + if self.state.config.is_push_staged_scheduling() { + event_sender + .post_event(QueryStageSchedulerEvent::ReviveOffers) + .await?; + } + } + Err(e) => { + error!( + "Failed to short-stop AQE producer stages for \ + job {job_id}: {e:?}" + ); + } + } + } QueryStageSchedulerEvent::JobDataClean(job_id) => { self.state.executor_manager.clean_up_job_data(job_id); } diff --git a/ballista/scheduler/src/state/aqe/limit_early_stop.rs b/ballista/scheduler/src/state/aqe/limit_early_stop.rs new file mode 100644 index 0000000000..8cf435053b --- /dev/null +++ b/ballista/scheduler/src/state/aqe/limit_early_stop.rs @@ -0,0 +1,371 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Plan-time analyzer that identifies eligible `GlobalLimitExec` operators +//! and the producer stages whose row counts feed them. +//! +//! Runs once at the end of [`AdaptiveExecutionGraph::try_new`]. Walks the +//! optimized plan top-down, finds the outermost eligible LIMIT, and traces +//! through pass-through operators (LocalLimit, Coalesce*, Projection) to +//! the immediate `ExchangeExec`s that produce its input. Returns one +//! [`JobLimitContext`] per eligible LIMIT, naming the producer stage IDs +//! the runtime tracker should observe. +//! +//! Eligibility (false negatives are safe — the query simply runs without +//! early-stop, matching today's behavior): +//! - `skip == 0` (no OFFSET; v2) +//! - `fetch.is_some()` (no `LIMIT ALL`) +//! - no `SortExec` / `SortPreservingMergeExec` / ordered aggregate / +//! ordered window in the subtree (Top-K is a separate optimization; +//! ordered cancellation would be wrong without sort awareness) +//! - the LIMIT is connected to its producer Exchange(s) only through +//! pass-through operators (LocalLimit, CoalesceBatches, +//! CoalescePartitions, Projection, Union) +//! - each producer Exchange already has a `stage_id` assigned +//! +//! Nested LIMITs: only the outermost is tagged. + +use std::collections::HashSet; +use std::sync::Arc; + +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::aggregates::AggregateExec; +#[allow(deprecated)] +use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion::physical_plan::union::UnionExec; +use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; +use log::debug; + +use crate::state::aqe::execution_plan::ExchangeExec; + +/// Per-job tagging produced by [`LimitEarlyStopAnalyzer`]. +/// +/// `fetch` is the LIMIT's row count. `producer_stage_ids` is the set of +/// stages whose `ShuffleWritePartition.num_rows` should be summed by the +/// runtime tracker; once the sum crosses `fetch * safety_factor` the +/// scheduler fires an early-stop cancellation for the job. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct JobLimitContext { + pub fetch: u64, + pub producer_stage_ids: HashSet, +} + +/// Walks the post-stage-resolution plan and collects [`JobLimitContext`] +/// entries for eligible LIMIT operators. +pub struct LimitEarlyStopAnalyzer<'a> { + plan: &'a Arc, +} + +impl<'a> LimitEarlyStopAnalyzer<'a> { + pub fn new(plan: &'a Arc) -> Self { + Self { plan } + } + + /// Top-down scan: emit one context per outermost eligible LIMIT. + /// + /// Returns an empty vec for plans with no eligible LIMIT, including + /// plans whose producer Exchange has no stage_id yet (early-stop can + /// still be added in a later AQE iteration; v1 simply skips). + pub fn analyze(&self) -> Vec { + let mut contexts = Vec::new(); + Self::visit(self.plan, &mut contexts); + contexts + } + + fn visit(plan: &Arc, out: &mut Vec) { + if let Some(limit) = plan.as_any().downcast_ref::() { + if let Some(ctx) = Self::try_build_context(limit) { + out.push(ctx); + // Outermost-only: do not recurse below an emitted LIMIT. + return; + } + // Ineligible — still recurse to find a nested eligible LIMIT. + } + for child in plan.children() { + Self::visit(child, out); + } + } + + fn try_build_context(limit: &GlobalLimitExec) -> Option { + if limit.skip() != 0 { + return None; + } + let fetch = limit.fetch()?; + if fetch == 0 { + return None; + } + let child = limit.children().into_iter().next()?.clone(); + if subtree_has_ordering(&child) { + return None; + } + let mut producer_stage_ids = HashSet::new(); + if !collect_producer_stage_ids(&child, &mut producer_stage_ids) { + return None; + } + if producer_stage_ids.is_empty() { + return None; + } + Some(JobLimitContext { + fetch: fetch as u64, + producer_stage_ids, + }) + } +} + +/// Walk from the LIMIT's input until reaching producer Exchanges, +/// pushing each Exchange's `stage_id` into `out`. Returns false if any +/// branch leads to an Exchange without a `stage_id` set, or hits a node +/// outside the pass-through allowlist before finding one. +/// +/// Allowlist (operators that do not invalidate row-count tracking when +/// placed between the LIMIT and its producer Exchange): +/// - `LocalLimitExec` (slices but the writer above counts pre-slice) +/// - `CoalesceBatchesExec`, `CoalescePartitionsExec` +/// - `ProjectionExec` (pure expression evaluation) +/// - `UnionExec` (recurse into each branch) +fn collect_producer_stage_ids( + node: &Arc, + out: &mut HashSet, +) -> bool { + let any = node.as_any(); + if let Some(exchange) = any.downcast_ref::() { + match exchange.stage_id() { + Some(id) => { + out.insert(id); + true + } + None => { + debug!( + "LimitEarlyStopAnalyzer: producer ExchangeExec has no \ + stage_id yet; skipping early-stop for this LIMIT" + ); + false + } + } + } else if any.downcast_ref::().is_some() + || is_coalesce_batches(any) + || any.downcast_ref::().is_some() + || any.downcast_ref::().is_some() + || any.downcast_ref::().is_some() + { + // Pass-through: recurse through all children. + for child in node.children() { + if !collect_producer_stage_ids(child, out) { + return false; + } + } + true + } else { + // Anything else (joins, aggregates, scans, sorts, etc.) breaks + // the simple producer-row counting model used in v1. + false + } +} + +#[allow(deprecated)] +fn is_coalesce_batches(any: &dyn std::any::Any) -> bool { + any.downcast_ref::().is_some() +} + +/// Conservative ordering detection. Any of these in the LIMIT's subtree +/// means we bail — early-stop must not silently reorder results. +fn subtree_has_ordering(node: &Arc) -> bool { + let any = node.as_any(); + if any.downcast_ref::().is_some() + || any.downcast_ref::().is_some() + || any.downcast_ref::().is_some() + || any.downcast_ref::().is_some() + { + return true; + } + if let Some(agg) = any.downcast_ref::() { + if agg.input_order_mode() != &datafusion::physical_plan::InputOrderMode::Linear + { + return true; + } + } + node.children() + .iter() + .any(|c| subtree_has_ordering(c)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::state::aqe::execution_plan::AdaptiveDatafusionExec; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::physical_plan::empty::EmptyExec; + use datafusion::physical_plan::projection::ProjectionExec; + use datafusion::physical_plan::sorts::sort::SortExec; + use datafusion::physical_plan::union::UnionExec; + use datafusion::physical_plan::{ExecutionPlan, expressions::Column}; + use std::sync::Arc; + + fn schema() -> Arc { + Arc::new(Schema::new(vec![Field::new("c", DataType::Int32, true)])) + } + + fn leaf() -> Arc { + Arc::new(EmptyExec::new(schema())) + } + + fn exchange_with_stage( + input: Arc, + stage_id: Option, + ) -> Arc { + let exec = ExchangeExec::new(input, None, 0); + if let Some(id) = stage_id { + exec.set_stage_id(id); + } + Arc::new(exec) + } + + fn limit( + skip: usize, + fetch: Option, + input: Arc, + ) -> Arc { + Arc::new(GlobalLimitExec::new(input, skip, fetch)) + } + + fn analyze(plan: Arc) -> Vec { + let plan = Arc::new(AdaptiveDatafusionExec::new(99, plan)); + let plan: Arc = plan; + LimitEarlyStopAnalyzer::new(&plan).analyze() + } + + #[test] + fn eligible_simple_limit() { + // root -> GlobalLimit(100) -> Exchange(stage=7) -> leaf + let plan = limit(0, Some(100), exchange_with_stage(leaf(), Some(7))); + let ctx = analyze(plan); + assert_eq!(ctx.len(), 1); + assert_eq!(ctx[0].fetch, 100); + assert_eq!(ctx[0].producer_stage_ids, HashSet::from([7])); + } + + #[test] + fn eligible_with_passthrough_operators() { + // GlobalLimit -> CoalescePartitions -> LocalLimit -> Exchange + let exchange = exchange_with_stage(leaf(), Some(3)); + let local = Arc::new(LocalLimitExec::new(exchange, 100)); + let coalesce = Arc::new(CoalescePartitionsExec::new(local)); + let plan = limit(0, Some(100), coalesce); + let ctx = analyze(plan); + assert_eq!(ctx.len(), 1); + assert_eq!(ctx[0].producer_stage_ids, HashSet::from([3])); + } + + #[test] + fn eligible_union_multi_producer() { + // GlobalLimit -> Union(Exchange(2), Exchange(5)) + let union: Arc = UnionExec::try_new(vec![ + exchange_with_stage(leaf(), Some(2)), + exchange_with_stage(leaf(), Some(5)), + ]) + .unwrap(); + let plan = limit(0, Some(50), union); + let ctx = analyze(plan); + assert_eq!(ctx.len(), 1); + assert_eq!(ctx[0].producer_stage_ids, HashSet::from([2, 5])); + } + + #[test] + fn ineligible_offset() { + let plan = limit(5, Some(100), exchange_with_stage(leaf(), Some(7))); + assert!(analyze(plan).is_empty()); + } + + #[test] + fn ineligible_no_fetch() { + let plan = limit(0, None, exchange_with_stage(leaf(), Some(7))); + assert!(analyze(plan).is_empty()); + } + + #[test] + fn ineligible_sort_below() { + let sort_expr = + datafusion::physical_expr::PhysicalSortExpr::new_default(Arc::new( + Column::new("c", 0), + )); + let sorted = Arc::new(SortExec::new( + datafusion::physical_expr::LexOrdering::new(vec![sort_expr]).unwrap(), + leaf(), + )); + let plan = limit(0, Some(100), exchange_with_stage(sorted, Some(7))); + assert!(analyze(plan).is_empty()); + } + + #[test] + fn ineligible_producer_exchange_without_stage_id() { + let plan = limit(0, Some(100), exchange_with_stage(leaf(), None)); + assert!(analyze(plan).is_empty()); + } + + #[test] + fn ineligible_non_passthrough_between_limit_and_exchange() { + // GlobalLimit -> Projection(empty list) -> Exchange : Projection IS + // in the allowlist, so this should still be eligible. Pick something + // that's not allow-listed instead: Sort with a single sort expr. + let sort_expr = + datafusion::physical_expr::PhysicalSortExpr::new_default(Arc::new( + Column::new("c", 0), + )); + let sorted = Arc::new(SortExec::new( + datafusion::physical_expr::LexOrdering::new(vec![sort_expr]).unwrap(), + exchange_with_stage(leaf(), Some(7)), + )); + let plan = limit(0, Some(100), sorted); + // Subtree contains SortExec → bails out via ordering check before + // producer collection. Either way: ineligible. + assert!(analyze(plan).is_empty()); + } + + #[test] + fn projection_between_limit_and_exchange_is_eligible() { + use datafusion::physical_plan::PhysicalExpr; + let col: Arc = Arc::new(Column::new("c", 0)); + let proj: Arc = Arc::new( + ProjectionExec::try_new( + vec![(col, "c".to_string())], + exchange_with_stage(leaf(), Some(11)), + ) + .unwrap(), + ); + let plan = limit(0, Some(100), proj); + let ctx = analyze(plan); + assert_eq!(ctx.len(), 1); + assert_eq!(ctx[0].producer_stage_ids, HashSet::from([11])); + } + + #[test] + fn nested_limits_outermost_only() { + // Outer Limit(50) -> Exchange(stage 4) -> inner Limit(10) -> Exchange(stage 9) + let inner = limit(0, Some(10), exchange_with_stage(leaf(), Some(9))); + let middle_exchange = exchange_with_stage(inner, Some(4)); + let plan = limit(0, Some(50), middle_exchange); + let ctx = analyze(plan); + assert_eq!(ctx.len(), 1); + assert_eq!(ctx[0].fetch, 50); + assert_eq!(ctx[0].producer_stage_ids, HashSet::from([4])); + } +} diff --git a/ballista/scheduler/src/state/aqe/limit_tracker.rs b/ballista/scheduler/src/state/aqe/limit_tracker.rs new file mode 100644 index 0000000000..3b3b757a8e --- /dev/null +++ b/ballista/scheduler/src/state/aqe/limit_tracker.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Per-job row-count tracker that fires once a global `LIMIT` is satisfied. +//! +//! Used by the AQE early-stop feature: the scheduler-side analyzer tags +//! the producer stages feeding an eligible `GlobalLimitExec`, then this +//! tracker observes per-task row counts as `TaskStatus` updates arrive. +//! When the sum crosses `limit * safety_factor`, the tracker returns +//! `CancelRemaining` exactly once and the scheduler dispatches an +//! `EarlyStopCancel` event for the job. + +use std::collections::HashSet; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; + +/// Decision returned by [`JobLimitTracker::observe`]. +#[derive(Debug, PartialEq, Eq)] +pub enum EarlyStopDecision { + /// Threshold has not been crossed, or another observer already won the + /// race to fire the trigger. + Continue, + /// Threshold has just been crossed by this observation. Caller must + /// dispatch the cancellation event exactly once. + CancelRemaining, +} + +/// Tracks the running sum of rows produced by a job's tagged producer +/// stages and fires a one-shot trigger when `sum >= threshold`. +/// +/// The tracker is keyed by job ID externally (in the scheduler's +/// `DashMap>`); the tracker itself does +/// not hold the ID. `Send + Sync` so it can be shared across the +/// scheduler's async tasks. +#[derive(Debug)] +pub struct JobLimitTracker { + limit: u64, + /// Pre-computed `limit * safety_factor`, with the floating-point + /// multiplication done once at construction time to keep `observe` + /// branch-free of f64 arithmetic. + threshold: u64, + tagged_producer_stage_ids: HashSet, + rows_so_far: AtomicU64, + triggered: AtomicBool, +} + +impl JobLimitTracker { + /// Construct a tracker with the given `limit` (the LIMIT's fetch + /// count) and `safety_factor` (typically 1.5). + /// + /// Panics if `limit == 0` (a zero-row LIMIT should be handled by the + /// optimizer, not the tracker) or if `safety_factor < 1.0` (we must + /// never trigger before the limit is reached). + pub fn new( + limit: u64, + safety_factor: f64, + tagged_producer_stage_ids: HashSet, + ) -> Self { + assert!(limit > 0, "JobLimitTracker requires a positive limit"); + assert!( + safety_factor >= 1.0, + "safety_factor must be >= 1.0 to preserve the asymmetric \ + correctness invariant (fire no earlier than rows >= limit)" + ); + let threshold = compute_threshold(limit, safety_factor); + Self { + limit, + threshold, + tagged_producer_stage_ids, + rows_so_far: AtomicU64::new(0), + triggered: AtomicBool::new(false), + } + } + + /// Record that `partition_rows` rows were produced by `stage_id`. + /// + /// Returns `CancelRemaining` exactly once across all calls, when the + /// running sum first crosses `threshold`. + /// + /// Asymmetric correctness invariant: we must fire no earlier than + /// `sum >= limit` (otherwise the downstream LimitExec under-reports). + /// Firing late is always safe; it only wastes I/O. + pub fn observe( + &self, + stage_id: usize, + partition_rows: u64, + ) -> EarlyStopDecision { + if !self.tagged_producer_stage_ids.contains(&stage_id) { + return EarlyStopDecision::Continue; + } + let new_total = + self.rows_so_far.fetch_add(partition_rows, Ordering::Relaxed) + + partition_rows; + if new_total >= self.threshold + && !self.triggered.swap(true, Ordering::SeqCst) + { + EarlyStopDecision::CancelRemaining + } else { + EarlyStopDecision::Continue + } + } + + pub fn limit(&self) -> u64 { + self.limit + } + + pub fn threshold(&self) -> u64 { + self.threshold + } + + pub fn rows_so_far(&self) -> u64 { + self.rows_so_far.load(Ordering::Relaxed) + } + + pub fn is_triggered(&self) -> bool { + self.triggered.load(Ordering::Relaxed) + } + + /// The producer stage IDs whose row counts this tracker observes. + pub fn tagged_producer_stage_ids(&self) -> &HashSet { + &self.tagged_producer_stage_ids + } +} + +/// Compute `limit * safety_factor` rounded up to the next integer, using +/// fixed-point arithmetic to avoid f64 imprecision at large limits. +fn compute_threshold(limit: u64, safety_factor: f64) -> u64 { + const SCALE: u128 = 1_000_000; + let scaled = (safety_factor * SCALE as f64).round() as u128; + let product = (limit as u128) * scaled; + let threshold = product.div_ceil(SCALE); + threshold.min(u64::MAX as u128) as u64 +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + use std::sync::atomic::AtomicUsize; + use std::thread; + + fn tracker(limit: u64, stages: &[usize]) -> JobLimitTracker { + JobLimitTracker::new(limit, 1.5, stages.iter().copied().collect()) + } + + #[test] + fn threshold_rounding() { + assert_eq!(compute_threshold(10, 1.5), 15); + assert_eq!(compute_threshold(100, 1.5), 150); + assert_eq!(compute_threshold(1, 1.5), 2); // 1.5 rounds up + assert_eq!(compute_threshold(1_000_000, 1.0), 1_000_000); + } + + #[test] + fn observe_continues_below_threshold() { + let t = tracker(100, &[0]); + assert_eq!(t.observe(0, 50), EarlyStopDecision::Continue); + assert_eq!(t.observe(0, 50), EarlyStopDecision::Continue); + assert_eq!(t.rows_so_far(), 100); + assert!(!t.is_triggered()); + } + + #[test] + fn observe_fires_at_threshold() { + // limit=10, sf=1.5 -> threshold=15 + let t = tracker(10, &[0]); + assert_eq!(t.observe(0, 10), EarlyStopDecision::Continue); + assert_eq!(t.observe(0, 5), EarlyStopDecision::CancelRemaining); + assert!(t.is_triggered()); + } + + #[test] + fn observe_fires_once_then_continue() { + let t = tracker(10, &[0]); + assert_eq!(t.observe(0, 20), EarlyStopDecision::CancelRemaining); + // Subsequent observations after triggering must not re-fire. + assert_eq!(t.observe(0, 5), EarlyStopDecision::Continue); + assert_eq!(t.observe(0, 100), EarlyStopDecision::Continue); + } + + #[test] + fn observe_ignores_untagged_stage() { + let t = tracker(10, &[0]); + assert_eq!(t.observe(42, 1_000_000), EarlyStopDecision::Continue); + assert_eq!(t.rows_so_far(), 0); + assert!(!t.is_triggered()); + } + + #[test] + fn observe_aggregates_multiple_tagged_stages() { + let t = tracker(100, &[1, 2, 3]); + assert_eq!(t.observe(1, 50), EarlyStopDecision::Continue); + assert_eq!(t.observe(2, 50), EarlyStopDecision::Continue); + assert_eq!(t.observe(3, 50), EarlyStopDecision::CancelRemaining); + } + + #[test] + fn concurrent_observers_trigger_once() { + let t = Arc::new(tracker(1_000, &[0])); + let fires = Arc::new(AtomicUsize::new(0)); + let mut handles = vec![]; + for _ in 0..32 { + let t = t.clone(); + let fires = fires.clone(); + handles.push(thread::spawn(move || { + // Each thread adds 100 rows -> total 3200 >> threshold(1500). + if t.observe(0, 100) == EarlyStopDecision::CancelRemaining { + fires.fetch_add(1, Ordering::Relaxed); + } + })); + } + for h in handles { + h.join().unwrap(); + } + assert_eq!(fires.load(Ordering::Relaxed), 1); + assert!(t.is_triggered()); + assert_eq!(t.rows_so_far(), 3200); + } + + #[test] + #[should_panic(expected = "positive limit")] + fn zero_limit_panics() { + JobLimitTracker::new(0, 1.5, HashSet::new()); + } + + #[test] + #[should_panic(expected = "safety_factor")] + fn safety_factor_below_one_panics() { + JobLimitTracker::new(10, 0.9, HashSet::new()); + } +} diff --git a/ballista/scheduler/src/state/aqe/mod.rs b/ballista/scheduler/src/state/aqe/mod.rs index 6b1903dc94..8bee847b35 100644 --- a/ballista/scheduler/src/state/aqe/mod.rs +++ b/ballista/scheduler/src/state/aqe/mod.rs @@ -18,20 +18,25 @@ use crate::display::print_stage_metrics; use crate::scheduler_server::event::QueryStageSchedulerEvent; use crate::scheduler_server::timestamp_millis; +use crate::state::aqe::limit_early_stop::{ + JobLimitContext, LimitEarlyStopAnalyzer, +}; use crate::state::aqe::planner::AdaptivePlanner; use crate::state::execution_graph::{ ExecutionGraph, ExecutionGraphBox, ExecutionStage, ResolvedStage, RunningTaskInfo, StageOutput, }; use crate::state::execution_stage::RunningStage; +use crate::state::execution_stage::TaskInfo; use crate::state::task_manager::UpdatedStages; use ballista_core::error::BallistaError; use ballista_core::execution_plans::ShuffleWriter; +use ballista_core::extension::SessionConfigExt; use ballista_core::serde::protobuf::failed_task::FailedReason; use ballista_core::serde::protobuf::job_status::Status; use ballista_core::serde::protobuf::{ - FailedJob, FailedTask, JobStatus, ResultLost, RunningJob, SuccessfulJob, TaskStatus, - job_status, task_status, + FailedJob, FailedTask, JobStatus, ResultLost, RunningJob, SuccessfulJob, + SuccessfulTask, TaskStatus, job_status, task_status, }; use ballista_core::serde::scheduler::{ExecutorMetadata, PartitionLocation}; use datafusion::physical_plan::ExecutionPlan; @@ -39,6 +44,7 @@ use datafusion::prelude::SessionConfig; use log::{debug, error, info, warn}; use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; use std::vec; // TODO: the AQE planner runs DataFusion's DefaultPhysicalPlanner with a @@ -51,6 +57,8 @@ use std::vec; mod adapter; mod execution_plan; +pub mod limit_early_stop; +pub mod limit_tracker; pub mod optimizer_rule; pub mod planner; #[cfg(test)] @@ -120,6 +128,46 @@ pub(crate) struct AdaptiveExecutionGraph { logical_plan: Option, /// Physical plan, captured at submission time. physical_plan: Arc, + /// Eligible global-LIMIT contexts identified by the AQE early-stop + /// analyzer at submission time. Empty if early-stop is disabled or no + /// eligible LIMIT was found. Consumed by `TaskManager::submit_job` to + /// register a `JobLimitTracker` for the job. + pub(crate) limit_contexts: Vec, +} + +/// AQE early-stop helper: fill every task_info slot of `stage` with a +/// synthetic Successful TaskInfo so the stage's `to_successful` path +/// (which panics on any None) sees a complete set. The synthetic +/// entries carry no shuffle output partitions — the partial output +/// already on disk (from real completed tasks) is what the downstream +/// consumer will read. +fn synthesize_early_stop_completion(stage: &mut RunningStage) { + let now_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_millis()) + .unwrap_or_default(); + for slot in stage.task_infos.iter_mut() { + let needs_synthesis = match slot { + None => true, + Some(info) => { + !matches!(info.task_status, task_status::Status::Successful(_)) + } + }; + if needs_synthesis { + *slot = Some(TaskInfo { + task_id: 0, + scheduled_time: now_ms, + launch_time: now_ms, + start_exec_time: now_ms, + end_exec_time: now_ms, + finish_time: now_ms, + task_status: task_status::Status::Successful(SuccessfulTask { + executor_id: "aqe-early-stop-synthetic".to_string(), + partitions: vec![], + }), + }); + } + } } impl AdaptiveExecutionGraph { @@ -166,6 +214,26 @@ impl AdaptiveExecutionGraph { .collect(); let stages = stages?; + // Run the AQE early-stop analyzer on the post-stage-resolution plan, + // so producer ExchangeExecs have their stage_ids assigned. If the + // flag is disabled, we still build the field (empty) so downstream + // code can unconditionally read it. + let limit_contexts = + if session_config.ballista_aqe_limit_early_stop_enabled() { + let contexts = LimitEarlyStopAnalyzer::new(planner.plan()).analyze(); + if !contexts.is_empty() { + info!( + "AQE early-stop analyzer tagged {} eligible LIMIT \ + context(s) for job {}", + contexts.len(), + job_id, + ); + } + contexts + } else { + Vec::new() + }; + Ok(Self { planner, scheduler_id: Some(scheduler_id.to_string()), @@ -192,6 +260,7 @@ impl AdaptiveExecutionGraph { session_config, logical_plan, physical_plan: plan, + limit_contexts, }) } } @@ -1212,6 +1281,85 @@ impl ExecutionGraph for AdaptiveExecutionGraph { Ok(()) } + fn early_stop_stages( + &mut self, + producer_stage_ids: &HashSet, + ) -> ballista_core::error::Result<( + Vec, + Vec, + )> { + let job_id = self.job_id.clone(); + let mut tasks_to_cancel: Vec = Vec::new(); + let mut newly_successful: HashSet = HashSet::new(); + + for &stage_id in producer_stage_ids { + let Some(ExecutionStage::Running(running_stage)) = + self.stages.get_mut(&stage_id) + else { + warn!( + "early_stop_stages: stage {}/{} is not in Running \ + state; skipping", + job_id, stage_id + ); + continue; + }; + + // Collect actually-running tasks BEFORE we synthesize their + // completion. The scheduler will cancel these via RPC; any + // that race to completion before their cancel arrives are + // fine — their shuffle output is already on disk and the + // downstream LimitExec slices to exact fetch. + for (task_id, _stage_id, partition_id, executor_id) in + running_stage.running_tasks() + { + tasks_to_cancel.push(RunningTaskInfo { + task_id, + job_id: job_id.clone(), + stage_id, + partition_id, + executor_id, + }); + } + + // Synthesize Successful TaskInfo with empty shuffle output + // for every partition that has not already completed + // successfully. `to_successful` requires every task_info + // slot to be Some(Successful(_)); without this synthesis it + // would panic on the unfinished partitions. + synthesize_early_stop_completion(running_stage); + newly_successful.insert(stage_id); + } + + // Drive the planner's per-stage finalisation. The planner has + // accumulated locations from real completed tasks via + // update_exchange_locations; finalise_stage extracts them and + // identifies the next batch of runnable stages. + for &stage_id in &newly_successful { + let stages_to_cancel = + self.update_stage_progress(stage_id, true, vec![])?; + if !stages_to_cancel.is_empty() { + debug!( + "early_stop_stages: planner flagged stages {:?} for \ + cancellation while finalising stage {}", + stages_to_cancel, stage_id + ); + } + } + + // Cascade graph state: for each early-stopped stage, transition + // Running -> Successful via succeed_stage and, if the consumer + // stage is also already done, finalize the job as Successful. + let events = self.processing_stages_update(UpdatedStages { + resolved_stages: HashSet::new(), + successful_stages: newly_successful, + failed_stages: HashMap::new(), + rollback_running_stages: HashMap::new(), + resubmit_successful_stages: HashSet::new(), + })?; + + Ok((tasks_to_cancel, events)) + } + fn stages(&self) -> &HashMap { &self.stages } diff --git a/ballista/scheduler/src/state/aqe/planner.rs b/ballista/scheduler/src/state/aqe/planner.rs index 2346bb380b..86145ec68f 100644 --- a/ballista/scheduler/src/state/aqe/planner.rs +++ b/ballista/scheduler/src/state/aqe/planner.rs @@ -417,6 +417,13 @@ impl AdaptivePlanner { pub fn current_plan(&self) -> &dyn ExecutionPlan { self.plan.as_ref() } + + /// Returns the current physical plan as an `Arc`. Used by the + /// scheduler-side AQE early-stop analyzer which needs to walk the + /// post-stage-resolution plan tree to identify eligible LIMITs. + pub(crate) fn plan(&self) -> &Arc { + &self.plan + } /// Returns the default set of physical optimizer rules. /// /// # Returns diff --git a/ballista/scheduler/src/state/aqe/test/early_stop.rs b/ballista/scheduler/src/state/aqe/test/early_stop.rs new file mode 100644 index 0000000000..c919c66a87 --- /dev/null +++ b/ballista/scheduler/src/state/aqe/test/early_stop.rs @@ -0,0 +1,482 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration tests for AQE early-stop wiring: +//! +//! * the analyzer is invoked from `AdaptiveExecutionGraph::try_new` and +//! correctly returns 0 contexts for plans that DataFusion's +//! limit-pushdown optimizer has stripped the GlobalLimitExec from; +//! * `early_stop_stages` correctly synthesizes successful completion +//! for unfinished partitions and reports the running tasks for +//! cancellation; +//! * the AdaptiveExecutionGraph respects the `enabled` config flag; +//! * the full `TaskManager` pipeline (submit → observe → emit +//! `EarlyStopCancel` → `early_stop_job` finalize) is wired correctly. +//! +//! The analyzer's own eligibility-check coverage (eligible / ineligible +//! mixes of operators) lives in +//! `state::aqe::limit_early_stop::tests` using synthetic plans; here we +//! only exercise the end-to-end wiring against real SQL plans and real +//! graph state. Note: in current DataFusion (53.x), the +//! `LimitPushdown` physical optimizer rewrites `GlobalLimitExec` into +//! `LocalLimitExec` + fetch-bearing parents in most query shapes, so +//! these end-to-end tests for the no-shuffle case end up with empty +//! `limit_contexts`. Extending the analyzer to also recognise +//! `LocalLimitExec` at a stage boundary is tracked in the v2 spec +//! (`aqe-tasks/03b-early-stop-global-limit-v2.md`, section v2.6). For +//! the same reason, the `TaskManager`-level tests below inject the +//! tracker manually rather than relying on the analyzer. + +use std::collections::HashSet; +use std::sync::Arc; + +use ballista_core::config::BALLISTA_ADAPTIVE_PLANNER_ENABLED; +use ballista_core::extension::SessionConfigExt; +use ballista_core::serde::BallistaCodec; +use ballista_core::serde::protobuf::{ + ShuffleWritePartition, SuccessfulTask, TaskStatus, task_status, +}; +use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; +use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; + +use crate::scheduler_server::event::QueryStageSchedulerEvent; +use crate::scheduler_server::timestamp_millis; +use crate::state::SchedulerState; +use crate::state::aqe::AdaptiveExecutionGraph; +use crate::state::aqe::limit_tracker::JobLimitTracker; +use crate::state::aqe::test::{ + mock_context, mock_context_with_ballista_flag, mock_memory_table, +}; +use crate::state::execution_graph::{ExecutionGraph, ExecutionStage}; +use crate::test_utils::{mock_executor, test_cluster_context}; + +/// SQL whose physical plan reliably contains at least one Running +/// producer stage after `revive()`, used by every test below that +/// needs a real graph to short-stop. +const GROUP_BY_SQL: &str = "select a, count(*) from t group by a"; + +async fn build_graph(sql: &str) -> AdaptiveExecutionGraph { + let ctx = mock_context(); + ctx.register_table("t", mock_memory_table()).unwrap(); + let plan = ctx + .sql(sql) + .await + .unwrap() + .create_physical_plan() + .await + .unwrap(); + AdaptiveExecutionGraph::try_new( + "scheduler-1", + "job-aqe-early-stop", + "test", + ctx.state().session_id(), + plan, + timestamp_millis(), + Arc::new(ctx.copied_config()), + None, + ) + .unwrap() +} + +/// Return the id of the first stage in `Running` state. Panics if none +/// exists — callers expect the graph to have been `revive()`d. +fn first_running_stage(graph: &G) -> usize { + graph + .stages() + .iter() + .find_map(|(id, stage)| match stage { + ExecutionStage::Running(_) => Some(*id), + _ => None, + }) + .expect("expected at least one Running stage") +} + +#[tokio::test] +async fn analyzer_invoked_for_aqe_jobs() { + // Any AQE job (with or without an eligible LIMIT) should have the + // analyzer run. For LIMIT queries whose GlobalLimitExec was + // stripped by DF's limit-pushdown, the resulting set is empty — + // that's a correct ineligibility answer (nothing to short-stop). + let graph = + build_graph("select a, count(*) from t group by a limit 3").await; + assert!(graph.limit_contexts.is_empty()); +} + +#[tokio::test] +async fn disabling_flag_short_circuits_analyzer() { + use ballista_core::config::BALLISTA_AQE_LIMIT_EARLY_STOP_ENABLED; + use datafusion::execution::SessionStateBuilder; + use datafusion::prelude::{SessionConfig, SessionContext}; + + let config = SessionConfig::new_with_ballista() + .set_str(BALLISTA_AQE_LIMIT_EARLY_STOP_ENABLED, "false") + .with_target_partitions(2) + .with_round_robin_repartition(false); + let state = SessionStateBuilder::new() + .with_config(config) + .with_default_features() + .build(); + let ctx = SessionContext::new_with_state(state); + ctx.register_table("t", mock_memory_table()).unwrap(); + let plan = ctx + .sql(GROUP_BY_SQL) + .await + .unwrap() + .create_physical_plan() + .await + .unwrap(); + let graph = AdaptiveExecutionGraph::try_new( + "scheduler-1", + "job-aqe-early-stop", + "test", + ctx.state().session_id(), + plan, + timestamp_millis(), + Arc::new(ctx.copied_config()), + None, + ) + .unwrap(); + assert!(graph.limit_contexts.is_empty()); +} + +#[tokio::test] +async fn early_stop_stages_is_no_op_for_unknown_stage_ids() { + let mut graph = build_graph(GROUP_BY_SQL).await; + graph.revive(); + let bogus: HashSet = [999].into_iter().collect(); + let (cancel, events) = graph.early_stop_stages(&bogus).unwrap(); + assert!(cancel.is_empty()); + assert!(events.is_empty()); +} + +#[tokio::test] +async fn early_stop_stages_synthesizes_completion_and_reports_running_tasks() { + let mut graph = build_graph(GROUP_BY_SQL).await; + graph.revive(); + let producer = first_running_stage(&graph); + + let executor = mock_executor("exec-1".to_string()); + let mut launched = 0usize; + while let Some(task) = graph.pop_next_task(&executor.id).unwrap() { + if task.partition.stage_id != producer { + break; + } + launched += 1; + } + assert!( + launched > 0, + "expected to launch at least one task in stage {producer}" + ); + + let producer_ids: HashSet = [producer].into_iter().collect(); + let (cancel, _events) = graph.early_stop_stages(&producer_ids).unwrap(); + assert_eq!( + cancel.len(), + launched, + "every launched task should be reported for cancellation" + ); + for task in &cancel { + assert_eq!(task.stage_id, producer); + assert_eq!(task.executor_id, executor.id); + } + assert!( + matches!( + graph.stages().get(&producer), + Some(ExecutionStage::Successful(_)) + ), + "producer stage should have transitioned Running -> Successful" + ); +} + +// ------------------------------------------------------------------------- +// TaskManager-level integration tests +// +// Each test runs: +// submit_job -> active graph cached +// inject JobLimitTracker (manual; see module-level note about v2.6) +// update_task_statuses with synthetic successful TaskStatus(es) +// -> assert presence / absence / cardinality of EarlyStopCancel +// One additional test exercises the early_stop_job finalization path. + +type TestSchedulerState = SchedulerState; + +async fn build_state_with_job( + sql: &str, + job_id: &str, +) -> (TestSchedulerState, ExecutorMetadata) { + let cluster = test_cluster_context(); + let state = SchedulerState::new_with_default_scheduler_name( + cluster, + BallistaCodec::default(), + ); + + let ctx = mock_context_with_ballista_flag(BALLISTA_ADAPTIVE_PLANNER_ENABLED); + ctx.register_table("t", mock_memory_table()).unwrap(); + let plan = ctx + .sql(sql) + .await + .unwrap() + .create_physical_plan() + .await + .unwrap(); + + state + .task_manager + .queue_job(job_id, "test", timestamp_millis()) + .unwrap(); + state + .task_manager + .submit_job( + job_id, + "test", + &ctx.state().session_id(), + plan, + timestamp_millis(), + Arc::new(ctx.copied_config()), + None, + None, + ) + .await + .unwrap(); + + let executor = mock_executor("exec-1".to_string()); + state + .executor_manager + .register_executor( + executor.clone(), + ExecutorData { + executor_id: executor.id.clone(), + total_task_slots: 16, + available_task_slots: 16, + }, + ) + .await + .unwrap(); + + (state, executor) +} + +async fn producer_stage_id(state: &TestSchedulerState, job_id: &str) -> usize { + let cached = state + .task_manager + .get_active_execution_graph(job_id) + .expect("job should be in active cache"); + let graph = cached.read().await; + first_running_stage(&**graph) +} + +fn synthetic_status( + job_id: &str, + stage_id: u32, + partition_id: u32, + executor_id: &str, + num_rows: u64, +) -> TaskStatus { + TaskStatus { + // task_id is synthetic — the observe path keys off (job_id, + // stage_id) only. + task_id: partition_id, + job_id: job_id.to_string(), + stage_id, + stage_attempt_num: 0, + partition_id, + launch_time: 0, + start_exec_time: 0, + end_exec_time: 0, + metrics: vec![], + status: Some(task_status::Status::Successful(SuccessfulTask { + executor_id: executor_id.to_string(), + partitions: vec![ShuffleWritePartition { + partition_id: partition_id as u64, + num_batches: 1, + num_rows, + num_bytes: num_rows * 8, + file_id: None, + is_sort_shuffle: false, + }], + })), + } +} + +/// Inject a tracker for `job_id` with the configured `limit` targeting +/// the given producer stage. Returns the `Arc` so callers can inspect +/// `is_triggered()` / `rows_so_far()` afterwards. +fn inject_tracker( + state: &TestSchedulerState, + job_id: &str, + limit: u64, + producer_stage: usize, +) -> Arc { + let tracker = Arc::new(JobLimitTracker::new( + limit, + 1.5, + [producer_stage].into_iter().collect(), + )); + state + .task_manager + .insert_limit_tracker_for_test(job_id, tracker.clone()); + tracker +} + +/// Drive a single batch of synthetic successful statuses for one stage +/// and return the count of `EarlyStopCancel` events emitted for `job_id`. +async fn drive_rows( + state: &TestSchedulerState, + executor: &ExecutorMetadata, + job_id: &str, + stage_id: usize, + rows_per_partition: &[u64], +) -> usize { + let statuses: Vec = rows_per_partition + .iter() + .enumerate() + .map(|(idx, rows)| { + synthetic_status( + job_id, + stage_id as u32, + idx as u32, + &executor.id, + *rows, + ) + }) + .collect(); + let events = state + .task_manager + .update_task_statuses(executor, statuses) + .await + .unwrap(); + events + .iter() + .filter(|e| { + matches!( + e, + QueryStageSchedulerEvent::EarlyStopCancel { job_id: jid } + if jid == job_id + ) + }) + .count() +} + +#[tokio::test] +async fn update_task_statuses_emits_early_stop_cancel_at_threshold() { + let job_id = "job-early-stop-emit"; + let (state, executor) = build_state_with_job(GROUP_BY_SQL, job_id).await; + let producer = producer_stage_id(&state, job_id).await; + let tracker = inject_tracker(&state, job_id, 10, producer); + + // limit=10, sf=1.5 -> threshold=15; 20 rows crosses on the first batch. + let fires = drive_rows(&state, &executor, job_id, producer, &[20]).await; + assert_eq!(fires, 1, "expected exactly one EarlyStopCancel"); + assert!(tracker.is_triggered()); + assert_eq!(tracker.rows_so_far(), 20); +} + +#[tokio::test] +async fn update_task_statuses_does_not_emit_below_threshold() { + let job_id = "job-early-stop-below"; + let (state, executor) = build_state_with_job(GROUP_BY_SQL, job_id).await; + let producer = producer_stage_id(&state, job_id).await; + let tracker = inject_tracker(&state, job_id, 100, producer); + + // limit=100, sf=1.5 -> threshold=150; 50 rows stays below. + let fires = drive_rows(&state, &executor, job_id, producer, &[50]).await; + assert_eq!(fires, 0); + assert!(!tracker.is_triggered()); + assert_eq!(tracker.rows_so_far(), 50); +} + +#[tokio::test] +async fn update_task_statuses_emits_early_stop_only_once() { + let job_id = "job-early-stop-once"; + let (state, executor) = build_state_with_job(GROUP_BY_SQL, job_id).await; + let producer = producer_stage_id(&state, job_id).await; + let _tracker = inject_tracker(&state, job_id, 10, producer); + + let fires1 = drive_rows(&state, &executor, job_id, producer, &[50]).await; + assert_eq!(fires1, 1); + let fires2 = drive_rows(&state, &executor, job_id, producer, &[50]).await; + assert_eq!(fires2, 0, "EarlyStopCancel must fire exactly once per job"); +} + +#[tokio::test] +async fn early_stop_job_finalizes_producer_stage_and_clears_tracker() { + let job_id = "job-early-stop-finalize"; + let (state, executor) = build_state_with_job(GROUP_BY_SQL, job_id).await; + let producer = producer_stage_id(&state, job_id).await; + + // Launch tasks so the producer has running work that early_stop_job + // must report for cancellation. + { + let cached = state + .task_manager + .get_active_execution_graph(job_id) + .unwrap(); + let mut graph = cached.write().await; + let mut launched = 0; + while let Some(task) = graph.pop_next_task(&executor.id).unwrap() { + if task.partition.stage_id != producer { + break; + } + launched += 1; + } + assert!(launched > 0, "expected to launch at least one task"); + } + + inject_tracker(&state, job_id, 10, producer); + + let (cancel, _events) = + state.task_manager.early_stop_job(job_id).await.unwrap(); + assert!( + !cancel.is_empty(), + "expected at least one running task to be reported for cancellation" + ); + for task in &cancel { + assert_eq!(task.stage_id, producer); + assert_eq!(task.executor_id, executor.id); + } + assert!( + state + .task_manager + .get_limit_tracker_for_test(job_id) + .is_none(), + "tracker must be removed once early_stop_job fires" + ); + + let cached = state + .task_manager + .get_active_execution_graph(job_id) + .unwrap(); + let graph = cached.read().await; + assert!( + matches!( + graph.stages().get(&producer), + Some(ExecutionStage::Successful(_)) + ), + "producer stage should be Successful after early_stop_job" + ); +} + +#[tokio::test] +async fn early_stop_job_is_no_op_without_tracker() { + let job_id = "job-early-stop-missing"; + let (state, _executor) = build_state_with_job(GROUP_BY_SQL, job_id).await; + + let (cancel, events) = + state.task_manager.early_stop_job(job_id).await.unwrap(); + assert!(cancel.is_empty()); + assert!(events.is_empty()); +} diff --git a/ballista/scheduler/src/state/aqe/test/mod.rs b/ballista/scheduler/src/state/aqe/test/mod.rs index a60cd038ab..51c63a41da 100644 --- a/ballista/scheduler/src/state/aqe/test/mod.rs +++ b/ballista/scheduler/src/state/aqe/test/mod.rs @@ -17,6 +17,8 @@ /// Test if stages can be added or removed mod alter_stages; +/// Integration tests for the AQE early-stop on global LIMIT feature +mod early_stop; /// Tests if plan is going to be split to stages correctly mod plan_to_stages; @@ -126,9 +128,12 @@ pub(crate) fn mock_context() -> SessionContext { SessionContext::new_with_state(state) } -pub(crate) fn mock_context_sort_shuffle() -> SessionContext { +/// Build a session context with a single Ballista boolean flag set to +/// `"true"`. Shared by tests that need the AQE planner or sort-based +/// shuffle enabled and otherwise mirror `mock_context`'s settings. +pub(crate) fn mock_context_with_ballista_flag(flag: &str) -> SessionContext { let config = SessionConfig::new_with_ballista() - .set_str(BALLISTA_SHUFFLE_SORT_BASED_ENABLED, "true") + .set_str(flag, "true") .with_target_partitions(2) .with_round_robin_repartition(false); @@ -139,3 +144,7 @@ pub(crate) fn mock_context_sort_shuffle() -> SessionContext { SessionContext::new_with_state(state) } + +pub(crate) fn mock_context_sort_shuffle() -> SessionContext { + mock_context_with_ballista_flag(BALLISTA_SHUFFLE_SORT_BASED_ENABLED) +} diff --git a/ballista/scheduler/src/state/execution_graph.rs b/ballista/scheduler/src/state/execution_graph.rs index 8f0c30cfa5..aabf6a0011 100644 --- a/ballista/scheduler/src/state/execution_graph.rs +++ b/ballista/scheduler/src/state/execution_graph.rs @@ -221,6 +221,26 @@ pub trait ExecutionGraph: Debug { /// Returns an error if the job is not in a successful state. fn succeed_job(&mut self) -> Result<()>; + /// AQE early-stop: short-stop the producer stages identified by the + /// `LimitEarlyStopAnalyzer`. For each stage, synthesizes successful + /// completion for any tasks that have not yet finished (so the + /// `LimitExec`-bearing consumer stage can run with the partial + /// shuffle output already on disk), and returns: + /// - the list of in-flight tasks the scheduler must cancel + /// - the cascading scheduler events emitted by the stage + /// transition (e.g. dependent stages becoming resolved, the job + /// reaching `Successful`) + /// + /// Default impl is a no-op for graphs without AQE support (the + /// `LimitEarlyStopAnalyzer` only runs under AQE, so this should + /// never be invoked for non-adaptive graphs). + fn early_stop_stages( + &mut self, + _producer_stage_ids: &HashSet, + ) -> Result<(Vec, Vec)> { + Ok((vec![], vec![])) + } + /// Exposes executions stages and stage id's fn stages(&self) -> &HashMap; diff --git a/ballista/scheduler/src/state/task_manager.rs b/ballista/scheduler/src/state/task_manager.rs index 20aa6990ea..abe6f83a56 100644 --- a/ballista/scheduler/src/state/task_manager.rs +++ b/ballista/scheduler/src/state/task_manager.rs @@ -35,11 +35,13 @@ use crate::cluster::JobState; use ballista_core::serde::BallistaCodec; use ballista_core::serde::protobuf::{ JobStatus, MultiTaskDefinition, TaskDefinition, TaskId, TaskStatus, job_status, + task_status, }; use ballista_core::serde::scheduler::ExecutorMetadata; use dashmap::DashMap; use crate::state::aqe::AdaptiveExecutionGraph; +use crate::state::aqe::limit_tracker::{EarlyStopDecision, JobLimitTracker}; use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}; @@ -135,6 +137,12 @@ pub struct TaskManager task_max_failures: usize, /// Maximum number of failure attempts for stage-level retry before the stage is considered failed. stage_max_failures: usize, + /// Per-job row-count trackers for AQE early-stop on global LIMIT. + /// Populated by `submit_job` when the AQE planner tagged an eligible + /// LIMIT. Consulted by `update_task_statuses` on each task completion. + /// Removed by `clean_up_job_delayed` on job completion (success or + /// failure) and by `early_stop_job` once the trigger has fired. + job_limit_trackers: Arc>>, } /// Cache for active job information managed by this scheduler. @@ -232,6 +240,7 @@ impl TaskManager launcher: Arc::new(DefaultTaskLauncher::new(scheduler_id)), task_max_failures: config.task_max_failures, stage_max_failures: config.stage_max_failures, + job_limit_trackers: Arc::new(DashMap::new()), } } @@ -251,6 +260,7 @@ impl TaskManager launcher, task_max_failures: config.task_max_failures, stage_max_failures: config.stage_max_failures, + job_limit_trackers: Arc::new(DashMap::new()), } } @@ -292,7 +302,7 @@ impl TaskManager warn!( "Adaptive Query Planning is EXPERIMENTAL, should be used for testing purposes only!" ); - Box::new(AdaptiveExecutionGraph::try_new( + let adaptive = AdaptiveExecutionGraph::try_new( &self.scheduler_id, job_id, job_name, @@ -301,7 +311,9 @@ impl TaskManager queued_at, session_config, logical_plan, - )?) as ExecutionGraphBox + )?; + self.register_limit_tracker(job_id, &adaptive); + Box::new(adaptive) as ExecutionGraphBox } else { debug!("Using static query planner for job planning"); Box::new(StaticExecutionGraph::new( @@ -458,6 +470,15 @@ impl TaskManager let num_tasks = statuses.len(); debug!("Updating {num_tasks} tasks in job {job_id}"); + // Observe early-stop row counts BEFORE mutating the graph so + // that a single status batch that crosses the threshold still + // fires exactly once (the swap on `triggered` enforces this). + if let Some(early_stop_event) = + self.observe_early_stop_rows(&job_id, &statuses) + { + events.push(early_stop_event); + } + // let graph = self.get_active_execution_graph(&job_id).await; let job_events = if let Some(cached) = self.get_active_execution_graph(&job_id) @@ -485,6 +506,120 @@ impl TaskManager Ok(events) } + /// Register a `JobLimitTracker` if the AQE planner tagged exactly one + /// eligible LIMIT for this job. Multi-LIMIT jobs are deferred to v2 + /// — see `aqe-tasks/03b-early-stop-global-limit-v2.md`. + fn register_limit_tracker( + &self, + job_id: &str, + graph: &AdaptiveExecutionGraph, + ) { + // Default safety_factor; the config system does not currently + // support f64 keys, so this is hardcoded. Surface as a tunable + // when the config layer grows f64 support. + const SAFETY_FACTOR: f64 = 1.5; + match graph.limit_contexts.as_slice() { + [] => {} + [ctx] => { + let tracker = Arc::new(JobLimitTracker::new( + ctx.fetch, + SAFETY_FACTOR, + ctx.producer_stage_ids.clone(), + )); + info!( + "Registering AQE early-stop tracker for job {} (limit={}, \ + threshold={}, producer stages={:?})", + job_id, + tracker.limit(), + tracker.threshold(), + ctx.producer_stage_ids, + ); + self.job_limit_trackers + .insert(job_id.to_string(), tracker); + } + many => { + warn!( + "AQE early-stop: job {} has {} eligible LIMIT contexts; \ + v1 only tracks single-LIMIT jobs. See \ + aqe-tasks/03b-early-stop-global-limit-v2.md.", + job_id, + many.len() + ); + } + } + } + + /// Test-only: inject a pre-built tracker for `job_id`. This bypasses + /// the analyzer-driven registration in `register_limit_tracker` so + /// integration tests can exercise the observation + cancel pipeline + /// against real graphs whose `GlobalLimitExec` was stripped by + /// DataFusion's `LimitPushdown` (see v2.6 in + /// `aqe-tasks/03b-early-stop-global-limit-v2.md`). + #[cfg(test)] + pub(crate) fn insert_limit_tracker_for_test( + &self, + job_id: &str, + tracker: Arc, + ) { + self.job_limit_trackers + .insert(job_id.to_string(), tracker); + } + + /// Test-only: read the current tracker for `job_id`, if any. + #[cfg(test)] + pub(crate) fn get_limit_tracker_for_test( + &self, + job_id: &str, + ) -> Option> { + self.job_limit_trackers.get(job_id).map(|r| r.clone()) + } + + /// Sum `ShuffleWritePartition.num_rows` across all successful tasks + /// in this batch and pass to the tracker. Returns an + /// `EarlyStopCancel` event iff this batch crossed the threshold for + /// the first time. + fn observe_early_stop_rows( + &self, + job_id: &str, + statuses: &[TaskStatus], + ) -> Option { + let tracker = self.job_limit_trackers.get(job_id)?; + if tracker.is_triggered() { + return None; + } + let mut decision = EarlyStopDecision::Continue; + for status in statuses { + let Some(task_status::Status::Successful(ref ok)) = status.status + else { + continue; + }; + let rows: u64 = ok.partitions.iter().map(|p| p.num_rows).sum(); + if rows == 0 { + continue; + } + if let EarlyStopDecision::CancelRemaining = + tracker.observe(status.stage_id as usize, rows) + { + decision = EarlyStopDecision::CancelRemaining; + } + } + match decision { + EarlyStopDecision::CancelRemaining => { + info!( + "AQE early-stop trigger fired for job {} \ + (rows_so_far={}, threshold={})", + job_id, + tracker.rows_so_far(), + tracker.threshold(), + ); + Some(QueryStageSchedulerEvent::EarlyStopCancel { + job_id: job_id.to_string(), + }) + } + EarlyStopDecision::Continue => None, + } + } + /// Mark a job to success. This will create a key under the CompletedJobs keyspace /// and remove the job from ActiveJobs pub(crate) async fn succeed_job(&self, job_id: &str) -> Result<()> { @@ -513,6 +648,41 @@ impl TaskManager self.abort_job(job_id, "Cancelled".to_owned()).await } + /// Short-stop the producer stages tagged by the AQE early-stop + /// tracker for `job_id`. Synthesizes successful completion for the + /// remaining tasks of those stages so the consumer (LIMIT) stage can + /// run on the partial shuffle output. Returns the in-flight tasks to + /// cancel and any follow-up scheduler events emitted by the cascade. + /// + /// The job is NOT removed from the active cache here — finalization + /// happens naturally when the consumer stage completes, at which + /// point the job ends `Successful` (not `Cancelled`) with the + /// downstream `LimitExec` slicing to exactly `fetch` rows. + pub(crate) async fn early_stop_job( + &self, + job_id: &str, + ) -> Result<(Vec, Vec)> { + let Some(tracker) = + self.job_limit_trackers.remove(job_id).map(|(_, v)| v) + else { + warn!( + "early_stop_job called for job {job_id} but no tracker is \ + registered; this is a no-op" + ); + return Ok((vec![], vec![])); + }; + let producer_stage_ids = tracker.tagged_producer_stage_ids().clone(); + let Some(cached) = self.get_active_execution_graph(job_id) else { + warn!( + "early_stop_job: job {job_id} no longer in active cache; \ + nothing to do" + ); + return Ok((vec![], vec![])); + }; + let mut graph = cached.write().await; + graph.early_stop_stages(&producer_stage_ids) + } + /// Abort the job and return a Vec of running tasks need to cancel pub(crate) async fn abort_job( &self, @@ -764,11 +934,14 @@ impl TaskManager .map(|cached| cached.execution_graph.clone()) } - /// Remove the `ExecutionGraph` for the given job ID from cache + /// Remove the `ExecutionGraph` for the given job ID from cache. Also + /// drops any AQE early-stop tracker registered for the job since the + /// graph is the only consumer. pub(crate) fn remove_active_execution_graph( &self, job_id: &str, ) -> Option>> { + self.job_limit_trackers.remove(job_id); self.active_job_cache .remove(job_id) .map(|value| value.1.execution_graph)