Skip to content

Commit a824faa

Browse files
gatesndimitarvdimitrov
authored andcommitted
IsConstant Aggregate Function (#7028)
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent e1f2e3c commit a824faa

67 files changed

Lines changed: 1294 additions & 1389 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

encodings/datetime-parts/public-api.lock

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_datetime_parts::DateTim
2222

2323
pub fn vortex_datetime_parts::DateTimeParts::slice(array: &Self::Array, range: core::ops::range::Range<usize>) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
2424

25-
impl vortex_array::compute::is_constant::IsConstantKernel for vortex_datetime_parts::DateTimeParts
26-
27-
pub fn vortex_datetime_parts::DateTimeParts::is_constant(&self, array: &vortex_datetime_parts::DateTimePartsArray, opts: &vortex_array::compute::is_constant::IsConstantOpts) -> vortex_error::VortexResult<core::option::Option<bool>>
28-
2925
impl vortex_array::scalar_fn::fns::binary::compare::CompareKernel for vortex_datetime_parts::DateTimeParts
3026

3127
pub fn vortex_datetime_parts::DateTimeParts::compare(lhs: &vortex_datetime_parts::DateTimePartsArray, rhs: &vortex_array::array::ArrayRef, operator: vortex_array::scalar_fn::fns::operators::CompareOperator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
@@ -220,4 +216,6 @@ pub vortex_datetime_parts::TemporalParts::seconds: vortex_array::array::ArrayRef
220216

221217
pub vortex_datetime_parts::TemporalParts::subseconds: vortex_array::array::ArrayRef
222218

219+
pub fn vortex_datetime_parts::initialize(session: &mut vortex_session::VortexSession)
220+
223221
pub fn vortex_datetime_parts::split_temporal(array: vortex_array::arrays::datetime::TemporalArray) -> vortex_error::VortexResult<vortex_datetime_parts::TemporalParts>
Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,41 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_array::compute::IsConstantKernel;
5-
use vortex_array::compute::IsConstantKernelAdapter;
6-
use vortex_array::compute::IsConstantOpts;
7-
use vortex_array::compute::is_constant_opts;
8-
use vortex_array::register_kernel;
4+
use vortex_array::ArrayRef;
5+
use vortex_array::ExecutionCtx;
6+
use vortex_array::aggregate_fn::AggregateFnRef;
7+
use vortex_array::aggregate_fn::fns::is_constant::IsConstant;
8+
use vortex_array::aggregate_fn::fns::is_constant::is_constant;
9+
use vortex_array::aggregate_fn::kernels::DynAggregateKernel;
10+
use vortex_array::scalar::Scalar;
911
use vortex_error::VortexResult;
1012

1113
use crate::DateTimeParts;
12-
use crate::DateTimePartsArray;
1314

14-
impl IsConstantKernel for DateTimeParts {
15-
fn is_constant(
16-
&self,
17-
array: &DateTimePartsArray,
18-
opts: &IsConstantOpts,
19-
) -> VortexResult<Option<bool>> {
20-
let Some(days) = is_constant_opts(array.days(), opts)? else {
21-
return Ok(None);
22-
};
23-
if !days {
24-
return Ok(Some(false));
25-
}
15+
/// DateTimeParts-specific is_constant kernel.
16+
///
17+
/// Checks each component (days, seconds, subseconds) individually.
18+
#[derive(Debug)]
19+
pub(crate) struct DateTimePartsIsConstantKernel;
2620

27-
let Some(seconds) = is_constant_opts(array.seconds(), opts)? else {
21+
impl DynAggregateKernel for DateTimePartsIsConstantKernel {
22+
fn aggregate(
23+
&self,
24+
aggregate_fn: &AggregateFnRef,
25+
batch: &ArrayRef,
26+
ctx: &mut ExecutionCtx,
27+
) -> VortexResult<Option<Scalar>> {
28+
if !aggregate_fn.is::<IsConstant>() {
2829
return Ok(None);
29-
};
30-
if !seconds {
31-
return Ok(Some(false));
3230
}
3331

34-
let Some(subseconds) = is_constant_opts(array.subseconds(), opts)? else {
32+
let Some(array) = batch.as_opt::<DateTimeParts>() else {
3533
return Ok(None);
3634
};
37-
if !subseconds {
38-
return Ok(Some(false));
39-
}
4035

41-
Ok(Some(true))
36+
let result = is_constant(array.days(), ctx)?
37+
&& is_constant(array.seconds(), ctx)?
38+
&& is_constant(array.subseconds(), ctx)?;
39+
Ok(Some(IsConstant::make_partial(batch, result)?))
4240
}
4341
}
44-
45-
register_kernel!(IsConstantKernelAdapter(DateTimeParts).lift());

encodings/datetime-parts/src/compute/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
mod cast;
55
mod compare;
66
mod filter;
7-
mod is_constant;
7+
pub(crate) mod is_constant;
88
pub(crate) mod kernel;
99
mod mask;
1010
pub(super) mod rules;

encodings/datetime-parts/src/lib.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,23 @@ mod compute;
1111
mod ops;
1212
mod timestamp;
1313

14+
use vortex_array::aggregate_fn::AggregateFnVTable;
15+
use vortex_array::aggregate_fn::fns::is_constant::IsConstant;
16+
use vortex_array::aggregate_fn::session::AggregateFnSessionExt;
17+
use vortex_array::session::ArraySessionExt;
18+
use vortex_session::VortexSession;
19+
20+
/// Initialize datetime-parts encoding in the given session.
21+
pub fn initialize(session: &mut VortexSession) {
22+
session.arrays().register(DateTimeParts::ID, DateTimeParts);
23+
24+
session.aggregate_fns().register_aggregate_kernel(
25+
DateTimeParts::ID,
26+
Some(IsConstant.id()),
27+
&compute::is_constant::DateTimePartsIsConstantKernel,
28+
);
29+
}
30+
1431
#[cfg(test)]
1532
mod test {
1633
use vortex_array::ProstMetadata;

encodings/decimal-byte-parts/public-api.lock

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_decimal_byte_parts::Dec
2222

2323
pub fn vortex_decimal_byte_parts::DecimalByteParts::slice(array: &vortex_decimal_byte_parts::DecimalBytePartsArray, range: core::ops::range::Range<usize>) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
2424

25-
impl vortex_array::compute::is_constant::IsConstantKernel for vortex_decimal_byte_parts::DecimalByteParts
26-
27-
pub fn vortex_decimal_byte_parts::DecimalByteParts::is_constant(&self, array: &vortex_decimal_byte_parts::DecimalBytePartsArray, opts: &vortex_array::compute::is_constant::IsConstantOpts) -> vortex_error::VortexResult<core::option::Option<bool>>
28-
2925
impl vortex_array::scalar_fn::fns::binary::compare::CompareKernel for vortex_decimal_byte_parts::DecimalByteParts
3026

3127
pub fn vortex_decimal_byte_parts::DecimalByteParts::compare(lhs: &Self::Array, rhs: &vortex_array::array::ArrayRef, operator: vortex_array::scalar_fn::fns::operators::CompareOperator, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
@@ -167,3 +163,5 @@ impl prost::message::Message for vortex_decimal_byte_parts::DecimalBytesPartsMet
167163
pub fn vortex_decimal_byte_parts::DecimalBytesPartsMetadata::clear(&mut self)
168164

169165
pub fn vortex_decimal_byte_parts::DecimalBytesPartsMetadata::encoded_len(&self) -> usize
166+
167+
pub fn vortex_decimal_byte_parts::initialize(session: &mut vortex_session::VortexSession)

encodings/decimal-byte-parts/src/decimal_byte_parts/compute/is_constant.rs

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,39 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
use vortex_array::compute::IsConstantKernel;
5-
use vortex_array::compute::IsConstantKernelAdapter;
6-
use vortex_array::compute::IsConstantOpts;
7-
use vortex_array::compute::is_constant_opts;
8-
use vortex_array::register_kernel;
4+
use vortex_array::ArrayRef;
5+
use vortex_array::ExecutionCtx;
6+
use vortex_array::aggregate_fn::AggregateFnRef;
7+
use vortex_array::aggregate_fn::fns::is_constant::IsConstant;
8+
use vortex_array::aggregate_fn::fns::is_constant::is_constant;
9+
use vortex_array::aggregate_fn::kernels::DynAggregateKernel;
10+
use vortex_array::scalar::Scalar;
911
use vortex_error::VortexResult;
1012

1113
use crate::DecimalByteParts;
12-
use crate::DecimalBytePartsArray;
1314

14-
impl IsConstantKernel for DecimalByteParts {
15-
fn is_constant(
15+
/// DecimalByteParts-specific is_constant kernel.
16+
///
17+
/// Delegates to checking if the MSP (most significant part) is constant.
18+
#[derive(Debug)]
19+
pub(crate) struct DecimalBytePartsIsConstantKernel;
20+
21+
impl DynAggregateKernel for DecimalBytePartsIsConstantKernel {
22+
fn aggregate(
1623
&self,
17-
array: &DecimalBytePartsArray,
18-
opts: &IsConstantOpts,
19-
) -> VortexResult<Option<bool>> {
20-
is_constant_opts(&array.msp, opts)
24+
aggregate_fn: &AggregateFnRef,
25+
batch: &ArrayRef,
26+
ctx: &mut ExecutionCtx,
27+
) -> VortexResult<Option<Scalar>> {
28+
if !aggregate_fn.is::<IsConstant>() {
29+
return Ok(None);
30+
}
31+
32+
let Some(array) = batch.as_opt::<DecimalByteParts>() else {
33+
return Ok(None);
34+
};
35+
36+
let result = is_constant(array.msp(), ctx)?;
37+
Ok(Some(IsConstant::make_partial(batch, result)?))
2138
}
2239
}
23-
24-
register_kernel!(IsConstantKernelAdapter(DecimalByteParts).lift());

encodings/decimal-byte-parts/src/decimal_byte_parts/compute/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
mod cast;
55
mod compare;
66
mod filter;
7-
mod is_constant;
7+
pub(crate) mod is_constant;
88
pub(crate) mod kernel;
99
mod mask;
1010
mod take;

encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4-
mod compute;
4+
pub(crate) mod compute;
55
mod rules;
66
mod slice;
77

encodings/decimal-byte-parts/src/lib.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
mod decimal_byte_parts;
55

6+
use decimal_byte_parts::compute::is_constant::DecimalBytePartsIsConstantKernel;
67
/// This encoding allow compression of decimals using integer compression schemes.
78
/// Decimals can be compressed by narrowing the signed decimal value into the smallest signed value,
89
/// then integer compression if that is a value `ptype`, otherwise the decimal can be split into
@@ -12,3 +13,21 @@ mod decimal_byte_parts;
1213
/// an i128 decimal could be converted into a [i64, u64] with further narrowing applied to either
1314
/// value.
1415
pub use decimal_byte_parts::*;
16+
use vortex_array::aggregate_fn::AggregateFnVTable;
17+
use vortex_array::aggregate_fn::fns::is_constant::IsConstant;
18+
use vortex_array::aggregate_fn::session::AggregateFnSessionExt;
19+
use vortex_array::session::ArraySessionExt;
20+
use vortex_session::VortexSession;
21+
22+
/// Initialize decimal-byte-parts encoding in the given session.
23+
pub fn initialize(session: &mut VortexSession) {
24+
session
25+
.arrays()
26+
.register(DecimalByteParts::ID, DecimalByteParts);
27+
28+
session.aggregate_fns().register_aggregate_kernel(
29+
DecimalByteParts::ID,
30+
Some(IsConstant.id()),
31+
&DecimalBytePartsIsConstantKernel,
32+
);
33+
}

encodings/fastlanes/public-api.lock

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,6 @@ impl vortex_array::arrays::slice::SliceKernel for vortex_fastlanes::BitPacked
134134

135135
pub fn vortex_fastlanes::BitPacked::slice(array: &vortex_fastlanes::BitPackedArray, range: core::ops::range::Range<usize>, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
136136

137-
impl vortex_array::compute::is_constant::IsConstantKernel for vortex_fastlanes::BitPacked
138-
139-
pub fn vortex_fastlanes::BitPacked::is_constant(&self, array: &vortex_fastlanes::BitPackedArray, opts: &vortex_array::compute::is_constant::IsConstantOpts) -> vortex_error::VortexResult<core::option::Option<bool>>
140-
141137
impl vortex_array::scalar_fn::fns::cast::kernel::CastReduce for vortex_fastlanes::BitPacked
142138

143139
pub fn vortex_fastlanes::BitPacked::cast(array: &vortex_fastlanes::BitPackedArray, dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
@@ -424,10 +420,6 @@ impl vortex_array::arrays::slice::SliceReduce for vortex_fastlanes::FoR
424420

425421
pub fn vortex_fastlanes::FoR::slice(array: &Self::Array, range: core::ops::range::Range<usize>) -> vortex_error::VortexResult<core::option::Option<vortex_array::array::ArrayRef>>
426422

427-
impl vortex_array::compute::is_constant::IsConstantKernel for vortex_fastlanes::FoR
428-
429-
pub fn vortex_fastlanes::FoR::is_constant(&self, array: &vortex_fastlanes::FoRArray, opts: &vortex_array::compute::is_constant::IsConstantOpts) -> vortex_error::VortexResult<core::option::Option<bool>>
430-
431423
impl vortex_array::compute::is_sorted::IsSortedKernel for vortex_fastlanes::FoR
432424

433425
pub fn vortex_fastlanes::FoR::is_sorted(&self, array: &vortex_fastlanes::FoRArray) -> vortex_error::VortexResult<core::option::Option<bool>>
@@ -683,3 +675,5 @@ impl vortex_array::vtable::validity::ValidityChildSliceHelper for vortex_fastlan
683675
pub fn vortex_fastlanes::RLEArray::unsliced_child_and_slice(&self) -> (&vortex_array::array::ArrayRef, usize, usize)
684676

685677
pub fn vortex_fastlanes::delta_compress(array: &vortex_array::arrays::primitive::array::PrimitiveArray) -> vortex_error::VortexResult<(vortex_array::arrays::primitive::array::PrimitiveArray, vortex_array::arrays::primitive::array::PrimitiveArray)>
678+
679+
pub fn vortex_fastlanes::initialize(session: &mut vortex_session::VortexSession)

0 commit comments

Comments
 (0)