Skip to content

Commit 19a1fb3

Browse files
Propagate ExecutionCtx through CASE WHEN binary merges (#8040)
Signed-off-by: "Dimitar Dimitrov" <dimitar@spiraldb.com> Signed-off-by: Dimitar Dimitrov <dimitar@spiraldb.com>
1 parent 971aa1c commit 19a1fb3

2 files changed

Lines changed: 48 additions & 19 deletions

File tree

vortex-array/src/scalar_fn/fns/case_when.rs

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ impl ScalarFnVTable for CaseWhen {
258258
}
259259
}
260260

261-
/// Average run length at which slicing + `extend_from_array` becomes cheaper than `scalar_at`.
261+
/// Average run length at which slicing + context-aware builder appends become cheaper than `scalar_at`.
262262
/// Measured empirically via benchmarks.
263263
const SLICE_CROSSOVER_RUN_LEN: usize = 4;
264264

@@ -272,7 +272,7 @@ fn merge_case_branches(
272272
) -> VortexResult<ArrayRef> {
273273
if branches.len() == 1 {
274274
let (mask, then_value) = &branches[0];
275-
return zip_impl(then_value, &else_value, mask);
275+
return zip_impl(then_value, &else_value, mask, ctx);
276276
}
277277

278278
let output_nullability = branches
@@ -314,7 +314,14 @@ fn merge_case_branches(
314314
ctx,
315315
)
316316
} else {
317-
merge_run_by_run(&branch_arrays, &else_value, &spans, &output_dtype, builder)
317+
merge_run_by_run(
318+
&branch_arrays,
319+
&else_value,
320+
&spans,
321+
&output_dtype,
322+
builder,
323+
ctx,
324+
)
318325
}
319326
}
320327

@@ -348,7 +355,7 @@ fn merge_row_by_row(
348355
Ok(builder.finish())
349356
}
350357

351-
/// Bulk-copies each span via `slice()` + `extend_from_array`.
358+
/// Bulk-copies each span via `slice()` and context-aware builder appends.
352359
/// Preferred when runs are long enough that memcpy dominates over per-slice allocation cost.
353360
/// Lazy cast via `arr.cast(output_dtype)` is executed once per span as a block.
354361
fn merge_run_by_run(
@@ -357,21 +364,25 @@ fn merge_run_by_run(
357364
spans: &[(usize, usize, usize)],
358365
output_dtype: &DType,
359366
mut builder: Box<dyn ArrayBuilder>,
367+
ctx: &mut ExecutionCtx,
360368
) -> VortexResult<ArrayRef> {
361369
let else_value = else_value.cast(output_dtype.clone())?;
362370
let len = else_value.len();
363371
for (start, end, branch_idx) in spans {
364372
if builder.len() < *start {
365-
builder.extend_from_array(&else_value.slice(builder.len()..*start)?);
373+
else_value
374+
.slice(builder.len()..*start)?
375+
.append_to_builder(builder.as_mut(), ctx)?;
366376
}
367-
builder.extend_from_array(
368-
&branch_arrays[*branch_idx]
369-
.cast(output_dtype.clone())?
370-
.slice(*start..*end)?,
371-
);
377+
branch_arrays[*branch_idx]
378+
.cast(output_dtype.clone())?
379+
.slice(*start..*end)?
380+
.append_to_builder(builder.as_mut(), ctx)?;
372381
}
373382
if builder.len() < len {
374-
builder.extend_from_array(&else_value.slice(builder.len()..len)?);
383+
else_value
384+
.slice(builder.len()..len)?
385+
.append_to_builder(builder.as_mut(), ctx)?;
375386
}
376387

377388
Ok(builder.finish())

vortex-array/src/scalar_fn/fns/zip/mod.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ impl ScalarFnVTable for Zip {
139139
return mask.into_array().zip(if_true, if_false);
140140
}
141141

142-
zip_impl(&if_true, &if_false, &mask)
142+
zip_impl(&if_true, &if_false, &mask, ctx)
143143
}
144144

145145
fn simplify(
@@ -176,6 +176,7 @@ pub(crate) fn zip_impl(
176176
if_true: &ArrayRef,
177177
if_false: &ArrayRef,
178178
mask: &Mask,
179+
ctx: &mut ExecutionCtx,
179180
) -> VortexResult<ArrayRef> {
180181
assert_eq!(
181182
if_true.len(),
@@ -195,12 +196,18 @@ pub(crate) fn zip_impl(
195196
return if_false.cast(return_type);
196197
}
197198

199+
// `append_to_builder` requires exact dtype equality, so normalize branch
200+
// nullability to the output dtype before appending slices into the builder.
201+
let if_true = if_true.cast(return_type.clone())?;
202+
let if_false = if_false.cast(return_type.clone())?;
203+
198204
zip_impl_with_builder(
199-
if_true,
200-
if_false,
205+
&if_true,
206+
&if_false,
201207
mask.values()
202208
.vortex_expect("zip_impl_with_builder: mask is not all-true or all-false"),
203209
builder_with_capacity(&return_type, if_true.len()),
210+
ctx,
204211
)
205212
}
206213

@@ -209,13 +216,22 @@ fn zip_impl_with_builder(
209216
if_false: &ArrayRef,
210217
mask: &MaskValues,
211218
mut builder: Box<dyn ArrayBuilder>,
219+
ctx: &mut ExecutionCtx,
212220
) -> VortexResult<ArrayRef> {
213221
for (start, end) in mask.slices() {
214-
builder.extend_from_array(&if_false.slice(builder.len()..*start)?);
215-
builder.extend_from_array(&if_true.slice(*start..*end)?);
222+
if builder.len() < *start {
223+
if_false
224+
.slice(builder.len()..*start)?
225+
.append_to_builder(builder.as_mut(), ctx)?;
226+
}
227+
if_true
228+
.slice(*start..*end)?
229+
.append_to_builder(builder.as_mut(), ctx)?;
216230
}
217231
if builder.len() < if_false.len() {
218-
builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?);
232+
if_false
233+
.slice(builder.len()..if_false.len())?
234+
.append_to_builder(builder.as_mut(), ctx)?;
219235
}
220236
Ok(builder.finish())
221237
}
@@ -319,7 +335,8 @@ mod tests {
319335
let if_false =
320336
PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
321337

322-
let result = zip_impl(&if_true, &if_false, &mask)?;
338+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
339+
let result = zip_impl(&if_true, &if_false, &mask, &mut ctx)?;
323340
assert_arrays_eq!(
324341
result,
325342
PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40)])
@@ -336,7 +353,8 @@ mod tests {
336353
PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), None]).into_array();
337354
let if_false = buffer![1i32, 2, 3, 4].into_array();
338355

339-
let result = zip_impl(&if_true, &if_false, &mask)?;
356+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
357+
let result = zip_impl(&if_true, &if_false, &mask, &mut ctx)?;
340358
assert_arrays_eq!(
341359
result,
342360
PrimitiveArray::from_option_iter([Some(1i32), Some(2), Some(3), Some(4)]).into_array()

0 commit comments

Comments
 (0)