diff --git a/AGENTS.md b/AGENTS.md index e5c3d0cc13b..759008d730b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -113,6 +113,14 @@ cargo +nightly fmt --all cargo clippy --all-targets --all-features ``` +Do not push Rust code changes before running the applicable lint command above. If the change adds +or edits Rustdoc on public APIs, also run the CI docs command so broken intra-doc links are caught +locally: + +```bash +RUSTDOCFLAGS="-D warnings" cargo doc --profile ci --no-deps +``` + Notes: - For `.github/` changes, follow `.github/AGENTS.md` and run @@ -190,5 +198,8 @@ you ran and call out any checks you could not run. All commits must be signed off by the committers in this form: ```text -Signed-off-by: "COMMITTER" +Signed-off-by: COMMITTER ``` + +Do not wrap the committer name in quotes; the DCO check expects the exact unquoted name/email +pair from the commit author. diff --git a/vortex-array/benches/aggregate_grouped.rs b/vortex-array/benches/aggregate_grouped.rs index b067314c1d9..2d46a5cce8a 100644 --- a/vortex-array/benches/aggregate_grouped.rs +++ b/vortex-array/benches/aggregate_grouped.rs @@ -18,10 +18,8 @@ use vortex_array::aggregate_fn::EmptyOptions; use vortex_array::aggregate_fn::GroupedAccumulator; use vortex_array::aggregate_fn::fns::count::Count; use vortex_array::aggregate_fn::fns::sum::Sum; -use vortex_array::arrays::ListViewArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::VarBinViewArray; -use vortex_array::dtype::DType; use vortex_array::validity::Validity; use vortex_buffer::Buffer; @@ -45,44 +43,42 @@ fn total_element_count(group_sizes: &[usize]) -> usize { group_sizes.iter().sum() } -fn contiguous_list_view(elements: ArrayRef, group_sizes: &[usize]) -> ArrayRef { - let mut offset = 0usize; - let offsets: Buffer = group_sizes +struct DenseGroupedInput { + values: ArrayRef, + group_ids: Vec, + num_groups: usize, +} + +fn dense_grouped_input(values: ArrayRef, group_sizes: &[usize]) -> DenseGroupedInput { + assert_eq!(values.len(), total_element_count(group_sizes)); + + let group_ids = group_sizes .iter() - .map(|&size| { - let current_offset = offset; - offset += size; - current_offset as u32 - }) + .enumerate() + .flat_map(|(group_id, &size)| std::iter::repeat_n(group_id as u32, size)) .collect(); - let sizes: Buffer = group_sizes.iter().map(|&size| size as u32).collect(); - assert_eq!(elements.len(), total_element_count(group_sizes)); - - ListViewArray::try_new( - elements, - offsets.into_array(), - sizes.into_array(), - Validity::NonNullable, - ) - .unwrap() - .into_array() + DenseGroupedInput { + values, + group_ids, + num_groups: group_sizes.len(), + } } -fn i32_nullable_all_valid_input() -> ArrayRef { +fn i32_nullable_all_valid_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let values: Buffer = (0..element_count) .map(|i| (i % 1024) as i32 - 512) .collect(); let validity = Validity::from_iter(std::iter::repeat_n(true, element_count)); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::new(values, validity).into_array(), &group_sizes, ) } -fn i32_clustered_nulls_input() -> ArrayRef { +fn i32_clustered_nulls_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let values = (0..element_count).map(|i| { @@ -92,26 +88,26 @@ fn i32_clustered_nulls_input() -> ArrayRef { Some((i % 1024) as i32 - 512) } }); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::from_option_iter(values).into_array(), &group_sizes, ) } -fn f64_all_valid_input() -> ArrayRef { +fn f64_all_valid_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let mut rng = StdRng::seed_from_u64(GROUP_SIZE_SEED); let values: Buffer = (0..element_count) .map(|_| rng.random_range(-1000.0..1000.0)) .collect(); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::new(values, Validity::NonNullable).into_array(), &group_sizes, ) } -fn f64_clustered_nulls_input() -> ArrayRef { +fn f64_clustered_nulls_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let mut rng = StdRng::seed_from_u64(GROUP_SIZE_SEED); @@ -122,40 +118,38 @@ fn f64_clustered_nulls_input() -> ArrayRef { Some(rng.random_range(-1000.0f64..1000.0)) } }); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::from_option_iter(values).into_array(), &group_sizes, ) } -fn varbinview_input() -> ArrayRef { +fn varbinview_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let values: Vec = (0..element_count) .map(|i| format!("value-{i:06}")) .collect(); - contiguous_list_view( + dense_grouped_input( VarBinViewArray::from_iter_str(values.iter().map(String::as_str)).into_array(), &group_sizes, ) } -fn list_element_dtype(list_view: &ArrayRef) -> DType { - match list_view.dtype() { - DType::List(element_dtype, _) => element_dtype.as_ref().clone(), - dtype => unreachable!("expected List dtype, got {dtype}"), - } -} - -fn grouped_accumulator(list_view: &ArrayRef, vtable: V) -> ArrayRef +fn grouped_accumulator(input: &DenseGroupedInput, vtable: V) -> ArrayRef where V: AggregateFnVTable + Clone, { let mut acc = - GroupedAccumulator::try_new(vtable, EmptyOptions, list_element_dtype(list_view)).unwrap(); - acc.accumulate_list(list_view, &mut LEGACY_SESSION.create_execution_ctx()) - .unwrap(); - divan::black_box(acc.finish().unwrap()) + GroupedAccumulator::try_new(vtable, EmptyOptions, input.values.dtype().clone()).unwrap(); + acc.accumulate( + &input.values, + &input.group_ids, + input.num_groups, + &mut LEGACY_SESSION.create_execution_ctx(), + ) + .unwrap(); + divan::black_box(acc.finish(input.num_groups).unwrap()) } #[divan::bench] diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index c89418e67a6..ab4e0ee26ba 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -172,7 +172,7 @@ impl DynAccumulator for Accumulator { } // 3. Iteratively check the registry against each intermediate encoding, executing one - // step between checks. Mirrors the loop in `GroupedAccumulator::accumulate_list_view`. + // step between checks. Mirrors the loop in `GroupedAccumulator::accumulate`. // Iteration 0 re-checks the initial encoding — a redundant HashMap miss, the price of // keeping the loop body uniform. Terminates on `AnyColumnar` (Canonical or Constant) // since the vtable's `accumulate(&Columnar)` handles both cases directly. diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 4b94159127b..938ee2b0dd6 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -1,19 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use arrow_buffer::ArrowNativeType; use vortex_buffer::Buffer; -use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; -use vortex_error::vortex_panic; -use vortex_mask::Mask; -use crate::AnyCanonical; use crate::ArrayRef; -use crate::Canonical; use crate::Columnar; use crate::ExecutionCtx; use crate::IntoArray; @@ -22,26 +15,23 @@ use crate::aggregate_fn::AggregateFn; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; use crate::aggregate_fn::session::AggregateFnSessionExt; -use crate::arrays::ChunkedArray; -use crate::arrays::FixedSizeListArray; -use crate::arrays::ListViewArray; -use crate::arrays::fixed_size_list::FixedSizeListArrayExt; -use crate::arrays::listview::ListViewArrayExt; use crate::builders::builder_with_capacity; -use crate::builtins::ArrayBuiltins; +use crate::columnar::AnyColumnar; use crate::dtype::DType; -use crate::dtype::IntegerPType; use crate::executor::max_iterations; -use crate::match_each_integer_ptype; +use crate::scalar::Scalar; /// Reference-counted type-erased grouped accumulator. pub type GroupedAccumulatorRef = Box; -/// An accumulator used for computing grouped aggregates. +/// An accumulator used for computing aggregates over dense group ids. /// -/// Note that the groups must be processed in order, and the accumulator does not support random -/// access to groups. +/// Group ids are caller-assigned `u32` ordinals in the dense range `0..num_groups`. Input batches +/// may repeat, omit, and reorder those ids, but every id must identify a state slot rather than a +/// raw group key. The accumulator keeps one partial state per slot, so ordered and unordered +/// grouping only differ in how the caller assigns ids. pub struct GroupedAccumulator { /// The vtable of the aggregate function. vtable: V, @@ -55,8 +45,8 @@ pub struct GroupedAccumulator { return_dtype: DType, /// The DType of the partial accumulator state. partial_dtype: DType, - /// The accumulated state for prior batches of groups. - partials: Vec, + /// Dense per-group partial state. + partials: Vec, } impl GroupedAccumulator { @@ -84,249 +74,322 @@ impl GroupedAccumulator { dtype, return_dtype, partial_dtype, - partials: vec![], + partials: Vec::new(), }) } + + fn ensure_groups(&mut self, num_groups: usize) -> VortexResult<()> { + validate_num_groups(num_groups)?; + + while self.partials.len() < num_groups { + self.partials + .push(self.vtable.empty_partial(&self.options, &self.dtype)?); + } + Ok(()) + } + + fn validate_group_ids(&self, group_ids: &[u32], num_groups: usize) -> VortexResult<()> { + validate_num_groups(num_groups)?; + for &group_id in group_ids { + vortex_ensure!( + (group_id as usize) < num_groups, + "Group id {} out of range for {} groups", + group_id, + num_groups + ); + } + Ok(()) + } + + fn accumulate_kernel_result( + &mut self, + result: GroupedAggregateKernelResult, + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + self.accumulate_partials(result.partials(), result.group_ids(), num_groups, ctx) + } + + fn accumulate_fallback( + &mut self, + batch: &ArrayRef, + group_ids: &[u32], + ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + let Some((&first, rest)) = group_ids.split_first() else { + return Ok(()); + }; + let mut first = first; + let mut last = first; + for &group_id in rest { + first = first.min(group_id); + last = last.max(group_id); + } + + let first = first as usize; + let mut buckets = vec![Vec::new(); last as usize - first + 1]; + for (row_idx, &group_id) in group_ids.iter().enumerate() { + buckets[group_id as usize - first].push(row_idx as u64); + } + + for (offset, rows) in buckets.into_iter().enumerate() { + if rows.is_empty() { + continue; + } + + let group = first + offset; + if self.vtable.is_saturated(&self.partials[group]) { + continue; + } + + let taken = batch.clone().take(Buffer::from_iter(rows).into_array())?; + let mut accumulator = Accumulator::try_new( + self.vtable.clone(), + self.options.clone(), + self.dtype.clone(), + )?; + accumulator.accumulate(&taken, ctx)?; + let partial = accumulator.flush()?; + self.vtable + .combine_partials(&mut self.partials[group], partial)?; + } + Ok(()) + } +} + +fn validate_num_groups(num_groups: usize) -> VortexResult<()> { + vortex_ensure!( + num_groups == 0 || u32::try_from(num_groups - 1).is_ok(), + "num_groups {} exceeds dense u32 group id capacity", + num_groups + ); + Ok(()) } -/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the aggregate -/// function is not known at compile time. +/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the +/// aggregate function is not known at compile time. pub trait DynGroupedAccumulator: 'static + Send { - /// Accumulate a list of groups into the accumulator. - fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>; + /// Accumulate a values batch into dense group state. + /// + /// `group_ids` is parallel to `batch`. Each id must be a caller-assigned group ordinal in + /// `0..num_groups`; ids may repeat, appear out of order, or be absent from a given batch. + fn accumulate( + &mut self, + batch: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult<()>; + + /// Fold columnar partial states into dense group state. + /// + /// `group_ids` is parallel to `partials` and follows the same dense ordinal contract as + /// [`Self::accumulate`]. + fn accumulate_partials( + &mut self, + partials: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult<()>; + + /// Merge one group from another grouped accumulator into this accumulator. + fn merge_group( + &mut self, + into: u32, + other: &dyn DynGroupedAccumulator, + from: u32, + ) -> VortexResult<()>; - /// Finish the accumulation and return the partial aggregate results for all groups. + /// Return this accumulator's partial dtype. + fn partial_dtype(&self) -> &DType; + + /// Read one group's current partial state. + fn partial_scalar(&self, group_id: u32) -> VortexResult; + + /// Finish the accumulation and return partial aggregate results for all groups. + /// /// Resets the accumulator state for the next round of accumulation. - fn flush(&mut self) -> VortexResult; + fn flush_partials(&mut self, num_groups: usize) -> VortexResult; - /// Finish the accumulation and return the final aggregate results for all groups. + /// Finish the accumulation and return final aggregate results for all groups. + /// /// Resets the accumulator state for the next round of accumulation. - fn finish(&mut self) -> VortexResult; + fn finish(&mut self, num_groups: usize) -> VortexResult; } impl DynGroupedAccumulator for GroupedAccumulator { - fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { - let elements_dtype = match groups.dtype() { - DType::List(elem, _) => elem, - DType::FixedSizeList(elem, ..) => elem, - _ => vortex_bail!( - "Input DType mismatch: expected List or FixedSizeList, got {}", - groups.dtype() - ), - }; + fn accumulate( + &mut self, + batch: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { vortex_ensure!( - elements_dtype.as_ref() == &self.dtype, + batch.dtype() == &self.dtype, "Input DType mismatch: expected {}, got {}", self.dtype, - elements_dtype + batch.dtype() ); - - // We first execute the groups until it is a ListView or FixedSizeList, since we only - // dispatch the aggregate kernel over the elements of these arrays. - let canonical = match groups.clone().execute::(ctx)? { - Columnar::Canonical(c) => c, - Columnar::Constant(c) => c.into_array().execute::(ctx)?, - }; - match canonical { - Canonical::List(groups) => self.accumulate_list_view(&groups, ctx), - Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx), - _ => vortex_panic!("We checked the DType above, so this should never happen"), - } - } - - fn flush(&mut self) -> VortexResult { - let states = std::mem::take(&mut self.partials); - Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array()) - } - - fn finish(&mut self) -> VortexResult { - let states = self.flush()?; - let results = self.vtable.finalize(states)?; - vortex_ensure!( - results.dtype() == &self.return_dtype, - "Return DType mismatch: expected {}, got {}", - self.return_dtype, - results.dtype() + batch.len() == group_ids.len(), + "Grouped aggregate input length mismatch: {} values, {} group ids", + batch.len(), + group_ids.len() ); - Ok(results) - } -} + self.validate_group_ids(group_ids, num_groups)?; + self.ensure_groups(num_groups)?; -impl GroupedAccumulator { - fn accumulate_list_view( - &mut self, - groups: &ListViewArray, - ctx: &mut ExecutionCtx, - ) -> VortexResult<()> { - let mut elements = groups.elements().clone(); - let groups_validity = groups.validity()?; let session = ctx.session().clone(); + if let Some(kernel) = session + .aggregate_fns() + .find_grouped_kernel(batch.encoding_id(), self.aggregate_fn.id()) + && let Some(result) = + kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)? + { + return self.accumulate_kernel_result(result, num_groups, ctx); + } + + if self.vtable.try_accumulate_grouped( + &mut self.partials[..num_groups], + batch, + group_ids, + ctx, + )? { + return Ok(()); + } + + let input = batch.clone(); + let mut batch = batch.clone(); for _ in 0..max_iterations() { - if elements.is::() { + if batch.is::() { break; } - if let Some(result) = session + if let Some(kernel) = session .aggregate_fns() - .find_grouped_kernel(elements.encoding_id(), self.aggregate_fn.id()) - .and_then(|kernel| { - // SAFETY: we assume that elements execution is safe - let groups = unsafe { - ListViewArray::new_unchecked( - elements.clone(), - groups.offsets().clone(), - groups.sizes().clone(), - groups_validity.clone(), - ) - }; - kernel - .grouped_aggregate(&self.aggregate_fn, &groups) - .transpose() - }) - .transpose()? + .find_grouped_kernel(batch.encoding_id(), self.aggregate_fn.id()) + && let Some(result) = kernel.grouped_aggregate( + &self.aggregate_fn, + &batch, + group_ids, + num_groups, + ctx, + )? { - return self.push_result(result); + return self.accumulate_kernel_result(result, num_groups, ctx); } - // Execute one step and try again - elements = elements.execute(ctx)?; + batch = batch.execute(ctx)?; } - // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. - let elements = elements.execute::(ctx)?.into_array(); - let offsets = groups.offsets(); - let sizes = groups.sizes().cast(offsets.dtype().clone())?; - let validity = groups_validity.execute_mask(offsets.len(), ctx)?; - - match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { - let offsets = offsets.clone().execute::>(ctx)?; - let sizes = sizes.execute::>(ctx)?; - self.accumulate_list_view_typed( - &elements, - offsets.as_ref(), - sizes.as_ref(), - &validity, - ctx, - ) - }) + let columnar = batch.clone().execute::(ctx)?; + if self.vtable.accumulate_grouped( + &mut self.partials[..num_groups], + &columnar, + group_ids, + ctx, + )? { + return Ok(()); + } + + self.accumulate_fallback(&input, group_ids, ctx) } - fn accumulate_list_view_typed( + fn accumulate_partials( &mut self, - elements: &ArrayRef, - offsets: &[O], - sizes: &[O], - validity: &Mask, + partials: &ArrayRef, + group_ids: &[u32], + num_groups: usize, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let mut accumulator = Accumulator::try_new( - self.vtable.clone(), - self.options.clone(), - self.dtype.clone(), - )?; - let mut states = builder_with_capacity(&self.partial_dtype, offsets.len()); - - // `validity` is the per-group list-view validity, so it is zipped element-wise with the - // offsets and sizes (one entry per group). - for ((offset, size), valid) in offsets.iter().zip(sizes.iter()).zip(validity.iter()) { - let offset = offset.to_usize().vortex_expect("Offset value is not usize"); - let size = size.to_usize().vortex_expect("Size value is not usize"); - - if valid { - let group = elements.slice(offset..offset + size)?; - accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.flush()?)?; - } else { - states.append_null() - } - } + vortex_ensure!( + partials.dtype() == &self.partial_dtype, + "Partial DType mismatch: expected {}, got {}", + self.partial_dtype, + partials.dtype() + ); + vortex_ensure!( + partials.len() == group_ids.len(), + "Grouped aggregate partial length mismatch: {} partials, {} group ids", + partials.len(), + group_ids.len() + ); + + self.validate_group_ids(group_ids, num_groups)?; + self.ensure_groups(num_groups)?; - self.push_result(states.finish()) + for (row_idx, &group_id) in group_ids.iter().enumerate() { + let partial = partials.execute_scalar(row_idx, ctx)?; + self.vtable + .combine_partials(&mut self.partials[group_id as usize], partial)?; + } + Ok(()) } - fn accumulate_fixed_size_list( + fn merge_group( &mut self, - groups: &FixedSizeListArray, - ctx: &mut ExecutionCtx, + into: u32, + other: &dyn DynGroupedAccumulator, + from: u32, ) -> VortexResult<()> { - let mut elements = groups.elements().clone(); - let groups_validity = groups.validity()?; - let session = ctx.session().clone(); - - for _ in 0..64 { - if elements.is::() { - break; - } + vortex_ensure!( + other.partial_dtype() == &self.partial_dtype, + "Partial DType mismatch: expected {}, got {}", + self.partial_dtype, + other.partial_dtype() + ); + self.ensure_groups((into as usize) + 1)?; + let partial = other.partial_scalar(from)?; + self.vtable + .combine_partials(&mut self.partials[into as usize], partial) + } - if let Some(result) = session - .aggregate_fns() - .find_grouped_kernel(elements.encoding_id(), self.aggregate_fn.id()) - .and_then(|kernel| { - // SAFETY: we assume that elements execution is safe - let groups = unsafe { - FixedSizeListArray::new_unchecked( - elements.clone(), - groups.list_size(), - groups_validity.clone(), - groups.len(), - ) - }; - - kernel - .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups) - .transpose() - }) - .transpose()? - { - return self.push_result(result); - } + fn partial_dtype(&self) -> &DType { + &self.partial_dtype + } - // Execute one step and try again - elements = elements.execute(ctx)?; + fn partial_scalar(&self, group_id: u32) -> VortexResult { + if let Some(partial) = self.partials.get(group_id as usize) { + self.vtable.to_scalar(partial) + } else { + let partial = self.vtable.empty_partial(&self.options, &self.dtype)?; + self.vtable.to_scalar(&partial) } + } - // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. - let elements = elements.execute::(ctx)?.into_array(); - let validity = groups_validity.execute_mask(groups.len(), ctx)?; - - let mut accumulator = Accumulator::try_new( - self.vtable.clone(), - self.options.clone(), - self.dtype.clone(), - )?; - let mut states = builder_with_capacity(&self.partial_dtype, groups.len()); - - let mut offset = 0; - let size = groups - .list_size() - .to_usize() - .vortex_expect("List size is not usize"); - - for valid in validity.iter() { - if valid { - let group = elements.slice(offset..offset + size)?; - accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.flush()?)?; - } else { - states.append_null() - } - offset += size; + fn flush_partials(&mut self, num_groups: usize) -> VortexResult { + vortex_ensure!( + num_groups >= self.partials.len(), + "Cannot flush {} groups after accumulating {} groups", + num_groups, + self.partials.len() + ); + self.ensure_groups(num_groups)?; + + let mut states = builder_with_capacity(&self.partial_dtype, num_groups); + for partial in &self.partials { + states.append_scalar(&self.vtable.to_scalar(partial)?)?; } + self.partials.clear(); - self.push_result(states.finish()) + Ok(states.finish()) } - fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> { + fn finish(&mut self, num_groups: usize) -> VortexResult { + let states = self.flush_partials(num_groups)?; + let results = self.vtable.finalize(states)?; + vortex_ensure!( - state.dtype() == &self.partial_dtype, - "State DType mismatch: expected {}, got {}", - self.partial_dtype, - state.dtype() + results.dtype() == &self.return_dtype, + "Return DType mismatch: expected {}, got {}", + self.return_dtype, + results.dtype() ); - self.partials.push(state); - Ok(()) + + Ok(results) } } diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index e25c42e0845..07395211ca8 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -82,6 +82,22 @@ impl AggregateFnVTable for Count { Ok(true) } + fn try_accumulate_grouped( + &self, + states: &mut [Self::Partial], + batch: &ArrayRef, + group_ids: &[u32], + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let validity = batch.validity()?.execute_mask(batch.len(), ctx)?; + for (&group_id, valid) in group_ids.iter().zip(validity.iter()) { + if valid { + states[group_id as usize] += 1; + } + } + Ok(true) + } + fn accumulate( &self, _partial: &mut Self::Partial, @@ -114,11 +130,14 @@ mod tests { use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::DynGroupedAccumulator; use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupedAccumulator; use crate::aggregate_fn::fns::count::Count; use crate::arrays::ChunkedArray; use crate::arrays::ConstantArray; use crate::arrays::PrimitiveArray; + use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -225,6 +244,53 @@ mod tests { Ok(()) } + #[test] + fn grouped_count_dense_ids() -> VortexResult<()> { + let values = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None, Some(6)]) + .into_array(); + let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; + acc.accumulate( + &values, + &[0, 0, 1, 1, 2, 2], + 3, + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + + let actual = acc.finish(3)?; + let expected = PrimitiveArray::from_iter([1u64, 2, 1]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn grouped_count_rejects_out_of_range_group_id() -> VortexResult<()> { + let values = PrimitiveArray::new(buffer![1i32, 2], Validity::NonNullable).into_array(); + let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + assert!(acc.accumulate(&values, &[0, 2], 2, &mut ctx).is_err()); + Ok(()) + } + + #[test] + fn grouped_count_accumulate_partials_and_merge_group() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::Nullable); + let partials = PrimitiveArray::from_iter([2u64, 3, 5]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + let mut left = GroupedAccumulator::try_new(Count, EmptyOptions, dtype.clone())?; + left.accumulate_partials(&partials, &[0, 1, 1], 2, &mut ctx)?; + + let mut right = GroupedAccumulator::try_new(Count, EmptyOptions, dtype)?; + right.merge_group(0, &left, 1)?; + + let actual = right.finish(1)?; + let expected = PrimitiveArray::from_iter([8u64]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + #[test] fn count_constant_non_null() -> VortexResult<()> { let array = ConstantArray::new(42i32, 10); diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 24799570ff7..4e75a1390f3 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -6,11 +6,15 @@ mod constant; mod decimal; mod primitive; +use num_traits::AsPrimitive; +use num_traits::ToPrimitive; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; +use vortex_mask::AllOr; +use vortex_mask::Mask; use self::bool::accumulate_bool; use self::constant::multiply_constant; @@ -25,14 +29,19 @@ use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; +use crate::arrays::BoolArray; +use crate::arrays::PrimitiveArray; +use crate::arrays::bool::BoolArrayExt; use crate::dtype::DType; use crate::dtype::DecimalDType; use crate::dtype::MAX_PRECISION; +use crate::dtype::NativePType; use crate::dtype::Nullability; use crate::dtype::PType; use crate::expr::stats::Precision; use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; +use crate::match_each_native_ptype; use crate::scalar::DecimalValue; use crate::scalar::Scalar; @@ -253,6 +262,30 @@ impl AggregateFnVTable for Sum { Ok(()) } + fn accumulate_grouped( + &self, + partials: &mut [Self::Partial], + batch: &Columnar, + group_ids: &[u32], + ctx: &mut ExecutionCtx, + ) -> VortexResult { + match batch { + Columnar::Canonical(Canonical::Primitive(p)) => { + accumulate_grouped_primitive(partials, p, group_ids, ctx)?; + Ok(true) + } + Columnar::Canonical(Canonical::Bool(b)) => { + accumulate_grouped_bool(partials, b, group_ids, ctx)?; + Ok(true) + } + // Decimal and constants still use the universal grouped fallback. + Columnar::Canonical(Canonical::Decimal(_)) | Columnar::Constant(_) => Ok(false), + Columnar::Canonical(_) => { + vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()) + } + } + } + fn finalize(&self, partials: ArrayRef) -> VortexResult { Ok(partials) } @@ -299,6 +332,146 @@ fn make_zero_state(return_dtype: &DType) -> SumState { } } +fn for_each_valid_idx(validity: &Mask, len: usize, mut f: impl FnMut(usize)) { + match validity.indices() { + AllOr::All => { + for idx in 0..len { + f(idx); + } + } + AllOr::None => {} + AllOr::Some(indices) => { + for &idx in indices { + f(idx); + } + } + } +} + +fn accumulate_grouped_unsigned(partials: &mut [SumPartial], group_id: u32, value: u64) { + let partial = &mut partials[group_id as usize]; + let saturated = match partial.current.as_mut() { + None => return, + Some(SumState::Unsigned(acc)) => checked_add_u64(acc, value), + Some(_) => vortex_panic!("unsigned sum state with non-unsigned input"), + }; + if saturated { + partial.current = None; + } +} + +fn accumulate_grouped_signed(partials: &mut [SumPartial], group_id: u32, value: i64) { + let partial = &mut partials[group_id as usize]; + let saturated = match partial.current.as_mut() { + None => return, + Some(SumState::Signed(acc)) => checked_add_i64(acc, value), + Some(_) => vortex_panic!("signed sum state with non-signed input"), + }; + if saturated { + partial.current = None; + } +} + +fn accumulate_grouped_float(partials: &mut [SumPartial], group_id: u32, value: f64) { + if value.is_nan() { + return; + } + + match partials[group_id as usize].current.as_mut() { + None => {} + Some(SumState::Float(acc)) => *acc += value, + Some(_) => vortex_panic!("float sum state with non-float input"), + } +} + +fn accumulate_grouped_primitive( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + ctx: &mut ExecutionCtx, +) -> VortexResult<()> { + let validity = primitive + .as_ref() + .validity()? + .execute_mask(primitive.as_ref().len(), ctx)?; + match_each_native_ptype!(primitive.ptype(), + unsigned: |T| { + accumulate_grouped_primitive_unsigned::(partials, primitive, group_ids, &validity); + Ok(()) + }, + signed: |T| { + accumulate_grouped_primitive_signed::(partials, primitive, group_ids, &validity); + Ok(()) + }, + floating: |T| { + accumulate_grouped_primitive_float::(partials, primitive, group_ids, &validity); + Ok(()) + } + ) +} + +fn accumulate_grouped_primitive_unsigned( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + validity: &Mask, +) where + T: NativePType + AsPrimitive, +{ + let values = primitive.as_slice::(); + for_each_valid_idx(validity, values.len(), |idx| { + accumulate_grouped_unsigned(partials, group_ids[idx], values[idx].as_()); + }); +} + +fn accumulate_grouped_primitive_signed( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + validity: &Mask, +) where + T: NativePType + AsPrimitive, +{ + let values = primitive.as_slice::(); + for_each_valid_idx(validity, values.len(), |idx| { + accumulate_grouped_signed(partials, group_ids[idx], values[idx].as_()); + }); +} + +fn accumulate_grouped_primitive_float( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + validity: &Mask, +) where + T: NativePType + ToPrimitive, +{ + let values = primitive.as_slice::(); + for_each_valid_idx(validity, values.len(), |idx| { + let value = values[idx].to_f64().vortex_expect("float to f64"); + accumulate_grouped_float(partials, group_ids[idx], value); + }); +} + +fn accumulate_grouped_bool( + partials: &mut [SumPartial], + bools: &BoolArray, + group_ids: &[u32], + ctx: &mut ExecutionCtx, +) -> VortexResult<()> { + let validity = bools + .as_ref() + .validity()? + .execute_mask(bools.as_ref().len(), ctx)?; + let values = bools.to_bit_buffer(); + for_each_valid_idx(&validity, values.len(), |idx| { + if values.value(idx) { + accumulate_grouped_unsigned(partials, group_ids[idx], 1); + } + }); + Ok(()) +} + /// Checked add for u64, returning true if overflow occurred. #[inline(always)] fn checked_add_u64(acc: &mut u64, val: u64) -> bool { @@ -346,8 +519,6 @@ mod tests { use crate::arrays::ChunkedArray; use crate::arrays::ConstantArray; use crate::arrays::DecimalArray; - use crate::arrays::FixedSizeListArray; - use crate::arrays::ListViewArray; use crate::arrays::PrimitiveArray; use crate::assert_arrays_eq; use crate::dtype::DType; @@ -512,20 +683,26 @@ mod tests { // Grouped sum tests - fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { - let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?; - acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; - acc.finish() + fn run_grouped_sum( + values: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ) -> VortexResult { + let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, values.dtype().clone())?; + acc.accumulate( + values, + group_ids, + num_groups, + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + acc.finish(num_groups) } #[test] - fn grouped_sum_fixed_size_list() -> VortexResult<()> { - let elements = + fn grouped_sum_dense_ids() -> VortexResult<()> { + let values = PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); - let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let result = run_grouped_sum(&values, &[0, 0, 0, 1, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -534,13 +711,10 @@ mod tests { #[test] fn grouped_sum_with_null_elements() -> VortexResult<()> { - let elements = + let values = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)]) .into_array(); - let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let result = run_grouped_sum(&values, &[0, 0, 0, 1, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -548,30 +722,22 @@ mod tests { } #[test] - fn grouped_sum_with_null_group() -> VortexResult<()> { - let elements = - PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable) - .into_array(); - let validity = Validity::from_iter([true, false, true]); - let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + fn grouped_sum_empty_group() -> VortexResult<()> { + let values = + PrimitiveArray::new(buffer![1i32, 2, 3, 7, 8, 9], Validity::NonNullable).into_array(); + let result = run_grouped_sum(&values, &[0, 0, 0, 2, 2, 2], 3)?; let expected = - PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array(); + PrimitiveArray::from_option_iter([Some(6i64), Some(0i64), Some(24i64)]).into_array(); assert_arrays_eq!(&result, &expected); Ok(()) } #[test] fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> { - let elements = + let values = PrimitiveArray::from_option_iter([None::, None, Some(3), Some(4)]).into_array(); - let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let result = run_grouped_sum(&values, &[0, 0, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -580,12 +746,8 @@ mod tests { #[test] fn grouped_sum_bool() -> VortexResult<()> { - let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect(); - let groups = - FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Bool(Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let values: BoolArray = [true, false, true, true, true, true].into_iter().collect(); + let result = run_grouped_sum(&values.into_array(), &[0, 0, 0, 1, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -598,19 +760,17 @@ mod tests { let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype)?; - let elements1 = + let values1 = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); - let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?; - acc.accumulate_list(&groups1.into_array(), &mut ctx)?; - let result1 = acc.finish()?; + acc.accumulate(&values1, &[0, 0, 1, 1], 2, &mut ctx)?; + let result1 = acc.finish(2)?; let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array(); assert_arrays_eq!(&result1, &expected1); - let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); - let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?; - acc.accumulate_list(&groups2.into_array(), &mut ctx)?; - let result2 = acc.finish()?; + let values2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); + acc.accumulate(&values2, &[0, 0], 1, &mut ctx)?; + let result2 = acc.finish(1)?; let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array(); assert_arrays_eq!(&result2, &expected2); @@ -618,20 +778,13 @@ mod tests { } #[test] - fn grouped_sum_listview_out_of_order_offsets_with_null_group() -> VortexResult<()> { - let elements = + fn grouped_sum_out_of_order_group_ids() -> VortexResult<()> { + let values = PrimitiveArray::new(buffer![100i32, 200, 300], Validity::NonNullable).into_array(); - let offsets = PrimitiveArray::new(buffer![2i32, 0, 1], Validity::NonNullable).into_array(); - let sizes = PrimitiveArray::new(buffer![1i32, 1, 1], Validity::NonNullable).into_array(); - let validity = Validity::from_iter([true, false, true]); - let groups = ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array(); - - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let result = run_grouped_sum(&groups, &elem_dtype)?; + let result = run_grouped_sum(&values, &[2, 0, 1], 3)?; - // group 0 -> elements[2..3] = 300; group 1 -> null; group 2 -> elements[1..2] = 200. let expected = - PrimitiveArray::from_option_iter([Some(300i64), None, Some(200i64)]).into_array(); + PrimitiveArray::from_option_iter([Some(200i64), Some(300), Some(100)]).into_array(); assert_arrays_eq!(&result, &expected); Ok(()) } diff --git a/vortex-array/src/aggregate_fn/kernels.rs b/vortex-array/src/aggregate_fn/kernels.rs index d806b18d84d..91248091437 100644 --- a/vortex-array/src/aggregate_fn/kernels.rs +++ b/vortex-array/src/aggregate_fn/kernels.rs @@ -6,13 +6,12 @@ use std::fmt::Debug; +use vortex_buffer::Buffer; use vortex_error::VortexResult; use crate::ArrayRef; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnRef; -use crate::arrays::FixedSizeListArray; -use crate::arrays::ListViewArray; use crate::scalar::Scalar; /// A pluggable kernel for an aggregate function. @@ -28,36 +27,50 @@ pub trait DynAggregateKernel: 'static + Send + Sync + Debug { ) -> VortexResult>; } -/// A pluggable kernel for batch aggregation of many groups. +/// Partial grouped aggregate output produced by an encoding-specific grouped kernel. /// -/// The kernel is matched on the encoding of the _elements_ array, which is the inner array of the -/// provided `ListViewArray`. This is more pragmatic than having every kernel match on the outer -/// list encoding and having to deal with the possibility of multiple list encodings. +/// `group_ids` is parallel to `partials`: each row in `partials` is a partial state for the +/// corresponding dense group ordinal. The ids may repeat, omit, and reorder groups, but must be +/// valid slots in the accumulator's `0..num_groups` range. The grouped accumulator merges this +/// batch through `accumulate_partials`. +#[derive(Clone, Debug)] +pub struct GroupedAggregateKernelResult { + group_ids: Buffer, + partials: ArrayRef, +} + +impl GroupedAggregateKernelResult { + pub fn new(group_ids: Buffer, partials: ArrayRef) -> Self { + Self { + group_ids, + partials, + } + } + + pub fn group_ids(&self) -> &[u32] { + self.group_ids.as_ref() + } + + pub fn partials(&self) -> &ArrayRef { + &self.partials + } +} + +/// A pluggable kernel for batch aggregation of many groups. /// -/// Each element of the list array represents a group and the result of the grouped aggregate -/// should be an array of the same length, where each element is the aggregate state of the -/// corresponding group. +/// The kernel is matched on the encoding of the values array. It receives the same dense group +/// ordinals that the caller passed to the grouped accumulator and may aggregate directly in the +/// encoded domain. /// /// Return `Ok(None)` if the kernel cannot be applied to the given aggregate function. pub trait DynGroupedAggregateKernel: 'static + Send + Sync + Debug { - /// Aggregate each group in the provided `ListViewArray` and return an array of the - /// aggregate states. + /// Aggregate values into a partial-state batch keyed by dense group ordinal. fn grouped_aggregate( &self, aggregate_fn: &AggregateFnRef, - groups: &ListViewArray, - ) -> VortexResult>; - - /// Aggregate each group in the provided `FixedSizeListArray` and return an array of the - /// aggregate states. - fn grouped_aggregate_fixed_size( - &self, - aggregate_fn: &AggregateFnRef, - groups: &FixedSizeListArray, - ) -> VortexResult> { - // TODO(ngates): we could automatically delegate to `grouped_aggregate` if SequenceArray - // was in the vortex-array crate - let _ = (aggregate_fn, groups); - Ok(None) - } + batch: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult>; } diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 28b91d45166..24c2113e64a 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -146,6 +146,37 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { ctx: &mut ExecutionCtx, ) -> VortexResult<()>; + /// Try to accumulate a raw values batch into dense per-group states before decompression. + /// + /// `group_ids` is parallel to `batch` and contains caller-assigned dense ordinals in + /// `0..states.len()`. Ids may repeat, appear out of order, or be absent from the batch. + /// Returns `true` when the batch was fully handled. + fn try_accumulate_grouped( + &self, + _states: &mut [Self::Partial], + _batch: &ArrayRef, + _group_ids: &[u32], + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + Ok(false) + } + + /// Accumulate a canonical values batch into dense per-group states. + /// + /// `group_ids` is parallel to `batch` and contains caller-assigned dense ordinals in + /// `0..states.len()`. Ids may repeat, appear out of order, or be absent from the batch. + /// Returns `true` when the batch was fully handled. The provided default preserves universal + /// correctness through [`crate::aggregate_fn::GroupedAccumulator`]'s fallback. + fn accumulate_grouped( + &self, + _states: &mut [Self::Partial], + _batch: &Columnar, + _group_ids: &[u32], + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + Ok(false) + } + /// Finalize an array of accumulator states into an array of aggregate results. /// /// The provides `states` array has dtype as specified by `state_dtype`, the result array