Skip to content

Commit a87bdc9

Browse files
authored
perf: optimize array_remove for scalar needle (#22390)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Similar to #22387 (array_replace scalar optimization) `array_remove` / `array_remove_n` / `array_remove_all` perform element-wise comparison by invoking `compare_element_to_list` against each row's sub-array individually. When the needle is a scalar, this can be optimized by performing a single vectorized `distinct` comparison over the entire flattened values buffer. ## What changes are included in this PR? - Add a specialized removal kernel (`general_remove_with_scalar`) that uses `arrow_ord::cmp::distinct` with `Scalar` wrapper for a single bulk comparison pass over the flat values buffer. - Extend SLT tests with multi-row scalar-argument coverage, NULL-containing arrays, empty-array edge cases, boundary `n` values, and LargeList type coverage. <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ### Benchmarks ``` group main optimized ----- ---- --------- array_remove_all_int64/remove/list size: 10, num_rows: 4000 4.35 856.8±97.81µs ? ?/sec 1.00 196.9±4.48µs ? ?/sec array_remove_all_int64/remove/list size: 100, num_rows: 10000 1.90 5.5±0.09ms ? ?/sec 1.00 2.9±0.09ms ? ?/sec array_remove_all_int64/remove/list size: 500, num_rows: 10000 1.35 19.2±0.21ms ? ?/sec 1.00 14.2±0.48ms ? ?/sec array_remove_all_int64_nested/remove/list size: 10, num_rows: 4000 1.00 7.1±0.12ms ? ?/sec 1.04 7.4±0.12ms ? ?/sec array_remove_all_int64_nested/remove/list size: 100, num_rows: 3000 1.00 36.5±0.39ms ? ?/sec 1.05 38.3±2.61ms ? ?/sec array_remove_all_int64_nested/remove/list size: 300, num_rows: 1500 1.01 53.5±2.26ms ? ?/sec 1.00 53.0±0.99ms ? ?/sec array_remove_boolean/remove/list size: 10, num_rows: 4000 3.83 813.9±7.08µs ? ?/sec 1.00 212.4±2.28µs ? ?/sec array_remove_boolean/remove/list size: 100, num_rows: 10000 2.73 3.7±0.03ms ? ?/sec 1.00 1364.7±177.83µs ? ?/sec array_remove_boolean/remove/list size: 500, num_rows: 10000 2.34 9.8±0.14ms ? ?/sec 1.00 4.2±0.25ms ? ?/sec array_remove_fixed_size_binary/remove/list size: 10, num_rows: 4000 3.16 918.2±16.76µs ? ?/sec 1.00 290.6±9.79µs ? ?/sec array_remove_fixed_size_binary/remove/list size: 100, num_rows: 10000 1.56 6.9±0.13ms ? ?/sec 1.00 4.4±0.15ms ? ?/sec array_remove_fixed_size_binary/remove/list size: 500, num_rows: 10000 1.17 27.7±0.84ms ? ?/sec 1.00 23.6±2.04ms ? ?/sec array_remove_int64/remove/list size: 10, num_rows: 4000 4.55 825.7±6.30µs ? ?/sec 1.00 181.3±4.32µs ? ?/sec array_remove_int64/remove/list size: 100, num_rows: 10000 3.35 3.8±0.11ms ? ?/sec 1.00 1135.6±54.87µs ? ?/sec array_remove_int64/remove/list size: 500, num_rows: 10000 2.04 10.3±0.35ms ? ?/sec 1.00 5.1±0.39ms ? ?/sec array_remove_int64_nested/remove/list size: 10, num_rows: 4000 1.00 7.1±0.18ms ? ?/sec 1.02 7.2±0.07ms ? ?/sec array_remove_int64_nested/remove/list size: 100, num_rows: 3000 1.00 36.1±1.35ms ? ?/sec 1.07 38.5±3.67ms ? ?/sec array_remove_int64_nested/remove/list size: 300, num_rows: 1500 1.00 51.7±0.57ms ? ?/sec 1.05 54.1±2.13ms ? ?/sec array_remove_n_int64/remove/list size: 10, num_rows: 4000 4.43 845.3±5.00µs ? ?/sec 1.00 190.6±2.84µs ? ?/sec array_remove_n_int64/remove/list size: 100, num_rows: 10000 2.29 4.7±0.11ms ? ?/sec 1.00 2.0±0.12ms ? ?/sec array_remove_n_int64/remove/list size: 500, num_rows: 10000 1.63 14.8±0.42ms ? ?/sec 1.00 9.0±0.51ms ? ?/sec array_remove_n_int64_nested/remove/list size: 10, num_rows: 4000 1.00 7.0±0.09ms ? ?/sec 1.29 8.9±3.44ms ? ?/sec array_remove_n_int64_nested/remove/list size: 100, num_rows: 3000 1.00 36.6±0.42ms ? ?/sec 1.03 37.7±0.68ms ? ?/sec array_remove_n_int64_nested/remove/list size: 300, num_rows: 1500 1.00 52.7±3.68ms ? ?/sec 1.03 54.5±4.49ms ? ?/sec array_remove_strings/remove/list size: 10, num_rows: 4000 2.50 1144.6±21.95µs ? ?/sec 1.00 457.0±14.15µs ? ?/sec array_remove_strings/remove/list size: 100, num_rows: 10000 1.42 10.5±1.16ms ? ?/sec 1.00 7.4±0.34ms ? ?/sec array_remove_strings/remove/list size: 500, num_rows: 10000 1.12 39.8±0.91ms ? ?/sec 1.00 35.5±1.51ms ? ?/sec ``` ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Yes, existing and new SLT edge-case tests in `array_remove.slt`. ## Are there any user-facing changes? No. <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent fcc9cc4 commit a87bdc9

2 files changed

Lines changed: 286 additions & 30 deletions

File tree

datafusion/functions-nested/src/remove.rs

Lines changed: 202 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@
1818
//! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions.
1919
2020
use crate::utils;
21-
use crate::utils::make_scalar_function;
2221
use arrow::array::{
23-
Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait,
24-
cast::AsArray, make_array,
22+
Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetBufferBuilder,
23+
OffsetSizeTrait, Scalar, cast::AsArray, make_array,
2524
};
2625
use arrow::buffer::{NullBuffer, OffsetBuffer};
2726
use arrow::datatypes::{DataType, FieldRef};
2827
use datafusion_common::cast::as_int64_array;
2928
use datafusion_common::utils::ListCoercion;
30-
use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
29+
use datafusion_common::{
30+
Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
31+
};
3132
use datafusion_expr::{
3233
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
3334
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
@@ -113,7 +114,24 @@ impl ScalarUDFImpl for ArrayRemove {
113114
}
114115

115116
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
116-
make_scalar_function(array_remove_inner)(&args.args)
117+
let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?;
118+
let num_rows = args.number_rows;
119+
let list_array = list_arg.to_array(num_rows)?;
120+
match element_arg {
121+
ColumnarValue::Scalar(scalar_element)
122+
if !scalar_element.is_null()
123+
&& !scalar_element.data_type().is_nested() =>
124+
{
125+
let result =
126+
array_remove_with_scalar_args(&list_array, scalar_element, 1i64)?;
127+
Ok(ColumnarValue::Array(result))
128+
}
129+
element_arg => {
130+
let element_array = element_arg.to_array(num_rows)?;
131+
let result = array_remove_internal(&list_array, &element_array, &[1])?;
132+
Ok(ColumnarValue::Array(result))
133+
}
134+
}
117135
}
118136

119137
fn aliases(&self) -> &[String] {
@@ -214,7 +232,40 @@ impl ScalarUDFImpl for ArrayRemoveN {
214232
}
215233

216234
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
217-
make_scalar_function(array_remove_n_inner)(&args.args)
235+
let [list_arg, element_arg, max_arg] =
236+
take_function_args(self.name(), &args.args)?;
237+
let num_rows = args.number_rows;
238+
let list_array = list_arg.to_array(num_rows)?;
239+
match (element_arg, max_arg) {
240+
(
241+
ColumnarValue::Scalar(scalar_element),
242+
ColumnarValue::Scalar(scalar_max),
243+
) if !scalar_element.is_null() && !scalar_element.data_type().is_nested() => {
244+
let ScalarValue::Int64(Some(n)) = scalar_max else {
245+
// null max means no remove
246+
return Ok(ColumnarValue::Array(list_array));
247+
};
248+
let result =
249+
array_remove_with_scalar_args(&list_array, scalar_element, *n)?;
250+
Ok(ColumnarValue::Array(result))
251+
}
252+
(element_arg, max_arg) => {
253+
let element_array = element_arg.to_array(num_rows)?;
254+
let max_array = max_arg.to_array(num_rows)?;
255+
let max_array = as_int64_array(&max_array)?;
256+
let arr_n = (0..max_array.len())
257+
.map(|i| {
258+
if max_array.is_null(i) {
259+
0
260+
} else {
261+
max_array.value(i)
262+
}
263+
})
264+
.collect::<Vec<_>>();
265+
let result = array_remove_internal(&list_array, &element_array, &arr_n)?;
266+
Ok(ColumnarValue::Array(result))
267+
}
268+
}
218269
}
219270

220271
fn aliases(&self) -> &[String] {
@@ -304,7 +355,25 @@ impl ScalarUDFImpl for ArrayRemoveAll {
304355
}
305356

306357
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
307-
make_scalar_function(array_remove_all_inner)(&args.args)
358+
let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?;
359+
let num_rows = args.number_rows;
360+
let list_array = list_arg.to_array(num_rows)?;
361+
match element_arg {
362+
ColumnarValue::Scalar(scalar_element)
363+
if !scalar_element.is_null()
364+
&& !scalar_element.data_type().is_nested() =>
365+
{
366+
let result =
367+
array_remove_with_scalar_args(&list_array, scalar_element, i64::MAX)?;
368+
Ok(ColumnarValue::Array(result))
369+
}
370+
element_arg => {
371+
let element_array = element_arg.to_array(num_rows)?;
372+
let result =
373+
array_remove_internal(&list_array, &element_array, &[i64::MAX])?;
374+
Ok(ColumnarValue::Array(result))
375+
}
376+
}
308377
}
309378

310379
fn aliases(&self) -> &[String] {
@@ -316,27 +385,6 @@ impl ScalarUDFImpl for ArrayRemoveAll {
316385
}
317386
}
318387

319-
fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
320-
let [array, element] = take_function_args("array_remove", args)?;
321-
322-
let arr_n = vec![1; array.len()];
323-
array_remove_internal(array, element, &arr_n)
324-
}
325-
326-
fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
327-
let [array, element, max] = take_function_args("array_remove_n", args)?;
328-
329-
let arr_n = as_int64_array(max)?.values().to_vec();
330-
array_remove_internal(array, element, &arr_n)
331-
}
332-
333-
fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
334-
let [array, element] = take_function_args("array_remove_all", args)?;
335-
336-
let arr_n = vec![i64::MAX; array.len()];
337-
array_remove_internal(array, element, &arr_n)
338-
}
339-
340388
fn array_remove_internal(
341389
array: &ArrayRef,
342390
element_array: &ArrayRef,
@@ -357,6 +405,28 @@ fn array_remove_internal(
357405
}
358406
}
359407

408+
/// Fast path for `array_remove` when the needle is a non-null, non-nested scalar.
409+
/// Dispatches to the bulk `not_distinct` comparison kernel.
410+
fn array_remove_with_scalar_args(
411+
array: &ArrayRef,
412+
scalar_needle: &ScalarValue,
413+
max_removals: i64,
414+
) -> Result<ArrayRef> {
415+
match array.data_type() {
416+
DataType::List(_) => {
417+
let list_array = array.as_list::<i32>();
418+
general_remove_with_scalar::<i32>(list_array, scalar_needle, max_removals)
419+
}
420+
DataType::LargeList(_) => {
421+
let list_array = array.as_list::<i64>();
422+
general_remove_with_scalar::<i64>(list_array, scalar_needle, max_removals)
423+
}
424+
array_type => exec_err!(
425+
"array_remove/array_remove_n/array_remove_all does not support type '{array_type}'."
426+
),
427+
}
428+
}
429+
360430
/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
361431
/// of `element_array[i]`.
362432
///
@@ -411,7 +481,11 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
411481
let start = offset_window[0].to_usize().unwrap();
412482
let end = offset_window[1].to_usize().unwrap();
413483
// n is the number of elements to remove in this row
414-
let n = arr_n[row_index];
484+
let n = if arr_n.len() == 1 {
485+
arr_n[0]
486+
} else {
487+
arr_n[row_index]
488+
};
415489

416490
// compare each element in the list, `false` means the element matches and should be removed
417491
let eq_array = utils::compare_element_to_list(
@@ -468,6 +542,105 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
468542
)?))
469543
}
470544

545+
/// For each element of `list_array[i]`, removes up to `max_removals` occurrences
546+
/// of the scalar needle.
547+
///
548+
/// This is a specialized version of `general_remove` for scalar elements that
549+
/// uses bulk comparison for better performance.
550+
fn general_remove_with_scalar<OffsetSize: OffsetSizeTrait>(
551+
list_array: &GenericListArray<OffsetSize>,
552+
scalar_needle: &ScalarValue,
553+
max_removals: i64,
554+
) -> Result<ArrayRef> {
555+
if max_removals <= 0 {
556+
return Ok(Arc::new(list_array.clone()));
557+
}
558+
559+
let list_field = match list_array.data_type() {
560+
DataType::List(field) | DataType::LargeList(field) => field,
561+
_ => {
562+
return exec_err!(
563+
"Expected List or LargeList data type, got {:?}",
564+
list_array.data_type()
565+
);
566+
}
567+
};
568+
569+
let list_offsets = list_array.offsets();
570+
let first_offset = list_offsets[0].to_usize().unwrap();
571+
let last_offset = list_offsets[list_offsets.len() - 1].to_usize().unwrap();
572+
let values_range_len = last_offset - first_offset;
573+
let values_slice = list_array.values().slice(first_offset, values_range_len);
574+
let original_data = values_slice.to_data();
575+
let mut offsets = OffsetBufferBuilder::<OffsetSize>::new(list_array.len());
576+
577+
let mut mutable = MutableArrayData::with_capacities(
578+
vec![&original_data],
579+
false,
580+
Capacities::Array(original_data.len()),
581+
);
582+
let nulls = list_array.nulls().cloned();
583+
let needle = scalar_needle.to_array_of_size(1)?;
584+
let remove_mask = arrow_ord::cmp::not_distinct(&values_slice, &Scalar::new(needle))?;
585+
let remove_bits = remove_mask.values();
586+
587+
for (row_index, offset_window) in list_offsets.windows(2).enumerate() {
588+
if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) {
589+
offsets.push_length(0);
590+
continue;
591+
}
592+
593+
let start = offset_window[0].to_usize().unwrap() - first_offset;
594+
let end = offset_window[1].to_usize().unwrap() - first_offset;
595+
let row_len = end - start;
596+
597+
let row_remove_bits = remove_bits.slice(start, row_len);
598+
let num_to_remove = row_remove_bits.count_set_bits();
599+
600+
if num_to_remove == 0 {
601+
mutable.extend(0, start, end);
602+
offsets.push_length(row_len);
603+
continue;
604+
}
605+
606+
let removals_to_apply = max_removals.min(num_to_remove as i64) as usize;
607+
608+
// Iterate only over the removal positions via set_indices. This is
609+
// efficient when the number of removals is small relative to the row
610+
// length (common case), since it skips over retained elements.
611+
let mut removed = 0usize;
612+
let mut copied = 0usize;
613+
let mut prev_end = start;
614+
for remove_pos in row_remove_bits.set_indices() {
615+
let abs_pos = start + remove_pos;
616+
if abs_pos > prev_end {
617+
mutable.extend(0, prev_end, abs_pos);
618+
copied += abs_pos - prev_end;
619+
}
620+
prev_end = abs_pos + 1;
621+
removed += 1;
622+
if removed == removals_to_apply {
623+
break;
624+
}
625+
}
626+
// Copy the remaining tail after the last removal
627+
if prev_end < end {
628+
mutable.extend(0, prev_end, end);
629+
copied += end - prev_end;
630+
}
631+
632+
offsets.push_length(copied);
633+
}
634+
635+
let new_values = make_array(mutable.freeze());
636+
Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
637+
Arc::clone(list_field),
638+
offsets.finish(),
639+
new_values,
640+
nulls,
641+
)?))
642+
}
643+
471644
#[cfg(test)]
472645
mod tests {
473646
use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};

0 commit comments

Comments
 (0)