Skip to content

Commit f8fb5bd

Browse files
authored
fix: Avoid unnecessary type casts in concat_ws (apache#20436)
## Which issue does this PR close? - Closes apache#20434. ## Rationale for this change 1. `concat_ws` returned `Utf8`, regardless of the input types it was called with. If it was called with `LargeUtf8`, returning `Utf8` might overflow. In general, functions like these should operate on all three string representations unless there is a compelling reason not to (e.g., this is how `concat` works). 4. `simplify_concat_ws` always constructed new literals with type `Utf8`. This lead to unnecessary casts when its inputs were of a different string type. ## What changes are included in this PR? * Support `concat_ws` return type matching its input types, following how `concat` does it. * In `simplify_concat_ws`, construct literals with the right type, not always `Utf8` * Refactor `return_type` for `concat` to be more readable * Make `StringViewArrayBuilder` API more similar to the other string array builders, WRT null handling * Add new unit and SLT tests * Update test output for changed types ## Are these changes tested? Yes. ## Are there any user-facing changes? Yes: some queries involving `concat_ws` will now omit unnecessary cast operations, and the return type of `concat_ws` might be any of the three string types. Generally these changes should match user expectations better than the previous behavior.
1 parent d68b800 commit f8fb5bd

File tree

4 files changed

+380
-68
lines changed

4 files changed

+380
-68
lines changed

datafusion/functions/src/string/concat.rs

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -88,37 +88,33 @@ impl ScalarUDFImpl for ConcatFunc {
8888
&self.signature
8989
}
9090

91+
/// Match the return type to the input types to avoid unnecessary casts. On
92+
/// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid
93+
/// potential overflow on LargeUtf8 input.
9194
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
9295
use DataType::*;
93-
let mut dt = &Utf8;
94-
arg_types.iter().for_each(|data_type| {
95-
if data_type == &Utf8View {
96-
dt = data_type;
97-
}
98-
if data_type == &LargeUtf8 && dt != &Utf8View {
99-
dt = data_type;
100-
}
101-
});
102-
103-
Ok(dt.to_owned())
96+
if arg_types.contains(&Utf8View) {
97+
Ok(Utf8View)
98+
} else if arg_types.contains(&LargeUtf8) {
99+
Ok(LargeUtf8)
100+
} else {
101+
Ok(Utf8)
102+
}
104103
}
105104

106105
/// Concatenates the text representations of all the arguments. NULL arguments are ignored.
107106
/// concat('abcde', 2, NULL, 22) = 'abcde222'
108107
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109108
let ScalarFunctionArgs { args, .. } = args;
110109

111-
let mut return_datatype = DataType::Utf8;
112-
args.iter().for_each(|col| {
113-
if col.data_type() == DataType::Utf8View {
114-
return_datatype = col.data_type();
115-
}
116-
if col.data_type() == DataType::LargeUtf8
117-
&& return_datatype != DataType::Utf8View
118-
{
119-
return_datatype = col.data_type();
120-
}
121-
});
110+
let return_datatype = if args.iter().any(|c| c.data_type() == DataType::Utf8View)
111+
{
112+
DataType::Utf8View
113+
} else if args.iter().any(|c| c.data_type() == DataType::LargeUtf8) {
114+
DataType::LargeUtf8
115+
} else {
116+
DataType::Utf8
117+
};
122118

123119
let array_len = args.iter().find_map(|x| match x {
124120
ColumnarValue::Array(array) => Some(array.len()),
@@ -247,7 +243,7 @@ impl ScalarUDFImpl for ConcatFunc {
247243
builder.append_offset();
248244
}
249245

250-
let string_array = builder.finish();
246+
let string_array = builder.finish(None);
251247
Ok(ColumnarValue::Array(Arc::new(string_array)))
252248
}
253249
DataType::LargeUtf8 => {

0 commit comments

Comments
 (0)