Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions encodings/sequence/src/compute/list_contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ mod tests {
use std::sync::Arc;

use vortex_array::DynArray;
use vortex_array::IntoArray;
use vortex_array::arrays::BoolArray;
use vortex_array::assert_arrays_eq;
use vortex_array::dtype::Nullability;
Expand All @@ -77,7 +78,7 @@ mod tests {
let array = SequenceArray::try_new_typed(1, 1, Nullability::NonNullable, 3).unwrap();

let expr = list_contains(lit(list_scalar.clone()), root());
let result = array.apply(&expr).unwrap();
let result = array.into_array().apply(&expr).unwrap();
let expected = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
assert_arrays_eq!(result, expected);
}
Expand All @@ -89,7 +90,7 @@ mod tests {
let array = SequenceArray::try_new_typed(1, 2, Nullability::NonNullable, 3).unwrap();

let expr = list_contains(lit(list_scalar), root());
let result = array.apply(&expr).unwrap();
let result = array.into_array().apply(&expr).unwrap();
let expected = BoolArray::from_iter([Some(true), Some(true), Some(false)]);
assert_arrays_eq!(result, expected);
}
Expand Down
1 change: 1 addition & 0 deletions fuzz/fuzz_targets/file_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ fuzz_target!(|fuzz: FuzzFileAction| -> Corpus {

let expected_array = {
let bool_mask = array_data
.clone()
.apply(&filter_expr.clone().unwrap_or_else(|| lit(true)))
.vortex_expect("filter expression evaluation should succeed in fuzz test");
let mask = bool_mask.to_bool().to_mask_fill_null_false();
Expand Down
6 changes: 4 additions & 2 deletions vortex-array/benches/dict_compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ fn bench_compare_sliced_dict_primitive(
bencher
.with_inputs(|| (&dict, session.create_execution_ctx()))
.bench_refs(|(dict, ctx)| {
dict.apply(&eq(root(), lit(value)))
dict.clone()
.apply(&eq(root(), lit(value)))
.unwrap()
.execute::<RecursiveCanonical>(ctx)
.unwrap()
Expand All @@ -152,7 +153,8 @@ fn bench_compare_sliced_dict_varbinview(
bencher
.with_inputs(|| (&dict, session.create_execution_ctx()))
.bench_refs(|(dict, ctx)| {
dict.apply(&eq(root(), lit(value)))
dict.clone()
.apply(&eq(root(), lit(value)))
.unwrap()
.execute::<RecursiveCanonical>(ctx)
.unwrap()
Expand Down
9 changes: 9 additions & 0 deletions vortex-array/benches/expr/case_when_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ fn case_when_simple(bencher: Bencher, size: usize) {
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.clone()
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
Expand Down Expand Up @@ -101,6 +102,7 @@ fn case_when_nary_3_conditions(bencher: Bencher, size: usize) {
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.clone()
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
Expand Down Expand Up @@ -129,6 +131,7 @@ fn case_when_nary_10_conditions(bencher: Bencher, size: usize) {
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.clone()
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
Expand All @@ -152,6 +155,7 @@ fn case_when_nary_equality_lookup(bencher: Bencher, size: usize) {
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.clone()
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
Expand All @@ -172,6 +176,7 @@ fn case_when_without_else(bencher: Bencher, size: usize) {
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.clone()
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
Expand All @@ -196,6 +201,7 @@ fn case_when_all_true(bencher: Bencher, size: usize) {
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.clone()
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
Expand Down Expand Up @@ -229,6 +235,7 @@ fn case_when_nary_early_dominant(bencher: Bencher, size: usize) {
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.clone()
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
Expand All @@ -253,6 +260,7 @@ fn case_when_all_false(bencher: Bencher, size: usize) {
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.clone()
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
Expand Down Expand Up @@ -280,6 +288,7 @@ fn case_when_fragmented(bencher: Bencher, size: usize) {
.bench_refs(|(expr, array)| {
let mut ctx = SESSION.create_execution_ctx();
array
.clone()
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/arrays/datetime/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ fn test222() -> VortexResult<()> {
)),
);

let _result = temporal.as_ref().apply(&filter_expr);
let _result = temporal.into_array().apply(&filter_expr);

// let err = result.is_err().unwrap();
// println!("Expected error: {}", err);
Expand Down
6 changes: 4 additions & 2 deletions vortex-array/src/expression.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::sync::Arc;

use itertools::Itertools;
use vortex_error::VortexResult;

Expand All @@ -16,7 +18,7 @@ use crate::scalar_fn::fns::root::Root;

impl dyn DynArray + '_ {
/// Apply the expression to this array, producing a new array in constant time.
pub fn apply(&self, expr: &Expression) -> VortexResult<ArrayRef> {
pub fn apply(self: Arc<Self>, expr: &Expression) -> VortexResult<ArrayRef> {
// If the expression is a root, return self.
if expr.is::<Root>() {
return Ok(self.to_array());
Expand All @@ -31,7 +33,7 @@ impl dyn DynArray + '_ {
let children: Vec<_> = expr
.children()
.iter()
.map(|e| self.apply(e))
.map(|e| self.clone().apply(e))
.try_collect()?;

// And wrap the scalar function up in an array.
Expand Down
1 change: 1 addition & 0 deletions vortex-array/src/scalar_fn/fns/case_when.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ mod tests {
fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
let mut ctx = SESSION.create_execution_ctx();
array
.clone()
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/scalar_fn/fns/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ mod tests {
);
let input = buffer![1i32, 5, 10].into_array();

let result = input.apply(&expr)?;
let result = input.clone().apply(&expr)?;
assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));

threshold.store(10, Ordering::SeqCst);
Expand Down
18 changes: 9 additions & 9 deletions vortex-array/src/scalar_fn/fns/list_contains/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,15 +670,15 @@ mod tests {

// Test contains true
let expr = list_contains(lit(list_scalar.clone()), lit(2i32));
let result = arr.apply(&expr).unwrap();
let result = arr.clone().apply(&expr).unwrap();
assert_eq!(
result.scalar_at(0).unwrap(),
Scalar::bool(true, Nullability::NonNullable)
);

// Test contains false
let expr = list_contains(lit(list_scalar), lit(42i32));
let result = arr.apply(&expr).unwrap();
let result = arr.clone().apply(&expr).unwrap();
assert_eq!(
result.scalar_at(0).unwrap(),
Scalar::bool(false, Nullability::NonNullable)
Expand Down Expand Up @@ -856,7 +856,7 @@ mod tests {
};

let expr = list_contains(root(), lit(42i32));
let result = list_array.apply(&expr).unwrap();
let result = list_array.into_array().apply(&expr).unwrap();

let expected = BoolArray::from_iter([false, false, false, false]);
assert_arrays_eq!(result, expected);
Expand All @@ -881,7 +881,7 @@ mod tests {
// Searching for null
let null_scalar = Scalar::null(DType::Primitive(I32, Nullability::Nullable));
let expr = list_contains(root(), lit(null_scalar));
let result = list_array.apply(&expr).unwrap();
let result = list_array.clone().into_array().apply(&expr).unwrap();

let expected = BoolArray::new(
[false, false, false].into_iter().collect(),
Expand All @@ -891,7 +891,7 @@ mod tests {

// Searching for non-null
let expr2 = list_contains(root(), lit(42i32));
let result2 = list_array.apply(&expr2).unwrap();
let result2 = list_array.into_array().apply(&expr2).unwrap();

let expected2 = BoolArray::from_iter([false, false, false]);
assert_arrays_eq!(result2, expected2);
Expand All @@ -908,13 +908,13 @@ mod tests {
ListViewArray::new(elements.into_array(), offsets, sizes, Validity::NonNullable);

let expr = list_contains(root(), lit(2i32));
let result = list_array.apply(&expr).unwrap();
let result = list_array.clone().into_array().apply(&expr).unwrap();

let expected = BoolArray::from_iter([false, true, false, false]);
assert_arrays_eq!(result, expected);

let expr5 = list_contains(root(), lit(5i32));
let result5 = list_array.apply(&expr5).unwrap();
let result5 = list_array.into_array().apply(&expr5).unwrap();

let expected5 = BoolArray::from_iter([false, false, true, false]);
assert_arrays_eq!(result5, expected5);
Expand All @@ -930,13 +930,13 @@ mod tests {
ListViewArray::new(elements.into_array(), offsets, sizes, Validity::NonNullable);

let expr = list_contains(root(), lit(255i32));
let result = list_array.apply(&expr).unwrap();
let result = list_array.clone().into_array().apply(&expr).unwrap();

let expected = BoolArray::from_iter([false, false, false, true]);
assert_arrays_eq!(result, expected);

let expr_zero = list_contains(root(), lit(0i32));
let result_zero = list_array.apply(&expr_zero).unwrap();
let result_zero = list_array.into_array().apply(&expr_zero).unwrap();

let expected_zero = BoolArray::from_iter([true, false, false, false]);
assert_arrays_eq!(result_zero, expected_zero);
Expand Down
2 changes: 1 addition & 1 deletion vortex-ffi/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ pub unsafe extern "C" fn vx_array_apply(
vortex_ensure!(!expression.is_null());
let array = vx_array::as_ref(array);
let expression = vx_expression::as_ref(expression);
Ok(vx_array::new(Arc::new(array.apply(expression)?)))
Ok(vx_array::new(array.clone().apply(expression)?))
})
}

Expand Down
11 changes: 5 additions & 6 deletions vortex-ffi/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ pub unsafe extern "C" fn vx_expression_list_contains(
#[cfg(test)]
mod tests {
use std::ptr;
use std::sync::Arc;

use vortex::array::IntoArray;
use vortex::array::ToCanonical;
Expand Down Expand Up @@ -355,7 +354,7 @@ mod tests {
let column = vx_expression_get_item(c"age".as_ptr(), root);
assert_ne!(column, ptr::null_mut());

let array = vx_array::new(Arc::new(array.into_array()));
let array = vx_array::new(array.into_array());
let mut error = ptr::null_mut();

let applied_array = vx_array_apply(array, column, &raw mut error);
Expand All @@ -378,7 +377,7 @@ mod tests {
assert!(!error.is_null());
vx_error_free(error);

let names_array_vx = vx_array::new(Arc::new(names_array.into_array()));
let names_array_vx = vx_array::new(names_array.into_array());
let applied_array = vx_array_apply(names_array_vx, column, &raw mut error);
assert!(applied_array.is_null());
assert!(!error.is_null());
Expand All @@ -399,7 +398,7 @@ mod tests {
unsafe {
let root = vx_expression_root();

let array = vx_array::new(Arc::new(array.into_array()));
let array = vx_array::new(array.into_array());

let columns = [c"name".as_ptr(), c"age".as_ptr()];
let column = vx_expression_select(columns.as_ptr(), 2, root);
Expand Down Expand Up @@ -441,7 +440,7 @@ mod tests {
let array = StructArray::try_new(names, fields, 4, Validity::NonNullable);

unsafe {
let array = vx_array::new(Arc::new(array.unwrap().into_array()));
let array = vx_array::new(array.unwrap().into_array());

let root = vx_expression_root();
let expression_col1 = vx_expression_get_item(c"col1".as_ptr(), root);
Expand Down Expand Up @@ -524,7 +523,7 @@ mod tests {

unsafe {
let root = vx_expression_root();
let array = vx_array::new(Arc::new(array.into_array()));
let array = vx_array::new(array.into_array());
let expression_value = vx_expression::new(Box::new(lit(1)));

let expression = vx_expression_list_contains(root, expression_value);
Expand Down
Loading