Skip to content

Commit e6c07fc

Browse files
committed
.
1 parent ba038e9 commit e6c07fc

3 files changed

Lines changed: 314 additions & 61 deletions

File tree

datafusion/functions/src/string/common.rs

Lines changed: 163 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
2020
use std::sync::Arc;
2121

22-
use crate::strings::{GenericStringArrayBuilder, StringViewArrayBuilder, append_view};
22+
use crate::strings::{
23+
GenericStringArrayBuilder, STRING_VIEW_INIT_BLOCK_SIZE, STRING_VIEW_MAX_BLOCK_SIZE,
24+
StringViewArrayBuilder, append_view,
25+
};
2326
use arrow::array::{
2427
Array, ArrayRef, GenericStringArray, NullBufferBuilder, OffsetSizeTrait,
2528
StringViewArray, new_null_array,
@@ -323,32 +326,42 @@ fn string_trim<T: OffsetSizeTrait, Tr: Trimmer>(args: &[ArrayRef]) -> Result<Arr
323326
}
324327

325328
pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
326-
case_conversion(args, |string| string.to_lowercase(), name)
329+
case_conversion(args, true, name)
327330
}
328331

329332
pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
330-
case_conversion(args, |string| string.to_uppercase(), name)
333+
case_conversion(args, false, name)
334+
}
335+
336+
#[inline]
337+
fn unicode_case(s: &str, lower: bool) -> String {
338+
if lower {
339+
s.to_lowercase()
340+
} else {
341+
s.to_uppercase()
342+
}
331343
}
332344

333-
fn case_conversion<'a, F>(
334-
args: &'a [ColumnarValue],
335-
op: F,
345+
fn case_conversion(
346+
args: &[ColumnarValue],
347+
lower: bool,
336348
name: &str,
337-
) -> Result<ColumnarValue>
338-
where
339-
F: Fn(&'a str) -> String,
340-
{
349+
) -> Result<ColumnarValue> {
341350
match &args[0] {
342351
ColumnarValue::Array(array) => match array.data_type() {
343-
DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
344-
array, op,
352+
DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32>(
353+
array, lower,
345354
)?)),
346-
DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
347-
i64,
348-
_,
349-
>(array, op)?)),
355+
DataType::LargeUtf8 => Ok(ColumnarValue::Array(
356+
case_conversion_array::<i64>(array, lower)?,
357+
)),
350358
DataType::Utf8View => {
351359
let string_array = as_string_view_array(array)?;
360+
if string_array.is_ascii() {
361+
return Ok(ColumnarValue::Array(Arc::new(
362+
case_conversion_utf8view_ascii(string_array, lower),
363+
)));
364+
}
352365
let item_len = string_array.len();
353366
// Null-preserving: reuse the input null buffer as the output null buffer.
354367
let nulls = string_array.nulls().cloned();
@@ -361,14 +374,14 @@ where
361374
} else {
362375
// SAFETY: `n.is_null(i)` was false in the branch above.
363376
let s = unsafe { string_array.value_unchecked(i) };
364-
builder.append_value(&op(s));
377+
builder.append_value(&unicode_case(s, lower));
365378
}
366379
}
367380
} else {
368381
for i in 0..item_len {
369382
// SAFETY: no null buffer means every index is valid.
370383
let s = unsafe { string_array.value_unchecked(i) };
371-
builder.append_value(&op(s));
384+
builder.append_value(&unicode_case(s, lower));
372385
}
373386
}
374387

@@ -378,32 +391,31 @@ where
378391
},
379392
ColumnarValue::Scalar(scalar) => match scalar {
380393
ScalarValue::Utf8(a) => {
381-
let result = a.as_ref().map(|x| op(x));
394+
let result = a.as_ref().map(|x| unicode_case(x, lower));
382395
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
383396
}
384397
ScalarValue::LargeUtf8(a) => {
385-
let result = a.as_ref().map(|x| op(x));
398+
let result = a.as_ref().map(|x| unicode_case(x, lower));
386399
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
387400
}
388401
ScalarValue::Utf8View(a) => {
389-
let result = a.as_ref().map(|x| op(x));
402+
let result = a.as_ref().map(|x| unicode_case(x, lower));
390403
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(result)))
391404
}
392405
other => exec_err!("Unsupported data type {other:?} for function {name}"),
393406
},
394407
}
395408
}
396409

397-
fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
398-
where
399-
O: OffsetSizeTrait,
400-
F: Fn(&'a str) -> String,
401-
{
410+
fn case_conversion_array<O: OffsetSizeTrait>(
411+
array: &ArrayRef,
412+
lower: bool,
413+
) -> Result<ArrayRef> {
402414
const PRE_ALLOC_BYTES: usize = 8;
403415

404416
let string_array = as_generic_string_array::<O>(array)?;
405417
if string_array.is_ascii() {
406-
return case_conversion_ascii_array::<O, _>(string_array, op);
418+
return case_conversion_ascii_array::<O>(string_array, lower);
407419
}
408420

409421
// Values contain non-ASCII.
@@ -423,43 +435,147 @@ where
423435
} else {
424436
// SAFETY: `n.is_null(i)` was false in the branch above.
425437
let s = unsafe { string_array.value_unchecked(i) };
426-
builder.append_value(&op(s));
438+
builder.append_value(&unicode_case(s, lower));
427439
}
428440
}
429441
} else {
430442
for i in 0..item_len {
431443
// SAFETY: no null buffer means every index is valid.
432444
let s = unsafe { string_array.value_unchecked(i) };
433-
builder.append_value(&op(s));
445+
builder.append_value(&unicode_case(s, lower));
434446
}
435447
}
436448
Ok(Arc::new(builder.finish(nulls)?))
437449
}
438450

451+
/// Fast path for case conversion on an all-ASCII `StringViewArray`.
452+
fn case_conversion_utf8view_ascii(
453+
array: &StringViewArray,
454+
lower: bool,
455+
) -> StringViewArray {
456+
// Specialize per conversion so the byte call inlines in the hot loops below.
457+
if lower {
458+
case_conversion_utf8view_ascii_inner(array, u8::to_ascii_lowercase)
459+
} else {
460+
case_conversion_utf8view_ascii_inner(array, u8::to_ascii_uppercase)
461+
}
462+
}
463+
464+
/// Walks the views once: inline rows (length ≤ 12) convert their inline bytes
465+
/// in place; long rows copy their referenced bytes into a single packed output
466+
/// buffer while converting, then rewrite the view (`buffer_index = 0`, new
467+
/// offset, new 4-byte prefix) to point at it.
468+
fn case_conversion_utf8view_ascii_inner<F: Fn(&u8) -> u8>(
469+
array: &StringViewArray,
470+
convert: F,
471+
) -> StringViewArray {
472+
let item_len = array.len();
473+
let views = array.views();
474+
let data_buffers = array.data_buffers();
475+
let nulls = array.nulls();
476+
477+
let mut new_views: Vec<u128> = Vec::with_capacity(item_len);
478+
let mut in_progress: Vec<u8> = Vec::new();
479+
let mut completed: Vec<Buffer> = Vec::new();
480+
let mut block_size: u32 = STRING_VIEW_INIT_BLOCK_SIZE;
481+
482+
for i in 0..item_len {
483+
if nulls.is_some_and(|n| n.is_null(i)) {
484+
new_views.push(0);
485+
continue;
486+
}
487+
let view = views[i];
488+
let len = view as u32 as usize;
489+
if len == 0 {
490+
new_views.push(0);
491+
continue;
492+
}
493+
let mut bytes = view.to_le_bytes();
494+
if len <= 12 {
495+
// Inline row: convert the inline data bytes; layout unchanged.
496+
for b in &mut bytes[4..4 + len] {
497+
*b = convert(b);
498+
}
499+
new_views.push(u128::from_le_bytes(bytes));
500+
} else {
501+
// Make sure the current data block has room for this value;
502+
// otherwise flush and start a new, larger block.
503+
let required_cap = in_progress.len() + len;
504+
if in_progress.capacity() < required_cap {
505+
if !in_progress.is_empty() {
506+
completed.push(Buffer::from_vec(std::mem::take(&mut in_progress)));
507+
}
508+
if block_size < STRING_VIEW_MAX_BLOCK_SIZE {
509+
block_size = block_size.saturating_mul(2);
510+
}
511+
let to_reserve = len.max(block_size as usize);
512+
in_progress.reserve(to_reserve);
513+
}
514+
515+
let buffer_index: u32 = i32::try_from(completed.len())
516+
.expect("buffer count exceeds i32::MAX")
517+
as u32;
518+
let new_offset: u32 =
519+
i32::try_from(in_progress.len()).expect("offset exceeds i32::MAX") as u32;
520+
521+
let src_buffer_index =
522+
u32::from_le_bytes(bytes[8..12].try_into().unwrap()) as usize;
523+
let src_offset =
524+
u32::from_le_bytes(bytes[12..16].try_into().unwrap()) as usize;
525+
let src =
526+
&data_buffers[src_buffer_index].as_slice()[src_offset..src_offset + len];
527+
528+
let prefix_start = in_progress.len();
529+
in_progress.extend(src.iter().map(&convert));
530+
531+
// Prefix is the first 4 bytes of the converted data we just wrote.
532+
let prefix: [u8; 4] = in_progress[prefix_start..prefix_start + 4]
533+
.try_into()
534+
.unwrap();
535+
bytes[4..8].copy_from_slice(&prefix);
536+
bytes[8..12].copy_from_slice(&buffer_index.to_le_bytes());
537+
bytes[12..16].copy_from_slice(&new_offset.to_le_bytes());
538+
new_views.push(u128::from_le_bytes(bytes));
539+
}
540+
}
541+
542+
if !in_progress.is_empty() {
543+
completed.push(Buffer::from_vec(in_progress));
544+
}
545+
546+
// SAFETY: each long view's buffer_index addresses a buffer we wrote, and
547+
// its offset addresses bytes within that buffer; prefixes were copied from
548+
// those same bytes; inline views were rewritten from valid inline bytes;
549+
// null/empty rows are zero views with no buffer reference; row count is
550+
// unchanged.
551+
unsafe {
552+
StringViewArray::new_unchecked(
553+
ScalarBuffer::from(new_views),
554+
completed,
555+
array.nulls().cloned(),
556+
)
557+
}
558+
}
559+
439560
/// Fast path for case conversion on an all-ASCII string array. ASCII case
440561
/// conversion is byte-length-preserving, so we can convert the entire addressed
441-
/// range in one call and reuse the offsets and nulls buffers — rebasing the
442-
/// offsets when the input is a sliced array.
443-
fn case_conversion_ascii_array<'a, O, F>(
444-
string_array: &'a GenericStringArray<O>,
445-
op: F,
446-
) -> Result<ArrayRef>
447-
where
448-
O: OffsetSizeTrait,
449-
F: Fn(&'a str) -> String,
450-
{
562+
/// byte range in one pass over the value buffer and reuse the offsets and nulls
563+
/// buffers — rebasing the offsets when the input is a sliced array.
564+
fn case_conversion_ascii_array<O: OffsetSizeTrait>(
565+
string_array: &GenericStringArray<O>,
566+
lower: bool,
567+
) -> Result<ArrayRef> {
451568
let value_offsets = string_array.value_offsets();
452569
let start = value_offsets.first().unwrap().as_usize();
453570
let end = value_offsets.last().unwrap().as_usize();
454571
let relevant = &string_array.value_data()[start..end];
455572

456-
// SAFETY: `relevant` is a subslice of the string array's value buffer,
457-
// which is valid UTF-8.
458-
let str_values = unsafe { std::str::from_utf8_unchecked(relevant) };
459-
460-
let converted_values = op(str_values);
461-
debug_assert_eq!(converted_values.len(), str_values.len());
462-
let values = Buffer::from_vec(converted_values.into_bytes());
573+
let converted: Vec<u8> = if lower {
574+
relevant.iter().map(u8::to_ascii_lowercase).collect()
575+
} else {
576+
relevant.iter().map(u8::to_ascii_uppercase).collect()
577+
};
578+
let values = Buffer::from_vec(converted);
463579

464580
// Shift offsets from `start`-based to 0-based so they index into `values`.
465581
let offsets = if start == 0 {
@@ -468,7 +584,7 @@ where
468584
let s = O::usize_as(start);
469585
let rebased: Vec<O> = value_offsets.iter().map(|&o| o - s).collect();
470586
// SAFETY: subtracting a constant from monotonic offsets preserves
471-
// monotonicity, and `start` is the minimum offset so no underflow.
587+
// monotonicity, and `start` is the minimum offset, so no underflow.
472588
unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(rebased)) }
473589
};
474590

0 commit comments

Comments
 (0)