Skip to content

Commit c3f0807

Browse files
neilconwaymartin-gJefffrey
authored
perf: Optimize translate() UDF for scalar inputs (#20305)
## Which issue does this PR close? - Closes #20302. ## Rationale for this change `translate()` is commonly invoked with constant values for its second and third arguments. We can take advantage of that to significantly optimize its performance by precomputing the translation lookup table, rather than recomputing it for every row. For ASCII-only inputs, we can further replace the hashmap lookup table with a fixed-size array that maps ASCII byte values directly. For scalar ASCII inputs, this yields roughly a 10x performance improvement. For scalar UTF8 inputs, the performance improvement is more like 50%, although less so for long strings. Along the way, add support for `translate()` on `LargeUtf8` input, along with an SLT test, and improve the docs. ## What changes are included in this PR? * Add a benchmark for scalar/constant input to translate * Add a missing test case * Improve translate() docs * Support translate() on LargeUtf8 input * Optimize translate() for scalar inputs by precomputing lookup hashmap * Optimize translate() for ASCII inputs by precomputing ASCII byte-wise lookup table ## Are these changes tested? Yes. Added an extra test case and did a bunch of benchmarking. ## Are there any user-facing changes? No. --------- Co-authored-by: Martin Grigorov <martin-g@users.noreply.github.com> Co-authored-by: Jeffrey Vo <jeffrey.vo.australia@gmail.com>
1 parent 4f4e814 commit c3f0807

File tree

4 files changed

+221
-26
lines changed

4 files changed

+221
-26
lines changed

datafusion/functions/benches/translate.rs

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,19 @@ use arrow::array::OffsetSizeTrait;
1919
use arrow::datatypes::{DataType, Field};
2020
use arrow::util::bench_util::create_string_array_with_len;
2121
use criterion::{Criterion, SamplingMode, criterion_group, criterion_main};
22-
use datafusion_common::DataFusionError;
2322
use datafusion_common::config::ConfigOptions;
23+
use datafusion_common::{DataFusionError, ScalarValue};
2424
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
2525
use datafusion_functions::unicode;
2626
use std::hint::black_box;
2727
use std::sync::Arc;
2828
use std::time::Duration;
2929

30-
fn create_args<O: OffsetSizeTrait>(size: usize, str_len: usize) -> Vec<ColumnarValue> {
30+
fn create_args_array_from_to<O: OffsetSizeTrait>(
31+
size: usize,
32+
str_len: usize,
33+
) -> Vec<ColumnarValue> {
3134
let string_array = Arc::new(create_string_array_with_len::<O>(size, 0.1, str_len));
32-
// Create simple from/to strings for translation
3335
let from_array = Arc::new(create_string_array_with_len::<O>(size, 0.1, 3));
3436
let to_array = Arc::new(create_string_array_with_len::<O>(size, 0.1, 2));
3537

@@ -40,6 +42,19 @@ fn create_args<O: OffsetSizeTrait>(size: usize, str_len: usize) -> Vec<ColumnarV
4042
]
4143
}
4244

45+
fn create_args_scalar_from_to<O: OffsetSizeTrait>(
46+
size: usize,
47+
str_len: usize,
48+
) -> Vec<ColumnarValue> {
49+
let string_array = Arc::new(create_string_array_with_len::<O>(size, 0.1, str_len));
50+
51+
vec![
52+
ColumnarValue::Array(string_array),
53+
ColumnarValue::Scalar(ScalarValue::from("aeiou")),
54+
ColumnarValue::Scalar(ScalarValue::from("AEIOU")),
55+
]
56+
}
57+
4358
fn invoke_translate_with_args(
4459
args: Vec<ColumnarValue>,
4560
number_rows: usize,
@@ -67,17 +82,22 @@ fn criterion_benchmark(c: &mut Criterion) {
6782
group.sample_size(10);
6883
group.measurement_time(Duration::from_secs(10));
6984

70-
for str_len in [8, 32] {
71-
let args = create_args::<i32>(size, str_len);
72-
group.bench_function(
73-
format!("translate_string [size={size}, str_len={str_len}]"),
74-
|b| {
75-
b.iter(|| {
76-
let args_cloned = args.clone();
77-
black_box(invoke_translate_with_args(args_cloned, size))
78-
})
79-
},
80-
);
85+
for str_len in [8, 32, 128, 1024] {
86+
let args = create_args_array_from_to::<i32>(size, str_len);
87+
group.bench_function(format!("array_from_to [str_len={str_len}]"), |b| {
88+
b.iter(|| {
89+
let args_cloned = args.clone();
90+
black_box(invoke_translate_with_args(args_cloned, size))
91+
})
92+
});
93+
94+
let args = create_args_scalar_from_to::<i32>(size, str_len);
95+
group.bench_function(format!("scalar_from_to [str_len={str_len}]"), |b| {
96+
b.iter(|| {
97+
let args_cloned = args.clone();
98+
black_box(invoke_translate_with_args(args_cloned, size))
99+
})
100+
});
81101
}
82102

83103
group.finish();

datafusion/functions/src/unicode/translate.rs

Lines changed: 178 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ use datafusion_macros::user_doc;
3535

3636
#[user_doc(
3737
doc_section(label = "String Functions"),
38-
description = "Translates characters in a string to specified translation characters.",
39-
syntax_example = "translate(str, chars, translation)",
38+
description = "Performs character-wise substitution based on a mapping.",
39+
syntax_example = "translate(str, from, to)",
4040
sql_example = r#"```sql
4141
> select translate('twice', 'wic', 'her');
4242
+--------------------------------------------------+
@@ -46,10 +46,10 @@ use datafusion_macros::user_doc;
4646
+--------------------------------------------------+
4747
```"#,
4848
standard_argument(name = "str", prefix = "String"),
49-
argument(name = "chars", description = "Characters to translate."),
49+
argument(name = "from", description = "The characters to be replaced."),
5050
argument(
51-
name = "translation",
52-
description = "Translation characters. Translation characters replace only characters at the same position in the **chars** string."
51+
name = "to",
52+
description = "The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed. If a character appears more than once in **from**, the first occurrence determines the mapping."
5353
)
5454
)]
5555
#[derive(Debug, PartialEq, Eq, Hash)]
@@ -71,6 +71,7 @@ impl TranslateFunc {
7171
vec![
7272
Exact(vec![Utf8View, Utf8, Utf8]),
7373
Exact(vec![Utf8, Utf8, Utf8]),
74+
Exact(vec![LargeUtf8, Utf8, Utf8]),
7475
],
7576
Volatility::Immutable,
7677
),
@@ -99,6 +100,61 @@ impl ScalarUDFImpl for TranslateFunc {
99100
&self,
100101
args: datafusion_expr::ScalarFunctionArgs,
101102
) -> Result<ColumnarValue> {
103+
// When from and to are scalars, pre-build the translation map once
104+
if let (Some(from_str), Some(to_str)) = (
105+
try_as_scalar_str(&args.args[1]),
106+
try_as_scalar_str(&args.args[2]),
107+
) {
108+
let to_graphemes: Vec<&str> = to_str.graphemes(true).collect();
109+
110+
let mut from_map: HashMap<&str, usize> = HashMap::new();
111+
for (index, c) in from_str.graphemes(true).enumerate() {
112+
// Ignore characters that already exist in from_map
113+
from_map.entry(c).or_insert(index);
114+
}
115+
116+
let ascii_table = build_ascii_translate_table(from_str, to_str);
117+
118+
let string_array = args.args[0].to_array_of_size(args.number_rows)?;
119+
120+
let result = match string_array.data_type() {
121+
DataType::Utf8View => {
122+
let arr = string_array.as_string_view();
123+
translate_with_map::<i32, _>(
124+
arr,
125+
&from_map,
126+
&to_graphemes,
127+
ascii_table.as_ref(),
128+
)
129+
}
130+
DataType::Utf8 => {
131+
let arr = string_array.as_string::<i32>();
132+
translate_with_map::<i32, _>(
133+
arr,
134+
&from_map,
135+
&to_graphemes,
136+
ascii_table.as_ref(),
137+
)
138+
}
139+
DataType::LargeUtf8 => {
140+
let arr = string_array.as_string::<i64>();
141+
translate_with_map::<i64, _>(
142+
arr,
143+
&from_map,
144+
&to_graphemes,
145+
ascii_table.as_ref(),
146+
)
147+
}
148+
other => {
149+
return exec_err!(
150+
"Unsupported data type {other:?} for function translate"
151+
);
152+
}
153+
}?;
154+
155+
return Ok(ColumnarValue::Array(result));
156+
}
157+
102158
make_scalar_function(invoke_translate, vec![])(&args.args)
103159
}
104160

@@ -107,6 +163,14 @@ impl ScalarUDFImpl for TranslateFunc {
107163
}
108164
}
109165

166+
/// If `cv` is a non-null scalar string, return its value.
167+
fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> {
168+
match cv {
169+
ColumnarValue::Scalar(s) => s.try_as_str().flatten(),
170+
_ => None,
171+
}
172+
}
173+
110174
fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
111175
match args[0].data_type() {
112176
DataType::Utf8View => {
@@ -123,8 +187,8 @@ fn invoke_translate(args: &[ArrayRef]) -> Result<ArrayRef> {
123187
}
124188
DataType::LargeUtf8 => {
125189
let string_array = args[0].as_string::<i64>();
126-
let from_array = args[1].as_string::<i64>();
127-
let to_array = args[2].as_string::<i64>();
190+
let from_array = args[1].as_string::<i32>();
191+
let to_array = args[2].as_string::<i32>();
128192
translate::<i64, _, _>(string_array, from_array, to_array)
129193
}
130194
other => {
@@ -170,7 +234,7 @@ where
170234
// Build from_map using reusable buffer
171235
from_graphemes.extend(from.graphemes(true));
172236
for (index, c) in from_graphemes.iter().enumerate() {
173-
// Ignore characters that already exist in from_map, else insert
237+
// Ignore characters that already exist in from_map
174238
from_map.entry(*c).or_insert(index);
175239
}
176240

@@ -199,6 +263,97 @@ where
199263
Ok(Arc::new(result) as ArrayRef)
200264
}
201265

266+
/// Sentinel value in the ASCII translate table indicating the character should
267+
/// be deleted (the `from` character has no corresponding `to` character). Any
268+
/// value > 127 works since valid ASCII is 0–127.
269+
const ASCII_DELETE: u8 = 0xFF;
270+
271+
/// If `from` and `to` are both ASCII, build a fixed-size lookup table for
272+
/// translation. Each entry maps an input byte to its replacement byte, or to
273+
/// [`ASCII_DELETE`] if the character should be removed. Returns `None` if
274+
/// either string contains non-ASCII characters.
275+
fn build_ascii_translate_table(from: &str, to: &str) -> Option<[u8; 128]> {
276+
if !from.is_ascii() || !to.is_ascii() {
277+
return None;
278+
}
279+
let mut table = [0u8; 128];
280+
for i in 0..128u8 {
281+
table[i as usize] = i;
282+
}
283+
let to_bytes = to.as_bytes();
284+
let mut seen = [false; 128];
285+
for (i, from_byte) in from.bytes().enumerate() {
286+
let idx = from_byte as usize;
287+
if !seen[idx] {
288+
seen[idx] = true;
289+
if i < to_bytes.len() {
290+
table[idx] = to_bytes[i];
291+
} else {
292+
table[idx] = ASCII_DELETE;
293+
}
294+
}
295+
}
296+
Some(table)
297+
}
298+
299+
/// Optimized translate for constant `from` and `to` arguments: uses a pre-built
300+
/// translation map instead of rebuilding it for every row. When an ASCII byte
301+
/// lookup table is provided, ASCII input rows use the lookup table; non-ASCII
302+
/// inputs fallback to using the map.
303+
fn translate_with_map<'a, T: OffsetSizeTrait, V>(
304+
string_array: V,
305+
from_map: &HashMap<&str, usize>,
306+
to_graphemes: &[&str],
307+
ascii_table: Option<&[u8; 128]>,
308+
) -> Result<ArrayRef>
309+
where
310+
V: ArrayAccessor<Item = &'a str>,
311+
{
312+
let mut result_graphemes: Vec<&str> = Vec::new();
313+
let mut ascii_buf: Vec<u8> = Vec::new();
314+
315+
let result = ArrayIter::new(string_array)
316+
.map(|string| {
317+
string.map(|s| {
318+
// Fast path: byte-level table lookup for ASCII strings
319+
if let Some(table) = ascii_table
320+
&& s.is_ascii()
321+
{
322+
ascii_buf.clear();
323+
for &b in s.as_bytes() {
324+
let mapped = table[b as usize];
325+
if mapped != ASCII_DELETE {
326+
ascii_buf.push(mapped);
327+
}
328+
}
329+
// SAFETY: all bytes are ASCII, hence valid UTF-8.
330+
return unsafe {
331+
std::str::from_utf8_unchecked(&ascii_buf).to_owned()
332+
};
333+
}
334+
335+
// Slow path: grapheme-based translation
336+
result_graphemes.clear();
337+
338+
for c in s.graphemes(true) {
339+
match from_map.get(c) {
340+
Some(n) => {
341+
if let Some(replacement) = to_graphemes.get(*n) {
342+
result_graphemes.push(*replacement);
343+
}
344+
}
345+
None => result_graphemes.push(c),
346+
}
347+
}
348+
349+
result_graphemes.concat()
350+
})
351+
})
352+
.collect::<GenericStringArray<T>>();
353+
354+
Ok(Arc::new(result) as ArrayRef)
355+
}
356+
202357
#[cfg(test)]
203358
mod tests {
204359
use arrow::array::{Array, StringArray};
@@ -284,6 +439,21 @@ mod tests {
284439
Utf8,
285440
StringArray
286441
);
442+
// Non-ASCII input with ASCII scalar from/to: exercises the
443+
// grapheme fallback within translate_with_map.
444+
test_function!(
445+
TranslateFunc::new(),
446+
vec![
447+
ColumnarValue::Scalar(ScalarValue::from("café")),
448+
ColumnarValue::Scalar(ScalarValue::from("ae")),
449+
ColumnarValue::Scalar(ScalarValue::from("AE"))
450+
],
451+
Ok(Some("cAfé")),
452+
&str,
453+
Utf8,
454+
StringArray
455+
);
456+
287457
#[cfg(not(feature = "unicode_expressions"))]
288458
test_function!(
289459
TranslateFunc::new(),

datafusion/sqllogictest/test_files/functions.slt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ SELECT translate('12345', '143', NULL)
239239
----
240240
NULL
241241

242+
query T
243+
SELECT translate(arrow_cast('12345', 'LargeUtf8'), '143', 'ax')
244+
----
245+
a2x5
246+
242247
statement ok
243248
CREATE TABLE test(
244249
c1 VARCHAR

docs/source/user-guide/sql/scalar_functions.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2068,17 +2068,17 @@ to_hex(int)
20682068

20692069
### `translate`
20702070

2071-
Translates characters in a string to specified translation characters.
2071+
Performs character-wise substitution based on a mapping.
20722072

20732073
```sql
2074-
translate(str, chars, translation)
2074+
translate(str, from, to)
20752075
```
20762076

20772077
#### Arguments
20782078

20792079
- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators.
2080-
- **chars**: Characters to translate.
2081-
- **translation**: Translation characters. Translation characters replace only characters at the same position in the **chars** string.
2080+
- **from**: The characters to be replaced.
2081+
- **to**: The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed. If a character appears more than once in **from**, the first occurrence determines the mapping.
20822082

20832083
#### Example
20842084

0 commit comments

Comments
 (0)