diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 84b3fb9513d..11a6116e6f9 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -78,6 +78,58 @@ pub fn vortex_array::aggregate_fn::fns::count::Count::to_scalar(&self, partial: pub fn vortex_array::aggregate_fn::fns::count::Count::try_accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +pub mod vortex_array::aggregate_fn::fns::first + +pub struct vortex_array::aggregate_fn::fns::first::First + +impl core::clone::Clone for vortex_array::aggregate_fn::fns::first::First + +pub fn vortex_array::aggregate_fn::fns::first::First::clone(&self) -> vortex_array::aggregate_fn::fns::first::First + +impl core::fmt::Debug for vortex_array::aggregate_fn::fns::first::First + +pub fn vortex_array::aggregate_fn::fns::first::First::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::first::First + +pub type vortex_array::aggregate_fn::fns::first::First::Options = vortex_array::aggregate_fn::EmptyOptions + +pub type vortex_array::aggregate_fn::fns::first::First::Partial = vortex_array::aggregate_fn::fns::first::FirstPartial + +pub fn vortex_array::aggregate_fn::fns::first::First::accumulate(&self, _partial: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::first::First::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::first::First::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::first::First::is_saturated(&self, partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::fns::first::First::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::first::First::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::fns::first::First::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::first::First::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::first::First::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::try_accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub struct vortex_array::aggregate_fn::fns::first::FirstPartial + +pub fn vortex_array::aggregate_fn::fns::first::first(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub mod vortex_array::aggregate_fn::fns::is_constant pub mod vortex_array::aggregate_fn::fns::is_constant::primitive @@ -238,6 +290,58 @@ pub fn vortex_array::aggregate_fn::fns::is_sorted::is_strict_sorted(array: &vort pub fn vortex_array::aggregate_fn::fns::is_sorted::make_is_sorted_partial_dtype(element_dtype: &vortex_array::dtype::DType) -> vortex_array::dtype::DType +pub mod vortex_array::aggregate_fn::fns::last + +pub struct vortex_array::aggregate_fn::fns::last::Last + +impl core::clone::Clone for vortex_array::aggregate_fn::fns::last::Last + +pub fn vortex_array::aggregate_fn::fns::last::Last::clone(&self) -> vortex_array::aggregate_fn::fns::last::Last + +impl core::fmt::Debug for vortex_array::aggregate_fn::fns::last::Last + +pub fn vortex_array::aggregate_fn::fns::last::Last::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::last::Last + +pub type vortex_array::aggregate_fn::fns::last::Last::Options = vortex_array::aggregate_fn::EmptyOptions + +pub type vortex_array::aggregate_fn::fns::last::Last::Partial = vortex_array::aggregate_fn::fns::last::LastPartial + +pub fn vortex_array::aggregate_fn::fns::last::Last::accumulate(&self, _partial: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::last::Last::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::last::Last::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::last::Last::is_saturated(&self, _partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::fns::last::Last::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::last::Last::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::fns::last::Last::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::last::Last::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::last::Last::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::try_accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + +pub struct vortex_array::aggregate_fn::fns::last::LastPartial + +pub fn vortex_array::aggregate_fn::fns::last::last(array: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + pub mod vortex_array::aggregate_fn::fns::min_max pub struct vortex_array::aggregate_fn::fns::min_max::MinMax @@ -712,6 +816,42 @@ pub fn vortex_array::aggregate_fn::fns::count::Count::to_scalar(&self, partial: pub fn vortex_array::aggregate_fn::fns::count::Count::try_accumulate(&self, state: &mut Self::Partial, batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::first::First + +pub type vortex_array::aggregate_fn::fns::first::First::Options = vortex_array::aggregate_fn::EmptyOptions + +pub type vortex_array::aggregate_fn::fns::first::First::Partial = vortex_array::aggregate_fn::fns::first::FirstPartial + +pub fn vortex_array::aggregate_fn::fns::first::First::accumulate(&self, _partial: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::first::First::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::first::First::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::first::First::is_saturated(&self, partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::fns::first::First::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::first::First::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::fns::first::First::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::first::First::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::first::First::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::first::First::try_accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::is_constant::IsConstant pub type vortex_array::aggregate_fn::fns::is_constant::IsConstant::Options = vortex_array::aggregate_fn::EmptyOptions @@ -784,6 +924,42 @@ pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::to_scalar(&self, pa pub fn vortex_array::aggregate_fn::fns::is_sorted::IsSorted::try_accumulate(&self, _state: &mut Self::Partial, _batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult +impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::last::Last + +pub type vortex_array::aggregate_fn::fns::last::Last::Options = vortex_array::aggregate_fn::EmptyOptions + +pub type vortex_array::aggregate_fn::fns::last::Last::Partial = vortex_array::aggregate_fn::fns::last::LastPartial + +pub fn vortex_array::aggregate_fn::fns::last::Last::accumulate(&self, _partial: &mut Self::Partial, _batch: &vortex_array::Columnar, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::last::Last::coerce_args(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::combine_partials(&self, partial: &mut Self::Partial, other: vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()> + +pub fn vortex_array::aggregate_fn::fns::last::Last::deserialize(&self, _metadata: &[u8], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::empty_partial(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::finalize(&self, partials: vortex_array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::finalize_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::id(&self) -> vortex_array::aggregate_fn::AggregateFnId + +pub fn vortex_array::aggregate_fn::fns::last::Last::is_saturated(&self, _partial: &Self::Partial) -> bool + +pub fn vortex_array::aggregate_fn::fns::last::Last::partial_dtype(&self, options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::last::Last::reset(&self, partial: &mut Self::Partial) + +pub fn vortex_array::aggregate_fn::fns::last::Last::return_dtype(&self, _options: &Self::Options, input_dtype: &vortex_array::dtype::DType) -> core::option::Option + +pub fn vortex_array::aggregate_fn::fns::last::Last::serialize(&self, _options: &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_array::aggregate_fn::fns::last::Last::to_scalar(&self, partial: &Self::Partial) -> vortex_error::VortexResult + +pub fn vortex_array::aggregate_fn::fns::last::Last::try_accumulate(&self, partial: &mut Self::Partial, batch: &vortex_array::ArrayRef, _ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult + impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::min_max::MinMax pub type vortex_array::aggregate_fn::fns::min_max::MinMax::Options = vortex_array::aggregate_fn::EmptyOptions diff --git a/vortex-array/src/aggregate_fn/fns/first/mod.rs b/vortex-array/src/aggregate_fn/fns/first/mod.rs new file mode 100644 index 00000000000..8ff71191dcd --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/first/mod.rs @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::Columnar; +use crate::ExecutionCtx; +use crate::aggregate_fn::Accumulator; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::EmptyOptions; +use crate::dtype::DType; +use crate::scalar::Scalar; + +/// Return the first non-null value of an array. +/// +/// See [`First`] for details. +pub fn first(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let mut acc = Accumulator::try_new(First, EmptyOptions, array.dtype().clone())?; + acc.accumulate(array, ctx)?; + acc.finish() +} + +/// Return the first non-null value seen across all batches. +#[derive(Clone, Debug)] +pub struct First; + +/// Partial accumulator state for the [`First`] aggregate. +pub struct FirstPartial { + /// The nullable version of the input dtype, used for the result and for empty/all-null inputs. + return_dtype: DType, + /// The first non-null value seen so far, or `None` if no non-null value has been observed. + value: Option, +} + +impl AggregateFnVTable for First { + type Options = EmptyOptions; + type Partial = FirstPartial; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new_ref("vortex.first") + } + + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + &self, + _metadata: &[u8], + _session: &vortex_session::VortexSession, + ) -> VortexResult { + Ok(EmptyOptions) + } + + fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + Some(input_dtype.as_nullable()) + } + + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { + self.return_dtype(options, input_dtype) + } + + fn empty_partial( + &self, + _options: &Self::Options, + input_dtype: &DType, + ) -> VortexResult { + Ok(FirstPartial { + return_dtype: input_dtype.as_nullable(), + value: None, + }) + } + + fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { + // Only the first non-null partial wins; later ones are ignored. + if partial.value.is_none() && !other.is_null() { + partial.value = Some(other); + } + Ok(()) + } + + fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { + Ok(match &partial.value { + Some(v) => v.clone(), + None => Scalar::null(partial.return_dtype.clone()), + }) + } + + fn reset(&self, partial: &mut Self::Partial) { + partial.value = None; + } + + #[inline] + fn is_saturated(&self, partial: &Self::Partial) -> bool { + partial.value.is_some() + } + + fn try_accumulate( + &self, + partial: &mut Self::Partial, + batch: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + if partial.value.is_some() { + return Ok(true); + } + if let Some(idx) = batch.validity_mask()?.first() { + let scalar = batch.scalar_at(idx)?; + partial.value = Some(scalar.into_nullable()); + } + Ok(true) + } + + fn accumulate( + &self, + _partial: &mut Self::Partial, + _batch: &Columnar, + _ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + unreachable!("First::try_accumulate handles all arrays") + } + + fn finalize(&self, partials: ArrayRef) -> VortexResult { + Ok(partials) + } + + fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult { + self.to_scalar(partial) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::Accumulator; + use crate::aggregate_fn::AggregateFnVTable; + use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::first::First; + use crate::aggregate_fn::fns::first::first; + use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::PrimitiveArray; + use crate::arrays::VarBinArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::Nullability::Nullable; + use crate::dtype::PType; + use crate::scalar::Scalar; + use crate::validity::Validity; + + #[test] + fn first_non_null() -> VortexResult<()> { + let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(first(&array, &mut ctx)?, Scalar::primitive(10i32, Nullable)); + Ok(()) + } + + #[test] + fn first_skips_leading_nulls() -> VortexResult<()> { + let array = + PrimitiveArray::from_option_iter([None, None, Some(7i32), Some(8)]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(first(&array, &mut ctx)?, Scalar::primitive(7i32, Nullable)); + Ok(()) + } + + #[test] + fn first_all_null() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullable); + assert_eq!(first(&array, &mut ctx)?, Scalar::null(dtype)); + Ok(()) + } + + #[test] + fn first_empty() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut acc = Accumulator::try_new(First, EmptyOptions, dtype)?; + let result = acc.finish()?; + assert_eq!(result, Scalar::null(DType::Primitive(PType::I32, Nullable))); + Ok(()) + } + + #[test] + fn first_constant() -> VortexResult<()> { + let array = ConstantArray::new(42i32, 10).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(first(&array, &mut ctx)?, Scalar::primitive(42i32, Nullable)); + Ok(()) + } + + #[test] + fn first_constant_null() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullable); + let array = ConstantArray::new(Scalar::null(dtype.clone()), 10).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(first(&array, &mut ctx)?, Scalar::null(dtype)); + Ok(()) + } + + #[test] + fn first_varbin() -> VortexResult<()> { + let array = VarBinArray::from_iter( + vec![None, Some("hello"), Some("world")], + DType::Utf8(Nullable), + ) + .into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(first(&array, &mut ctx)?, Scalar::utf8("hello", Nullable)); + Ok(()) + } + + #[test] + fn first_multi_batch_picks_earliest_non_null() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullable); + let mut acc = Accumulator::try_new(First, EmptyOptions, dtype)?; + + // First batch is all null - should not saturate. + let batch1 = PrimitiveArray::from_option_iter::([None, None]).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + assert!(!acc.is_saturated()); + + // Second batch contains the first non-null value. + let batch2 = PrimitiveArray::from_option_iter([None, Some(99i32), Some(100)]).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + assert!(acc.is_saturated()); + + // Third batch must be ignored - First is already saturated. + let batch3 = PrimitiveArray::from_option_iter([Some(1i32)]).into_array(); + acc.accumulate(&batch3, &mut ctx)?; + + let result = acc.finish()?; + assert_eq!(result, Scalar::primitive(99i32, Nullable)); + Ok(()) + } + + #[test] + fn first_finish_resets_state() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut acc = Accumulator::try_new(First, EmptyOptions, dtype)?; + + let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + assert_eq!(acc.finish()?, Scalar::primitive(10i32, Nullable)); + + let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + assert_eq!(acc.finish()?, Scalar::primitive(3i32, Nullable)); + Ok(()) + } + + #[test] + fn first_state_merge() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut state = First.empty_partial(&EmptyOptions, &dtype)?; + + // A null partial means the sub-accumulator saw nothing valid - should be ignored. + First.combine_partials(&mut state, Scalar::null(dtype.as_nullable()))?; + assert!(!First.is_saturated(&state)); + + First.combine_partials(&mut state, Scalar::primitive(5i32, Nullable))?; + assert!(First.is_saturated(&state)); + + // Subsequent valid partials are dropped. + First.combine_partials(&mut state, Scalar::primitive(7i32, Nullable))?; + assert_eq!(First.to_scalar(&state)?, Scalar::primitive(5i32, Nullable)); + Ok(()) + } + + #[test] + fn first_chunked() -> VortexResult<()> { + let chunk1 = PrimitiveArray::from_option_iter::([None, None]); + let chunk2 = PrimitiveArray::from_option_iter([None, Some(42i32), Some(100)]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!( + first(&chunked.into_array(), &mut ctx)?, + Scalar::primitive(42i32, Nullable) + ); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/last/mod.rs b/vortex-array/src/aggregate_fn/fns/last/mod.rs new file mode 100644 index 00000000000..37419dfebf9 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/last/mod.rs @@ -0,0 +1,293 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::Columnar; +use crate::ExecutionCtx; +use crate::aggregate_fn::Accumulator; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::EmptyOptions; +use crate::dtype::DType; +use crate::scalar::Scalar; + +/// Return the last non-null value of an array. +/// +/// See [`Last`] for details. +pub fn last(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let mut acc = Accumulator::try_new(Last, EmptyOptions, array.dtype().clone())?; + acc.accumulate(array, ctx)?; + acc.finish() +} + +/// Return the last non-null value seen across all batches. +#[derive(Clone, Debug)] +pub struct Last; + +/// Partial accumulator state for the [`Last`] aggregate. +pub struct LastPartial { + /// The nullable version of the input dtype, used for the result and for empty/all-null inputs. + return_dtype: DType, + /// The last non-null value seen so far, or `None` if no non-null value has been observed. + value: Option, +} + +impl AggregateFnVTable for Last { + type Options = EmptyOptions; + type Partial = LastPartial; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new_ref("vortex.last") + } + + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + &self, + _metadata: &[u8], + _session: &vortex_session::VortexSession, + ) -> VortexResult { + Ok(EmptyOptions) + } + + fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + Some(input_dtype.as_nullable()) + } + + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { + self.return_dtype(options, input_dtype) + } + + fn empty_partial( + &self, + _options: &Self::Options, + input_dtype: &DType, + ) -> VortexResult { + Ok(LastPartial { + return_dtype: input_dtype.as_nullable(), + value: None, + }) + } + + fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { + // Each new non-null partial replaces the previous one; nulls are ignored. + if !other.is_null() { + partial.value = Some(other); + } + Ok(()) + } + + fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { + Ok(match &partial.value { + Some(v) => v.clone(), + None => Scalar::null(partial.return_dtype.clone()), + }) + } + + fn reset(&self, partial: &mut Self::Partial) { + partial.value = None; + } + + #[inline] + fn is_saturated(&self, _partial: &Self::Partial) -> bool { + // Last can never short-circuit: a later batch can always supersede the current value. + false + } + + fn try_accumulate( + &self, + partial: &mut Self::Partial, + batch: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + if let Some(idx) = batch.validity_mask()?.last() { + let scalar = batch.scalar_at(idx)?; + partial.value = Some(scalar.into_nullable()); + } + Ok(true) + } + + fn accumulate( + &self, + _partial: &mut Self::Partial, + _batch: &Columnar, + _ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + unreachable!("Last::try_accumulate handles all arrays") + } + + fn finalize(&self, partials: ArrayRef) -> VortexResult { + Ok(partials) + } + + fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult { + self.to_scalar(partial) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::Accumulator; + use crate::aggregate_fn::AggregateFnVTable; + use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::last::Last; + use crate::aggregate_fn::fns::last::last; + use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::PrimitiveArray; + use crate::arrays::VarBinArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::Nullability::Nullable; + use crate::dtype::PType; + use crate::scalar::Scalar; + use crate::validity::Validity; + + #[test] + fn last_non_null() -> VortexResult<()> { + let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(last(&array, &mut ctx)?, Scalar::primitive(30i32, Nullable)); + Ok(()) + } + + #[test] + fn last_skips_trailing_nulls() -> VortexResult<()> { + let array = + PrimitiveArray::from_option_iter([Some(7i32), Some(8), None, None]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(last(&array, &mut ctx)?, Scalar::primitive(8i32, Nullable)); + Ok(()) + } + + #[test] + fn last_all_null() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullable); + assert_eq!(last(&array, &mut ctx)?, Scalar::null(dtype)); + Ok(()) + } + + #[test] + fn last_empty() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Last, EmptyOptions, dtype)?; + let result = acc.finish()?; + assert_eq!(result, Scalar::null(DType::Primitive(PType::I32, Nullable))); + Ok(()) + } + + #[test] + fn last_constant() -> VortexResult<()> { + let array = ConstantArray::new(42i32, 10).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(last(&array, &mut ctx)?, Scalar::primitive(42i32, Nullable)); + Ok(()) + } + + #[test] + fn last_constant_null() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullable); + let array = ConstantArray::new(Scalar::null(dtype.clone()), 10).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(last(&array, &mut ctx)?, Scalar::null(dtype)); + Ok(()) + } + + #[test] + fn last_varbin() -> VortexResult<()> { + let array = VarBinArray::from_iter( + vec![Some("hello"), Some("world"), None], + DType::Utf8(Nullable), + ) + .into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(last(&array, &mut ctx)?, Scalar::utf8("world", Nullable)); + Ok(()) + } + + #[test] + fn last_multi_batch_picks_latest_non_null() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullable); + let mut acc = Accumulator::try_new(Last, EmptyOptions, dtype)?; + + let batch1 = PrimitiveArray::from_option_iter([Some(1i32), Some(2)]).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + + // All-null batch must not clobber the previously-stored value. + let batch2 = PrimitiveArray::from_option_iter::([None, None]).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + + let batch3 = PrimitiveArray::from_option_iter([Some(99i32), None]).into_array(); + acc.accumulate(&batch3, &mut ctx)?; + + // Last is never saturated; later batches keep updating it. + assert!(!acc.is_saturated()); + + let result = acc.finish()?; + assert_eq!(result, Scalar::primitive(99i32, Nullable)); + Ok(()) + } + + #[test] + fn last_finish_resets_state() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Last, EmptyOptions, dtype)?; + + let batch1 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + assert_eq!(acc.finish()?, Scalar::primitive(20i32, Nullable)); + + let batch2 = PrimitiveArray::new(buffer![3i32, 6, 9], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + assert_eq!(acc.finish()?, Scalar::primitive(9i32, Nullable)); + Ok(()) + } + + #[test] + fn last_state_merge() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut state = Last.empty_partial(&EmptyOptions, &dtype)?; + + Last.combine_partials(&mut state, Scalar::primitive(5i32, Nullable))?; + assert_eq!(Last.to_scalar(&state)?, Scalar::primitive(5i32, Nullable)); + + // A later non-null partial replaces the prior value. + Last.combine_partials(&mut state, Scalar::primitive(7i32, Nullable))?; + assert_eq!(Last.to_scalar(&state)?, Scalar::primitive(7i32, Nullable)); + + // A null partial must not clobber the stored value. + Last.combine_partials(&mut state, Scalar::null(dtype.as_nullable()))?; + assert_eq!(Last.to_scalar(&state)?, Scalar::primitive(7i32, Nullable)); + Ok(()) + } + + #[test] + fn last_chunked() -> VortexResult<()> { + let chunk1 = PrimitiveArray::from_option_iter([Some(42i32), Some(100)]); + let chunk2 = PrimitiveArray::from_option_iter::([None, None]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!( + last(&chunked.into_array(), &mut ctx)?, + Scalar::primitive(100i32, Nullable) + ); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/mod.rs b/vortex-array/src/aggregate_fn/fns/mod.rs index a8ae116d939..38d5340cd1f 100644 --- a/vortex-array/src/aggregate_fn/fns/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mod.rs @@ -2,8 +2,10 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors pub mod count; +pub mod first; pub mod is_constant; pub mod is_sorted; +pub mod last; pub mod min_max; pub mod nan_count; pub mod sum; diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index c5024fd4a8d..64e66195358 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -12,8 +12,10 @@ use vortex_utils::aliases::hash_map::HashMap; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnPluginRef; use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::fns::first::First; use crate::aggregate_fn::fns::is_constant::IsConstant; use crate::aggregate_fn::fns::is_sorted::IsSorted; +use crate::aggregate_fn::fns::last::Last; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::nan_count::NanCount; use crate::aggregate_fn::fns::sum::Sum; @@ -50,8 +52,10 @@ impl Default for AggregateFnSession { }; // Register the built-in aggregate functions + this.register(First); this.register(IsConstant); this.register(IsSorted); + this.register(Last); this.register(MinMax); this.register(NanCount); this.register(Sum); diff --git a/vortex-mask/public-api.lock b/vortex-mask/public-api.lock index 1fdf9ccd598..a986381bcff 100644 --- a/vortex-mask/public-api.lock +++ b/vortex-mask/public-api.lock @@ -68,6 +68,8 @@ pub fn vortex_mask::Mask::into_bit_buffer(self) -> vortex_buffer::bit::buf::BitB pub fn vortex_mask::Mask::is_empty(&self) -> bool +pub fn vortex_mask::Mask::last(&self) -> core::option::Option + pub fn vortex_mask::Mask::len(&self) -> usize pub fn vortex_mask::Mask::limit(self, limit: usize) -> Self diff --git a/vortex-mask/src/lib.rs b/vortex-mask/src/lib.rs index bf60d6fa5d6..9c551837de1 100644 --- a/vortex-mask/src/lib.rs +++ b/vortex-mask/src/lib.rs @@ -450,6 +450,23 @@ impl Mask { } } + /// Returns the last true index in the mask. + pub fn last(&self) -> Option { + match &self { + Self::AllTrue(len) => (*len > 0).then_some(*len - 1), + Self::AllFalse(_) => None, + Self::Values(values) => { + if let Some(indices) = values.indices.get() { + return indices.last().copied(); + } + if let Some(slices) = values.slices.get() { + return slices.last().map(|(_, end)| end - 1); + } + values.buffer.set_slices().last().map(|(_, end)| end - 1) + } + } + } + /// Returns the position in the mask of the nth true value. pub fn rank(&self, n: usize) -> usize { if n >= self.true_count() {