Skip to content

Commit da187f2

Browse files
authored
feat: support PartialMerge aggregation mode (#4003)
1 parent 7200cc7 commit da187f2

35 files changed

Lines changed: 1892 additions & 1003 deletions

File tree

native/common/src/error.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ pub enum SparkError {
7272
#[error("[ARITHMETIC_OVERFLOW] Overflow in integral divide. Use `try_divide` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
7373
IntegralDivideOverflow,
7474

75-
#[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
76-
DecimalSumOverflow,
75+
#[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. Use `try_{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
76+
DecimalSumOverflow { function_name: String },
7777

7878
#[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
7979
DivideByZero,
@@ -240,7 +240,7 @@ impl SparkError {
240240
SparkError::CannotParseDecimal => "CannotParseDecimal",
241241
SparkError::ArithmeticOverflow { .. } => "ArithmeticOverflow",
242242
SparkError::IntegralDivideOverflow => "IntegralDivideOverflow",
243-
SparkError::DecimalSumOverflow => "DecimalSumOverflow",
243+
SparkError::DecimalSumOverflow { .. } => "DecimalSumOverflow",
244244
SparkError::DivideByZero => "DivideByZero",
245245
SparkError::RemainderByZero => "RemainderByZero",
246246
SparkError::IntervalDividedByZero => "IntervalDividedByZero",
@@ -336,6 +336,11 @@ impl SparkError {
336336
"fromType": from_type,
337337
})
338338
}
339+
SparkError::DecimalSumOverflow { function_name } => {
340+
serde_json::json!({
341+
"functionName": function_name,
342+
})
343+
}
339344
SparkError::BinaryArithmeticOverflow {
340345
value1,
341346
symbol,
@@ -523,7 +528,7 @@ impl SparkError {
523528
| SparkError::NumericOutOfRange { .. } // Comet-specific extension
524529
| SparkError::ArithmeticOverflow { .. }
525530
| SparkError::IntegralDivideOverflow
526-
| SparkError::DecimalSumOverflow
531+
| SparkError::DecimalSumOverflow { .. }
527532
| SparkError::BinaryArithmeticOverflow { .. }
528533
| SparkError::IntervalArithmeticOverflowWithSuggestion { .. }
529534
| SparkError::IntervalArithmeticOverflowWithoutSuggestion
@@ -601,7 +606,7 @@ impl SparkError {
601606
SparkError::IntervalDividedByZero => Some("INTERVAL_DIVIDED_BY_ZERO"),
602607
SparkError::ArithmeticOverflow { .. } => Some("ARITHMETIC_OVERFLOW"),
603608
SparkError::IntegralDivideOverflow => Some("ARITHMETIC_OVERFLOW"),
604-
SparkError::DecimalSumOverflow => Some("ARITHMETIC_OVERFLOW"),
609+
SparkError::DecimalSumOverflow { .. } => Some("ARITHMETIC_OVERFLOW"),
605610
SparkError::BinaryArithmeticOverflow { .. } => Some("BINARY_ARITHMETIC_OVERFLOW"),
606611
SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => {
607612
Some("INTERVAL_ARITHMETIC_OVERFLOW")
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! MergeAsPartial wrapper for implementing Spark's PartialMerge aggregate mode.
19+
//!
20+
//! Spark's PartialMerge mode merges intermediate state buffers and outputs intermediate
21+
//! state (not final values). DataFusion has no equivalent mode — `Partial` calls
22+
//! `update_batch` and outputs state, while `Final` calls `merge_batch` and outputs
23+
//! evaluated results.
24+
//!
25+
//! This wrapper bridges the gap: it operates under DataFusion's `Partial` mode (which
26+
//! outputs state) but redirects `update_batch` calls to `merge_batch`, giving merge
27+
//! semantics with state output.
28+
29+
use std::any::Any;
30+
use std::fmt::Debug;
31+
use std::hash::{Hash, Hasher};
32+
33+
use arrow::array::{ArrayRef, BooleanArray};
34+
use arrow::datatypes::{DataType, FieldRef};
35+
use datafusion::common::Result;
36+
use datafusion::logical_expr::function::AccumulatorArgs;
37+
use datafusion::logical_expr::function::StateFieldsArgs;
38+
use datafusion::logical_expr::{
39+
Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF,
40+
Signature, Volatility,
41+
};
42+
use datafusion::physical_expr::aggregate::AggregateFunctionExpr;
43+
use datafusion::scalar::ScalarValue;
44+
45+
/// An AggregateUDF wrapper that gives merge semantics in Partial mode.
46+
///
47+
/// When DataFusion runs an AggregateExec in Partial mode, it calls `update_batch`
48+
/// on each accumulator and outputs `state()`. This wrapper intercepts `update_batch`
49+
/// and redirects it to `merge_batch` on the inner accumulator, effectively
50+
/// implementing PartialMerge: merge inputs, output state.
51+
///
52+
/// We store the inner AggregateUDF (not the AggregateFunctionExpr) to avoid keeping
53+
/// references to UnboundColumn expressions that would panic if evaluated.
54+
#[derive(Debug)]
55+
pub struct MergeAsPartialUDF {
56+
/// The inner aggregate UDF, cloned from the original expression.
57+
inner_udf: AggregateUDF,
58+
/// Pre-computed return type from the original expression.
59+
return_type: DataType,
60+
/// Pre-computed state fields from the original expression.
61+
cached_state_fields: Vec<FieldRef>,
62+
/// Cached signature that accepts state field types.
63+
signature: Signature,
64+
/// Name for this wrapper.
65+
name: String,
66+
}
67+
68+
impl PartialEq for MergeAsPartialUDF {
69+
fn eq(&self, other: &Self) -> bool {
70+
self.name == other.name
71+
}
72+
}
73+
74+
impl Eq for MergeAsPartialUDF {}
75+
76+
impl Hash for MergeAsPartialUDF {
77+
fn hash<H: Hasher>(&self, state: &mut H) {
78+
self.name.hash(state);
79+
}
80+
}
81+
82+
impl MergeAsPartialUDF {
83+
pub fn new(inner_expr: &AggregateFunctionExpr) -> Result<Self> {
84+
let name = format!("merge_as_partial_{}", inner_expr.name());
85+
let return_type = inner_expr.field().data_type().clone();
86+
let cached_state_fields = inner_expr.state_fields()?;
87+
88+
// Use a permissive signature since we accept state field types which
89+
// vary per aggregate function.
90+
let signature = Signature::variadic_any(Volatility::Immutable);
91+
92+
Ok(Self {
93+
inner_udf: inner_expr.fun().clone(),
94+
return_type,
95+
cached_state_fields,
96+
signature,
97+
name,
98+
})
99+
}
100+
}
101+
102+
impl AggregateUDFImpl for MergeAsPartialUDF {
103+
fn as_any(&self) -> &dyn Any {
104+
self
105+
}
106+
107+
fn name(&self) -> &str {
108+
&self.name
109+
}
110+
111+
fn signature(&self) -> &Signature {
112+
&self.signature
113+
}
114+
115+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
116+
// In Partial mode, return_type isn't used for output schema (state_fields is).
117+
// Return the inner function's return type for consistency.
118+
Ok(self.return_type.clone())
119+
}
120+
121+
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
122+
// Cached at construction: state schema depends on the inner aggregate's
123+
// return type, not on StateFieldsArgs.
124+
Ok(self.cached_state_fields.clone())
125+
}
126+
127+
fn accumulator(&self, args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
128+
// args.exprs are state-typed (match this wrapper's signature), not the
129+
// inner aggregate's original inputs. Safe for built-ins (SUM/COUNT/
130+
// MIN/MAX/AVG) which build accumulators from return_type; aggregates
131+
// that inspect args.exprs types would need reconsideration.
132+
let inner_acc = self.inner_udf.accumulator(args)?;
133+
Ok(Box::new(MergeAsPartialAccumulator { inner: inner_acc }))
134+
}
135+
136+
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
137+
// See `accumulator`: args.exprs are state-typed.
138+
self.inner_udf.groups_accumulator_supported(args)
139+
}
140+
141+
fn create_groups_accumulator(
142+
&self,
143+
args: AccumulatorArgs,
144+
) -> Result<Box<dyn GroupsAccumulator>> {
145+
// See `accumulator`: args.exprs are state-typed.
146+
let inner_acc = self.inner_udf.create_groups_accumulator(args)?;
147+
Ok(Box::new(MergeAsPartialGroupsAccumulator {
148+
inner: inner_acc,
149+
}))
150+
}
151+
152+
fn reverse_expr(&self) -> ReversedUDAF {
153+
ReversedUDAF::NotSupported
154+
}
155+
156+
fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
157+
ScalarValue::try_from(data_type)
158+
}
159+
160+
fn is_descending(&self) -> Option<bool> {
161+
None
162+
}
163+
}
164+
165+
/// Accumulator wrapper that redirects update_batch to merge_batch.
166+
struct MergeAsPartialAccumulator {
167+
inner: Box<dyn Accumulator>,
168+
}
169+
170+
impl Debug for MergeAsPartialAccumulator {
171+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172+
f.debug_struct("MergeAsPartialAccumulator").finish()
173+
}
174+
}
175+
176+
impl Accumulator for MergeAsPartialAccumulator {
177+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
178+
// Redirect update to merge — this is the key trick.
179+
self.inner.merge_batch(values)
180+
}
181+
182+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
183+
self.inner.merge_batch(states)
184+
}
185+
186+
fn evaluate(&mut self) -> Result<ScalarValue> {
187+
self.inner.evaluate()
188+
}
189+
190+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
191+
self.inner.state()
192+
}
193+
194+
fn size(&self) -> usize {
195+
self.inner.size()
196+
}
197+
}
198+
199+
/// GroupsAccumulator wrapper that redirects update_batch to merge_batch.
200+
struct MergeAsPartialGroupsAccumulator {
201+
inner: Box<dyn GroupsAccumulator>,
202+
}
203+
204+
impl Debug for MergeAsPartialGroupsAccumulator {
205+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206+
f.debug_struct("MergeAsPartialGroupsAccumulator").finish()
207+
}
208+
}
209+
210+
impl GroupsAccumulator for MergeAsPartialGroupsAccumulator {
211+
fn update_batch(
212+
&mut self,
213+
values: &[ArrayRef],
214+
group_indices: &[usize],
215+
opt_filter: Option<&BooleanArray>,
216+
total_num_groups: usize,
217+
) -> Result<()> {
218+
// Redirect update to merge — this is the key trick.
219+
self.inner
220+
.merge_batch(values, group_indices, opt_filter, total_num_groups)
221+
}
222+
223+
fn merge_batch(
224+
&mut self,
225+
values: &[ArrayRef],
226+
group_indices: &[usize],
227+
opt_filter: Option<&BooleanArray>,
228+
total_num_groups: usize,
229+
) -> Result<()> {
230+
self.inner
231+
.merge_batch(values, group_indices, opt_filter, total_num_groups)
232+
}
233+
234+
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
235+
self.inner.evaluate(emit_to)
236+
}
237+
238+
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
239+
self.inner.state(emit_to)
240+
}
241+
242+
fn size(&self) -> usize {
243+
self.inner.size()
244+
}
245+
}

native/core/src/execution/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
pub mod columnar_to_row;
2020
pub mod expressions;
2121
pub mod jni_api;
22+
pub(crate) mod merge_as_partial;
2223
pub(crate) mod metrics;
2324
pub mod operators;
2425
pub(crate) mod planner;

0 commit comments

Comments
 (0)