Skip to content

Commit 28050c0

Browse files
committed
Centralize aggregate stat bridge
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 95f429d commit 28050c0

6 files changed

Lines changed: 75 additions & 101 deletions

File tree

vortex-array/public-api.lock

Lines changed: 46 additions & 46 deletions
Large diffs are not rendered by default.

vortex-array/src/aggregate_fn/accumulator.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use crate::aggregate_fn::session::AggregateFnSessionExt;
1515
use crate::columnar::AnyColumnar;
1616
use crate::dtype::DType;
1717
use crate::executor::max_iterations;
18+
use crate::expr::stats::Precision;
19+
use crate::expr::stats::StatsProvider;
1820
use crate::scalar::Scalar;
1921

2022
/// Reference-counted type-erased accumulator.
@@ -116,17 +118,20 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
116118
batch.dtype()
117119
);
118120

119-
// 0. Stats-driven shortcut: if the aggregate can be derived directly from the batch's
120-
// cached statistics, use that and skip both kernel dispatch and decode. This is the
121-
// only layer that consults `batch.statistics()`; encoding kernels must not.
122-
if let Some(result) = self.vtable.try_partial_from_stats(batch)? {
121+
// 0. Legacy stats bridge: if this aggregate is still cached under a legacy Stat slot,
122+
// consume that exact stat before kernel dispatch or decode.
123+
if let Some(stat) = self.vtable.maybe_stat()
124+
&& let Some(Precision::Exact(partial)) = batch.statistics().get(stat)
125+
{
123126
vortex_ensure!(
124-
result.dtype() == &self.partial_dtype,
125-
"Aggregate try_partial_from_stats returned {}, expected {}",
126-
result.dtype(),
127+
partial.dtype() == &self.partial_dtype,
128+
"Aggregate {} read legacy stat {} with dtype {}, expected {}",
129+
self.aggregate_fn,
130+
stat,
131+
partial.dtype(),
127132
self.partial_dtype,
128133
);
129-
self.vtable.combine_partials(&mut self.partial, result)?;
134+
self.vtable.combine_partials(&mut self.partial, partial)?;
130135
return Ok(());
131136
}
132137

vortex-array/src/aggregate_fn/fns/nan_count/mod.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ mod primitive;
66
use vortex_error::VortexExpect;
77
use vortex_error::VortexResult;
88
use vortex_error::vortex_bail;
9-
use vortex_error::vortex_err;
109

1110
use self::primitive::accumulate_primitive;
1211
use crate::ArrayRef;
@@ -23,20 +22,13 @@ use crate::dtype::Nullability::NonNullable;
2322
use crate::dtype::PType;
2423
use crate::expr::stats::Precision;
2524
use crate::expr::stats::Stat;
26-
use crate::expr::stats::StatsProvider;
2725
use crate::scalar::Scalar;
2826
use crate::scalar::ScalarValue;
2927

3028
/// Return the number of NaN values in an array.
3129
///
3230
/// See [`NanCount`] for details.
3331
pub fn nan_count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> {
34-
// Short-circuit using cached array statistics.
35-
if let Some(Precision::Exact(nan_count_scalar)) = array.statistics().get(Stat::NaNCount) {
36-
return usize::try_from(&nan_count_scalar)
37-
.map_err(|e| vortex_err!("Failed to convert NaN count stat to usize: {e}"));
38-
}
39-
4032
// Short-circuit for non-float types.
4133
if NanCount
4234
.return_dtype(&EmptyOptions, array.dtype())
@@ -132,6 +124,10 @@ impl AggregateFnVTable for NanCount {
132124
false
133125
}
134126

127+
fn maybe_stat(&self) -> Option<Stat> {
128+
Some(Stat::NaNCount)
129+
}
130+
135131
fn accumulate(
136132
&self,
137133
partial: &mut Self::Partial,

vortex-array/src/aggregate_fn/fns/sum/mod.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,13 @@ use crate::dtype::Nullability;
3232
use crate::dtype::PType;
3333
use crate::expr::stats::Precision;
3434
use crate::expr::stats::Stat;
35-
use crate::expr::stats::StatsProvider;
3635
use crate::scalar::DecimalValue;
3736
use crate::scalar::Scalar;
3837

3938
/// Return the sum of an array.
4039
///
4140
/// See [`Sum`] for details.
4241
pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
43-
// Short-circuit using cached array statistics.
44-
if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) {
45-
return Ok(sum_scalar);
46-
}
47-
4842
// Compute using Accumulator<Sum>.
4943
// TODO(ngates): we may want to wrap this three-step dance up into an extension crate maybe.
5044
let mut acc = Accumulator::try_new(Sum, EmptyOptions, array.dtype().clone())?;
@@ -213,6 +207,10 @@ impl AggregateFnVTable for Sum {
213207
}
214208
}
215209

210+
fn maybe_stat(&self) -> Option<Stat> {
211+
Some(Stat::Sum)
212+
}
213+
216214
fn accumulate(
217215
&self,
218216
partial: &mut Self::Partial,

vortex-array/src/aggregate_fn/fns/uncompressed_size_in_bytes/mod.rs

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ use crate::dtype::Nullability::NonNullable;
4848
use crate::dtype::PType;
4949
use crate::expr::stats::Precision;
5050
use crate::expr::stats::Stat;
51-
use crate::expr::stats::StatsProvider;
5251
use crate::scalar::Scalar;
5352
use crate::scalar::ScalarValue;
5453

@@ -63,13 +62,6 @@ pub fn uncompressed_size_in_bytes(array: &ArrayRef, ctx: &mut ExecutionCtx) -> V
6362
}
6463

6564
fn uncompressed_size_in_bytes_u64(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<u64> {
66-
if let Some(Precision::Exact(size_scalar)) =
67-
array.statistics().get(Stat::UncompressedSizeInBytes)
68-
{
69-
return u64::try_from(&size_scalar)
70-
.map_err(|e| vortex_err!("Failed to convert uncompressed size stat to u64: {e}"));
71-
}
72-
7365
let mut acc =
7466
Accumulator::try_new(UncompressedSizeInBytes, EmptyOptions, array.dtype().clone())?;
7567
acc.accumulate(array, ctx)?;
@@ -150,15 +142,8 @@ impl AggregateFnVTable for UncompressedSizeInBytes {
150142
false
151143
}
152144

153-
fn try_partial_from_stats(&self, batch: &ArrayRef) -> VortexResult<Option<Scalar>> {
154-
let Some(Precision::Exact(size_scalar)) =
155-
batch.statistics().get(Stat::UncompressedSizeInBytes)
156-
else {
157-
return Ok(None);
158-
};
159-
let size = u64::try_from(&size_scalar)
160-
.map_err(|e| vortex_err!("Failed to convert uncompressed size stat to u64: {e}"))?;
161-
Ok(Some(Scalar::primitive(size, NonNullable)))
145+
fn maybe_stat(&self) -> Option<Stat> {
146+
Some(Stat::UncompressedSizeInBytes)
162147
}
163148

164149
fn accumulate(

vortex-array/src/aggregate_fn/vtable.rs

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::aggregate_fn::AggregateFn;
1818
use crate::aggregate_fn::AggregateFnId;
1919
use crate::aggregate_fn::AggregateFnRef;
2020
use crate::dtype::DType;
21+
use crate::expr::stats::Stat;
2122
use crate::scalar::Scalar;
2223

2324
/// Defines the interface for aggregate function vtables.
@@ -102,23 +103,12 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync {
102103
/// final result is fully determined.
103104
fn is_saturated(&self, state: &Self::Partial) -> bool;
104105

105-
/// Try to derive a partial scalar from the batch's cached statistics, before any
106-
/// kernel dispatch or canonicalization.
106+
/// Return the legacy [`Stat`] slot that stores this aggregate, if one exists.
107107
///
108-
/// Returns `Some(partial_scalar)` if the answer can be read directly from `batch.statistics()`,
109-
/// otherwise `Ok(None)` to fall through to the rest of dispatch. The returned scalar must
110-
/// have the dtype reported by `partial_dtype`.
111-
///
112-
/// This is the single place stats-based shortcuts live; encoding kernels must not consult
113-
/// stats themselves. Runs first so that an upstream producer who pre-populates the relevant
114-
/// stat (e.g. a layout reader hydrating `Stat::UncompressedSizeInBytes` from file metadata)
115-
/// can skip both kernel dispatch and decode.
116-
///
117-
/// TODO: this hook may be removed once `ArrayStats` stores aggregate partials internally —
118-
/// at that point stat-driven shortcuts can be resolved automatically by the dispatch layer
119-
/// without each aggregate vtable opting in.
120-
fn try_partial_from_stats(&self, _batch: &ArrayRef) -> VortexResult<Option<Scalar>> {
121-
Ok(None)
108+
/// This is a temporary bridge while some aggregate partials are still cached under the
109+
/// legacy [`Stat`] enum rather than by aggregate function identity.
110+
fn maybe_stat(&self) -> Option<Stat> {
111+
None
122112
}
123113

124114
/// Try to accumulate the raw array before decompression.

0 commit comments

Comments
 (0)