Skip to content

Commit de41306

Browse files
authored
perf: optimize array_replace for scalar needle (#22387)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #. ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> Currently, `array_replace` / `array_replace_n` / `array_replace_all` perform element-wise comparison by invoking `compare_element_to_list` against each row's sub-array individually. When the needle is a scalar, this can be optimized by performing a single vectorized `not_distinct` comparison over the entire flattened values buffer. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> - Add a specialized replacement kernel that uses `arrow_ord::cmp::not_distinct` with `Scalar` wrapper for a single bulk comparison pass over the flat values buffer. - Extend SLT tests with multi-row scalar-argument coverage, empty-array edge cases, NULL needle replacement, and boundary `n` values for LargeList/FixedSizeList types. ### Benchmarks ``` group baseline optimized ----- -------- --------- array_replace_all_int64/replace/list size: 10, num_rows: 4000 5.04 1124.5±146.98µs ? ?/sec 1.00 223.1±2.79µs ? ?/sec array_replace_all_int64/replace/list size: 100, num_rows: 10000 1.64 7.2±0.59ms ? ?/sec 1.00 4.4±0.12ms ? ?/sec array_replace_all_int64/replace/list size: 500, num_rows: 10000 1.16 25.3±4.09ms ? ?/sec 1.00 21.8±0.69ms ? ?/sec array_replace_all_int64_nested/replace/list size: 10, num_rows: 4000 1.00 7.5±0.30ms ? ?/sec 1.01 7.5±0.24ms ? ?/sec array_replace_all_int64_nested/replace/list size: 100, num_rows: 3000 1.00 38.5±0.52ms ? ?/sec 1.02 39.2±1.02ms ? ?/sec array_replace_all_int64_nested/replace/list size: 300, num_rows: 1500 1.00 55.4±1.73ms ? ?/sec 1.02 56.5±2.13ms ? ?/sec array_replace_boolean/replace/list size: 10, num_rows: 4000 4.57 1072.4±82.05µs ? ?/sec 1.00 234.6±7.55µs ? ?/sec array_replace_boolean/replace/list size: 100, num_rows: 10000 2.38 3.7±0.43ms ? ?/sec 1.00 1536.5±47.67µs ? ?/sec array_replace_boolean/replace/list size: 500, num_rows: 10000 1.51 6.5±0.51ms ? ?/sec 1.00 4.3±0.12ms ? ?/sec array_replace_fixed_size_binary/replace/list size: 10, num_rows: 4000 3.61 1174.3±90.82µs ? ?/sec 1.00 325.2±26.75µs ? ?/sec array_replace_fixed_size_binary/replace/list size: 100, num_rows: 10000 1.45 7.2±0.88ms ? ?/sec 1.00 4.9±0.11ms ? ?/sec array_replace_fixed_size_binary/replace/list size: 500, num_rows: 10000 1.05 25.9±2.34ms ? ?/sec 1.00 24.6±0.71ms ? ?/sec array_replace_int64/replace/list size: 10, num_rows: 4000 5.49 1025.4±24.08µs ? ?/sec 1.00 186.7±18.10µs ? ?/sec array_replace_int64/replace/list size: 100, num_rows: 10000 2.46 3.6±0.13ms ? ?/sec 1.00 1455.7±138.70µs ? ?/sec array_replace_int64/replace/list size: 500, num_rows: 10000 1.26 7.0±0.75ms ? ?/sec 1.00 5.6±0.77ms ? ?/sec array_replace_int64_nested/replace/list size: 10, num_rows: 4000 1.03 7.3±0.14ms ? ?/sec 1.00 7.2±0.21ms ? ?/sec array_replace_int64_nested/replace/list size: 100, num_rows: 3000 1.03 37.8±1.62ms ? ?/sec 1.00 36.7±0.43ms ? ?/sec array_replace_int64_nested/replace/list size: 300, num_rows: 1500 1.03 53.2±1.16ms ? ?/sec 1.00 51.7±1.87ms ? ?/sec array_replace_n_int64/replace/list size: 10, num_rows: 4000 5.02 1074.4±30.92µs ? ?/sec 1.00 214.1±2.22µs ? ?/sec array_replace_n_int64/replace/list size: 100, num_rows: 10000 1.83 5.0±0.15ms ? ?/sec 1.00 2.7±0.06ms ? ?/sec array_replace_n_int64/replace/list size: 500, num_rows: 10000 1.17 15.5±1.11ms ? ?/sec 1.00 13.3±0.24ms ? ?/sec array_replace_n_int64_nested/replace/list size: 10, num_rows: 4000 1.05 7.5±0.45ms ? ?/sec 1.00 7.1±0.07ms ? ?/sec array_replace_n_int64_nested/replace/list size: 100, num_rows: 3000 1.02 37.4±0.51ms ? ?/sec 1.00 36.5±0.62ms ? ?/sec array_replace_n_int64_nested/replace/list size: 300, num_rows: 1500 1.02 54.9±4.97ms ? ?/sec 1.00 53.8±3.15ms ? ?/sec array_replace_strings/replace/list size: 10, num_rows: 4000 2.78 1408.8±44.99µs ? ?/sec 1.00 506.6±16.32µs ? ?/sec array_replace_strings/replace/list size: 100, num_rows: 10000 1.32 11.0±1.25ms ? ?/sec 1.00 8.3±0.37ms ? ?/sec array_replace_strings/replace/list size: 500, num_rows: 10000 1.14 42.4±6.39ms ? ?/sec 1.00 37.2±0.74ms ? ?/sec ``` ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Yes, existing and new slt edge-case tests in `array_replace.slt`. ## Are there any user-facing changes? No. <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 7bcb613 commit de41306

2 files changed

Lines changed: 293 additions & 51 deletions

File tree

datafusion/functions-nested/src/replace.rs

Lines changed: 222 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,20 @@
1919
2020
use arrow::array::{
2121
Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData,
22-
NullBufferBuilder, OffsetSizeTrait, new_null_array,
22+
NullBufferBuilder, OffsetBufferBuilder, OffsetSizeTrait, Scalar, new_null_array,
2323
};
24-
use arrow::datatypes::{DataType, Field};
25-
2624
use arrow::buffer::OffsetBuffer;
25+
use arrow::datatypes::{DataType, Field};
2726
use datafusion_common::cast::as_int64_array;
2827
use datafusion_common::utils::ListCoercion;
29-
use datafusion_common::{Result, exec_err, utils::take_function_args};
28+
use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args};
3029
use datafusion_expr::{
3130
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
3231
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
3332
};
3433
use datafusion_macros::user_doc;
3534

3635
use crate::utils::compare_element_to_list;
37-
use crate::utils::make_scalar_function;
3836

3937
use std::sync::Arc;
4038

@@ -125,7 +123,27 @@ impl ScalarUDFImpl for ArrayReplace {
125123
}
126124

127125
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
128-
make_scalar_function(array_replace_inner)(&args.args)
126+
let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?;
127+
let num_rows = args.number_rows;
128+
let list_array = list_arg.to_array(num_rows)?;
129+
match (from_arg, to_arg) {
130+
(ColumnarValue::Scalar(scalar_from), ColumnarValue::Scalar(scalar_to)) => {
131+
let result = array_replace_with_scalar_args(
132+
&list_array,
133+
scalar_from,
134+
scalar_to,
135+
1i64,
136+
)?;
137+
Ok(ColumnarValue::Array(result))
138+
}
139+
(from_arg, to_arg) => {
140+
let from_array = from_arg.to_array(num_rows)?;
141+
let to_array = to_arg.to_array(num_rows)?;
142+
let result =
143+
array_replace_internal(&list_array, &from_array, &to_array, &[1])?;
144+
Ok(ColumnarValue::Array(result))
145+
}
146+
}
129147
}
130148

131149
fn aliases(&self) -> &[String] {
@@ -200,7 +218,47 @@ impl ScalarUDFImpl for ArrayReplaceN {
200218
}
201219

202220
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
203-
make_scalar_function(array_replace_n_inner)(&args.args)
221+
let [list_arg, from_arg, to_arg, max_arg] =
222+
take_function_args(self.name(), &args.args)?;
223+
let num_rows = args.number_rows;
224+
let list_array = list_arg.to_array(num_rows)?;
225+
match (from_arg, to_arg, max_arg) {
226+
(
227+
ColumnarValue::Scalar(scalar_from),
228+
ColumnarValue::Scalar(scalar_to),
229+
ColumnarValue::Scalar(scalar_max),
230+
) => {
231+
let ScalarValue::Int64(Some(n)) = scalar_max else {
232+
// null max means no replacements
233+
return Ok(ColumnarValue::Array(list_array));
234+
};
235+
let result = array_replace_with_scalar_args(
236+
&list_array,
237+
scalar_from,
238+
scalar_to,
239+
*n,
240+
)?;
241+
Ok(ColumnarValue::Array(result))
242+
}
243+
(from_arg, to_arg, max_arg) => {
244+
let from_array = from_arg.to_array(num_rows)?;
245+
let to_array = to_arg.to_array(num_rows)?;
246+
let max_array = max_arg.to_array(num_rows)?;
247+
let max_array = as_int64_array(&max_array)?;
248+
let arr_n = (0..max_array.len())
249+
.map(|i| {
250+
if max_array.is_null(i) {
251+
0
252+
} else {
253+
max_array.value(i)
254+
}
255+
})
256+
.collect::<Vec<_>>();
257+
let result =
258+
array_replace_internal(&list_array, &from_array, &to_array, &arr_n)?;
259+
Ok(ColumnarValue::Array(result))
260+
}
261+
}
204262
}
205263

206264
fn aliases(&self) -> &[String] {
@@ -273,7 +331,31 @@ impl ScalarUDFImpl for ArrayReplaceAll {
273331
}
274332

275333
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
276-
make_scalar_function(array_replace_all_inner)(&args.args)
334+
let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?;
335+
let num_rows = args.number_rows;
336+
let list_array = list_arg.to_array(num_rows)?;
337+
match (from_arg, to_arg) {
338+
(ColumnarValue::Scalar(scalar_from), ColumnarValue::Scalar(scalar_to)) => {
339+
let result = array_replace_with_scalar_args(
340+
&list_array,
341+
scalar_from,
342+
scalar_to,
343+
i64::MAX,
344+
)?;
345+
Ok(ColumnarValue::Array(result))
346+
}
347+
(from_arg, to_arg) => {
348+
let from_array = from_arg.to_array(num_rows)?;
349+
let to_array = to_arg.to_array(num_rows)?;
350+
let result = array_replace_internal(
351+
&list_array,
352+
&from_array,
353+
&to_array,
354+
&[i64::MAX],
355+
)?;
356+
Ok(ColumnarValue::Array(result))
357+
}
358+
}
277359
}
278360

279361
fn aliases(&self) -> &[String] {
@@ -343,7 +425,11 @@ fn general_replace<O: OffsetSizeTrait>(
343425

344426
let original_idx = O::usize_as(0);
345427
let replace_idx = O::usize_as(1);
346-
let n = arr_n[row_index];
428+
let n = if arr_n.len() == 1 {
429+
arr_n[0]
430+
} else {
431+
arr_n[row_index]
432+
};
347433
let mut counter = 0;
348434

349435
// All elements are false, no need to replace, just copy original data
@@ -412,63 +498,154 @@ fn general_replace<O: OffsetSizeTrait>(
412498
)?))
413499
}
414500

415-
fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
416-
let [array, from, to] = take_function_args("array_replace", args)?;
501+
/// Replaces up to `max_replacements` occurrences of `needle` with the single
502+
/// element in `to_array` for each row in `list_array`.
503+
///
504+
/// This is a specialized fast path for the all-scalar case that uses a single
505+
/// bulk `not_distinct` comparison over only the visible values range, then
506+
/// iterates match positions via `set_indices` instead of scanning every bit.
507+
fn general_replace_with_scalar<O: OffsetSizeTrait>(
508+
list_array: &GenericListArray<O>,
509+
needle: &Scalar<ArrayRef>,
510+
scalar_to: &ScalarValue,
511+
max_replacements: i64,
512+
) -> Result<ArrayRef> {
513+
// No replacement needed - return unchanged.
514+
if max_replacements <= 0 {
515+
return Ok(Arc::new(list_array.clone()));
516+
}
417517

418-
// replace at most one occurrence for each element
419-
let arr_n = vec![1; array.len()];
420-
match array.data_type() {
421-
DataType::List(_) => {
422-
let list_array = array.as_list::<i32>();
423-
general_replace::<i32>(list_array, from, to, &arr_n)
518+
let first_offset = list_array.offsets()[0].to_usize().unwrap();
519+
let last_offset = list_array.offsets()[list_array.len()].to_usize().unwrap();
520+
let visible_values = list_array
521+
.values()
522+
.slice(first_offset, last_offset - first_offset);
523+
524+
let to_array = scalar_to.to_array_of_size(1)?;
525+
let original_data = visible_values.to_data();
526+
let to_data = to_array.to_data();
527+
let capacity = Capacities::Array(original_data.len());
528+
529+
let mut mutable = MutableArrayData::with_capacities(
530+
vec![&original_data, &to_data],
531+
false,
532+
capacity,
533+
);
534+
535+
let mut offsets = OffsetBufferBuilder::<O>::new(list_array.len());
536+
537+
// Single bulk comparison over the visible values only.
538+
let match_bitmap = arrow_ord::cmp::not_distinct(&visible_values, needle)?;
539+
let match_bits = match_bitmap.values();
540+
541+
for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
542+
// Offsets relative to visible_values (subtract first_offset).
543+
let start = offset_window[0].to_usize().unwrap() - first_offset;
544+
let end = offset_window[1].to_usize().unwrap() - first_offset;
545+
let row_len = end - start;
546+
547+
if list_array.is_null(row_index) {
548+
offsets.push_length(0);
549+
continue;
424550
}
425-
DataType::LargeList(_) => {
426-
let list_array = array.as_list::<i64>();
427-
general_replace::<i64>(list_array, from, to, &arr_n)
551+
552+
// Slice the match bits to this row and iterate only over true positions.
553+
let row_bits = match_bits.slice(start, row_len);
554+
let mut match_positions = row_bits
555+
.set_indices()
556+
.take(max_replacements as usize)
557+
.peekable();
558+
if match_positions.peek().is_none() {
559+
mutable.extend(0, start, end);
560+
offsets.push_length(row_len);
561+
continue;
428562
}
429-
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
430-
array_type => exec_err!("array_replace does not support type '{array_type}'."),
563+
564+
// Iterate only over the positions that match using set_indices,
565+
// which is more efficient than scanning every bit because the number
566+
// of matches is typically much smaller than the total array size.
567+
let mut prev_end = 0usize;
568+
for match_pos in match_positions {
569+
// Retain elements before this match.
570+
if match_pos > prev_end {
571+
mutable.extend(0, start + prev_end, start + match_pos);
572+
}
573+
// Emit the replacement element.
574+
mutable.extend(1, 0, 1);
575+
prev_end = match_pos + 1;
576+
}
577+
578+
// Copy remaining elements after the last replacement.
579+
if prev_end < row_len {
580+
mutable.extend(0, start + prev_end, end);
581+
}
582+
583+
offsets.push_length(row_len);
431584
}
585+
586+
let data = mutable.freeze();
587+
588+
Ok(Arc::new(GenericListArray::<O>::try_new(
589+
Arc::new(Field::new_list_field(list_array.value_type(), true)),
590+
offsets.finish(),
591+
arrow::array::make_array(data),
592+
list_array.nulls().cloned(),
593+
)?))
432594
}
433595

434-
fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
435-
let [array, from, to, max] = take_function_args("array_replace_n", args)?;
596+
/// Fast path for `array_replace` when all arguments are scalars.
597+
///
598+
/// Uses a single bulk `not_distinct` comparison instead of per-row comparisons.
599+
fn array_replace_with_scalar_args(
600+
list_array: &ArrayRef,
601+
scalar_from: &ScalarValue,
602+
scalar_to: &ScalarValue,
603+
max_replacements: i64,
604+
) -> Result<ArrayRef> {
605+
// `not_distinct` doesn't support nested types, fall back to the generic array path.
606+
if scalar_from.data_type().is_nested() {
607+
let num_rows = list_array.len();
608+
let from_array = scalar_from.to_array_of_size(num_rows)?;
609+
let to_array = scalar_to.to_array_of_size(num_rows)?;
610+
return array_replace_internal(
611+
list_array,
612+
&from_array,
613+
&to_array,
614+
&vec![max_replacements; num_rows],
615+
);
616+
}
436617

437-
// replace the specified number of occurrences
438-
let arr_n = as_int64_array(max)?.values().to_vec();
439-
match array.data_type() {
618+
let needle = Scalar::new(scalar_from.to_array_of_size(1)?);
619+
match list_array.data_type() {
440620
DataType::List(_) => {
441-
let list_array = array.as_list::<i32>();
442-
general_replace::<i32>(list_array, from, to, &arr_n)
621+
let list = list_array.as_list::<i32>();
622+
general_replace_with_scalar::<i32>(list, &needle, scalar_to, max_replacements)
443623
}
444624
DataType::LargeList(_) => {
445-
let list_array = array.as_list::<i64>();
446-
general_replace::<i64>(list_array, from, to, &arr_n)
447-
}
448-
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
449-
array_type => {
450-
exec_err!("array_replace_n does not support type '{array_type}'.")
625+
let list = list_array.as_list::<i64>();
626+
general_replace_with_scalar::<i64>(list, &needle, scalar_to, max_replacements)
451627
}
628+
DataType::Null => Ok(new_null_array(list_array.data_type(), 1)),
629+
array_type => exec_err!("array_replace does not support type '{array_type}'."),
452630
}
453631
}
454632

455-
fn array_replace_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
456-
let [array, from, to] = take_function_args("array_replace_all", args)?;
457-
458-
// replace all occurrences (up to "i64::MAX")
459-
let arr_n = vec![i64::MAX; array.len()];
633+
fn array_replace_internal(
634+
array: &ArrayRef,
635+
from: &ArrayRef,
636+
to: &ArrayRef,
637+
arr_n: &[i64],
638+
) -> Result<ArrayRef> {
460639
match array.data_type() {
461640
DataType::List(_) => {
462641
let list_array = array.as_list::<i32>();
463-
general_replace::<i32>(list_array, from, to, &arr_n)
642+
general_replace::<i32>(list_array, from, to, arr_n)
464643
}
465644
DataType::LargeList(_) => {
466645
let list_array = array.as_list::<i64>();
467-
general_replace::<i64>(list_array, from, to, &arr_n)
646+
general_replace::<i64>(list_array, from, to, arr_n)
468647
}
469648
DataType::Null => Ok(new_null_array(array.data_type(), 1)),
470-
array_type => {
471-
exec_err!("array_replace_all does not support type '{array_type}'.")
472-
}
649+
array_type => exec_err!("array_replace does not support type '{array_type}'."),
473650
}
474651
}

0 commit comments

Comments
 (0)