|
1 | 1 | // SPDX-License-Identifier: Apache-2.0 |
2 | 2 | // SPDX-FileCopyrightText: Copyright the Vortex contributors |
3 | 3 |
|
4 | | -use vortex_array::compute::MinMaxKernel; |
5 | | -use vortex_array::compute::MinMaxKernelAdapter; |
6 | | -use vortex_array::compute::MinMaxResult; |
7 | | -use vortex_array::dtype::Nullability::NonNullable; |
8 | | -use vortex_array::register_kernel; |
| 4 | +use vortex_array::ArrayRef; |
| 5 | +use vortex_array::ExecutionCtx; |
| 6 | +use vortex_array::aggregate_fn::AggregateFnRef; |
| 7 | +use vortex_array::aggregate_fn::fns::min_max::MinMax; |
| 8 | +use vortex_array::aggregate_fn::fns::min_max::make_minmax_dtype; |
| 9 | +use vortex_array::aggregate_fn::kernels::DynAggregateKernel; |
| 10 | +use vortex_array::dtype::DType; |
| 11 | +use vortex_array::dtype::Nullability; |
| 12 | +use vortex_array::match_each_pvalue; |
| 13 | +use vortex_array::scalar::PValue; |
9 | 14 | use vortex_array::scalar::Scalar; |
| 15 | +use vortex_array::scalar::ScalarValue; |
10 | 16 | use vortex_error::VortexResult; |
11 | 17 |
|
12 | | -use crate::SequenceArray; |
13 | | -use crate::array::Sequence; |
14 | | - |
15 | | -impl MinMaxKernel for Sequence { |
16 | | - fn min_max(&self, array: &SequenceArray) -> VortexResult<Option<MinMaxResult>> { |
17 | | - let base = array.base(); |
18 | | - let last = array.last(); |
19 | | - let (min, max) = if base < last { |
20 | | - (base, last) |
21 | | - } else { |
22 | | - (last, base) |
| 18 | +use crate::Sequence; |
| 19 | + |
| 20 | +/// Sequence-specific min/max kernel. |
| 21 | +/// |
| 22 | +/// A sequence array represents `A[i] = base + i * multiplier`, so min/max can be computed |
| 23 | +/// algebraically from `base` and `last` based on the sign of the multiplier. |
| 24 | +#[derive(Debug)] |
| 25 | +pub(crate) struct SequenceMinMaxKernel; |
| 26 | + |
| 27 | +impl DynAggregateKernel for SequenceMinMaxKernel { |
| 28 | + fn aggregate( |
| 29 | + &self, |
| 30 | + aggregate_fn: &AggregateFnRef, |
| 31 | + batch: &ArrayRef, |
| 32 | + _ctx: &mut ExecutionCtx, |
| 33 | + ) -> VortexResult<Option<Scalar>> { |
| 34 | + if !aggregate_fn.is::<MinMax>() { |
| 35 | + return Ok(None); |
| 36 | + } |
| 37 | + |
| 38 | + let Some(seq) = batch.as_opt::<Sequence>() else { |
| 39 | + return Ok(None); |
23 | 40 | }; |
24 | | - Ok(Some(MinMaxResult { |
25 | | - min: Scalar::primitive_value(min, array.ptype(), NonNullable), |
26 | | - max: Scalar::primitive_value(max, array.ptype(), NonNullable), |
27 | | - })) |
| 41 | + |
| 42 | + let struct_dtype = make_minmax_dtype(batch.dtype()); |
| 43 | + |
| 44 | + // Empty sequences shouldn't exist (try_new validates length), but handle gracefully. |
| 45 | + if seq.is_empty() { |
| 46 | + return Ok(Some(Scalar::null(struct_dtype))); |
| 47 | + } |
| 48 | + |
| 49 | + let base = seq.base(); |
| 50 | + let last = seq.last(); |
| 51 | + |
| 52 | + // Determine min and max based on multiplier direction. |
| 53 | + // For unsigned types, multiplier is always >= 0. |
| 54 | + let (min_pvalue, max_pvalue) = match_each_pvalue!( |
| 55 | + seq.multiplier(), |
| 56 | + uint: |_v| { (base, last) }, |
| 57 | + int: |v| { |
| 58 | + if v >= 0 { |
| 59 | + (base, last) |
| 60 | + } else { |
| 61 | + (last, base) |
| 62 | + } |
| 63 | + }, |
| 64 | + float: |_v| { unreachable!("float multiplier not supported for SequenceArray") } |
| 65 | + ); |
| 66 | + |
| 67 | + let non_nullable_dtype = DType::Primitive(seq.ptype(), Nullability::NonNullable); |
| 68 | + let min_scalar = Scalar::try_new( |
| 69 | + non_nullable_dtype.clone(), |
| 70 | + Some(ScalarValue::Primitive(min_pvalue)), |
| 71 | + )?; |
| 72 | + let max_scalar = |
| 73 | + Scalar::try_new(non_nullable_dtype, Some(ScalarValue::Primitive(max_pvalue)))?; |
| 74 | + |
| 75 | + Ok(Some(Scalar::struct_( |
| 76 | + struct_dtype, |
| 77 | + vec![min_scalar, max_scalar], |
| 78 | + ))) |
28 | 79 | } |
29 | 80 | } |
30 | | - |
31 | | -register_kernel!(MinMaxKernelAdapter(Sequence).lift()); |
|
0 commit comments