Skip to content

Commit 18fee37

Browse files
committed
Is this the right way to execute?
1 parent 553afaf commit 18fee37

2 files changed

Lines changed: 97 additions & 7 deletions

File tree

  • encodings/parquet-variant/src/variant_get
  • vortex-array/src/scalar_fn/fns/variant_get

encodings/parquet-variant/src/variant_get/tests.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use vortex_array::expr::variant_get;
3434
use vortex_array::expr::variant_get_as;
3535
use vortex_array::scalar_fn::fns::variant_get::VariantPath as VortexVariantPath;
3636
use vortex_error::VortexResult;
37+
use vortex_mask::Mask;
3738

3839
use crate::ParquetVariant;
3940
use crate::ParquetVariantArrayExt;
@@ -458,6 +459,45 @@ fn test_variant_get_different_field() -> VortexResult<()> {
458459
Ok(())
459460
}
460461

462+
#[test]
463+
fn test_variant_get_through_slice_wrapper() -> VortexResult<()> {
464+
let arr = make_object_array()?;
465+
466+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
467+
468+
let expr = variant_get("a", root());
469+
let actual = arr
470+
.slice(1..3)?
471+
.apply(&expr)?
472+
.execute::<ArrayRef>(&mut ctx)?;
473+
474+
let expected = apply_variant_get(&arr, "a")?;
475+
476+
assert_eq!(actual.len(), 2);
477+
assert_eq!(actual.scalar_at(0)?, expected.scalar_at(1)?);
478+
assert_eq!(actual.scalar_at(1)?, expected.scalar_at(2)?);
479+
Ok(())
480+
}
481+
482+
#[test]
483+
fn test_variant_get_through_filter_wrapper() -> VortexResult<()> {
484+
let arr = make_object_array()?;
485+
let mask = Mask::from_iter([true, false, true]);
486+
487+
let expr = variant_get("a", root());
488+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
489+
490+
let array = arr.filter(mask.clone())?.apply(&expr)?;
491+
let actual = array.execute::<ArrayRef>(&mut ctx)?;
492+
let expected = apply_variant_get(&arr, "a")?;
493+
494+
assert_eq!(mask.true_count(), 2);
495+
assert_eq!(actual.len(), 2);
496+
assert_eq!(actual.scalar_at(0)?, expected.scalar_at(0)?);
497+
assert_eq!(actual.scalar_at(1)?, expected.scalar_at(2)?);
498+
Ok(())
499+
}
500+
461501
// ---------------------------------------------------------------------------
462502
// Test data helpers
463503
// ---------------------------------------------------------------------------

vortex-array/src/scalar_fn/fns/variant_get/mod.rs

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ use vortex_session::VortexSession;
1414

1515
use crate::ArrayRef;
1616
use crate::ExecutionCtx;
17+
use crate::arrays::Filter;
18+
use crate::arrays::Slice;
1719
use crate::arrays::Variant;
20+
use crate::arrays::filter::FilterArrayExt;
21+
use crate::arrays::scalar_fn::ScalarFnFactoryExt;
22+
use crate::arrays::slice::SliceArrayExt;
1823
use crate::arrays::variant::VariantArrayExt;
1924
use crate::dtype::DType;
2025
use crate::dtype::Nullability;
@@ -213,14 +218,17 @@ impl ScalarFnVTable for VariantGet {
213218

214219
fn execute(
215220
&self,
216-
_options: &VariantGetOptions,
217-
_args: &dyn ExecutionArgs,
218-
_ctx: &mut ExecutionCtx,
221+
options: &VariantGetOptions,
222+
args: &dyn ExecutionArgs,
223+
ctx: &mut ExecutionCtx,
219224
) -> VortexResult<ArrayRef> {
220-
vortex_bail!(
221-
"variant_get cannot be executed directly; \
222-
it must be pushed down to a variant encoding via execute_parent"
223-
)
225+
let input = args.get(0)?;
226+
227+
if let Some(rewritten) = rewrite_wrapped_variant_input(input, options, ctx)? {
228+
return Ok(rewritten);
229+
}
230+
231+
vortex_bail!("variant_get cannot be executed directly")
224232
}
225233

226234
fn reduce(
@@ -254,6 +262,48 @@ impl ScalarFnVTable for VariantGet {
254262
}
255263
}
256264

265+
fn rewrite_wrapped_variant_input(
266+
input: ArrayRef,
267+
options: &VariantGetOptions,
268+
ctx: &mut ExecutionCtx,
269+
) -> VortexResult<Option<ArrayRef>> {
270+
if let Some(slice) = input.as_opt::<Slice>() {
271+
let Some(inner) = unwrap_variant_input(slice.child()) else {
272+
return Ok(None);
273+
};
274+
let sliced = inner
275+
.slice(slice.slice_range().clone())?
276+
.execute::<ArrayRef>(ctx)?;
277+
return VariantGet
278+
.try_new_array(input.len(), options.clone(), [sliced])?
279+
.execute::<ArrayRef>(ctx)
280+
.map(Some);
281+
}
282+
283+
if let Some(filter) = input.as_opt::<Filter>() {
284+
let Some(inner) = unwrap_variant_input(filter.child()) else {
285+
return Ok(None);
286+
};
287+
let filtered = inner
288+
.filter(filter.filter_mask().clone())?
289+
.execute::<ArrayRef>(ctx)?;
290+
return VariantGet
291+
.try_new_array(input.len(), options.clone(), [filtered])?
292+
.execute::<ArrayRef>(ctx)
293+
.map(Some);
294+
}
295+
296+
Ok(None)
297+
}
298+
299+
fn unwrap_variant_input(input: &ArrayRef) -> Option<ArrayRef> {
300+
if let Some(variant) = input.as_opt::<Variant>() {
301+
return Some(variant.child().clone());
302+
}
303+
304+
matches!(input.dtype(), DType::Variant(_)).then(|| input.clone())
305+
}
306+
257307
#[cfg(test)]
258308
mod tests {
259309
use vortex_session::VortexSession;

0 commit comments

Comments
 (0)