Skip to content

Commit d555221

Browse files
authored
feat: Support for StringSplit (#2772)
1 parent dffcb1b commit d555221

7 files changed

Lines changed: 500 additions & 2 deletions

File tree

fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ object Meta {
200200
FunctionSignature(Seq(SparkStringType, SparkIntegralType)),
201201
FunctionSignature(Seq(SparkStringType, SparkIntegralType, SparkStringType)))),
202202
createUnaryStringFunction("rtrim"),
203+
createFunctions(
204+
"split",
205+
Seq(
206+
FunctionSignature(Seq(SparkStringType, SparkStringType)),
207+
FunctionSignature(Seq(SparkStringType, SparkStringType, SparkIntType)))),
203208
createFunctionWithInputTypes("starts_with", Seq(SparkStringType, SparkStringType)),
204209
createFunctionWithInputTypes("string_space", Seq(SparkIntType)),
205210
createFunctionWithInputTypes("substring", Seq(SparkStringType, SparkIntType, SparkIntType)),

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@ pub fn create_comet_physical_fun_with_eval_mode(
181181
let func = Arc::new(abs);
182182
make_comet_scalar_udf!("abs", func, without data_type)
183183
}
184+
"split" => {
185+
let func = Arc::new(crate::string_funcs::spark_split);
186+
make_comet_scalar_udf!("split", func, without data_type)
187+
}
184188
_ => registry.udf(fun_name).map_err(|e| {
185189
DataFusionError::Execution(format!(
186190
"Function {fun_name} not found in the registry: {e}",

native/spark-expr/src/string_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
// under the License.
1717

1818
mod contains;
19+
mod split;
1920
mod string_space;
2021
mod substring;
2122

2223
pub use contains::SparkContains;
24+
pub use split::spark_split;
2325
pub use string_space::SparkStringSpace;
2426
pub use substring::SubstringExpr;
Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Array, ArrayRef, GenericStringArray, ListArray};
19+
use arrow::datatypes::{DataType, Field};
20+
use datafusion::common::{
21+
cast::as_generic_string_array, exec_err, DataFusionError, Result as DataFusionResult,
22+
ScalarValue,
23+
};
24+
use datafusion::logical_expr::ColumnarValue;
25+
use regex::Regex;
26+
use std::sync::Arc;
27+
28+
/// Spark-compatible split function
29+
/// Splits a string around matches of a regex pattern with optional limit
30+
///
31+
/// Arguments:
32+
/// - string: The string to split
33+
/// - pattern: The regex pattern to split on
34+
/// - limit (optional): Controls the number of splits
35+
/// - limit > 0: At most limit-1 splits, array length <= limit
36+
/// - limit = 0: As many splits as possible, trailing empty strings removed
37+
/// - limit < 0: As many splits as possible, trailing empty strings kept
38+
pub fn spark_split(args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
39+
if args.len() < 2 || args.len() > 3 {
40+
return exec_err!(
41+
"split expects 2 or 3 arguments (string, pattern, [limit]), got {}",
42+
args.len()
43+
);
44+
}
45+
46+
// Get limit parameter (default to -1 if not provided)
47+
let limit = if args.len() == 3 {
48+
match &args[2] {
49+
ColumnarValue::Scalar(ScalarValue::Int32(Some(l))) => *l,
50+
ColumnarValue::Scalar(ScalarValue::Int32(None)) => {
51+
// NULL limit, return NULL
52+
return Ok(ColumnarValue::Scalar(ScalarValue::Null));
53+
}
54+
_ => {
55+
return exec_err!("split limit argument must be an Int32 scalar");
56+
}
57+
}
58+
} else {
59+
-1
60+
};
61+
62+
match (&args[0], &args[1]) {
63+
(ColumnarValue::Array(string_array), ColumnarValue::Scalar(ScalarValue::Utf8(pattern)))
64+
| (
65+
ColumnarValue::Array(string_array),
66+
ColumnarValue::Scalar(ScalarValue::LargeUtf8(pattern)),
67+
) => {
68+
if pattern.is_none() {
69+
// NULL pattern returns NULL
70+
let null_array = new_null_list_array(string_array.len());
71+
return Ok(ColumnarValue::Array(null_array));
72+
}
73+
74+
let pattern_str = pattern.as_ref().unwrap();
75+
split_array(string_array.as_ref(), pattern_str, limit)
76+
}
77+
(ColumnarValue::Scalar(ScalarValue::Utf8(string)), ColumnarValue::Scalar(pattern_val))
78+
| (
79+
ColumnarValue::Scalar(ScalarValue::LargeUtf8(string)),
80+
ColumnarValue::Scalar(pattern_val),
81+
) => {
82+
if string.is_none() {
83+
return Ok(ColumnarValue::Scalar(ScalarValue::Null));
84+
}
85+
86+
let pattern_str = match pattern_val {
87+
ScalarValue::Utf8(Some(p)) | ScalarValue::LargeUtf8(Some(p)) => p,
88+
ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) => {
89+
return Ok(ColumnarValue::Scalar(ScalarValue::Null));
90+
}
91+
_ => {
92+
return exec_err!("split pattern must be a string");
93+
}
94+
};
95+
96+
let result = split_string(string.as_ref().unwrap(), pattern_str, limit)?;
97+
let string_array = GenericStringArray::<i32>::from(result);
98+
let list_array = create_list_array(Arc::new(string_array));
99+
100+
Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new(
101+
list_array,
102+
))))
103+
}
104+
_ => exec_err!("split expects (array, scalar) or (scalar, scalar) arguments"),
105+
}
106+
}
107+
108+
fn split_array(
109+
string_array: &dyn arrow::array::Array,
110+
pattern: &str,
111+
limit: i32,
112+
) -> DataFusionResult<ColumnarValue> {
113+
// Compile regex once for the entire array
114+
let regex = Regex::new(pattern).map_err(|e| {
115+
DataFusionError::Execution(format!("Invalid regex pattern '{}': {}", pattern, e))
116+
})?;
117+
118+
let string_array = match string_array.data_type() {
119+
DataType::Utf8 => as_generic_string_array::<i32>(string_array)?,
120+
DataType::LargeUtf8 => {
121+
// Convert LargeUtf8 to Utf8 for processing
122+
let large_array = as_generic_string_array::<i64>(string_array)?;
123+
return split_large_string_array(large_array, &regex, limit);
124+
}
125+
_ => {
126+
return exec_err!(
127+
"split expects Utf8 or LargeUtf8 string array, got {:?}",
128+
string_array.data_type()
129+
);
130+
}
131+
};
132+
133+
// Build the result ListArray
134+
let mut offsets: Vec<i32> = Vec::with_capacity(string_array.len() + 1);
135+
let mut values: Vec<String> = Vec::new();
136+
let mut null_buffer_builder = arrow::array::BooleanBufferBuilder::new(string_array.len());
137+
offsets.push(0);
138+
139+
for i in 0..string_array.len() {
140+
if string_array.is_null(i) {
141+
// NULL input produces NULL in result (Spark behavior)
142+
offsets.push(offsets[i]);
143+
null_buffer_builder.append(false); // false = NULL
144+
} else {
145+
let string_val = string_array.value(i);
146+
let parts = split_with_regex(string_val, &regex, limit);
147+
values.extend(parts);
148+
offsets.push(values.len() as i32);
149+
null_buffer_builder.append(true); // true = valid
150+
}
151+
}
152+
153+
let values_array = Arc::new(GenericStringArray::<i32>::from(values)) as ArrayRef;
154+
let field = Arc::new(Field::new("item", DataType::Utf8, false));
155+
let nulls = arrow::buffer::NullBuffer::new(null_buffer_builder.finish());
156+
let list_array = ListArray::new(
157+
field,
158+
arrow::buffer::OffsetBuffer::new(offsets.into()),
159+
values_array,
160+
Some(nulls),
161+
);
162+
163+
Ok(ColumnarValue::Array(Arc::new(list_array)))
164+
}
165+
166+
fn split_large_string_array(
167+
string_array: &GenericStringArray<i64>,
168+
regex: &Regex,
169+
limit: i32,
170+
) -> DataFusionResult<ColumnarValue> {
171+
let mut offsets: Vec<i32> = Vec::with_capacity(string_array.len() + 1);
172+
let mut values: Vec<String> = Vec::new();
173+
let mut null_buffer_builder = arrow::array::BooleanBufferBuilder::new(string_array.len());
174+
offsets.push(0);
175+
176+
for i in 0..string_array.len() {
177+
if string_array.is_null(i) {
178+
// NULL input produces NULL in result (Spark behavior)
179+
offsets.push(offsets[i]);
180+
null_buffer_builder.append(false); // false = NULL
181+
} else {
182+
let string_val = string_array.value(i);
183+
let parts = split_with_regex(string_val, regex, limit);
184+
values.extend(parts);
185+
offsets.push(values.len() as i32);
186+
null_buffer_builder.append(true); // true = valid
187+
}
188+
}
189+
190+
let values_array = Arc::new(GenericStringArray::<i32>::from(values)) as ArrayRef;
191+
let field = Arc::new(Field::new("item", DataType::Utf8, false));
192+
let nulls = arrow::buffer::NullBuffer::new(null_buffer_builder.finish());
193+
let list_array = ListArray::new(
194+
field,
195+
arrow::buffer::OffsetBuffer::new(offsets.into()),
196+
values_array,
197+
Some(nulls),
198+
);
199+
200+
Ok(ColumnarValue::Array(Arc::new(list_array)))
201+
}
202+
203+
fn split_string(string: &str, pattern: &str, limit: i32) -> DataFusionResult<Vec<String>> {
204+
let regex = Regex::new(pattern).map_err(|e| {
205+
DataFusionError::Execution(format!("Invalid regex pattern '{}': {}", pattern, e))
206+
})?;
207+
208+
Ok(split_with_regex(string, &regex, limit))
209+
}
210+
211+
fn split_with_regex(string: &str, regex: &Regex, limit: i32) -> Vec<String> {
212+
if limit == 0 {
213+
// limit = 0: split as many times as possible, discard trailing empty strings
214+
let mut parts: Vec<String> = regex.split(string).map(|s| s.to_string()).collect();
215+
// Remove trailing empty strings
216+
while parts.last().is_some_and(|s| s.is_empty()) {
217+
parts.pop();
218+
}
219+
if parts.is_empty() {
220+
vec!["".to_string()]
221+
} else {
222+
parts
223+
}
224+
} else if limit > 0 {
225+
// limit > 0: at most limit-1 splits (array length <= limit)
226+
let mut parts: Vec<String> = Vec::new();
227+
let mut last_end = 0;
228+
229+
for (count, mat) in regex.find_iter(string).enumerate() {
230+
if count >= (limit - 1) as usize {
231+
break;
232+
}
233+
parts.push(string[last_end..mat.start()].to_string());
234+
last_end = mat.end();
235+
}
236+
// Add the remaining string
237+
parts.push(string[last_end..].to_string());
238+
parts
239+
} else {
240+
// limit < 0: split as many times as possible, keep trailing empty strings
241+
regex.split(string).map(|s| s.to_string()).collect()
242+
}
243+
}
244+
245+
fn create_list_array(values: ArrayRef) -> ListArray {
246+
let field = Arc::new(Field::new("item", DataType::Utf8, false));
247+
let offsets = vec![0i32, values.len() as i32];
248+
ListArray::new(
249+
field,
250+
arrow::buffer::OffsetBuffer::new(offsets.into()),
251+
values,
252+
None,
253+
)
254+
}
255+
256+
fn new_null_list_array(len: usize) -> ArrayRef {
257+
let field = Arc::new(Field::new("item", DataType::Utf8, false));
258+
let values = Arc::new(GenericStringArray::<i32>::from(Vec::<String>::new())) as ArrayRef;
259+
let offsets = vec![0i32; len + 1];
260+
let nulls = arrow::buffer::NullBuffer::new_null(len);
261+
262+
Arc::new(ListArray::new(
263+
field,
264+
arrow::buffer::OffsetBuffer::new(offsets.into()),
265+
values,
266+
Some(nulls),
267+
))
268+
}
269+
270+
#[cfg(test)]
271+
mod tests {
272+
use super::*;
273+
use arrow::array::StringArray;
274+
275+
#[test]
276+
fn test_split_basic() {
277+
let string_array = Arc::new(StringArray::from(vec!["a,b,c", "x,y,z"])) as ArrayRef;
278+
let pattern = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
279+
let args = vec![ColumnarValue::Array(string_array), pattern];
280+
281+
let result = spark_split(&args).unwrap();
282+
// Should produce [["a", "b", "c"], ["x", "y", "z"]]
283+
assert!(matches!(result, ColumnarValue::Array(_)));
284+
}
285+
286+
#[test]
287+
fn test_split_with_limit() {
288+
let string_array = Arc::new(StringArray::from(vec!["a,b,c,d"])) as ArrayRef;
289+
let pattern = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
290+
let limit = ColumnarValue::Scalar(ScalarValue::Int32(Some(2)));
291+
let args = vec![ColumnarValue::Array(string_array), pattern, limit];
292+
293+
let result = spark_split(&args).unwrap();
294+
// Should produce [["a", "b,c,d"]]
295+
assert!(matches!(result, ColumnarValue::Array(_)));
296+
}
297+
298+
#[test]
299+
fn test_split_regex() {
300+
let parts = split_string("foo123bar456baz", r"\d+", -1).unwrap();
301+
assert_eq!(parts, vec!["foo", "bar", "baz"]);
302+
}
303+
304+
#[test]
305+
fn test_split_limit_positive() {
306+
let parts = split_string("a,b,c,d,e", ",", 3).unwrap();
307+
assert_eq!(parts, vec!["a", "b", "c,d,e"]);
308+
}
309+
310+
#[test]
311+
fn test_split_limit_zero() {
312+
let parts = split_string("a,b,c,,", ",", 0).unwrap();
313+
assert_eq!(parts, vec!["a", "b", "c"]);
314+
}
315+
316+
#[test]
317+
fn test_split_limit_negative() {
318+
let parts = split_string("a,b,c,,", ",", -1).unwrap();
319+
assert_eq!(parts, vec!["a", "b", "c", "", ""]);
320+
}
321+
322+
#[test]
323+
fn test_split_with_nulls() {
324+
// Test that NULL inputs produce NULL outputs (not empty arrays)
325+
let string_array = Arc::new(StringArray::from(vec![
326+
Some("a,b,c"),
327+
None,
328+
Some("x,y"),
329+
None,
330+
])) as ArrayRef;
331+
let pattern = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
332+
let args = vec![ColumnarValue::Array(string_array), pattern];
333+
334+
let result = spark_split(&args).unwrap();
335+
match result {
336+
ColumnarValue::Array(arr) => {
337+
let list_array = arr.as_any().downcast_ref::<ListArray>().unwrap();
338+
assert_eq!(list_array.len(), 4);
339+
// First row: valid ["a", "b", "c"]
340+
assert!(!list_array.is_null(0));
341+
// Second row: NULL
342+
assert!(list_array.is_null(1));
343+
// Third row: valid ["x", "y"]
344+
assert!(!list_array.is_null(2));
345+
// Fourth row: NULL
346+
assert!(list_array.is_null(3));
347+
}
348+
_ => panic!("Expected Array result"),
349+
}
350+
}
351+
352+
#[test]
353+
fn test_split_empty_string() {
354+
// Test that empty string input produces array with single empty string
355+
let parts = split_string("", ",", -1).unwrap();
356+
assert_eq!(parts, vec![""]);
357+
}
358+
}

0 commit comments

Comments
 (0)