Skip to content

Commit 6bafabe

Browse files
blagininclaude
andcommitted
Count and Mean aggregates
Signed-off-by: blaginin <dima@spiraldb.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent d0ed3fc commit 6bafabe

4 files changed

Lines changed: 743 additions & 2 deletions

File tree

vortex-array/src/aggregate_fn/accumulator_grouped.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
237237
if validity.value(offset) {
238238
let group = elements.slice(offset..offset + size)?;
239239
accumulator.accumulate(&group, ctx)?;
240-
states.append_scalar(&accumulator.finish()?)?;
240+
states.append_scalar(&accumulator.flush()?)?;
241241
} else {
242242
states.append_null()
243243
}
@@ -309,7 +309,7 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
309309
if validity.value(i) {
310310
let group = elements.slice(offset..offset + size)?;
311311
accumulator.accumulate(&group, ctx)?;
312-
states.append_scalar(&accumulator.finish()?)?;
312+
states.append_scalar(&accumulator.flush()?)?;
313313
} else {
314314
states.append_null()
315315
}
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexExpect;
5+
use vortex_error::VortexResult;
6+
7+
use crate::ArrayRef;
8+
use crate::Columnar;
9+
use crate::ExecutionCtx;
10+
use crate::aggregate_fn::Accumulator;
11+
use crate::aggregate_fn::AggregateFnId;
12+
use crate::aggregate_fn::AggregateFnVTable;
13+
use crate::aggregate_fn::DynAccumulator;
14+
use crate::aggregate_fn::EmptyOptions;
15+
use crate::dtype::DType;
16+
use crate::dtype::Nullability;
17+
use crate::dtype::PType;
18+
use crate::scalar::Scalar;
19+
20+
/// Return the count of non-null elements in an array.
21+
///
22+
/// See [`Count`] for details.
23+
pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<u64> {
24+
let mut acc = Accumulator::try_new(Count, EmptyOptions, array.dtype().clone())?;
25+
acc.accumulate(array, ctx)?;
26+
let result = acc.finish()?;
27+
28+
Ok(result
29+
.as_primitive()
30+
.typed_value::<u64>()
31+
.vortex_expect("count result should not be null"))
32+
}
33+
34+
/// Count the number of non-null elements in an array.
35+
///
36+
/// Applies to all types. Returns a `u64` count.
37+
/// The identity value is zero.
38+
#[derive(Clone, Debug)]
39+
pub struct Count;
40+
41+
impl AggregateFnVTable for Count {
42+
type Options = EmptyOptions;
43+
type Partial = u64;
44+
45+
fn id(&self) -> AggregateFnId {
46+
AggregateFnId::new_ref("vortex.count")
47+
}
48+
49+
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
50+
Ok(Some(vec![]))
51+
}
52+
53+
fn deserialize(
54+
&self,
55+
_metadata: &[u8],
56+
_session: &vortex_session::VortexSession,
57+
) -> VortexResult<Self::Options> {
58+
Ok(EmptyOptions)
59+
}
60+
61+
fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option<DType> {
62+
Some(DType::Primitive(PType::U64, Nullability::NonNullable))
63+
}
64+
65+
fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
66+
self.return_dtype(options, input_dtype)
67+
}
68+
69+
fn empty_partial(
70+
&self,
71+
_options: &Self::Options,
72+
_input_dtype: &DType,
73+
) -> VortexResult<Self::Partial> {
74+
Ok(0u64)
75+
}
76+
77+
fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
78+
let val = other
79+
.as_primitive()
80+
.typed_value::<u64>()
81+
.vortex_expect("count partial should not be null");
82+
*partial += val;
83+
Ok(())
84+
}
85+
86+
fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
87+
Ok(Scalar::primitive(*partial, Nullability::NonNullable))
88+
}
89+
90+
fn reset(&self, partial: &mut Self::Partial) {
91+
*partial = 0;
92+
}
93+
94+
#[inline]
95+
fn is_saturated(&self, _partial: &Self::Partial) -> bool {
96+
false
97+
}
98+
99+
fn accumulate(
100+
&self,
101+
partial: &mut Self::Partial,
102+
batch: &Columnar,
103+
_ctx: &mut ExecutionCtx,
104+
) -> VortexResult<()> {
105+
match batch {
106+
Columnar::Constant(c) => {
107+
if !c.scalar().is_null() {
108+
*partial += c.len() as u64;
109+
}
110+
}
111+
Columnar::Canonical(c) => {
112+
let valid = c.as_ref().valid_count()?;
113+
*partial += valid as u64;
114+
}
115+
}
116+
Ok(())
117+
}
118+
119+
fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
120+
Ok(partials)
121+
}
122+
123+
fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
124+
self.to_scalar(partial)
125+
}
126+
}
127+
128+
#[cfg(test)]
129+
mod tests {
130+
use vortex_buffer::buffer;
131+
use vortex_error::VortexResult;
132+
133+
use crate::IntoArray;
134+
use crate::LEGACY_SESSION;
135+
use crate::VortexSessionExecute;
136+
use crate::aggregate_fn::Accumulator;
137+
use crate::aggregate_fn::AggregateFnVTable;
138+
use crate::aggregate_fn::DynAccumulator;
139+
use crate::aggregate_fn::EmptyOptions;
140+
use crate::aggregate_fn::fns::count::Count;
141+
use crate::aggregate_fn::fns::count::count;
142+
use crate::arrays::ChunkedArray;
143+
use crate::arrays::ConstantArray;
144+
use crate::arrays::PrimitiveArray;
145+
use crate::dtype::DType;
146+
use crate::dtype::Nullability;
147+
use crate::dtype::PType;
148+
use crate::scalar::Scalar;
149+
use crate::validity::Validity;
150+
151+
#[test]
152+
fn count_all_valid() -> VortexResult<()> {
153+
let array =
154+
PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable).into_array();
155+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
156+
assert_eq!(count(&array, &mut ctx)?, 5);
157+
Ok(())
158+
}
159+
160+
#[test]
161+
fn count_with_nulls() -> VortexResult<()> {
162+
let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)])
163+
.into_array();
164+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
165+
assert_eq!(count(&array, &mut ctx)?, 3);
166+
Ok(())
167+
}
168+
169+
#[test]
170+
fn count_all_null() -> VortexResult<()> {
171+
let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]).into_array();
172+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
173+
assert_eq!(count(&array, &mut ctx)?, 0);
174+
Ok(())
175+
}
176+
177+
#[test]
178+
fn count_empty() -> VortexResult<()> {
179+
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
180+
let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
181+
let result = acc.finish()?;
182+
assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0));
183+
Ok(())
184+
}
185+
186+
#[test]
187+
fn count_multi_batch() -> VortexResult<()> {
188+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
189+
let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
190+
let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
191+
192+
let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array();
193+
acc.accumulate(&batch1, &mut ctx)?;
194+
195+
let batch2 = PrimitiveArray::from_option_iter([None, Some(5i32)]).into_array();
196+
acc.accumulate(&batch2, &mut ctx)?;
197+
198+
let result = acc.finish()?;
199+
assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3));
200+
Ok(())
201+
}
202+
203+
#[test]
204+
fn count_finish_resets_state() -> VortexResult<()> {
205+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
206+
let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
207+
let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?;
208+
209+
let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None]).into_array();
210+
acc.accumulate(&batch1, &mut ctx)?;
211+
let result1 = acc.finish()?;
212+
assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1));
213+
214+
let batch2 = PrimitiveArray::from_option_iter([Some(2i32), Some(3), None]).into_array();
215+
acc.accumulate(&batch2, &mut ctx)?;
216+
let result2 = acc.finish()?;
217+
assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2));
218+
Ok(())
219+
}
220+
221+
#[test]
222+
fn count_state_merge() -> VortexResult<()> {
223+
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
224+
let mut state = Count.empty_partial(&EmptyOptions, &dtype)?;
225+
226+
let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable);
227+
Count.combine_partials(&mut state, scalar1)?;
228+
229+
let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
230+
Count.combine_partials(&mut state, scalar2)?;
231+
232+
let result = Count.to_scalar(&state)?;
233+
Count.reset(&mut state);
234+
assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
235+
Ok(())
236+
}
237+
238+
#[test]
239+
fn count_constant_non_null() -> VortexResult<()> {
240+
let array = ConstantArray::new(42i32, 10);
241+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
242+
assert_eq!(count(&array.into_array(), &mut ctx)?, 10);
243+
Ok(())
244+
}
245+
246+
#[test]
247+
fn count_constant_null() -> VortexResult<()> {
248+
let array = ConstantArray::new(
249+
Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
250+
10,
251+
);
252+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
253+
assert_eq!(count(&array.into_array(), &mut ctx)?, 0);
254+
Ok(())
255+
}
256+
257+
#[test]
258+
fn count_chunked() -> VortexResult<()> {
259+
let chunk1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
260+
let chunk2 = PrimitiveArray::from_option_iter([None, Some(5i32), None]);
261+
let dtype = chunk1.dtype().clone();
262+
let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?;
263+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
264+
assert_eq!(count(&chunked.into_array(), &mut ctx)?, 3);
265+
Ok(())
266+
}
267+
}

0 commit comments

Comments
 (0)