Skip to content

Commit 37b7625

Browse files
break: apply should use apply(self: Arc<Self>, ...) over apply(&self, ...) (#7259)
This change breaks the array.apply(expr) method however making it take the ArrayRef instead. This fixes a huge perf regression in apply that clones the inner Array not the ArrayRef --------- Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 13937c0 commit 37b7625

11 files changed

Lines changed: 39 additions & 24 deletions

File tree

encodings/sequence/src/compute/list_contains.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ mod tests {
5151
use std::sync::Arc;
5252

5353
use vortex_array::DynArray;
54+
use vortex_array::IntoArray;
5455
use vortex_array::arrays::BoolArray;
5556
use vortex_array::assert_arrays_eq;
5657
use vortex_array::dtype::Nullability;
@@ -77,7 +78,7 @@ mod tests {
7778
let array = SequenceArray::try_new_typed(1, 1, Nullability::NonNullable, 3).unwrap();
7879

7980
let expr = list_contains(lit(list_scalar.clone()), root());
80-
let result = array.apply(&expr).unwrap();
81+
let result = array.into_array().apply(&expr).unwrap();
8182
let expected = BoolArray::from_iter([Some(true), Some(false), Some(true)]);
8283
assert_arrays_eq!(result, expected);
8384
}
@@ -89,7 +90,7 @@ mod tests {
8990
let array = SequenceArray::try_new_typed(1, 2, Nullability::NonNullable, 3).unwrap();
9091

9192
let expr = list_contains(lit(list_scalar), root());
92-
let result = array.apply(&expr).unwrap();
93+
let result = array.into_array().apply(&expr).unwrap();
9394
let expected = BoolArray::from_iter([Some(true), Some(true), Some(false)]);
9495
assert_arrays_eq!(result, expected);
9596
}

fuzz/fuzz_targets/file_io.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ fuzz_target!(|fuzz: FuzzFileAction| -> Corpus {
4646

4747
let expected_array = {
4848
let bool_mask = array_data
49+
.clone()
4950
.apply(&filter_expr.clone().unwrap_or_else(|| lit(true)))
5051
.vortex_expect("filter expression evaluation should succeed in fuzz test");
5152
let mask = bool_mask.to_bool().to_mask_fill_null_false();

vortex-array/benches/dict_compare.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ fn bench_compare_sliced_dict_primitive(
130130
bencher
131131
.with_inputs(|| (&dict, session.create_execution_ctx()))
132132
.bench_refs(|(dict, ctx)| {
133-
dict.apply(&eq(root(), lit(value)))
133+
dict.clone()
134+
.apply(&eq(root(), lit(value)))
134135
.unwrap()
135136
.execute::<RecursiveCanonical>(ctx)
136137
.unwrap()
@@ -152,7 +153,8 @@ fn bench_compare_sliced_dict_varbinview(
152153
bencher
153154
.with_inputs(|| (&dict, session.create_execution_ctx()))
154155
.bench_refs(|(dict, ctx)| {
155-
dict.apply(&eq(root(), lit(value)))
156+
dict.clone()
157+
.apply(&eq(root(), lit(value)))
156158
.unwrap()
157159
.execute::<RecursiveCanonical>(ctx)
158160
.unwrap()

vortex-array/benches/expr/case_when_bench.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ fn case_when_simple(bencher: Bencher, size: usize) {
7474
.bench_refs(|(expr, array)| {
7575
let mut ctx = SESSION.create_execution_ctx();
7676
array
77+
.clone()
7778
.apply(expr)
7879
.unwrap()
7980
.execute::<Canonical>(&mut ctx)
@@ -101,6 +102,7 @@ fn case_when_nary_3_conditions(bencher: Bencher, size: usize) {
101102
.bench_refs(|(expr, array)| {
102103
let mut ctx = SESSION.create_execution_ctx();
103104
array
105+
.clone()
104106
.apply(expr)
105107
.unwrap()
106108
.execute::<Canonical>(&mut ctx)
@@ -129,6 +131,7 @@ fn case_when_nary_10_conditions(bencher: Bencher, size: usize) {
129131
.bench_refs(|(expr, array)| {
130132
let mut ctx = SESSION.create_execution_ctx();
131133
array
134+
.clone()
132135
.apply(expr)
133136
.unwrap()
134137
.execute::<Canonical>(&mut ctx)
@@ -152,6 +155,7 @@ fn case_when_nary_equality_lookup(bencher: Bencher, size: usize) {
152155
.bench_refs(|(expr, array)| {
153156
let mut ctx = SESSION.create_execution_ctx();
154157
array
158+
.clone()
155159
.apply(expr)
156160
.unwrap()
157161
.execute::<Canonical>(&mut ctx)
@@ -172,6 +176,7 @@ fn case_when_without_else(bencher: Bencher, size: usize) {
172176
.bench_refs(|(expr, array)| {
173177
let mut ctx = SESSION.create_execution_ctx();
174178
array
179+
.clone()
175180
.apply(expr)
176181
.unwrap()
177182
.execute::<Canonical>(&mut ctx)
@@ -196,6 +201,7 @@ fn case_when_all_true(bencher: Bencher, size: usize) {
196201
.bench_refs(|(expr, array)| {
197202
let mut ctx = SESSION.create_execution_ctx();
198203
array
204+
.clone()
199205
.apply(expr)
200206
.unwrap()
201207
.execute::<Canonical>(&mut ctx)
@@ -229,6 +235,7 @@ fn case_when_nary_early_dominant(bencher: Bencher, size: usize) {
229235
.bench_refs(|(expr, array)| {
230236
let mut ctx = SESSION.create_execution_ctx();
231237
array
238+
.clone()
232239
.apply(expr)
233240
.unwrap()
234241
.execute::<Canonical>(&mut ctx)
@@ -253,6 +260,7 @@ fn case_when_all_false(bencher: Bencher, size: usize) {
253260
.bench_refs(|(expr, array)| {
254261
let mut ctx = SESSION.create_execution_ctx();
255262
array
263+
.clone()
256264
.apply(expr)
257265
.unwrap()
258266
.execute::<Canonical>(&mut ctx)
@@ -280,6 +288,7 @@ fn case_when_fragmented(bencher: Bencher, size: usize) {
280288
.bench_refs(|(expr, array)| {
281289
let mut ctx = SESSION.create_execution_ctx();
282290
array
291+
.clone()
283292
.apply(expr)
284293
.unwrap()
285294
.execute::<Canonical>(&mut ctx)

vortex-array/src/arrays/datetime/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ fn test222() -> VortexResult<()> {
224224
)),
225225
);
226226

227-
let _result = temporal.as_ref().apply(&filter_expr);
227+
let _result = temporal.into_array().apply(&filter_expr);
228228

229229
// let err = result.is_err().unwrap();
230230
// println!("Expected error: {}", err);

vortex-array/src/expression.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::sync::Arc;
5+
46
use itertools::Itertools;
57
use vortex_error::VortexResult;
68

@@ -16,7 +18,7 @@ use crate::scalar_fn::fns::root::Root;
1618

1719
impl dyn DynArray + '_ {
1820
/// Apply the expression to this array, producing a new array in constant time.
19-
pub fn apply(&self, expr: &Expression) -> VortexResult<ArrayRef> {
21+
pub fn apply(self: Arc<Self>, expr: &Expression) -> VortexResult<ArrayRef> {
2022
// If the expression is a root, return self.
2123
if expr.is::<Root>() {
2224
return Ok(self.to_array());
@@ -31,7 +33,7 @@ impl dyn DynArray + '_ {
3133
let children: Vec<_> = expr
3234
.children()
3335
.iter()
34-
.map(|e| self.apply(e))
36+
.map(|e| self.clone().apply(e))
3537
.try_collect()?;
3638

3739
// And wrap the scalar function up in an array.

vortex-array/src/scalar_fn/fns/case_when.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ mod tests {
406406
fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
407407
let mut ctx = SESSION.create_execution_ctx();
408408
array
409+
.clone()
409410
.apply(expr)
410411
.unwrap()
411412
.execute::<Canonical>(&mut ctx)

vortex-array/src/scalar_fn/fns/dynamic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ mod tests {
403403
);
404404
let input = buffer![1i32, 5, 10].into_array();
405405

406-
let result = input.apply(&expr)?;
406+
let result = input.clone().apply(&expr)?;
407407
assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
408408

409409
threshold.store(10, Ordering::SeqCst);

vortex-array/src/scalar_fn/fns/list_contains/mod.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -670,15 +670,15 @@ mod tests {
670670

671671
// Test contains true
672672
let expr = list_contains(lit(list_scalar.clone()), lit(2i32));
673-
let result = arr.apply(&expr).unwrap();
673+
let result = arr.clone().apply(&expr).unwrap();
674674
assert_eq!(
675675
result.scalar_at(0).unwrap(),
676676
Scalar::bool(true, Nullability::NonNullable)
677677
);
678678

679679
// Test contains false
680680
let expr = list_contains(lit(list_scalar), lit(42i32));
681-
let result = arr.apply(&expr).unwrap();
681+
let result = arr.clone().apply(&expr).unwrap();
682682
assert_eq!(
683683
result.scalar_at(0).unwrap(),
684684
Scalar::bool(false, Nullability::NonNullable)
@@ -856,7 +856,7 @@ mod tests {
856856
};
857857

858858
let expr = list_contains(root(), lit(42i32));
859-
let result = list_array.apply(&expr).unwrap();
859+
let result = list_array.into_array().apply(&expr).unwrap();
860860

861861
let expected = BoolArray::from_iter([false, false, false, false]);
862862
assert_arrays_eq!(result, expected);
@@ -881,7 +881,7 @@ mod tests {
881881
// Searching for null
882882
let null_scalar = Scalar::null(DType::Primitive(I32, Nullability::Nullable));
883883
let expr = list_contains(root(), lit(null_scalar));
884-
let result = list_array.apply(&expr).unwrap();
884+
let result = list_array.clone().into_array().apply(&expr).unwrap();
885885

886886
let expected = BoolArray::new(
887887
[false, false, false].into_iter().collect(),
@@ -891,7 +891,7 @@ mod tests {
891891

892892
// Searching for non-null
893893
let expr2 = list_contains(root(), lit(42i32));
894-
let result2 = list_array.apply(&expr2).unwrap();
894+
let result2 = list_array.into_array().apply(&expr2).unwrap();
895895

896896
let expected2 = BoolArray::from_iter([false, false, false]);
897897
assert_arrays_eq!(result2, expected2);
@@ -908,13 +908,13 @@ mod tests {
908908
ListViewArray::new(elements.into_array(), offsets, sizes, Validity::NonNullable);
909909

910910
let expr = list_contains(root(), lit(2i32));
911-
let result = list_array.apply(&expr).unwrap();
911+
let result = list_array.clone().into_array().apply(&expr).unwrap();
912912

913913
let expected = BoolArray::from_iter([false, true, false, false]);
914914
assert_arrays_eq!(result, expected);
915915

916916
let expr5 = list_contains(root(), lit(5i32));
917-
let result5 = list_array.apply(&expr5).unwrap();
917+
let result5 = list_array.into_array().apply(&expr5).unwrap();
918918

919919
let expected5 = BoolArray::from_iter([false, false, true, false]);
920920
assert_arrays_eq!(result5, expected5);
@@ -930,13 +930,13 @@ mod tests {
930930
ListViewArray::new(elements.into_array(), offsets, sizes, Validity::NonNullable);
931931

932932
let expr = list_contains(root(), lit(255i32));
933-
let result = list_array.apply(&expr).unwrap();
933+
let result = list_array.clone().into_array().apply(&expr).unwrap();
934934

935935
let expected = BoolArray::from_iter([false, false, false, true]);
936936
assert_arrays_eq!(result, expected);
937937

938938
let expr_zero = list_contains(root(), lit(0i32));
939-
let result_zero = list_array.apply(&expr_zero).unwrap();
939+
let result_zero = list_array.into_array().apply(&expr_zero).unwrap();
940940

941941
let expected_zero = BoolArray::from_iter([true, false, false, false]);
942942
assert_arrays_eq!(result_zero, expected_zero);

vortex-ffi/src/array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ pub unsafe extern "C" fn vx_array_apply(
414414
vortex_ensure!(!expression.is_null());
415415
let array = vx_array::as_ref(array);
416416
let expression = vx_expression::as_ref(expression);
417-
Ok(vx_array::new(Arc::new(array.apply(expression)?)))
417+
Ok(vx_array::new(array.clone().apply(expression)?))
418418
})
419419
}
420420

0 commit comments

Comments
 (0)