Skip to content

Commit c699361

Browse files
authored
fix: Handle Utf8View and LargeUtf8 separators in concat_ws (#20361)
## Which issue does this PR close? - Closes #20360 ## Rationale for this change concat_ws only handled Utf8 separators (despite its signature claiming otherwise). Attempting to pass a Utf8View or LargeUtf8 separator would result in a panic or internal error. ## What changes are included in this PR? * Add SLT test case for array Utf8View separator * Add unit test for scalar Utf8View separator * Fix behavior: add support for LargeUtf8 and Utf8View separators, both array and scalar * Other minor code cleanups and improvements ## Are these changes tested? Yes. Added new test cases. Note that we can't easily test the scalar separator case via SQL, because `simplify_concat_ws` casts constant/scalar separators to Utf8. That behavior is dubious and IMO should be changed, but I'll tackle that in a subsequent PR. ## Are there any user-facing changes? No, aside from a previously failing query now succeeding.
1 parent 08c09db commit c699361

File tree

3 files changed

+159
-51
lines changed

3 files changed

+159
-51
lines changed

datafusion/functions/src/string/concat.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,10 @@ impl ScalarUDFImpl for ConcatFunc {
120120
}
121121
});
122122

123-
let array_len = args
124-
.iter()
125-
.filter_map(|x| match x {
126-
ColumnarValue::Array(array) => Some(array.len()),
127-
_ => None,
128-
})
129-
.next();
123+
let array_len = args.iter().find_map(|x| match x {
124+
ColumnarValue::Array(array) => Some(array.len()),
125+
_ => None,
126+
});
130127

131128
// Scalar
132129
if array_len.is_none() {

datafusion/functions/src/string/concat_ws.rs

Lines changed: 136 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{Array, StringArray, as_largestring_array};
18+
use arrow::array::Array;
1919
use std::any::Any;
2020
use std::sync::Arc;
2121

@@ -25,7 +25,9 @@ use crate::string::concat;
2525
use crate::string::concat::simplify_concat;
2626
use crate::string::concat_ws;
2727
use crate::strings::{ColumnarValueRef, StringArrayBuilder};
28-
use datafusion_common::cast::{as_string_array, as_string_view_array};
28+
use datafusion_common::cast::{
29+
as_large_string_array, as_string_array, as_string_view_array,
30+
};
2931
use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err};
3032
use datafusion_expr::expr::ScalarFunction;
3133
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
@@ -105,26 +107,21 @@ impl ScalarUDFImpl for ConcatWsFunc {
105107
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
106108
let ScalarFunctionArgs { args, .. } = args;
107109

108-
// do not accept 0 arguments.
109110
if args.len() < 2 {
110111
return exec_err!(
111112
"concat_ws was called with {} arguments. It requires at least 2.",
112113
args.len()
113114
);
114115
}
115116

116-
let array_len = args
117-
.iter()
118-
.filter_map(|x| match x {
119-
ColumnarValue::Array(array) => Some(array.len()),
120-
_ => None,
121-
})
122-
.next();
117+
let array_len = args.iter().find_map(|x| match x {
118+
ColumnarValue::Array(array) => Some(array.len()),
119+
_ => None,
120+
});
123121

124122
// Scalar
125123
if array_len.is_none() {
126124
let ColumnarValue::Scalar(scalar) = &args[0] else {
127-
// loop above checks for all args being scalar
128125
unreachable!()
129126
};
130127
let sep = match scalar.try_as_str() {
@@ -139,7 +136,6 @@ impl ScalarUDFImpl for ConcatWsFunc {
139136
let mut values = Vec::with_capacity(args.len() - 1);
140137
for arg in &args[1..] {
141138
let ColumnarValue::Scalar(scalar) = arg else {
142-
// loop above checks for all args being scalar
143139
unreachable!()
144140
};
145141

@@ -162,23 +158,53 @@ impl ScalarUDFImpl for ConcatWsFunc {
162158

163159
// parse sep
164160
let sep = match &args[0] {
165-
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => {
166-
data_size += s.len() * len * (args.len() - 2); // estimate
167-
ColumnarValueRef::Scalar(s.as_bytes())
168-
}
169-
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {
170-
return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len))));
171-
}
172-
ColumnarValue::Array(array) => {
173-
let string_array = as_string_array(array)?;
174-
data_size += string_array.values().len() * (args.len() - 2); // estimate
175-
if array.is_nullable() {
176-
ColumnarValueRef::NullableArray(string_array)
177-
} else {
178-
ColumnarValueRef::NonNullableArray(string_array)
161+
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
162+
Some(Some(s)) => {
163+
data_size += s.len() * len * (args.len() - 2); // estimate
164+
ColumnarValueRef::Scalar(s.as_bytes())
179165
}
180-
}
181-
_ => unreachable!("concat ws"),
166+
Some(None) => {
167+
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
168+
}
169+
None => {
170+
return internal_err!("Expected string separator, got {scalar:?}");
171+
}
172+
},
173+
ColumnarValue::Array(array) => match array.data_type() {
174+
DataType::Utf8 => {
175+
let string_array = as_string_array(array)?;
176+
data_size += string_array.values().len() * (args.len() - 2);
177+
if array.is_nullable() {
178+
ColumnarValueRef::NullableArray(string_array)
179+
} else {
180+
ColumnarValueRef::NonNullableArray(string_array)
181+
}
182+
}
183+
DataType::LargeUtf8 => {
184+
let string_array = as_large_string_array(array)?;
185+
data_size += string_array.values().len() * (args.len() - 2);
186+
if array.is_nullable() {
187+
ColumnarValueRef::NullableLargeStringArray(string_array)
188+
} else {
189+
ColumnarValueRef::NonNullableLargeStringArray(string_array)
190+
}
191+
}
192+
DataType::Utf8View => {
193+
let string_array = as_string_view_array(array)?;
194+
data_size +=
195+
string_array.total_buffer_bytes_used() * (args.len() - 2);
196+
if array.is_nullable() {
197+
ColumnarValueRef::NullableStringViewArray(string_array)
198+
} else {
199+
ColumnarValueRef::NonNullableStringViewArray(string_array)
200+
}
201+
}
202+
other => {
203+
return plan_err!(
204+
"Input was {other} which is not a supported datatype for concat_ws separator"
205+
);
206+
}
207+
},
182208
};
183209

184210
let mut columns = Vec::with_capacity(args.len() - 1);
@@ -206,7 +232,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
206232
columns.push(column);
207233
}
208234
DataType::LargeUtf8 => {
209-
let string_array = as_largestring_array(array);
235+
let string_array = as_large_string_array(array)?;
210236

211237
data_size += string_array.values().len();
212238
let column = if array.is_nullable() {
@@ -221,11 +247,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
221247
DataType::Utf8View => {
222248
let string_array = as_string_view_array(array)?;
223249

224-
data_size += string_array
225-
.data_buffers()
226-
.iter()
227-
.map(|buf| buf.len())
228-
.sum::<usize>();
250+
data_size += string_array.total_buffer_bytes_used();
229251
let column = if array.is_nullable() {
230252
ColumnarValueRef::NullableStringViewArray(string_array)
231253
} else {
@@ -251,18 +273,14 @@ impl ScalarUDFImpl for ConcatWsFunc {
251273
continue;
252274
}
253275

254-
let mut iter = columns.iter();
255-
for column in iter.by_ref() {
276+
let mut first = true;
277+
for column in &columns {
256278
if column.is_valid(i) {
279+
if !first {
280+
builder.write::<false>(&sep, i);
281+
}
257282
builder.write::<false>(column, i);
258-
break;
259-
}
260-
}
261-
262-
for column in iter {
263-
if column.is_valid(i) {
264-
builder.write::<false>(&sep, i);
265-
builder.write::<false>(column, i);
283+
first = false;
266284
}
267285
}
268286

@@ -546,4 +564,78 @@ mod tests {
546564

547565
Ok(())
548566
}
567+
568+
#[test]
569+
fn concat_ws_utf8view_scalar_separator() -> Result<()> {
570+
let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
571+
let c1 =
572+
ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
573+
let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
574+
Some("x"),
575+
None,
576+
Some("z"),
577+
])));
578+
579+
let arg_fields = vec![
580+
Field::new("a", Utf8, true).into(),
581+
Field::new("a", Utf8, true).into(),
582+
Field::new("a", Utf8, true).into(),
583+
];
584+
let args = ScalarFunctionArgs {
585+
args: vec![c0, c1, c2],
586+
arg_fields,
587+
number_rows: 3,
588+
return_field: Field::new("f", Utf8, true).into(),
589+
config_options: Arc::new(ConfigOptions::default()),
590+
};
591+
592+
let result = ConcatWsFunc::new().invoke_with_args(args)?;
593+
let expected =
594+
Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
595+
match &result {
596+
ColumnarValue::Array(array) => {
597+
assert_eq!(&expected, array);
598+
}
599+
_ => panic!("Expected array result"),
600+
}
601+
602+
Ok(())
603+
}
604+
605+
#[test]
606+
fn concat_ws_largeutf8_scalar_separator() -> Result<()> {
607+
let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(",".to_string())));
608+
let c1 =
609+
ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
610+
let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
611+
Some("x"),
612+
None,
613+
Some("z"),
614+
])));
615+
616+
let arg_fields = vec![
617+
Field::new("a", Utf8, true).into(),
618+
Field::new("a", Utf8, true).into(),
619+
Field::new("a", Utf8, true).into(),
620+
];
621+
let args = ScalarFunctionArgs {
622+
args: vec![c0, c1, c2],
623+
arg_fields,
624+
number_rows: 3,
625+
return_field: Field::new("f", Utf8, true).into(),
626+
config_options: Arc::new(ConfigOptions::default()),
627+
};
628+
629+
let result = ConcatWsFunc::new().invoke_with_args(args)?;
630+
let expected =
631+
Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
632+
match &result {
633+
ColumnarValue::Array(array) => {
634+
assert_eq!(&expected, array);
635+
}
636+
_ => panic!("Expected array result"),
637+
}
638+
639+
Ok(())
640+
}
549641
}

datafusion/sqllogictest/test_files/expr.slt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,25 @@ abc
504504
statement ok
505505
drop table foo
506506

507+
# concat_ws with a Utf8View column as separator
508+
statement ok
509+
create table test_concat_ws_sep (sep varchar, val1 varchar, val2 varchar) as values (',', 'foo', 'bar'), ('|', 'a', 'b');
510+
511+
query T
512+
SELECT concat_ws(arrow_cast(sep, 'Utf8View'), val1, val2) FROM test_concat_ws_sep ORDER BY val1
513+
----
514+
a|b
515+
foo,bar
516+
517+
query T
518+
SELECT concat_ws(arrow_cast(sep, 'LargeUtf8'), val1, val2) FROM test_concat_ws_sep ORDER BY val1
519+
----
520+
a|b
521+
foo,bar
522+
523+
statement ok
524+
drop table test_concat_ws_sep
525+
507526
query T
508527
SELECT initcap('')
509528
----

0 commit comments

Comments
 (0)