Skip to content

Commit 089c9a9

Browse files
committed
Restore non-null fastpath but unsafely
1 parent 57cacc4 commit 089c9a9

1 file changed

Lines changed: 31 additions & 15 deletions

File tree

arrow-select/src/take.rs

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -495,14 +495,19 @@ fn take_boolean<IndexType: ArrowPrimitiveType>(
495495
/// # Safety
496496
/// Each `(start, end)` in `ranges` must be in-bounds of `src`, and
497497
/// `capacity` must equal the total bytes across all ranges.
498-
unsafe fn copy_byte_ranges(src: &[u8], ranges: &[(usize, usize)], capacity: usize) -> Vec<u8> {
498+
unsafe fn copy_byte_ranges(
499+
src: &[u8],
500+
ranges: &[(usize, usize)],
501+
capacity: usize,
502+
values: &mut Vec<u8>,
503+
) {
504+
values.reserve(capacity);
499505
debug_assert_eq!(
500506
ranges.iter().map(|(s, e)| e - s).sum::<usize>(),
501507
capacity,
502508
"capacity must equal total bytes across all ranges"
503509
);
504510
let src_len = src.len();
505-
let mut values = Vec::with_capacity(capacity);
506511
let src = src.as_ptr();
507512
let mut dst = values.as_mut_ptr();
508513
for &(start, end) in ranges {
@@ -523,28 +528,25 @@ unsafe fn copy_byte_ranges(src: &[u8], ranges: &[(usize, usize)], capacity: usiz
523528
// SAFETY: caller guarantees `capacity` == total bytes across all ranges,
524529
// so the loop above wrote exactly `capacity` bytes.
525530
unsafe { values.set_len(capacity) };
526-
values
527531
}
528532

529533
/// `take` implementation for string arrays
530534
fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
531535
array: &GenericByteArray<T>,
532536
indices: &PrimitiveArray<IndexType>,
533537
) -> Result<GenericByteArray<T>, ArrowError> {
538+
let mut values = Vec::new();
534539
let mut offsets = Vec::with_capacity(indices.len() + 1);
535540
offsets.push(T::Offset::default());
536541

537542
let input_offsets = array.value_offsets();
538-
let input_values = array.value_data();
539543
let mut capacity = 0;
540544
let nulls = take_nulls(array.nulls(), indices);
541545

542-
// Pass 1: compute offsets and collect byte ranges.
543546
// Branch on output nulls — `None` means every output slot is valid.
544-
let ranges = match nulls.as_ref().filter(|n| n.null_count() > 0) {
547+
match nulls.as_ref().filter(|n| n.null_count() > 0) {
545548
// Fast path: no nulls in output, every index is valid.
546549
None => {
547-
let mut ranges = Vec::with_capacity(indices.len());
548550
for index in indices.values() {
549551
let index = index.as_usize();
550552
let start = input_offsets[index].as_usize();
@@ -554,9 +556,26 @@ fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
554556
T::Offset::from_usize(capacity)
555557
.ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
556558
);
557-
ranges.push((start, end));
558559
}
559-
ranges
560+
561+
values.reserve(capacity);
562+
563+
let mut dst = values.as_mut_ptr();
564+
565+
for index in indices.values() {
566+
// SAFETY: in-bounds proven by the first loop's bounds-checked offset access.
567+
// dst stays within reserved capacity computed from the same indices.
568+
unsafe {
569+
let data: &[u8] = array.value_unchecked(index.as_usize()).as_ref();
570+
std::ptr::copy_nonoverlapping(data.as_ptr(), dst, data.len());
571+
dst = dst.add(data.len());
572+
}
573+
}
574+
575+
// SAFETY: wrote exactly `capacity` bytes above; reserved on line above.
576+
unsafe {
577+
values.set_len(capacity);
578+
}
560579
}
561580
// Nullable path: only process valid (non-null) output positions.
562581
Some(output_nulls) => {
@@ -566,6 +585,7 @@ fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
566585
// Pre-fill offsets; we overwrite valid positions below.
567586
offsets.resize(indices.len() + 1, T::Offset::default());
568587

588+
// Pass 1: find all valid ranges that need to be copied.
569589
for i in output_nulls.valid_indices() {
570590
let current_offset = T::Offset::from_usize(capacity)
571591
.ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
@@ -589,15 +609,11 @@ fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
589609
let final_offset = T::Offset::from_usize(capacity)
590610
.ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
591611
offsets[last_filled + 1..].fill(final_offset);
592-
ranges
612+
// Pass 2: copy byte data for all collected ranges.
613+
unsafe { copy_byte_ranges(array.value_data(), &ranges, capacity, &mut values) };
593614
}
594615
};
595616

596-
// Pass 2: copy byte data for all collected ranges.
597-
let values = unsafe { copy_byte_ranges(input_values, &ranges, capacity) };
598-
599-
debug_assert_eq!(capacity, values.len());
600-
601617
// SAFETY: offsets are monotonically increasing and in-bounds of `values`,
602618
// and `nulls` (if present) has length == `indices.len()`.
603619
let array = unsafe {

0 commit comments

Comments
 (0)