Skip to content

Commit b10b4f4

Browse files
break: zip arg (mask, true, false) (#6766)
The zip array is usually displayed as mask(cond, true, false). This PR does the same for the public API ## break zip is now mask.zip(true, false) --------- Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 6a59b88 commit b10b4f4

File tree

8 files changed

+101
-71
lines changed

8 files changed

+101
-71
lines changed

vortex-array/benches/varbinview_zip.rs

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
use divan::Bencher;
77
use vortex_array::IntoArray;
8+
use vortex_array::LEGACY_SESSION;
9+
use vortex_array::RecursiveCanonical;
10+
use vortex_array::VortexSessionExecute;
811
use vortex_array::arrays::VarBinViewArray;
912
use vortex_array::builtins::ArrayBuiltins;
1013
use vortex_array::dtype::DType;
@@ -24,9 +27,19 @@ fn varbinview_zip_fragmented_mask(bencher: Bencher) {
2427
let mask = alternating_mask(len);
2528

2629
bencher
27-
.with_inputs(|| (&if_true, &if_false, &mask))
28-
.bench_refs(|(t, f, m)| {
29-
t.zip(f.clone(), m.clone().into_array()).unwrap();
30+
.with_inputs(|| {
31+
(
32+
if_true.clone(),
33+
if_false.clone(),
34+
mask.clone().into_array(),
35+
LEGACY_SESSION.create_execution_ctx(),
36+
)
37+
})
38+
.bench_refs(|(t, f, m, ctx)| {
39+
m.zip(t.clone(), f.clone())
40+
.unwrap()
41+
.execute::<RecursiveCanonical>(ctx)
42+
.unwrap();
3043
});
3144
}
3245

@@ -39,9 +52,19 @@ fn varbinview_zip_block_mask(bencher: Bencher) {
3952
let mask = block_mask(len, 128);
4053

4154
bencher
42-
.with_inputs(|| (&if_true, &if_false, &mask))
43-
.bench_refs(|(t, f, m)| {
44-
t.zip(f.clone(), m.clone().into_array()).unwrap();
55+
.with_inputs(|| {
56+
(
57+
if_true.clone(),
58+
if_false.clone(),
59+
mask.clone().into_array(),
60+
LEGACY_SESSION.create_execution_ctx(),
61+
)
62+
})
63+
.bench_refs(|(t, f, m, ctx)| {
64+
m.zip(t.clone(), f.clone())
65+
.unwrap()
66+
.execute::<RecursiveCanonical>(ctx)
67+
.unwrap();
4568
});
4669
}
4770

vortex-array/public-api.lock

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5614,7 +5614,7 @@ pub fn vortex_array::builtins::ArrayBuiltins::mask(self, mask: vortex_array::Arr
56145614

56155615
pub fn vortex_array::builtins::ArrayBuiltins::not(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
56165616

5617-
pub fn vortex_array::builtins::ArrayBuiltins::zip(&self, if_false: vortex_array::ArrayRef, mask: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
5617+
pub fn vortex_array::builtins::ArrayBuiltins::zip(&self, if_true: vortex_array::ArrayRef, if_false: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
56185618

56195619
impl vortex_array::builtins::ArrayBuiltins for vortex_array::ArrayRef
56205620

@@ -5636,7 +5636,7 @@ pub fn vortex_array::ArrayRef::mask(self, mask: vortex_array::ArrayRef) -> vorte
56365636

56375637
pub fn vortex_array::ArrayRef::not(&self) -> vortex_error::VortexResult<vortex_array::ArrayRef>
56385638

5639-
pub fn vortex_array::ArrayRef::zip(&self, if_false: vortex_array::ArrayRef, mask: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
5639+
pub fn vortex_array::ArrayRef::zip(&self, if_true: vortex_array::ArrayRef, if_false: vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
56405640

56415641
pub trait vortex_array::builtins::ExprBuiltins: core::marker::Sized
56425642

@@ -5656,7 +5656,7 @@ pub fn vortex_array::builtins::ExprBuiltins::mask(&self, mask: vortex_array::exp
56565656

56575657
pub fn vortex_array::builtins::ExprBuiltins::not(&self) -> vortex_error::VortexResult<vortex_array::expr::Expression>
56585658

5659-
pub fn vortex_array::builtins::ExprBuiltins::zip(&self, if_false: vortex_array::expr::Expression, mask: vortex_array::expr::Expression) -> vortex_error::VortexResult<vortex_array::expr::Expression>
5659+
pub fn vortex_array::builtins::ExprBuiltins::zip(&self, if_true: vortex_array::expr::Expression, if_false: vortex_array::expr::Expression) -> vortex_error::VortexResult<vortex_array::expr::Expression>
56605660

56615661
impl vortex_array::builtins::ExprBuiltins for vortex_array::expr::Expression
56625662

@@ -5676,7 +5676,7 @@ pub fn vortex_array::expr::Expression::mask(&self, mask: vortex_array::expr::Exp
56765676

56775677
pub fn vortex_array::expr::Expression::not(&self) -> vortex_error::VortexResult<vortex_array::expr::Expression>
56785678

5679-
pub fn vortex_array::expr::Expression::zip(&self, if_false: vortex_array::expr::Expression, mask: vortex_array::expr::Expression) -> vortex_error::VortexResult<vortex_array::expr::Expression>
5679+
pub fn vortex_array::expr::Expression::zip(&self, if_true: vortex_array::expr::Expression, if_false: vortex_array::expr::Expression) -> vortex_error::VortexResult<vortex_array::expr::Expression>
56805680

56815681
pub mod vortex_array::compute
56825682

@@ -10044,7 +10044,7 @@ pub fn vortex_array::expr::Expression::mask(&self, mask: vortex_array::expr::Exp
1004410044

1004510045
pub fn vortex_array::expr::Expression::not(&self) -> vortex_error::VortexResult<vortex_array::expr::Expression>
1004610046

10047-
pub fn vortex_array::expr::Expression::zip(&self, if_false: vortex_array::expr::Expression, mask: vortex_array::expr::Expression) -> vortex_error::VortexResult<vortex_array::expr::Expression>
10047+
pub fn vortex_array::expr::Expression::zip(&self, if_true: vortex_array::expr::Expression, if_false: vortex_array::expr::Expression) -> vortex_error::VortexResult<vortex_array::expr::Expression>
1004810048

1004910049
impl vortex_array::expr::VortexExprExt for vortex_array::expr::Expression
1005010050

@@ -10176,7 +10176,7 @@ pub fn vortex_array::expr::select_exclude(fields: impl core::convert::Into<vorte
1017610176

1017710177
pub fn vortex_array::expr::split_conjunction(expr: &vortex_array::expr::Expression) -> alloc::vec::Vec<vortex_array::expr::Expression>
1017810178

10179-
pub fn vortex_array::expr::zip_expr(if_true: vortex_array::expr::Expression, if_false: vortex_array::expr::Expression, mask: vortex_array::expr::Expression) -> vortex_array::expr::Expression
10179+
pub fn vortex_array::expr::zip_expr(mask: vortex_array::expr::Expression, if_true: vortex_array::expr::Expression, if_false: vortex_array::expr::Expression) -> vortex_array::expr::Expression
1018010180

1018110181
pub type vortex_array::expr::Annotations<'a, A> = vortex_utils::aliases::hash_map::HashMap<&'a vortex_array::expr::Expression, vortex_utils::aliases::hash_set::HashSet<A>>
1018210182

vortex-array/src/arrays/chunked/compute/zip.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ impl ZipKernel for ChunkedVTable {
4848
let lhs_slice = lhs_chunk.slice(lhs_offset..lhs_offset + take_until)?;
4949
let rhs_slice = rhs_chunk.slice(rhs_offset..rhs_offset + take_until)?;
5050

51-
out_chunks.push(lhs_slice.zip(rhs_slice, mask_slice)?);
51+
out_chunks.push(mask_slice.zip(lhs_slice, rhs_slice)?);
5252

5353
pos += take_until;
5454
lhs_offset += take_until;
@@ -75,8 +75,11 @@ mod tests {
7575
use vortex_buffer::buffer;
7676
use vortex_mask::Mask;
7777

78+
use crate::ArrayRef;
7879
use crate::IntoArray;
80+
use crate::LEGACY_SESSION;
7981
use crate::ToCanonical;
82+
use crate::VortexSessionExecute;
8083
use crate::arrays::ChunkedArray;
8184
use crate::arrays::ChunkedVTable;
8285
use crate::builtins::ArrayBuiltins;
@@ -108,9 +111,14 @@ mod tests {
108111

109112
let mask = Mask::from_iter([true, false, true, false, true]);
110113

111-
let zipped = &if_true
112-
.to_array()
113-
.zip(if_false.to_array(), mask.into_array())
114+
let zipped = &mask
115+
.into_array()
116+
.zip(if_true.to_array(), if_false.to_array())
117+
.unwrap();
118+
// One step of execution will push down the zip.
119+
let zipped = zipped
120+
.clone()
121+
.execute::<ArrayRef>(&mut LEGACY_SESSION.create_execution_ctx())
114122
.unwrap();
115123
let zipped = zipped
116124
.as_opt::<ChunkedVTable>()

vortex-array/src/arrays/struct_/compute/zip.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ impl ZipKernel for StructVTable {
3636
.unmasked_fields()
3737
.iter()
3838
.zip(if_false.unmasked_fields().iter())
39-
.map(|(t, f)| ArrayBuiltins::zip(t, f.clone(), mask.clone()))
39+
.map(|(t, f)| ArrayBuiltins::zip(mask, t.clone(), f.clone()))
4040
.collect::<VortexResult<Vec<_>>>()?;
4141

4242
let validity = match (if_true.validity(), if_false.validity()) {
@@ -96,9 +96,9 @@ mod tests {
9696

9797
let mask = Mask::from_iter([false, false, true, false]);
9898

99-
let result = if_true
100-
.clone()
101-
.zip(if_false.clone(), mask.into_array())
99+
let result = mask
100+
.into_array()
101+
.zip(if_true.clone(), if_false.clone())
102102
.unwrap();
103103

104104
insta::assert_snapshot!(result.display_table(), @r"
@@ -136,9 +136,9 @@ mod tests {
136136

137137
let mask = Mask::from_iter([true, false, false, false]);
138138

139-
let result = if_true
140-
.clone()
141-
.zip(if_false.clone(), mask.into_array())
139+
let result = mask
140+
.into_array()
141+
.zip(if_true.clone(), if_false.clone())
142142
.unwrap();
143143

144144
insta::assert_snapshot!(result.display_table(), @r"

vortex-array/src/arrays/varbinview/compute/zip.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,10 @@ mod tests {
243243

244244
let mask = Mask::from_iter([true, false, true, false, false, true]);
245245

246-
let zipped = a
247-
.to_array()
248-
.zip(b.to_array(), mask.clone().into_array())
246+
let zipped = mask
247+
.clone()
248+
.into_array()
249+
.zip(a.to_array(), b.to_array())
249250
.unwrap()
250251
.to_varbinview();
251252

vortex-array/src/builtins.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@
1010
//! the equivalent Arrow compute function.
1111
1212
use vortex_error::VortexResult;
13-
use vortex_session::VortexSession;
1413

1514
use crate::Array;
1615
use crate::ArrayRef;
17-
use crate::ExecutionCtx;
1816
use crate::IntoArray;
1917
use crate::arrays::ConstantArray;
2018
use crate::arrays::ScalarFnArrayExt;
@@ -63,8 +61,8 @@ pub trait ExprBuiltins: Sized {
6361
/// Check if a list contains a value.
6462
fn list_contains(&self, value: Expression) -> VortexResult<Expression>;
6563

66-
/// Conditional selection: `result[i] = if mask[i] then self[i] else if_false[i]`.
67-
fn zip(&self, if_false: Expression, mask: Expression) -> VortexResult<Expression>;
64+
/// Conditional selection: `result[i] = if mask[i] then if_true[i] else if_false[i]`.
65+
fn zip(&self, if_true: Expression, if_false: Expression) -> VortexResult<Expression>;
6866

6967
/// Apply a binary operator to this expression and another.
7068
fn binary(&self, rhs: Expression, op: Operator) -> VortexResult<Expression>;
@@ -99,8 +97,8 @@ impl ExprBuiltins for Expression {
9997
ListContains.try_new_expr(EmptyOptions, [self.clone(), value])
10098
}
10199

102-
fn zip(&self, if_false: Expression, mask: Expression) -> VortexResult<Expression> {
103-
Zip.try_new_expr(EmptyOptions, [self.clone(), if_false, mask])
100+
fn zip(&self, if_true: Expression, if_false: Expression) -> VortexResult<Expression> {
101+
Zip.try_new_expr(EmptyOptions, [if_true, if_false, self.clone()])
104102
}
105103

106104
fn binary(&self, rhs: Expression, op: Operator) -> VortexResult<Expression> {
@@ -129,8 +127,8 @@ pub trait ArrayBuiltins: Sized {
129127
/// Boolean negation.
130128
fn not(&self) -> VortexResult<ArrayRef>;
131129

132-
/// Conditional selection: `result[i] = if mask[i] then self[i] else if_false[i]`.
133-
fn zip(&self, if_false: ArrayRef, mask: ArrayRef) -> VortexResult<ArrayRef>;
130+
/// Conditional selection: `result[i] = if mask[i] then if_true[i] else if_false[i]`.
131+
fn zip(&self, if_true: ArrayRef, if_false: ArrayRef) -> VortexResult<ArrayRef>;
134132

135133
/// Check if a list contains a value.
136134
fn list_contains(&self, value: ArrayRef) -> VortexResult<ArrayRef>;
@@ -195,11 +193,8 @@ impl ArrayBuiltins for ArrayRef {
195193
.optimize()
196194
}
197195

198-
fn zip(&self, if_false: ArrayRef, mask: ArrayRef) -> VortexResult<ArrayRef> {
199-
let scalar_fn =
200-
Zip.try_new_array(self.len(), EmptyOptions, [self.clone(), if_false, mask])?;
201-
let mut ctx = ExecutionCtx::new(VortexSession::empty());
202-
scalar_fn.execute::<ArrayRef>(&mut ctx)
196+
fn zip(&self, if_true: ArrayRef, if_false: ArrayRef) -> VortexResult<ArrayRef> {
197+
Zip.try_new_array(self.len(), EmptyOptions, [if_true, if_false, self.clone()])
203198
}
204199

205200
fn list_contains(&self, value: ArrayRef) -> VortexResult<ArrayRef> {

vortex-array/src/expr/exprs.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -577,9 +577,9 @@ pub fn merge_opts(
577577
///
578578
/// ```rust
579579
/// # use vortex_array::expr::{zip_expr, root, lit};
580-
/// let expr = zip_expr(root(), lit(0i32), lit(true));
580+
/// let expr = zip_expr(lit(true), root(), lit(0i32));
581581
/// ```
582-
pub fn zip_expr(if_true: Expression, if_false: Expression, mask: Expression) -> Expression {
582+
pub fn zip_expr(mask: Expression, if_true: Expression, if_false: Expression) -> Expression {
583583
Zip.new_expr(EmptyOptions, [if_true, if_false, mask])
584584
}
585585

0 commit comments

Comments
 (0)