Skip to content

Commit fb87092

Browse files
authored
NanCount AggregateFn (#7003)
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 2ba5636 commit fb87092

17 files changed

Lines changed: 517 additions & 290 deletions

File tree

encodings/alp/public-api.lock

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@ impl vortex_array::arrays::slice::SliceKernel for vortex_alp::ALP
2424

2525
pub fn vortex_alp::ALP::slice(array: &Self::Array, range: core::ops::range::Range<usize>, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
2626

27-
impl vortex_array::compute::nan_count::NaNCountKernel for vortex_alp::ALP
28-
29-
pub fn vortex_alp::ALP::nan_count(&self, array: &vortex_alp::ALPArray) -> vortex_error::VortexResult<usize>
30-
3127
impl vortex_array::scalar_fn::fns::between::kernel::BetweenReduce for vortex_alp::ALP
3228

3329
pub fn vortex_alp::ALP::between(array: &vortex_alp::ALPArray, lower: &vortex_array::array::ArrayRef, upper: &vortex_array::array::ArrayRef, options: &vortex_array::scalar_fn::fns::between::BetweenOptions) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
@@ -559,3 +555,5 @@ pub fn vortex_alp::alp_encode(parray: &vortex_array::arrays::primitive::array::P
559555
pub fn vortex_alp::alp_rd_decode<T: vortex_alp::ALPRDFloat>(left_parts: vortex_buffer::buffer::Buffer<u16>, left_parts_dict: &[u16], right_bit_width: u8, right_parts: vortex_buffer::buffer_mut::BufferMut<<T as vortex_alp::ALPRDFloat>::UINT>, left_parts_patches: core::option::Option<&vortex_array::patches::Patches>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_buffer::buffer::Buffer<T>>
560556

561557
pub fn vortex_alp::decompress_into_array(array: vortex_alp::ALPArray, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::arrays::primitive::array::PrimitiveArray>
558+
559+
pub fn vortex_alp::initialize(session: &mut vortex_session::VortexSession)

encodings/alp/src/alp/compute/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ mod cast;
66
mod compare;
77
mod filter;
88
mod mask;
9-
mod nan_count;
9+
pub(crate) mod nan_count;
1010
mod slice;
1111
mod take;
1212

encodings/alp/src/alp/compute/nan_count.rs

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,45 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_array::compute::NaNCountKernel;
5-
use vortex_array::compute::NaNCountKernelAdapter;
6-
use vortex_array::compute::nan_count;
7-
use vortex_array::register_kernel;
4+
use vortex_array::ArrayRef;
5+
use vortex_array::ExecutionCtx;
6+
use vortex_array::aggregate_fn::AggregateFnRef;
7+
use vortex_array::aggregate_fn::fns::nan_count::NanCount;
8+
use vortex_array::aggregate_fn::fns::nan_count::nan_count;
9+
use vortex_array::aggregate_fn::kernels::DynAggregateKernel;
10+
use vortex_array::scalar::Scalar;
811
use vortex_error::VortexResult;
912

1013
use crate::ALP;
11-
use crate::ALPArray;
1214

13-
impl NaNCountKernel for ALP {
14-
fn nan_count(&self, array: &ALPArray) -> VortexResult<usize> {
15-
// NANs can only be in patches
16-
if let Some(patches) = array.patches() {
17-
nan_count(patches.values())
18-
} else {
19-
Ok(0)
15+
/// ALP-specific NaN count kernel.
16+
///
17+
/// NaN values can only appear in the patches array of an ALP-encoded array, since the encoded
18+
/// integer values cannot represent NaN. This avoids decoding the entire array.
19+
#[derive(Debug)]
20+
pub(crate) struct ALPNanCountKernel;
21+
22+
impl DynAggregateKernel for ALPNanCountKernel {
23+
fn aggregate(
24+
&self,
25+
aggregate_fn: &AggregateFnRef,
26+
batch: &ArrayRef,
27+
ctx: &mut ExecutionCtx,
28+
) -> VortexResult<Option<Scalar>> {
29+
if !aggregate_fn.is::<NanCount>() {
30+
return Ok(None);
2031
}
32+
33+
let Some(alp) = batch.as_opt::<ALP>() else {
34+
return Ok(None);
35+
};
36+
37+
let count = if let Some(patches) = alp.patches() {
38+
nan_count(patches.values(), ctx)?
39+
} else {
40+
0
41+
};
42+
43+
Ok(Some(Scalar::from(count as u64)))
2144
}
2245
}
23-
24-
register_kernel!(NaNCountKernelAdapter(ALP).lift());

encodings/alp/src/alp/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use num_traits::ToPrimitive;
1515

1616
mod array;
1717
mod compress;
18-
mod compute;
18+
pub(crate) mod compute;
1919
mod decompress;
2020
mod ops;
2121
mod rules;

encodings/alp/src/lib.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,24 @@
1818
1919
pub use alp::*;
2020
pub use alp_rd::*;
21+
use vortex_array::aggregate_fn::AggregateFnVTable;
22+
use vortex_array::aggregate_fn::fns::nan_count::NanCount;
23+
use vortex_array::aggregate_fn::session::AggregateFnSessionExt;
24+
use vortex_array::session::ArraySessionExt;
25+
use vortex_session::VortexSession;
2126

2227
mod alp;
2328
mod alp_rd;
29+
30+
/// Initialize ALP encoding in the given session.
31+
pub fn initialize(session: &mut VortexSession) {
32+
session.arrays().register(ALP::ID, ALP);
33+
session.arrays().register(ALPRD::ID, ALPRD);
34+
35+
// Register the ALP-specific NaN count aggregate kernel.
36+
session.aggregate_fns().register_aggregate_kernel(
37+
ALP::ID,
38+
Some(NanCount.id()),
39+
&compute::nan_count::ALPNanCountKernel,
40+
);
41+
}

vortex-array/public-api.lock

Lines changed: 78 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,52 @@ pub mod vortex_array::aggregate_fn
3030

3131
pub mod vortex_array::aggregate_fn::fns
3232

33+
pub mod vortex_array::aggregate_fn::fns::nan_count
34+
35+
pub struct vortex_array::aggregate_fn::fns::nan_count::NanCount
36+
37+
impl core::clone::Clone for vortex_array::aggregate_fn::fns::nan_count::NanCount
38+
39+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::clone(&self) -> vortex_array::aggregate_fn::fns::nan_count::NanCount
40+
41+
impl core::fmt::Debug for vortex_array::aggregate_fn::fns::nan_count::NanCount
42+
43+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
44+
45+
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::nan_count::NanCount
46+
47+
pub type vortex_array::aggregate_fn::fns::nan_count::NanCount::Options = vortex_array::aggregate_fn::EmptyOptions
48+
49+
pub type vortex_array::aggregate_fn::fns::nan_count::NanCount::Partial = u64
50+
51+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
52+
53+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
54+
55+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
56+
57+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
58+
59+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::empty_partial(&self, _options: &Self::Options, _input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
60+
61+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
62+
63+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::finalize_scalar(&self, partial: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
64+
65+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::flush(&self, partial: &mut Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
66+
67+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
68+
69+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::is_saturated(&self, _partial: &Self::Partial) -> bool
70+
71+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
72+
73+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
74+
75+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
76+
77+
pub fn vortex_array::aggregate_fn::fns::nan_count::nan_count(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<usize>
78+
3379
pub mod vortex_array::aggregate_fn::fns::sum
3480

3581
pub enum vortex_array::aggregate_fn::fns::sum::SumState
@@ -318,6 +364,38 @@ pub fn vortex_array::aggregate_fn::AggregateFnVTable::return_dtype(&self, option
318364

319365
pub fn vortex_array::aggregate_fn::AggregateFnVTable::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
320366

367+
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::nan_count::NanCount
368+
369+
pub type vortex_array::aggregate_fn::fns::nan_count::NanCount::Options = vortex_array::aggregate_fn::EmptyOptions
370+
371+
pub type vortex_array::aggregate_fn::fns::nan_count::NanCount::Partial = u64
372+
373+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
374+
375+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
376+
377+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
378+
379+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
380+
381+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::empty_partial(&self, _options: &Self::Options, _input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
382+
383+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
384+
385+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::finalize_scalar(&self, partial: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
386+
387+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::flush(&self, partial: &mut Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
388+
389+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
390+
391+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::is_saturated(&self, _partial: &Self::Partial) -> bool
392+
393+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
394+
395+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
396+
397+
pub fn vortex_array::aggregate_fn::fns::nan_count::NanCount::serialize(&self, options: &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
398+
321399
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::sum::Sum
322400

323401
pub type vortex_array::aggregate_fn::fns::sum::Sum::Options = vortex_array::aggregate_fn::EmptyOptions
@@ -2954,10 +3032,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Primitive
29543032

29553033
pub fn vortex_array::arrays::Primitive::min_max(&self, array: &vortex_array::arrays::PrimitiveArray) -> vortex_error::VortexResult<core::option::Option<vortex_array::compute::MinMaxResult>>
29563034

2957-
impl vortex_array::compute::NaNCountKernel for vortex_array::arrays::Primitive
2958-
2959-
pub fn vortex_array::arrays::Primitive::nan_count(&self, array: &vortex_array::arrays::PrimitiveArray) -> vortex_error::VortexResult<usize>
2960-
29613035
impl vortex_array::optimizer::rules::ArrayParentReduceRule<vortex_array::arrays::Primitive> for vortex_array::arrays::primitive::PrimitiveMaskedValidityRule
29623036

29633037
pub type vortex_array::arrays::primitive::PrimitiveMaskedValidityRule::Parent = vortex_array::arrays::Masked
@@ -6542,10 +6616,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::Primitive
65426616

65436617
pub fn vortex_array::arrays::Primitive::min_max(&self, array: &vortex_array::arrays::PrimitiveArray) -> vortex_error::VortexResult<core::option::Option<vortex_array::compute::MinMaxResult>>
65446618

6545-
impl vortex_array::compute::NaNCountKernel for vortex_array::arrays::Primitive
6546-
6547-
pub fn vortex_array::arrays::Primitive::nan_count(&self, array: &vortex_array::arrays::PrimitiveArray) -> vortex_error::VortexResult<usize>
6548-
65496619
impl vortex_array::optimizer::rules::ArrayParentReduceRule<vortex_array::arrays::Primitive> for vortex_array::arrays::primitive::PrimitiveMaskedValidityRule
65506620

65516621
pub type vortex_array::arrays::primitive::PrimitiveMaskedValidityRule::Parent = vortex_array::arrays::Masked
@@ -9634,24 +9704,6 @@ pub fn vortex_array::compute::MinMaxResult::fmt(&self, f: &mut core::fmt::Format
96349704

96359705
impl core::marker::StructuralPartialEq for vortex_array::compute::MinMaxResult
96369706

9637-
pub struct vortex_array::compute::NaNCountKernelAdapter<V: vortex_array::vtable::VTable>(pub V)
9638-
9639-
impl<V: vortex_array::vtable::VTable + vortex_array::compute::NaNCountKernel> vortex_array::compute::NaNCountKernelAdapter<V>
9640-
9641-
pub const fn vortex_array::compute::NaNCountKernelAdapter<V>::lift(&'static self) -> vortex_array::compute::NaNCountKernelRef
9642-
9643-
impl<V: core::fmt::Debug + vortex_array::vtable::VTable> core::fmt::Debug for vortex_array::compute::NaNCountKernelAdapter<V>
9644-
9645-
pub fn vortex_array::compute::NaNCountKernelAdapter<V>::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result
9646-
9647-
impl<V: vortex_array::vtable::VTable + vortex_array::compute::NaNCountKernel> vortex_array::compute::Kernel for vortex_array::compute::NaNCountKernelAdapter<V>
9648-
9649-
pub fn vortex_array::compute::NaNCountKernelAdapter<V>::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult<core::option::Option<vortex_array::compute::Output>>
9650-
9651-
pub struct vortex_array::compute::NaNCountKernelRef(_)
9652-
9653-
impl inventory::Collect for vortex_array::compute::NaNCountKernelRef
9654-
96559707
pub struct vortex_array::compute::UnaryArgs<'a, O: vortex_array::compute::Options>
96569708

96579709
pub vortex_array::compute::UnaryArgs::array: &'a dyn vortex_array::DynArray
@@ -9832,10 +9884,6 @@ impl<V: vortex_array::vtable::VTable + vortex_array::compute::MinMaxKernel> vort
98329884

98339885
pub fn vortex_array::compute::MinMaxKernelAdapter<V>::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult<core::option::Option<vortex_array::compute::Output>>
98349886

9835-
impl<V: vortex_array::vtable::VTable + vortex_array::compute::NaNCountKernel> vortex_array::compute::Kernel for vortex_array::compute::NaNCountKernelAdapter<V>
9836-
9837-
pub fn vortex_array::compute::NaNCountKernelAdapter<V>::invoke(&self, args: &vortex_array::compute::InvocationArgs<'_>) -> vortex_error::VortexResult<core::option::Option<vortex_array::compute::Output>>
9838-
98399887
pub trait vortex_array::compute::MinMaxKernel: vortex_array::vtable::VTable
98409888

98419889
pub fn vortex_array::compute::MinMaxKernel::min_max(&self, array: &Self::Array) -> vortex_error::VortexResult<core::option::Option<vortex_array::compute::MinMaxResult>>
@@ -9896,14 +9944,6 @@ impl vortex_array::compute::MinMaxKernel for vortex_array::arrays::null::Null
98969944

98979945
pub fn vortex_array::arrays::null::Null::min_max(&self, _array: &vortex_array::arrays::null::NullArray) -> vortex_error::VortexResult<core::option::Option<vortex_array::compute::MinMaxResult>>
98989946

9899-
pub trait vortex_array::compute::NaNCountKernel: vortex_array::vtable::VTable
9900-
9901-
pub fn vortex_array::compute::NaNCountKernel::nan_count(&self, array: &Self::Array) -> vortex_error::VortexResult<usize>
9902-
9903-
impl vortex_array::compute::NaNCountKernel for vortex_array::arrays::Primitive
9904-
9905-
pub fn vortex_array::arrays::Primitive::nan_count(&self, array: &vortex_array::arrays::PrimitiveArray) -> vortex_error::VortexResult<usize>
9906-
99079947
pub trait vortex_array::compute::Options: 'static
99089948

99099949
pub fn vortex_array::compute::Options::as_any(&self) -> &dyn core::any::Any
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
pub mod nan_count;
45
pub mod sum;

0 commit comments

Comments
 (0)