Skip to content

Commit 23fc979

Browse files
committed
feat: Optimise convert_to_state for SUM and BIT_AND_OR
1 parent e1ad871 commit 23fc979

3 files changed

Lines changed: 32 additions & 7 deletions

File tree

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ where
5252
/// The starting value for new groups
5353
starting_value: T::Native,
5454

55+
/// When true, `starting_value` is the identity element for `prim_fn`,
56+
/// i.e. `prim_fn(starting_value, x) == x` for all x. This allows
57+
/// `convert_to_state` to skip allocating an initial-state array and
58+
/// the element-wise arithmetic, returning the input values directly.
59+
starting_value_is_identity: bool,
60+
5561
/// Track nulls in the input / filters
5662
null_state: NullState,
5763

@@ -70,6 +76,7 @@ where
7076
data_type: data_type.clone(),
7177
null_state: NullState::new(),
7278
starting_value: T::default_value(),
79+
starting_value_is_identity: false,
7380
prim_fn,
7481
}
7582
}
@@ -79,6 +86,12 @@ where
7986
self.starting_value = starting_value;
8087
self
8188
}
89+
90+
/// Mark that `starting_value` is the identity element
91+
pub fn with_starting_value_as_identity(mut self) -> Self {
92+
self.starting_value_is_identity = true;
93+
self
94+
}
8295
}
8396

8497
impl<T, F> GroupsAccumulator for PrimitiveGroupsAccumulator<T, F>
@@ -150,10 +163,6 @@ where
150163
) -> Result<Vec<ArrayRef>> {
151164
let values = values[0].as_primitive::<T>().clone();
152165

153-
// Initializing state with starting values
154-
let initial_state =
155-
PrimitiveArray::<T>::from_value(self.starting_value, values.len());
156-
157166
// Recalculating values in case there is filter
158167
let values = match opt_filter {
159168
None => values,
@@ -175,6 +184,20 @@ where
175184
}
176185
};
177186

187+
// When starting_value is the identity element for prim_fn
188+
// (e.g. 0 for SUM/BIT_OR/BIT_XOR), prim_fn(starting_value, x) == x,
189+
// so we can skip allocating the initial_state array and the binary_mut
190+
// arithmetic entirely — just return the (filtered) input values.
191+
if self.starting_value_is_identity {
192+
return Ok(vec![Arc::new(
193+
values.with_data_type(self.data_type.clone()),
194+
)]);
195+
}
196+
197+
// For non-identity starting values (MIN, MAX, BIT_AND, etc.),
198+
// apply the operation against the starting value array.
199+
let initial_state =
200+
PrimitiveArray::<T>::from_value(self.starting_value, values.len());
178201
let state_values = compute::binary_mut(initial_state, &values, |mut x, y| {
179202
(self.prim_fn)(&mut x, y);
180203
x

datafusion/functions-aggregate/src/bit_and_or_xor.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@ macro_rules! group_accumulator_helper {
5454
.with_starting_value(!0),
5555
)),
5656
BitwiseOperationType::Or => Ok(Box::new(
57-
PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y)),
57+
PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y))
58+
.with_starting_value_as_identity(),
5859
)),
5960
BitwiseOperationType::Xor => Ok(Box::new(
60-
PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y)),
61+
PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y))
62+
.with_starting_value_as_identity(),
6163
)),
6264
}
6365
};

datafusion/functions-aggregate/src/sum.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ impl AggregateUDFImpl for Sum {
292292
Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new(
293293
&$dt,
294294
|x, y| *x = x.add_wrapping(y),
295-
)))
295+
).with_starting_value_as_identity()))
296296
};
297297
}
298298
downcast_sum!(args, helper)

0 commit comments

Comments
 (0)