Skip to content

Commit 2c827e4

Browse files
committed
IsSortedKernel
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 263d8d9 commit 2c827e4

9 files changed

Lines changed: 85 additions & 94 deletions

File tree

vortex-array/src/aggregate_fn/accumulator.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
146146
}
147147

148148
fn flush(&mut self) -> VortexResult<Scalar> {
149-
let partial = self.vtable.flush(&mut self.partial)?;
149+
let partial = self.vtable.to_scalar(&self.partial)?;
150+
self.vtable.reset(&mut self.partial);
150151

151152
#[cfg(debug_assertions)]
152153
{
@@ -162,8 +163,8 @@ impl<V: AggregateFnVTable> DynAccumulator for Accumulator<V> {
162163
}
163164

164165
fn finish(&mut self) -> VortexResult<Scalar> {
165-
let partial = self.flush()?;
166-
let result = self.vtable.finalize_scalar(partial)?;
166+
let result = self.vtable.finalize_scalar(&self.partial)?;
167+
self.vtable.reset(&mut self.partial);
167168

168169
vortex_ensure!(
169170
result.dtype() == &self.return_dtype,

vortex-array/src/aggregate_fn/fns/is_constant/mod.rs

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,10 @@ impl IsConstantPartial {
234234
}
235235
}
236236

237+
static NAMES: std::sync::LazyLock<FieldNames> =
238+
std::sync::LazyLock::new(|| FieldNames::from(["is_constant", "value"]));
239+
237240
pub fn make_is_constant_partial_dtype(element_dtype: &DType) -> DType {
238-
static NAMES: std::sync::LazyLock<FieldNames> =
239-
std::sync::LazyLock::new(|| FieldNames::from(["is_constant", "value"]));
240241
DType::Struct(
241242
StructFields::new(
242243
NAMES.clone(),
@@ -313,9 +314,9 @@ impl AggregateFnVTable for IsConstant {
313314
Ok(())
314315
}
315316

316-
fn flush(&self, partial: &mut Self::Partial) -> VortexResult<Scalar> {
317+
fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
317318
let dtype = make_is_constant_partial_dtype(&partial.element_dtype);
318-
let result = match &partial.first_value {
319+
Ok(match &partial.first_value {
319320
None => {
320321
// Empty accumulator — return null struct.
321322
Scalar::null(dtype)
@@ -329,13 +330,12 @@ impl AggregateFnVTable for IsConstant {
329330
.cast(&partial.element_dtype.as_nullable())?,
330331
],
331332
),
332-
};
333+
})
334+
}
333335

334-
// Reset state.
336+
fn reset(&self, partial: &mut Self::Partial) {
335337
partial.is_constant = true;
336338
partial.first_value = None;
337-
338-
Ok(result)
339339
}
340340

341341
#[inline]
@@ -409,21 +409,15 @@ impl AggregateFnVTable for IsConstant {
409409
}
410410

411411
fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
412-
// TODO: extract is_constant field from struct array
413-
Ok(partials)
412+
partials.get_item(NAMES.get(0))
414413
}
415414

416-
fn finalize_scalar(&self, partial: Scalar) -> VortexResult<Scalar> {
417-
if partial.is_null() {
415+
fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
416+
if partial.first_value.is_none() {
418417
// Empty accumulator → return false.
419418
return Ok(Scalar::bool(false, Nullability::NonNullable));
420419
}
421-
let is_constant_val = partial
422-
.as_struct()
423-
.field_by_idx(0)
424-
.map(|s| s.as_bool().value().unwrap_or(false))
425-
.unwrap_or(false);
426-
Ok(Scalar::bool(is_constant_val, Nullability::NonNullable))
420+
Ok(Scalar::bool(partial.is_constant, Nullability::NonNullable))
427421
}
428422
}
429423

vortex-array/src/aggregate_fn/fns/is_sorted/mod.rs

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use crate::aggregate_fn::AggregateFnVTable;
3030
use crate::aggregate_fn::DynAccumulator;
3131
use crate::arrays::Constant;
3232
use crate::arrays::Null;
33+
use crate::builtins::ArrayBuiltins;
3334
use crate::dtype::DType;
3435
use crate::dtype::FieldNames;
3536
use crate::dtype::Nullability;
@@ -200,10 +201,11 @@ pub struct IsSortedPartial {
200201
element_dtype: DType,
201202
}
202203

204+
static NAMES: std::sync::LazyLock<FieldNames> = std::sync::LazyLock::new(|| {
205+
FieldNames::from(["is_sorted", "strict", "first_value", "last_value"])
206+
});
207+
203208
pub fn make_is_sorted_partial_dtype(element_dtype: &DType) -> DType {
204-
static NAMES: std::sync::LazyLock<FieldNames> = std::sync::LazyLock::new(|| {
205-
FieldNames::from(["is_sorted", "strict", "first_value", "last_value"])
206-
});
207209
DType::Struct(
208210
StructFields::new(
209211
NAMES.clone(),
@@ -315,56 +317,48 @@ impl AggregateFnVTable for IsSorted {
315317
Ok(())
316318
}
317319

318-
fn flush(&self, partial: &mut Self::Partial) -> VortexResult<Scalar> {
320+
fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
319321
let dtype = make_is_sorted_partial_dtype(&partial.element_dtype);
320-
321-
// Take ownership of the values to avoid cloning.
322-
let first = partial.first_value.take();
323-
let last = partial.last_value.take();
324-
let is_sorted = partial.is_sorted;
325-
let strict = partial.strict;
326-
327-
// Reset state.
328-
partial.is_sorted = true;
329-
330-
let result = match (first, last) {
322+
Ok(match (&partial.first_value, &partial.last_value) {
331323
(None, _) => {
332324
// Empty accumulator — return null struct.
333325
Scalar::null(dtype)
334326
}
335327
(Some(first_value), Some(last_value)) => {
336-
// Values are already nullable from into_nullable_unchecked in accumulate.
337328
// SAFETY: We constructed partial_dtype and the children match its field dtypes.
338329
unsafe {
339330
Scalar::struct_unchecked(
340331
dtype,
341332
[
342-
Scalar::bool(is_sorted, Nullability::NonNullable),
343-
Scalar::bool(strict, Nullability::NonNullable),
344-
first_value,
345-
last_value,
333+
Scalar::bool(partial.is_sorted, Nullability::NonNullable),
334+
Scalar::bool(partial.strict, Nullability::NonNullable),
335+
first_value.clone(),
336+
last_value.clone(),
346337
],
347338
)
348339
}
349340
}
350341
(Some(first_value), None) => {
351-
let cloned = first_value.clone();
352342
// SAFETY: We constructed partial_dtype and the children match its field dtypes.
353343
unsafe {
354344
Scalar::struct_unchecked(
355345
dtype,
356346
[
357-
Scalar::bool(is_sorted, Nullability::NonNullable),
358-
Scalar::bool(strict, Nullability::NonNullable),
359-
first_value,
360-
cloned,
347+
Scalar::bool(partial.is_sorted, Nullability::NonNullable),
348+
Scalar::bool(partial.strict, Nullability::NonNullable),
349+
first_value.clone(),
350+
first_value.clone(),
361351
],
362352
)
363353
}
364354
}
365-
};
355+
})
356+
}
366357

367-
Ok(result)
358+
fn reset(&self, partial: &mut Self::Partial) {
359+
partial.is_sorted = true;
360+
partial.first_value = None;
361+
partial.last_value = None;
368362
}
369363

370364
#[inline]
@@ -479,21 +473,15 @@ impl AggregateFnVTable for IsSorted {
479473
}
480474

481475
fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
482-
// TODO: extract is_sorted field from struct array
483-
Ok(partials)
476+
partials.get_item(NAMES.get(0))
484477
}
485478

486-
fn finalize_scalar(&self, partial: Scalar) -> VortexResult<Scalar> {
487-
if partial.is_null() {
479+
fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
480+
if partial.first_value.is_none() {
488481
// Empty accumulator → vacuously sorted.
489482
return Ok(Scalar::bool(true, Nullability::NonNullable));
490483
}
491-
let is_sorted_val = partial
492-
.as_struct()
493-
.field_by_idx(0)
494-
.map(|s| s.as_bool().value().unwrap_or(false))
495-
.unwrap_or(false);
496-
Ok(Scalar::bool(is_sorted_val, Nullability::NonNullable))
484+
Ok(Scalar::bool(partial.is_sorted, Nullability::NonNullable))
497485
}
498486
}
499487

vortex-array/src/aggregate_fn/fns/min_max/mod.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,17 @@ impl AggregateFnVTable for MinMax {
210210
Ok(())
211211
}
212212

213-
fn flush(&self, partial: &mut Self::Partial) -> VortexResult<Scalar> {
213+
fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
214214
let dtype = make_minmax_dtype(&partial.element_dtype);
215-
let result = match (partial.min.take(), partial.max.take()) {
216-
(Some(min), Some(max)) => Scalar::struct_(dtype, vec![min, max]),
215+
Ok(match (&partial.min, &partial.max) {
216+
(Some(min), Some(max)) => Scalar::struct_(dtype, vec![min.clone(), max.clone()]),
217217
_ => Scalar::null(dtype),
218-
};
219-
Ok(result)
218+
})
219+
}
220+
221+
fn reset(&self, partial: &mut Self::Partial) {
222+
partial.min = None;
223+
partial.max = None;
220224
}
221225

222226
#[inline]
@@ -266,8 +270,8 @@ impl AggregateFnVTable for MinMax {
266270
Ok(partials)
267271
}
268272

269-
fn finalize_scalar(&self, partial: Scalar) -> VortexResult<Scalar> {
270-
Ok(partial)
273+
fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
274+
self.to_scalar(partial)
271275
}
272276
}
273277

@@ -450,7 +454,7 @@ mod tests {
450454
let scalar2 = Scalar::struct_(struct_dtype, vec![Scalar::from(2i32), Scalar::from(10i32)]);
451455
MinMax.combine_partials(&mut state, scalar2)?;
452456

453-
let result = MinMaxResult::from_scalar(MinMax.flush(&mut state)?)?
457+
let result = MinMaxResult::from_scalar(MinMax.to_scalar(&state)?)?
454458
.vortex_expect("should have result");
455459
assert_eq!(result.min, Scalar::from(2i32));
456460
assert_eq!(result.max, Scalar::from(15i32));

vortex-array/src/aggregate_fn/fns/nan_count/mod.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,12 @@ impl AggregateFnVTable for NanCount {
115115
Ok(())
116116
}
117117

118-
fn flush(&self, partial: &mut Self::Partial) -> VortexResult<Scalar> {
119-
let result = Scalar::primitive(*partial, NonNullable);
118+
fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
119+
Ok(Scalar::primitive(*partial, NonNullable))
120+
}
121+
122+
fn reset(&self, partial: &mut Self::Partial) {
120123
*partial = 0;
121-
Ok(result)
122124
}
123125

124126
#[inline]
@@ -157,8 +159,8 @@ impl AggregateFnVTable for NanCount {
157159
Ok(partials)
158160
}
159161

160-
fn finalize_scalar(&self, partial: Scalar) -> VortexResult<Scalar> {
161-
Ok(partial)
162+
fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
163+
self.to_scalar(partial)
162164
}
163165
}
164166

@@ -236,7 +238,8 @@ mod tests {
236238
let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable);
237239
NanCount.combine_partials(&mut state, scalar2)?;
238240

239-
let result = NanCount.flush(&mut state)?;
241+
let result = NanCount.to_scalar(&state)?;
242+
NanCount.reset(&mut state);
240243
assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8));
241244
Ok(())
242245
}

vortex-array/src/aggregate_fn/fns/sum/decimal.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ mod tests {
320320
let small = Scalar::decimal(DecimalValue::from(9i64), DecimalDType::new(14, 0), Nullable);
321321
Sum.combine_partials(&mut state, small)?;
322322

323-
let result = Sum.flush(&mut state)?;
323+
let result = Sum.to_scalar(&state)?;
324324
assert!(!result.is_null());
325325
assert_eq!(
326326
result.as_decimal().decimal_value(),
@@ -351,7 +351,7 @@ mod tests {
351351
Scalar::decimal(DecimalValue::from(1i64), DecimalDType::new(14, 0), Nullable);
352352
Sum.combine_partials(&mut state, one_more)?;
353353

354-
let result = Sum.flush(&mut state)?;
354+
let result = Sum.to_scalar(&state)?;
355355
assert!(result.is_null());
356356
assert_eq!(
357357
result.dtype(),
@@ -380,7 +380,7 @@ mod tests {
380380
);
381381
Sum.combine_partials(&mut state, one_more)?;
382382

383-
let result = Sum.flush(&mut state)?;
383+
let result = Sum.to_scalar(&state)?;
384384
assert!(result.is_null());
385385
Ok(())
386386
}
@@ -413,7 +413,7 @@ mod tests {
413413
let mut ctx = LEGACY_SESSION.create_execution_ctx();
414414
Sum.accumulate(&mut state, &columnar, &mut ctx)?;
415415

416-
let result = Sum.flush(&mut state)?;
416+
let result = Sum.to_scalar(&state)?;
417417
assert!(result.is_null());
418418
Ok(())
419419
}

vortex-array/src/aggregate_fn/fns/sum/mod.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ impl AggregateFnVTable for Sum {
180180
Ok(())
181181
}
182182

183-
fn flush(&self, partial: &mut Self::Partial) -> VortexResult<Scalar> {
184-
let result = match &partial.current {
183+
fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
184+
Ok(match &partial.current {
185185
None => Scalar::null(partial.return_dtype.as_nullable()),
186186
Some(SumState::Unsigned(v)) => Scalar::primitive(*v, Nullability::Nullable),
187187
Some(SumState::Signed(v)) => Scalar::primitive(*v, Nullability::Nullable),
@@ -193,12 +193,11 @@ impl AggregateFnVTable for Sum {
193193
.vortex_expect("return dtype must be decimal");
194194
Scalar::decimal(*value, decimal_dtype, Nullability::Nullable)
195195
}
196-
};
196+
})
197+
}
197198

198-
// Reset the state
199+
fn reset(&self, partial: &mut Self::Partial) {
199200
partial.current = Some(make_zero_state(&partial.return_dtype));
200-
201-
Ok(result)
202201
}
203202

204203
#[inline]
@@ -250,8 +249,8 @@ impl AggregateFnVTable for Sum {
250249
Ok(partials)
251250
}
252251

253-
fn finalize_scalar(&self, partial: Scalar) -> VortexResult<Scalar> {
254-
Ok(partial)
252+
fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
253+
self.to_scalar(partial)
255254
}
256255
}
257256

@@ -461,7 +460,8 @@ mod tests {
461460
let scalar2 = Scalar::primitive(50i64, Nullable);
462461
Sum.combine_partials(&mut state, scalar2)?;
463462

464-
let result = Sum.flush(&mut state)?;
463+
let result = Sum.to_scalar(&state)?;
464+
Sum.reset(&mut state);
465465
assert_eq!(result.as_primitive().typed_value::<i64>(), Some(150));
466466
Ok(())
467467
}

0 commit comments

Comments
 (0)