Skip to content

Commit d019b9d

Browse files
authored
Merge branch 'main' into feat/ceil-two-args
2 parents b9098ce + 7d107f0 commit d019b9d

6 files changed

Lines changed: 510 additions & 80 deletions

File tree

datafusion/core/tests/physical_optimizer/partition_statistics.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,6 @@ mod test {
381381
let filter: Arc<dyn ExecutionPlan> =
382382
Arc::new(FilterExec::try_new(predicate, scan)?);
383383
let full_statistics = filter.partition_statistics(None)?;
384-
// Filter preserves original total_rows and byte_size from input
385-
// (4 total rows = 2 partitions * 2 rows each, byte_size = 4 * 4 = 16 bytes for int32)
386384
let expected_full_statistic = Statistics {
387385
num_rows: Precision::Inexact(0),
388386
total_byte_size: Precision::Inexact(0),
@@ -393,15 +391,15 @@ mod test {
393391
min_value: Precision::Exact(ScalarValue::Int32(None)),
394392
sum_value: Precision::Exact(ScalarValue::Int32(None)),
395393
distinct_count: Precision::Exact(0),
396-
byte_size: Precision::Exact(16),
394+
byte_size: Precision::Exact(0),
397395
},
398396
ColumnStatistics {
399397
null_count: Precision::Exact(0),
400398
max_value: Precision::Exact(ScalarValue::Date32(None)),
401399
min_value: Precision::Exact(ScalarValue::Date32(None)),
402400
sum_value: Precision::Exact(ScalarValue::Date32(None)),
403401
distinct_count: Precision::Exact(0),
404-
byte_size: Precision::Exact(16), // 4 rows * 4 bytes (Date32)
402+
byte_size: Precision::Exact(0),
405403
},
406404
],
407405
};
@@ -411,7 +409,6 @@ mod test {
411409
.map(|idx| filter.partition_statistics(Some(idx)))
412410
.collect::<Result<Vec<_>>>()?;
413411
assert_eq!(statistics.len(), 2);
414-
// Per-partition stats: each partition has 2 rows, byte_size = 2 * 4 = 8
415412
let expected_partition_statistic = Statistics {
416413
num_rows: Precision::Inexact(0),
417414
total_byte_size: Precision::Inexact(0),
@@ -422,15 +419,15 @@ mod test {
422419
min_value: Precision::Exact(ScalarValue::Int32(None)),
423420
sum_value: Precision::Exact(ScalarValue::Int32(None)),
424421
distinct_count: Precision::Exact(0),
425-
byte_size: Precision::Exact(8),
422+
byte_size: Precision::Exact(0),
426423
},
427424
ColumnStatistics {
428425
null_count: Precision::Exact(0),
429426
max_value: Precision::Exact(ScalarValue::Date32(None)),
430427
min_value: Precision::Exact(ScalarValue::Date32(None)),
431428
sum_value: Precision::Exact(ScalarValue::Date32(None)),
432429
distinct_count: Precision::Exact(0),
433-
byte_size: Precision::Exact(8), // 2 rows * 4 bytes (Date32)
430+
byte_size: Precision::Exact(0),
434431
},
435432
],
436433
};

datafusion/functions/src/string/common.rs

Lines changed: 180 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
2020
use std::sync::Arc;
2121

22-
use crate::strings::{GenericStringArrayBuilder, StringViewArrayBuilder, append_view};
22+
use crate::strings::{
23+
GenericStringArrayBuilder, STRING_VIEW_INIT_BLOCK_SIZE, STRING_VIEW_MAX_BLOCK_SIZE,
24+
StringViewArrayBuilder, append_view,
25+
};
2326
use arrow::array::{
2427
Array, ArrayRef, GenericStringArray, NullBufferBuilder, OffsetSizeTrait,
2528
StringViewArray, new_null_array,
@@ -323,32 +326,42 @@ fn string_trim<T: OffsetSizeTrait, Tr: Trimmer>(args: &[ArrayRef]) -> Result<Arr
323326
}
324327

325328
pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
326-
case_conversion(args, |string| string.to_lowercase(), name)
329+
case_conversion(args, true, name)
327330
}
328331

329332
pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
330-
case_conversion(args, |string| string.to_uppercase(), name)
333+
case_conversion(args, false, name)
334+
}
335+
336+
#[inline]
337+
fn unicode_case(s: &str, lower: bool) -> String {
338+
if lower {
339+
s.to_lowercase()
340+
} else {
341+
s.to_uppercase()
342+
}
331343
}
332344

333-
fn case_conversion<'a, F>(
334-
args: &'a [ColumnarValue],
335-
op: F,
345+
fn case_conversion(
346+
args: &[ColumnarValue],
347+
lower: bool,
336348
name: &str,
337-
) -> Result<ColumnarValue>
338-
where
339-
F: Fn(&'a str) -> String,
340-
{
349+
) -> Result<ColumnarValue> {
341350
match &args[0] {
342351
ColumnarValue::Array(array) => match array.data_type() {
343-
DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
344-
array, op,
352+
DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32>(
353+
array, lower,
345354
)?)),
346-
DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
347-
i64,
348-
_,
349-
>(array, op)?)),
355+
DataType::LargeUtf8 => Ok(ColumnarValue::Array(
356+
case_conversion_array::<i64>(array, lower)?,
357+
)),
350358
DataType::Utf8View => {
351359
let string_array = as_string_view_array(array)?;
360+
if string_array.is_ascii() {
361+
return Ok(ColumnarValue::Array(Arc::new(
362+
case_conversion_utf8view_ascii(string_array, lower),
363+
)));
364+
}
352365
let item_len = string_array.len();
353366
// Null-preserving: reuse the input null buffer as the output null buffer.
354367
let nulls = string_array.nulls().cloned();
@@ -361,14 +374,14 @@ where
361374
} else {
362375
// SAFETY: `n.is_null(i)` was false in the branch above.
363376
let s = unsafe { string_array.value_unchecked(i) };
364-
builder.append_value(&op(s));
377+
builder.append_value(&unicode_case(s, lower));
365378
}
366379
}
367380
} else {
368381
for i in 0..item_len {
369382
// SAFETY: no null buffer means every index is valid.
370383
let s = unsafe { string_array.value_unchecked(i) };
371-
builder.append_value(&op(s));
384+
builder.append_value(&unicode_case(s, lower));
372385
}
373386
}
374387

@@ -378,32 +391,31 @@ where
378391
},
379392
ColumnarValue::Scalar(scalar) => match scalar {
380393
ScalarValue::Utf8(a) => {
381-
let result = a.as_ref().map(|x| op(x));
394+
let result = a.as_ref().map(|x| unicode_case(x, lower));
382395
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
383396
}
384397
ScalarValue::LargeUtf8(a) => {
385-
let result = a.as_ref().map(|x| op(x));
398+
let result = a.as_ref().map(|x| unicode_case(x, lower));
386399
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
387400
}
388401
ScalarValue::Utf8View(a) => {
389-
let result = a.as_ref().map(|x| op(x));
402+
let result = a.as_ref().map(|x| unicode_case(x, lower));
390403
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(result)))
391404
}
392405
other => exec_err!("Unsupported data type {other:?} for function {name}"),
393406
},
394407
}
395408
}
396409

397-
fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
398-
where
399-
O: OffsetSizeTrait,
400-
F: Fn(&'a str) -> String,
401-
{
410+
fn case_conversion_array<O: OffsetSizeTrait>(
411+
array: &ArrayRef,
412+
lower: bool,
413+
) -> Result<ArrayRef> {
402414
const PRE_ALLOC_BYTES: usize = 8;
403415

404416
let string_array = as_generic_string_array::<O>(array)?;
405417
if string_array.is_ascii() {
406-
return case_conversion_ascii_array::<O, _>(string_array, op);
418+
return case_conversion_ascii_array::<O>(string_array, lower);
407419
}
408420

409421
// Values contain non-ASCII.
@@ -423,43 +435,164 @@ where
423435
} else {
424436
// SAFETY: `n.is_null(i)` was false in the branch above.
425437
let s = unsafe { string_array.value_unchecked(i) };
426-
builder.append_value(&op(s));
438+
builder.append_value(&unicode_case(s, lower));
427439
}
428440
}
429441
} else {
430442
for i in 0..item_len {
431443
// SAFETY: no null buffer means every index is valid.
432444
let s = unsafe { string_array.value_unchecked(i) };
433-
builder.append_value(&op(s));
445+
builder.append_value(&unicode_case(s, lower));
434446
}
435447
}
436448
Ok(Arc::new(builder.finish(nulls)?))
437449
}
438450

451+
/// Fast path for case conversion on an all-ASCII `StringViewArray`.
452+
fn case_conversion_utf8view_ascii(
453+
array: &StringViewArray,
454+
lower: bool,
455+
) -> StringViewArray {
456+
// Specialize per conversion so the byte call inlines in the hot loops below.
457+
if lower {
458+
case_conversion_utf8view_ascii_inner(array, u8::to_ascii_lowercase)
459+
} else {
460+
case_conversion_utf8view_ascii_inner(array, u8::to_ascii_uppercase)
461+
}
462+
}
463+
464+
/// Walks the views once and produces a new `StringViewArray` with
465+
/// case-converted bytes. Inline strings (<= 12 bytes) are converted in-place;
466+
/// long strings copy-and-convert into output buffers and have their view fields
467+
/// rewritten to address the new bytes. ASCII case conversion preserves is byte
468+
/// length, so no row migrates between the inline and long layouts.
469+
fn case_conversion_utf8view_ascii_inner<F: Fn(&u8) -> u8>(
470+
array: &StringViewArray,
471+
convert: F,
472+
) -> StringViewArray {
473+
let item_len = array.len();
474+
let views = array.views();
475+
let data_buffers = array.data_buffers();
476+
let nulls = array.nulls();
477+
478+
let mut new_views: Vec<u128> = Vec::with_capacity(item_len);
479+
// Long values are packed into `in_progress`; when full it is sealed into
480+
// `completed` and a new, larger block is started — same block-doubling
481+
// scheme as Arrow's `GenericByteViewBuilder`.
482+
let mut in_progress: Vec<u8> = Vec::new();
483+
let mut completed: Vec<Buffer> = Vec::new();
484+
let mut block_size: u32 = STRING_VIEW_INIT_BLOCK_SIZE;
485+
486+
for i in 0..item_len {
487+
if nulls.is_some_and(|n| n.is_null(i)) {
488+
// Zero view = empty, no buffer reference; the null buffer is what
489+
// marks the row null, so the view's value is irrelevant.
490+
new_views.push(0);
491+
continue;
492+
}
493+
let view = views[i];
494+
// Length is the low 32 bits; `as u32` discards the rest of the view.
495+
let len = view as u32 as usize;
496+
if len == 0 {
497+
new_views.push(0);
498+
continue;
499+
}
500+
let mut bytes = view.to_le_bytes();
501+
if len <= 12 {
502+
// Inline: value is in bytes[4..4+len], no buffer reference. Convert
503+
// in place; nothing else in the view needs to change.
504+
for b in &mut bytes[4..4 + len] {
505+
*b = convert(b);
506+
}
507+
new_views.push(u128::from_le_bytes(bytes));
508+
} else {
509+
// Long: input view points into shared `data_buffers` we can't
510+
// mutate, so copy-convert into our own buffer and rewrite the
511+
// view's prefix/buffer_index/offset (length is preserved).
512+
513+
// Ensure the current block has room; otherwise flush and grow.
514+
let required_cap = in_progress.len() + len;
515+
if in_progress.capacity() < required_cap {
516+
if !in_progress.is_empty() {
517+
completed.push(Buffer::from_vec(std::mem::take(&mut in_progress)));
518+
}
519+
if block_size < STRING_VIEW_MAX_BLOCK_SIZE {
520+
block_size = block_size.saturating_mul(2);
521+
}
522+
let to_reserve = len.max(block_size as usize);
523+
in_progress.reserve(to_reserve);
524+
}
525+
526+
// The in-progress block will be sealed at index `completed.len()`,
527+
// and our value starts at the current write position within it.
528+
let buffer_index: u32 = i32::try_from(completed.len())
529+
.expect("buffer count exceeds i32::MAX")
530+
as u32;
531+
let new_offset: u32 =
532+
i32::try_from(in_progress.len()).expect("offset exceeds i32::MAX") as u32;
533+
534+
// Source location from the input view: bytes 8..12 are buffer
535+
// index, bytes 12..16 are the offset within it.
536+
let src_buffer_index =
537+
u32::from_le_bytes(bytes[8..12].try_into().unwrap()) as usize;
538+
let src_offset =
539+
u32::from_le_bytes(bytes[12..16].try_into().unwrap()) as usize;
540+
let src =
541+
&data_buffers[src_buffer_index].as_slice()[src_offset..src_offset + len];
542+
543+
let prefix_start = in_progress.len();
544+
in_progress.extend(src.iter().map(&convert));
545+
546+
// Rewrite the three long-view fields; bytes[0..4] (length) is
547+
// left untouched. The prefix is read back from the bytes we just
548+
// wrote so the converted value has a single source of truth.
549+
let prefix: [u8; 4] = in_progress[prefix_start..prefix_start + 4]
550+
.try_into()
551+
.unwrap();
552+
bytes[4..8].copy_from_slice(&prefix);
553+
bytes[8..12].copy_from_slice(&buffer_index.to_le_bytes());
554+
bytes[12..16].copy_from_slice(&new_offset.to_le_bytes());
555+
new_views.push(u128::from_le_bytes(bytes));
556+
}
557+
}
558+
559+
if !in_progress.is_empty() {
560+
completed.push(Buffer::from_vec(in_progress));
561+
}
562+
563+
// SAFETY: each long view's buffer_index addresses a buffer we wrote, and
564+
// its offset addresses bytes within that buffer; prefixes were copied from
565+
// those same bytes; inline views were rewritten from valid inline bytes;
566+
// null/empty rows are zero views with no buffer reference; row count is
567+
// unchanged.
568+
unsafe {
569+
StringViewArray::new_unchecked(
570+
ScalarBuffer::from(new_views),
571+
completed,
572+
array.nulls().cloned(),
573+
)
574+
}
575+
}
576+
439577
/// Fast path for case conversion on an all-ASCII string array. ASCII case
440578
/// conversion is byte-length-preserving, so we can convert the entire addressed
441-
/// range in one call and reuse the offsets and nulls buffers — rebasing the
442-
/// offsets when the input is a sliced array.
443-
fn case_conversion_ascii_array<'a, O, F>(
444-
string_array: &'a GenericStringArray<O>,
445-
op: F,
446-
) -> Result<ArrayRef>
447-
where
448-
O: OffsetSizeTrait,
449-
F: Fn(&'a str) -> String,
450-
{
579+
/// byte range in one pass over the value buffer and reuse the offsets and nulls
580+
/// buffers — rebasing the offsets when the input is a sliced array.
581+
fn case_conversion_ascii_array<O: OffsetSizeTrait>(
582+
string_array: &GenericStringArray<O>,
583+
lower: bool,
584+
) -> Result<ArrayRef> {
451585
let value_offsets = string_array.value_offsets();
452586
let start = value_offsets.first().unwrap().as_usize();
453587
let end = value_offsets.last().unwrap().as_usize();
454588
let relevant = &string_array.value_data()[start..end];
455589

456-
// SAFETY: `relevant` is a subslice of the string array's value buffer,
457-
// which is valid UTF-8.
458-
let str_values = unsafe { std::str::from_utf8_unchecked(relevant) };
459-
460-
let converted_values = op(str_values);
461-
debug_assert_eq!(converted_values.len(), str_values.len());
462-
let values = Buffer::from_vec(converted_values.into_bytes());
590+
let converted: Vec<u8> = if lower {
591+
relevant.iter().map(u8::to_ascii_lowercase).collect()
592+
} else {
593+
relevant.iter().map(u8::to_ascii_uppercase).collect()
594+
};
595+
let values = Buffer::from_vec(converted);
463596

464597
// Shift offsets from `start`-based to 0-based so they index into `values`.
465598
let offsets = if start == 0 {
@@ -468,7 +601,7 @@ where
468601
let s = O::usize_as(start);
469602
let rebased: Vec<O> = value_offsets.iter().map(|&o| o - s).collect();
470603
// SAFETY: subtracting a constant from monotonic offsets preserves
471-
// monotonicity, and `start` is the minimum offset so no underflow.
604+
// monotonicity, and `start` is the minimum offset, so no underflow.
472605
unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(rebased)) }
473606
};
474607

0 commit comments

Comments
 (0)