Skip to content

Commit 0a03eea

Browse files
committed
fix reduction/execute cycle
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 9b11e57 commit 0a03eea

File tree

4 files changed

+38
-36
lines changed

4 files changed

+38
-36
lines changed

vortex-array/src/arrays/constant/vtable/canonical.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use vortex_error::VortexExpect;
1010
use vortex_error::VortexResult;
1111

1212
use crate::Canonical;
13+
use crate::ExecutionCtx;
1314
use crate::IntoArray;
1415
use crate::array::ArrayView;
1516
use crate::arrays::BoolArray;
@@ -36,7 +37,10 @@ use crate::scalar::Scalar;
3637
use crate::validity::Validity;
3738

3839
/// Shared implementation for both `canonicalize` and `execute` methods.
39-
pub(crate) fn constant_canonicalize(array: ArrayView<'_, Constant>) -> VortexResult<Canonical> {
40+
pub(crate) fn constant_canonicalize(
41+
array: ArrayView<'_, Constant>,
42+
ctx: &mut ExecutionCtx,
43+
) -> VortexResult<Canonical> {
4044
let scalar = array.scalar();
4145

4246
let validity = match array.dtype().nullability() {
@@ -163,7 +167,16 @@ pub(crate) fn constant_canonicalize(array: ArrayView<'_, Constant>) -> VortexRes
163167
let s = scalar.as_extension();
164168

165169
let storage_scalar = s.to_storage_scalar();
166-
let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array();
170+
171+
// NB: We need to execute the constant array to be canonical because there is a
172+
// reduction rule that turns `Extension(Constant(..))` into `Constant(Extension(..))`,
173+
// and if we don't do this we create an infinite cycle.
174+
// See `ExtensionConstantRule` for more details.
175+
let storage_self = ConstantArray::new(storage_scalar, array.len())
176+
.into_array()
177+
.execute::<Canonical>(ctx)?
178+
.into_array();
179+
167180
Canonical::Extension(ExtensionArray::new(ext_dtype.clone(), storage_self))
168181
}
169182
DType::Variant(_) => {

vortex-array/src/arrays/constant/vtable/mod.rs

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ impl VTable for Constant {
160160
PARENT_RULES.evaluate(array, parent, child_idx)
161161
}
162162

163-
fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
163+
fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
164164
Ok(ExecutionResult::done(constant_canonicalize(
165165
array.as_view(),
166+
ctx,
166167
)?))
167168
}
168169

@@ -268,10 +269,11 @@ fn append_value_or_nulls<B: ArrayBuilder + 'static>(
268269
#[cfg(test)]
269270
mod tests {
270271
use rstest::rstest;
272+
use vortex_error::VortexResult;
271273
use vortex_session::VortexSession;
272274

273-
use crate::ExecutionCtx;
274275
use crate::IntoArray;
276+
use crate::VortexSessionExecute;
275277
use crate::arrays::ConstantArray;
276278
use crate::arrays::constant::vtable::canonical::constant_canonicalize;
277279
use crate::assert_arrays_eq;
@@ -282,42 +284,37 @@ mod tests {
282284
use crate::dtype::StructFields;
283285
use crate::scalar::Scalar;
284286

285-
fn ctx() -> ExecutionCtx {
286-
ExecutionCtx::new(VortexSession::empty())
287-
}
288-
289287
/// Appends `array` into a fresh builder and asserts the result matches `constant_canonicalize`.
290-
fn assert_append_matches_canonical(array: ConstantArray) -> vortex_error::VortexResult<()> {
291-
let expected = constant_canonicalize(array.as_view())?.into_array();
288+
fn assert_append_matches_canonical(array: ConstantArray) -> VortexResult<()> {
289+
let mut ctx = VortexSession::empty().create_execution_ctx();
290+
291+
let expected = constant_canonicalize(array.as_view(), &mut ctx)?.into_array();
292292
let mut builder = builder_with_capacity(array.dtype(), array.len());
293293
array
294294
.into_array()
295-
.append_to_builder(builder.as_mut(), &mut ctx())?;
295+
.append_to_builder(builder.as_mut(), &mut ctx)?;
296296
let result = builder.finish();
297297
assert_arrays_eq!(&result, &expected);
298298
Ok(())
299299
}
300300

301301
#[test]
302-
fn test_null_constant_append() -> vortex_error::VortexResult<()> {
302+
fn test_null_constant_append() -> VortexResult<()> {
303303
assert_append_matches_canonical(ConstantArray::new(Scalar::null(DType::Null), 5))
304304
}
305305

306306
#[rstest]
307307
#[case::bool_true(true, 5)]
308308
#[case::bool_false(false, 3)]
309-
fn test_bool_constant_append(
310-
#[case] value: bool,
311-
#[case] n: usize,
312-
) -> vortex_error::VortexResult<()> {
309+
fn test_bool_constant_append(#[case] value: bool, #[case] n: usize) -> VortexResult<()> {
313310
assert_append_matches_canonical(ConstantArray::new(
314311
Scalar::bool(value, Nullability::NonNullable),
315312
n,
316313
))
317314
}
318315

319316
#[test]
320-
fn test_bool_null_constant_append() -> vortex_error::VortexResult<()> {
317+
fn test_bool_null_constant_append() -> VortexResult<()> {
321318
assert_append_matches_canonical(ConstantArray::new(
322319
Scalar::null(DType::Bool(Nullability::Nullable)),
323320
4,
@@ -332,7 +329,7 @@ mod tests {
332329
fn test_primitive_constant_append(
333330
#[case] scalar: Scalar,
334331
#[case] n: usize,
335-
) -> vortex_error::VortexResult<()> {
332+
) -> VortexResult<()> {
336333
assert_append_matches_canonical(ConstantArray::new(scalar, n))
337334
}
338335

@@ -341,18 +338,15 @@ mod tests {
341338
#[case::utf8_noninline("hello world!!", 5)] // >12 bytes: requires buffer block
342339
#[case::utf8_empty("", 3)]
343340
#[case::utf8_n_zero("hello world!!", 0)] // n=0 with non-inline: must not write orphaned bytes
344-
fn test_utf8_constant_append(
345-
#[case] value: &str,
346-
#[case] n: usize,
347-
) -> vortex_error::VortexResult<()> {
341+
fn test_utf8_constant_append(#[case] value: &str, #[case] n: usize) -> VortexResult<()> {
348342
assert_append_matches_canonical(ConstantArray::new(
349343
Scalar::utf8(value, Nullability::NonNullable),
350344
n,
351345
))
352346
}
353347

354348
#[test]
355-
fn test_utf8_null_constant_append() -> vortex_error::VortexResult<()> {
349+
fn test_utf8_null_constant_append() -> VortexResult<()> {
356350
assert_append_matches_canonical(ConstantArray::new(
357351
Scalar::null(DType::Utf8(Nullability::Nullable)),
358352
4,
@@ -362,26 +356,23 @@ mod tests {
362356
#[rstest]
363357
#[case::binary_inline(vec![1u8, 2, 3], 5)] // ≤12 bytes: inlined
364358
#[case::binary_noninline(vec![0u8; 13], 5)] // >12 bytes: buffer block
365-
fn test_binary_constant_append(
366-
#[case] value: Vec<u8>,
367-
#[case] n: usize,
368-
) -> vortex_error::VortexResult<()> {
359+
fn test_binary_constant_append(#[case] value: Vec<u8>, #[case] n: usize) -> VortexResult<()> {
369360
assert_append_matches_canonical(ConstantArray::new(
370361
Scalar::binary(value, Nullability::NonNullable),
371362
n,
372363
))
373364
}
374365

375366
#[test]
376-
fn test_binary_null_constant_append() -> vortex_error::VortexResult<()> {
367+
fn test_binary_null_constant_append() -> VortexResult<()> {
377368
assert_append_matches_canonical(ConstantArray::new(
378369
Scalar::null(DType::Binary(Nullability::Nullable)),
379370
4,
380371
))
381372
}
382373

383374
#[test]
384-
fn test_struct_constant_append() -> vortex_error::VortexResult<()> {
375+
fn test_struct_constant_append() -> VortexResult<()> {
385376
let fields = StructFields::new(
386377
["x", "y"].into(),
387378
vec![
@@ -400,7 +391,7 @@ mod tests {
400391
}
401392

402393
#[test]
403-
fn test_null_struct_constant_append() -> vortex_error::VortexResult<()> {
394+
fn test_null_struct_constant_append() -> VortexResult<()> {
404395
let fields = StructFields::new(
405396
["x"].into(),
406397
vec![DType::Primitive(PType::I32, Nullability::Nullable)],

vortex-array/src/arrays/extension/vtable/canonical.rs

Lines changed: 0 additions & 2 deletions
This file was deleted.

vortex-array/src/arrays/extension/vtable/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3-
mod canonical;
4-
mod kernel;
5-
mod operations;
6-
mod validity;
73

84
use std::hash::Hasher;
95

@@ -36,6 +32,10 @@ use crate::buffer::BufferHandle;
3632
use crate::dtype::DType;
3733
use crate::serde::ArrayChildren;
3834

35+
mod kernel;
36+
mod operations;
37+
mod validity;
38+
3939
#[derive(Clone, Debug)]
4040
pub struct Extension;
4141

0 commit comments

Comments
 (0)