Skip to content

Commit c560bee

Browse files
authored
perf: Optimize repeat function for scalar and array fast (#19976)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Part of apache/datafusion-comet#2986. ## Rationale for this change The `repeat` function currently converts scalar inputs to arrays before processing via `make_scalar_function`. Adding a scalar fast path avoids this overhead and improves performance for constant folding and scalar expression evaluation. <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? 1. Refactored `invoke_with_args` to handle scalar inputs directly without array conversion 2. Added scalar fast path for `Utf8`, `LargeUtf8`, and `Utf8View` scalar inputs 3. Added array fast path that skips per-element null checks when `null_count() == 0` | Type | Before | After | Speedup | |------|--------|-------|---------| | **repeat/scalar_utf8** | 519 ns | 91 ns | **5.7x** | | **repeat/scalar_utf8view** | 437 ns | 91 ns | **4.8x** | <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? Yes <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? No <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 58fb6e1 commit c560bee

2 files changed

Lines changed: 167 additions & 36 deletions

File tree

datafusion/functions/benches/repeat.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use arrow::util::bench_util::{
2424
};
2525
use criterion::{Criterion, SamplingMode, criterion_group, criterion_main};
2626
use datafusion_common::DataFusionError;
27+
use datafusion_common::ScalarValue;
2728
use datafusion_common::config::ConfigOptions;
2829
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
2930
use datafusion_functions::string;
@@ -80,6 +81,44 @@ fn invoke_repeat_with_args(
8081
}
8182

8283
fn criterion_benchmark(c: &mut Criterion) {
84+
let repeat_fn = string::repeat();
85+
let config_options = Arc::new(ConfigOptions::default());
86+
87+
// Scalar benchmarks (outside loop)
88+
c.bench_function("repeat/scalar_utf8", |b| {
89+
let args = ScalarFunctionArgs {
90+
args: vec![
91+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("hello".to_string()))),
92+
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
93+
],
94+
arg_fields: vec![
95+
Field::new("a", DataType::Utf8, false).into(),
96+
Field::new("b", DataType::Int64, false).into(),
97+
],
98+
number_rows: 1,
99+
return_field: Field::new("f", DataType::Utf8, true).into(),
100+
config_options: Arc::clone(&config_options),
101+
};
102+
b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap()))
103+
});
104+
105+
c.bench_function("repeat/scalar_utf8view", |b| {
106+
let args = ScalarFunctionArgs {
107+
args: vec![
108+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("hello".to_string()))),
109+
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
110+
],
111+
arg_fields: vec![
112+
Field::new("a", DataType::Utf8View, false).into(),
113+
Field::new("b", DataType::Int64, false).into(),
114+
],
115+
number_rows: 1,
116+
return_field: Field::new("f", DataType::Utf8, true).into(),
117+
config_options: Arc::clone(&config_options),
118+
};
119+
b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap()))
120+
});
121+
83122
for size in [1024, 4096] {
84123
// REPEAT 3 TIMES
85124
let repeat_times = 3;

datafusion/functions/src/string/repeat.rs

Lines changed: 128 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use crate::utils::{make_scalar_function, utf8_to_str_type};
21+
use crate::utils::utf8_to_str_type;
2222
use arrow::array::{
23-
ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
23+
Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
2424
OffsetSizeTrait, StringArrayType, StringViewArray,
2525
};
2626
use arrow::datatypes::DataType;
2727
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
2828
use datafusion_common::cast::as_int64_array;
2929
use datafusion_common::types::{NativeType, logical_int64, logical_string};
30-
use datafusion_common::{DataFusionError, Result, exec_err};
30+
use datafusion_common::utils::take_function_args;
31+
use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, internal_err};
3132
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
3233
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
3334
use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
@@ -99,39 +100,112 @@ impl ScalarUDFImpl for RepeatFunc {
99100
}
100101

101102
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
102-
make_scalar_function(repeat, vec![])(&args.args)
103+
let return_type = args.return_field.data_type().clone();
104+
let [string_arg, count_arg] = take_function_args(self.name(), args.args)?;
105+
106+
// Early return if either argument is a scalar null
107+
if let ColumnarValue::Scalar(s) = &string_arg
108+
&& s.is_null()
109+
{
110+
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
111+
}
112+
if let ColumnarValue::Scalar(c) = &count_arg
113+
&& c.is_null()
114+
{
115+
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?));
116+
}
117+
118+
match (&string_arg, &count_arg) {
119+
(
120+
ColumnarValue::Scalar(string_scalar),
121+
ColumnarValue::Scalar(count_scalar),
122+
) => {
123+
let count = match count_scalar {
124+
ScalarValue::Int64(Some(n)) => *n,
125+
_ => {
126+
return internal_err!(
127+
"Unexpected data type {:?} for repeat count",
128+
count_scalar.data_type()
129+
);
130+
}
131+
};
132+
133+
let result = match string_scalar {
134+
ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => {
135+
ScalarValue::Utf8(Some(compute_repeat(
136+
s,
137+
count,
138+
i32::MAX as usize,
139+
)?))
140+
}
141+
ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some(
142+
compute_repeat(s, count, i64::MAX as usize)?,
143+
)),
144+
_ => {
145+
return internal_err!(
146+
"Unexpected data type {:?} for function repeat",
147+
string_scalar.data_type()
148+
);
149+
}
150+
};
151+
152+
Ok(ColumnarValue::Scalar(result))
153+
}
154+
_ => {
155+
let string_array = string_arg.to_array(args.number_rows)?;
156+
let count_array = count_arg.to_array(args.number_rows)?;
157+
Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?))
158+
}
159+
}
103160
}
104161

105162
fn documentation(&self) -> Option<&Documentation> {
106163
self.doc()
107164
}
108165
}
109166

167+
/// Computes repeat for a single string value with max size check
168+
#[inline]
169+
fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result<String> {
170+
if count <= 0 {
171+
return Ok(String::new());
172+
}
173+
let result_len = s.len().saturating_mul(count as usize);
174+
if result_len > max_size {
175+
return exec_err!(
176+
"string size overflow on repeat, max size is {}, but got {}",
177+
max_size,
178+
result_len
179+
);
180+
}
181+
Ok(s.repeat(count as usize))
182+
}
183+
110184
/// Repeats string the specified number of times.
111185
/// repeat('Pg', 4) = 'PgPgPgPg'
112-
fn repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
113-
let number_array = as_int64_array(&args[1])?;
114-
match args[0].data_type() {
186+
fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
187+
let number_array = as_int64_array(count_array)?;
188+
match string_array.data_type() {
115189
Utf8View => {
116-
let string_view_array = args[0].as_string_view();
190+
let string_view_array = string_array.as_string_view();
117191
repeat_impl::<i32, &StringViewArray>(
118192
&string_view_array,
119193
number_array,
120194
i32::MAX as usize,
121195
)
122196
}
123197
Utf8 => {
124-
let string_array = args[0].as_string::<i32>();
198+
let string_arr = string_array.as_string::<i32>();
125199
repeat_impl::<i32, &GenericStringArray<i32>>(
126-
&string_array,
200+
&string_arr,
127201
number_array,
128202
i32::MAX as usize,
129203
)
130204
}
131205
LargeUtf8 => {
132-
let string_array = args[0].as_string::<i64>();
206+
let string_arr = string_array.as_string::<i64>();
133207
repeat_impl::<i64, &GenericStringArray<i64>>(
134-
&string_array,
208+
&string_arr,
135209
number_array,
136210
i64::MAX as usize,
137211
)
@@ -150,7 +224,7 @@ fn repeat_impl<'a, T, S>(
150224
) -> Result<ArrayRef>
151225
where
152226
T: OffsetSizeTrait,
153-
S: StringArrayType<'a>,
227+
S: StringArrayType<'a> + 'a,
154228
{
155229
let mut total_capacity = 0;
156230
let mut max_item_capacity = 0;
@@ -181,37 +255,55 @@ where
181255
// Reusable buffer to avoid allocations in string.repeat()
182256
let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);
183257

184-
string_array
185-
.iter()
186-
.zip(number_array.iter())
187-
.for_each(|(string, number)| {
258+
// Helper function to repeat a string into a buffer using doubling strategy
259+
// count must be > 0
260+
#[inline]
261+
fn repeat_to_buffer(buffer: &mut Vec<u8>, string: &str, count: usize) {
262+
buffer.clear();
263+
if !string.is_empty() {
264+
let src = string.as_bytes();
265+
// Initial copy
266+
buffer.extend_from_slice(src);
267+
// Doubling strategy: copy what we have so far until we reach the target
268+
while buffer.len() < src.len() * count {
269+
let copy_len = buffer.len().min(src.len() * count - buffer.len());
270+
// SAFETY: we're copying valid UTF-8 bytes that we already verified
271+
buffer.extend_from_within(..copy_len);
272+
}
273+
}
274+
}
275+
276+
// Fast path: no nulls in either array
277+
if string_array.null_count() == 0 && number_array.null_count() == 0 {
278+
for i in 0..string_array.len() {
279+
// SAFETY: i is within bounds (0..len) and null_count() == 0 guarantees valid value
280+
let string = unsafe { string_array.value_unchecked(i) };
281+
let count = number_array.value(i);
282+
if count > 0 {
283+
repeat_to_buffer(&mut buffer, string, count as usize);
284+
// SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
285+
builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
286+
} else {
287+
builder.append_value("");
288+
}
289+
}
290+
} else {
291+
// Slow path: handle nulls
292+
for (string, number) in string_array.iter().zip(number_array.iter()) {
188293
match (string, number) {
189-
(Some(string), Some(number)) if number >= 0 => {
190-
buffer.clear();
191-
let count = number as usize;
192-
if count > 0 && !string.is_empty() {
193-
let src = string.as_bytes();
194-
// Initial copy
195-
buffer.extend_from_slice(src);
196-
// Doubling strategy: copy what we have so far until we reach the target
197-
while buffer.len() < src.len() * count {
198-
let copy_len =
199-
buffer.len().min(src.len() * count - buffer.len());
200-
// SAFETY: we're copying valid UTF-8 bytes that we already verified
201-
buffer.extend_from_within(..copy_len);
202-
}
203-
}
204-
// SAFETY: buffer contains valid UTF-8 since we only ever copy from a valid &str
294+
(Some(string), Some(count)) if count > 0 => {
295+
repeat_to_buffer(&mut buffer, string, count as usize);
296+
// SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str
205297
builder
206298
.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) });
207299
}
208300
(Some(_), Some(_)) => builder.append_value(""),
209301
_ => builder.append_null(),
210302
}
211-
});
212-
let array = builder.finish();
303+
}
304+
}
213305

214-
Ok(Arc::new(array) as ArrayRef)
306+
Ok(Arc::new(builder.finish()) as ArrayRef)
215307
}
216308

217309
#[cfg(test)]

0 commit comments

Comments
 (0)