Skip to content

Commit 6952775

Browse files
committed
chore: datafusion-spark substring to support Binary types
1 parent ba038e9 commit 6952775

2 files changed

Lines changed: 353 additions & 40 deletions

File tree

datafusion/spark/src/function/string/substring.rs

Lines changed: 192 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
// under the License.
1717

1818
use arrow::array::{
19-
Array, ArrayBuilder, ArrayRef, AsArray, GenericStringBuilder, Int64Array,
20-
OffsetSizeTrait, StringArrayType, StringViewBuilder,
19+
Array, ArrayAccessor, ArrayBuilder, ArrayRef, AsArray, BinaryViewBuilder,
20+
GenericBinaryBuilder, GenericStringBuilder, Int64Array, OffsetSizeTrait,
21+
StringViewBuilder,
2122
};
2223
use arrow::datatypes::DataType;
2324
use datafusion_common::arrow::datatypes::{Field, FieldRef};
@@ -56,6 +57,7 @@ impl Default for SparkSubstring {
5657
impl SparkSubstring {
5758
pub fn new() -> Self {
5859
let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
60+
let binary = Coercion::new_exact(TypeSignatureClass::Binary);
5961
let int64 = Coercion::new_implicit(
6062
TypeSignatureClass::Native(logical_int64()),
6163
vec![TypeSignatureClass::Native(logical_int32())],
@@ -70,6 +72,12 @@ impl SparkSubstring {
7072
int64.clone(),
7173
int64.clone(),
7274
]),
75+
TypeSignature::Coercible(vec![binary.clone(), int64.clone()]),
76+
TypeSignature::Coercible(vec![
77+
binary.clone(),
78+
int64.clone(),
79+
int64.clone(),
80+
]),
7381
],
7482
Volatility::Immutable,
7583
)
@@ -128,26 +136,65 @@ fn spark_substring(args: &[ArrayRef]) -> Result<ArrayRef> {
128136
};
129137

130138
match args[0].data_type() {
131-
DataType::Utf8 => spark_substring_impl(
132-
&args[0].as_string::<i32>(),
139+
DataType::Utf8 => {
140+
let array = args[0].as_string::<i32>();
141+
let is_ascii = enable_ascii_fast_path(&array, start_array, length_array);
142+
spark_substring_generic(
143+
&array,
144+
start_array,
145+
length_array,
146+
GenericStringBuilder::<i32>::new(),
147+
is_ascii,
148+
)
149+
}
150+
DataType::LargeUtf8 => {
151+
let array = args[0].as_string::<i64>();
152+
let is_ascii = enable_ascii_fast_path(&array, start_array, length_array);
153+
spark_substring_generic(
154+
&array,
155+
start_array,
156+
length_array,
157+
GenericStringBuilder::<i64>::new(),
158+
is_ascii,
159+
)
160+
}
161+
DataType::Utf8View => {
162+
let array = args[0].as_string_view();
163+
let is_ascii = enable_ascii_fast_path(&array, start_array, length_array);
164+
spark_substring_generic(
165+
&array,
166+
start_array,
167+
length_array,
168+
StringViewBuilder::new(),
169+
is_ascii,
170+
)
171+
}
172+
// Binary paths always use byte-level indexing, so `is_ascii` is irrelevant
173+
// and set to `true` (its value is ignored by the `[u8]` impl of
174+
// `SubstringItem`).
175+
DataType::Binary => spark_substring_generic(
176+
&args[0].as_binary::<i32>(),
133177
start_array,
134178
length_array,
135-
GenericStringBuilder::<i32>::new(),
179+
GenericBinaryBuilder::<i32>::new(),
180+
true,
136181
),
137-
DataType::LargeUtf8 => spark_substring_impl(
138-
&args[0].as_string::<i64>(),
182+
DataType::LargeBinary => spark_substring_generic(
183+
&args[0].as_binary::<i64>(),
139184
start_array,
140185
length_array,
141-
GenericStringBuilder::<i64>::new(),
186+
GenericBinaryBuilder::<i64>::new(),
187+
true,
142188
),
143-
DataType::Utf8View => spark_substring_impl(
144-
&args[0].as_string_view(),
189+
DataType::BinaryView => spark_substring_generic(
190+
&args[0].as_binary_view(),
145191
start_array,
146192
length_array,
147-
StringViewBuilder::new(),
193+
BinaryViewBuilder::new(),
194+
true,
148195
),
149196
other => exec_err!(
150-
"Unsupported data type {other:?} for function spark_substring, expected Utf8View, Utf8 or LargeUtf8."
197+
"Unsupported data type {other:?} for function spark_substring, expected Utf8View, Utf8, LargeUtf8, Binary, LargeBinary or BinaryView."
151198
),
152199
}
153200
}
@@ -173,43 +220,156 @@ fn spark_start_to_datafusion_start(start: i64, len: usize) -> i64 {
173220
}
174221
}
175222

176-
trait StringArrayBuilder: ArrayBuilder {
177-
fn append_value(&mut self, val: &str);
223+
trait SubstringItem {
224+
/// Length used for Spark's negative-position adjustment.
225+
/// For `str` this is characters (or bytes in ASCII mode); for `[u8]` it is
226+
/// always byte count.
227+
fn positional_len(&self, is_ascii: bool) -> usize;
228+
229+
/// Converts Spark's 1-indexed adjusted start + optional length into a
230+
/// byte range clamped to `[0, byte_len]`.
231+
fn byte_range(
232+
&self,
233+
adjusted_start: i64,
234+
len: Option<i64>,
235+
is_ascii: bool,
236+
) -> Result<(usize, usize)>;
237+
238+
fn byte_slice(&self, start: usize, end: usize) -> &Self;
239+
}
240+
241+
impl SubstringItem for str {
242+
fn positional_len(&self, is_ascii: bool) -> usize {
243+
if is_ascii {
244+
self.len()
245+
} else {
246+
self.chars().count()
247+
}
248+
}
249+
250+
fn byte_range(
251+
&self,
252+
adjusted_start: i64,
253+
len: Option<i64>,
254+
is_ascii: bool,
255+
) -> Result<(usize, usize)> {
256+
get_true_start_end(self, adjusted_start, len, is_ascii)
257+
}
258+
259+
fn byte_slice(&self, start: usize, end: usize) -> &Self {
260+
&self[start..end]
261+
}
262+
}
263+
264+
impl SubstringItem for [u8] {
265+
fn positional_len(&self, _is_ascii: bool) -> usize {
266+
self.len()
267+
}
268+
269+
fn byte_range(
270+
&self,
271+
adjusted_start: i64,
272+
len: Option<i64>,
273+
_is_ascii: bool,
274+
) -> Result<(usize, usize)> {
275+
let byte_len = self.len();
276+
let start0 = adjusted_start.saturating_sub(1);
277+
let end0 = match len {
278+
Some(l) => start0.saturating_add(l),
279+
None => byte_len as i64,
280+
};
281+
let byte_len_i64 = byte_len as i64;
282+
Ok((
283+
start0.clamp(0, byte_len_i64) as usize,
284+
end0.clamp(0, byte_len_i64) as usize,
285+
))
286+
}
287+
288+
fn byte_slice(&self, start: usize, end: usize) -> &Self {
289+
&self[start..end]
290+
}
291+
}
292+
293+
trait SubstringBuilder: ArrayBuilder {
294+
type Item: SubstringItem + ?Sized;
295+
fn append_value(&mut self, val: &Self::Item);
178296
fn append_null(&mut self);
297+
/// Spark's semantic "empty" for this builder's item type, used for the
298+
/// negative-length short-circuit.
299+
fn append_empty(&mut self);
179300
}
180301

181-
impl<O: OffsetSizeTrait> StringArrayBuilder for GenericStringBuilder<O> {
302+
impl<O: OffsetSizeTrait> SubstringBuilder for GenericStringBuilder<O> {
303+
type Item = str;
182304
fn append_value(&mut self, val: &str) {
183305
GenericStringBuilder::append_value(self, val);
184306
}
185307
fn append_null(&mut self) {
186308
GenericStringBuilder::append_null(self);
187309
}
310+
fn append_empty(&mut self) {
311+
GenericStringBuilder::append_value(self, "");
312+
}
188313
}
189314

190-
impl StringArrayBuilder for StringViewBuilder {
315+
impl SubstringBuilder for StringViewBuilder {
316+
type Item = str;
191317
fn append_value(&mut self, val: &str) {
192318
StringViewBuilder::append_value(self, val);
193319
}
194320
fn append_null(&mut self) {
195321
StringViewBuilder::append_null(self);
196322
}
323+
fn append_empty(&mut self) {
324+
StringViewBuilder::append_value(self, "");
325+
}
197326
}
198327

199-
fn spark_substring_impl<'a, V, B>(
200-
string_array: &V,
328+
impl<O: OffsetSizeTrait> SubstringBuilder for GenericBinaryBuilder<O> {
329+
type Item = [u8];
330+
fn append_value(&mut self, val: &[u8]) {
331+
GenericBinaryBuilder::append_value(self, val);
332+
}
333+
fn append_null(&mut self) {
334+
GenericBinaryBuilder::append_null(self);
335+
}
336+
fn append_empty(&mut self) {
337+
GenericBinaryBuilder::append_value(self, &[]);
338+
}
339+
}
340+
341+
impl SubstringBuilder for BinaryViewBuilder {
342+
type Item = [u8];
343+
fn append_value(&mut self, val: &[u8]) {
344+
BinaryViewBuilder::append_value(self, val);
345+
}
346+
fn append_null(&mut self) {
347+
BinaryViewBuilder::append_null(self);
348+
}
349+
fn append_empty(&mut self) {
350+
BinaryViewBuilder::append_value(self, &[]);
351+
}
352+
}
353+
354+
/// Unified implementation of Spark's `substring`, generic over the source
355+
/// array (`StringArrayType`/`BinaryArrayType` via `ArrayAccessor`) and its
356+
/// corresponding builder. Per-row indexing semantics are delegated to
357+
/// [`SubstringItem`], which differs between `str` (char-aware when
358+
/// `is_ascii` is false) and `[u8]` (always byte-level).
359+
fn spark_substring_generic<'a, Source, Item, Builder>(
360+
array: &Source,
201361
start_array: &Int64Array,
202362
length_array: Option<&Int64Array>,
203-
mut builder: B,
363+
mut builder: Builder,
364+
is_ascii: bool,
204365
) -> Result<ArrayRef>
205366
where
206-
V: StringArrayType<'a>,
207-
B: StringArrayBuilder,
367+
Source: ArrayAccessor<Item = &'a Item>,
368+
Item: SubstringItem + ?Sized + 'a,
369+
Builder: SubstringBuilder<Item = Item>,
208370
{
209-
let is_ascii = enable_ascii_fast_path(string_array, start_array, length_array);
210-
211-
for i in 0..string_array.len() {
212-
if string_array.is_null(i) || start_array.is_null(i) {
371+
for i in 0..array.len() {
372+
if array.is_null(i) || start_array.is_null(i) {
213373
builder.append_null();
214374
continue;
215375
}
@@ -221,30 +381,23 @@ where
221381
continue;
222382
}
223383

224-
let string = string_array.value(i);
384+
let value = array.value(i);
225385
let start = start_array.value(i);
226386
let len_opt = length_array.map(|arr| arr.value(i));
227387

228-
// Spark: negative length returns empty string
388+
// Spark: negative length yields an empty value
229389
if let Some(len) = len_opt
230390
&& len < 0
231391
{
232-
builder.append_value("");
392+
builder.append_empty();
233393
continue;
234394
}
235395

236-
let string_len = if is_ascii {
237-
string.len()
238-
} else {
239-
string.chars().count()
240-
};
241-
242-
let adjusted_start = spark_start_to_datafusion_start(start, string_len);
243-
396+
let positional_len = value.positional_len(is_ascii);
397+
let adjusted_start = spark_start_to_datafusion_start(start, positional_len);
244398
let (byte_start, byte_end) =
245-
get_true_start_end(string, adjusted_start, len_opt, is_ascii)?;
246-
let substr = &string[byte_start..byte_end];
247-
builder.append_value(substr);
399+
value.byte_range(adjusted_start, len_opt, is_ascii)?;
400+
builder.append_value(value.byte_slice(byte_start, byte_end));
248401
}
249402

250403
Ok(builder.finish())

0 commit comments

Comments
 (0)