Skip to content

Commit 72a44a8

Browse files
committed
add tests.
1 parent 357246e commit 72a44a8

1 file changed

Lines changed: 63 additions & 0 deletions

File tree

  • datafusion/functions-aggregate-common/src/aggregate/groups_accumulator

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,66 @@ impl<V: SeenValues, O: GroupIndexOperations> NullState<V, O> {
237237
pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
238238
self.seen_values.emit(emit_to)
239239
}
240+
241+
/// Clone and build a single [`BooleanBuffer`] from `seen_values`,
242+
/// only used for testing.
243+
#[cfg(test)]
244+
fn build_cloned_seen_values(&self) -> BooleanBuffer {
245+
if let Some(seen_values) =
246+
self.seen_values.as_any().downcast_ref::<FlatSeenValues>()
247+
{
248+
seen_values.builder.finish_cloned()
249+
} else if let Some(seen_values) = self
250+
.seen_values
251+
.as_any()
252+
.downcast_ref::<BlockedSeenValues>()
253+
{
254+
let mut return_builder = BooleanBufferBuilder::new(0);
255+
for builder in &seen_values.blocked_builders {
256+
for idx in 0..builder.len() {
257+
return_builder.append(builder.get_bit(idx));
258+
}
259+
}
260+
return_builder.finish()
261+
} else {
262+
unreachable!("unknown impl of SeenValues")
263+
}
264+
}
265+
266+
/// Emit a single [`NullBuffer`], only used for testing.
267+
#[cfg(test)]
268+
fn emit_all_in_once(&mut self, total_num_groups: usize) -> NullBuffer {
269+
if let Some(seen_values) =
270+
self.seen_values.as_any().downcast_ref::<FlatSeenValues>()
271+
{
272+
seen_values.emit(EmitTo::All)
273+
} else if let Some(seen_values) = self
274+
.seen_values
275+
.as_any()
276+
.downcast_ref::<BlockedSeenValues>()
277+
{
278+
let mut return_builder = BooleanBufferBuilder::new(0);
279+
let num_blocks = seen_values.blocked_builders.len();
280+
for _ in 0..num_blocks {
281+
let blocked_nulls = seen_values.emit(EmitTo::NextBlock(true));
282+
for bit in blocked_nulls.inner().iter() {
283+
return_builder.append(bit);
284+
}
285+
}
286+
287+
NullBuffer::new(return_builder.finish())
288+
} else {
289+
unreachable!("unknown impl of SeenValues")
290+
}
291+
}
240292
}
241293

242294
/// Structure marking if accumulating groups are seen at least one
243295
pub trait SeenValues: Default + Debug + Send {
296+
fn as_any(&self) -> &dyn std::any::Any {
297+
self
298+
}
299+
244300
fn resize(&mut self, total_num_groups: usize, default_value: bool);
245301

246302
fn set_bit(&mut self, block_id: u32, block_offset: u64, value: bool);
@@ -585,6 +641,8 @@ pub fn accumulate<T, F>(
585641
/// * `group_idx`: The group index for the current row
586642
/// * `batch_idx`: The index of the current row in the input arrays
587643
/// * `columns`: Reference to all input arrays for accessing values
644+
// TODO: support `blocked group index` for `accumulate_multiple`
645+
// (for supporting `blocked group index` for correlation group accumulator)
588646
pub fn accumulate_multiple<T, F>(
589647
group_indices: &[usize],
590648
value_columns: &[&PrimitiveArray<T>],
@@ -648,6 +706,8 @@ pub fn accumulate_multiple<T, F>(
648706
///
649707
/// See [`NullState::accumulate`], for more details on other
650708
/// arguments.
709+
// TODO: support `blocked group index` for `accumulate_indices`
710+
// (for supporting `blocked group index` for count group accumulator)
651711
pub fn accumulate_indices<F>(
652712
group_indices: &[usize],
653713
nulls: Option<&NullBuffer>,
@@ -839,6 +899,9 @@ mod test {
839899

840900
/// filter (defaults to None)
841901
filter: BooleanArray,
902+
903+
///
904+
block_size: Option<usize>,
842905
}
843906

844907
impl Fixture {

0 commit comments

Comments
 (0)