Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 174 additions & 111 deletions native/shuffle/src/spark_unsafe/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,29 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::spark_unsafe::{
map::append_map_elements,
row::{append_field, downcast_builder_ref, SparkUnsafeRow},
unsafe_object::{impl_primitive_accessors, SparkUnsafeObject},
};
use arrow::array::{
builder::{
ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder,
Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
ListBuilder, NullBuilder, StringBuilder, StructBuilder, TimestampMicrosecondBuilder,
use arrow::datatypes::{DataType, TimeUnit};
use arrow::{
array::{
builder::{
ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder,
Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
ListBuilder, NullBuilder, StringBuilder, StructBuilder, TimestampMicrosecondBuilder,
},
MapBuilder, PrimitiveArray,
},
buffer::{BooleanBuffer, Buffer, NullBuffer, ScalarBuffer},
datatypes::{
Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
TimestampMicrosecondType,
},
MapBuilder,
};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_comet_jni_bridge::errors::CometError;

/// Generates bulk append methods for primitive types in SparkUnsafeArray.
Expand All @@ -38,38 +47,56 @@ use datafusion_comet_jni_bridge::errors::CometError;
/// - `null_bitset_ptr()` returns a pointer to `ceil(num_elements/64)` i64 words
/// - These invariants are guaranteed by the SparkUnsafeArray layout from the JVM
macro_rules! impl_append_to_builder {
($method_name:ident, $builder_type:ty, $element_type:ty) => {
($method_name:ident, $builder_type:ty, $element_type:ty, $arrow_type:ty) => {
pub(crate) fn $method_name<const NULLABLE: bool>(&self, builder: &mut $builder_type) {
let num_elements = self.num_elements;
if num_elements == 0 {
return;
}
// Note: alignment is not guaranteed - that is why do this
// This runtime check is needed. Look at `unsafe_object.rs:49` for more info
let ptr = self.element_offset as *const $element_type;
let aligned = (ptr as usize).is_multiple_of(std::mem::align_of::<$element_type>());

if NULLABLE {
let mut ptr = self.element_offset as *const $element_type;
let null_words = self.null_bitset_ptr();
debug_assert!(!null_words.is_null(), "null_bitset_ptr is null");
debug_assert!(!ptr.is_null(), "element_offset pointer is null");
for idx in 0..num_elements {
// SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements
let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) };

if is_null {
builder.append_null();
} else {
// SAFETY: ptr is within element data bounds
builder.append_value(unsafe { ptr.read_unaligned() });
if aligned {
// Raw values
let values = unsafe { std::slice::from_raw_parts(ptr, num_elements) };

// Note: in Spark bitmap is padded to 8 byte word-boundaries
// In Arrow we just use the needed number of whole bytes without padding
let null_mask_len = num_elements.div_ceil(8);
let null_mask = unsafe {
std::slice::from_raw_parts::<u8>(null_words as *const u8, null_mask_len)
};
// We need to perform this flip due to the null bitmap Spark vs Arrow incompatibility
// In `Spark` we have 1 set in bitmap meaning that element IS NULL
// In `Arrow` we have 1 set in bitmap meaning that element IS VALID (non-null)
let flipped: Vec<u8> = null_mask.iter().map(|n| !n).collect();
// Constructing null-buffer
let validity =
NullBuffer::new(BooleanBuffer::new(Buffer::from(flipped), 0, num_elements));

let arr = PrimitiveArray::<$arrow_type>::new(
ScalarBuffer::from(Buffer::from_slice_ref(values)),
Some(validity),
);
builder.append_array(&arr);
} else {
let mut ptr = ptr;
for idx in 0..num_elements {
if unsafe { Self::is_null_in_bitset(null_words, idx) } {
builder.append_null();
} else {
builder.append_value(unsafe { ptr.read_unaligned() });
}
ptr = unsafe { ptr.add(1) };
}
// SAFETY: ptr stays within bounds, iterating num_elements times
ptr = unsafe { ptr.add(1) };
}
} else {
// SAFETY: element_offset points to contiguous data of length num_elements
debug_assert!(self.element_offset != 0, "element_offset is null");
let ptr = self.element_offset as *const $element_type;
// Use bulk copy when data is properly aligned, fall back to
// per-element unaligned reads otherwise
if (ptr as usize).is_multiple_of(std::mem::align_of::<$element_type>()) {
if aligned {
let slice = unsafe { std::slice::from_raw_parts(ptr, num_elements) };
builder.append_slice(slice);
} else {
Expand Down Expand Up @@ -177,12 +204,12 @@ impl SparkUnsafeArray {
(null_words.add(word_idx).read_unaligned() & (1i64 << bit_idx)) != 0
}

impl_append_to_builder!(append_ints_to_builder, Int32Builder, i32);
impl_append_to_builder!(append_longs_to_builder, Int64Builder, i64);
impl_append_to_builder!(append_shorts_to_builder, Int16Builder, i16);
impl_append_to_builder!(append_bytes_to_builder, Int8Builder, i8);
impl_append_to_builder!(append_floats_to_builder, Float32Builder, f32);
impl_append_to_builder!(append_doubles_to_builder, Float64Builder, f64);
impl_append_to_builder!(append_ints_to_builder, Int32Builder, i32, Int32Type);
impl_append_to_builder!(append_longs_to_builder, Int64Builder, i64, Int64Type);
impl_append_to_builder!(append_shorts_to_builder, Int16Builder, i16, Int16Type);
impl_append_to_builder!(append_bytes_to_builder, Int8Builder, i8, Int8Type);
impl_append_to_builder!(append_floats_to_builder, Float32Builder, f32, Float32Type);
impl_append_to_builder!(append_doubles_to_builder, Float64Builder, f64, Float64Type);

/// Bulk append boolean values to builder.
/// Booleans are stored as 1 byte each in SparkUnsafeArray, requiring special handling.
Expand All @@ -194,37 +221,31 @@ impl SparkUnsafeArray {
if num_elements == 0 {
return;
}

let mut ptr = self.element_offset as *const u8;
// Bools have alignment == 1
// We dont have to worry about the fallback. Hence, we do not care about it
debug_assert!(
!ptr.is_null(),
self.element_offset != 0,
"append_booleans: element_offset pointer is null"
);

if NULLABLE {
let null_words = self.null_bitset_ptr();
debug_assert!(
!null_words.is_null(),
"append_booleans: null_bitset_ptr is null"
);
for idx in 0..num_elements {
// SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements
let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) };

if is_null {
let slice = unsafe {
std::slice::from_raw_parts(self.element_offset as *const bool, num_elements)
};
for (idx, &value) in slice.iter().enumerate() {
if unsafe { Self::is_null_in_bitset(null_words, idx) } {
builder.append_null();
} else {
// SAFETY: ptr is within element data bounds
builder.append_value(unsafe { *ptr != 0 });
builder.append_value(value);
}
// SAFETY: ptr stays within bounds, iterating num_elements times
ptr = unsafe { ptr.add(1) };
}
} else {
for _ in 0..num_elements {
// SAFETY: ptr is within element data bounds
builder.append_value(unsafe { *ptr != 0 });
ptr = unsafe { ptr.add(1) };
let values = unsafe {
std::slice::from_raw_parts(self.element_offset as *const u8, num_elements)
};
for &value in values {
builder.append_value(value != 0);
}
}
}
Expand All @@ -233,46 +254,68 @@ impl SparkUnsafeArray {
pub(crate) fn append_timestamps_to_builder<const NULLABLE: bool>(
&self,
builder: &mut TimestampMicrosecondBuilder,
timezone: Option<Arc<str>>,
) {
let num_elements = self.num_elements;
if num_elements == 0 {
return;
}

// SAFETY: element_offset points to contiguous i64 data of length num_elements
debug_assert!(
self.element_offset != 0,
"append_timestamps: element_offset is null"
);

let ptr = self.element_offset as *const i64;
// Note: alignment is not guaranteed - that is why do this
// This runtime check is needed. Look at `unsafe_object.rs:49` for more info
let aligned = (ptr as usize).is_multiple_of(std::mem::align_of::<i64>());

if NULLABLE {
let mut ptr = self.element_offset as *const i64;
let null_words = self.null_bitset_ptr();
debug_assert!(
!null_words.is_null(),
"append_timestamps: null_bitset_ptr is null"
);
debug_assert!(
!ptr.is_null(),
"append_timestamps: element_offset pointer is null"
);
for idx in 0..num_elements {
// SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements
let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) };
debug_assert!(!null_words.is_null(), "null_bitset_ptr is null");
if aligned {
// Raw values
let values = unsafe { std::slice::from_raw_parts(ptr, num_elements) };

// Note: in Spark bitmap is padded to 8 byte word-boundaries
// In Arrow we just use the needed number of whole bytes without padding
let null_mask_len = num_elements.div_ceil(8);
let null_mask = unsafe {
std::slice::from_raw_parts::<u8>(null_words as *const u8, null_mask_len)
};

if is_null {
builder.append_null();
} else {
// SAFETY: ptr is within element data bounds
builder.append_value(unsafe { ptr.read_unaligned() });
// We need to perform this flip due to the null bitmap Spark vs Arrow incompatibility
// In `Spark` we have 1 set in bitmap meaning that element IS NULL
// In `Arrow` we have 1 set in bitmap meaning that element IS VALID (non-null)
let flipped: Vec<u8> = null_mask.iter().map(|n| !n).collect();
// Constructing null-buffer
let validity =
NullBuffer::new(BooleanBuffer::new(Buffer::from(flipped), 0, num_elements));

// Constructing Arrow array with timezone set
let arr = PrimitiveArray::<TimestampMicrosecondType>::new(
ScalarBuffer::from(Buffer::from_slice_ref(values)),
Some(validity),
)
.with_timezone_opt(timezone);
builder.append_array(&arr);
} else {
let mut ptr = ptr;
for idx in 0..num_elements {
if unsafe { Self::is_null_in_bitset(null_words, idx) } {
builder.append_null();
} else {
builder.append_value(unsafe { ptr.read_unaligned() });
}
ptr = unsafe { ptr.add(1) }
}
// SAFETY: ptr stays within bounds, iterating num_elements times
ptr = unsafe { ptr.add(1) };
}
} else {
// SAFETY: element_offset points to contiguous i64 data of length num_elements
debug_assert!(
self.element_offset != 0,
"append_timestamps: element_offset is null"
);
let ptr = self.element_offset as *const i64;
if (ptr as usize).is_multiple_of(std::mem::align_of::<i64>()) {
let slice = unsafe { std::slice::from_raw_parts(ptr, num_elements) };
builder.append_slice(slice);
if aligned {
let values = unsafe { std::slice::from_raw_parts(ptr, num_elements) };
builder.append_slice(values);
} else {
let mut ptr = ptr;
for _ in 0..num_elements {
Expand All @@ -293,40 +336,60 @@ impl SparkUnsafeArray {
return;
}

// SAFETY: element_offset points to contiguous i64 data of length num_elements
debug_assert!(
self.element_offset != 0,
"append_timestamps: element_offset is null"
);

let ptr = self.element_offset as *const i32;
// Note: alignment is not guaranteed - that is why do this
// This runtime check is needed. Look at `unsafe_object.rs:49` for more info
let aligned = (ptr as usize).is_multiple_of(std::mem::align_of::<i32>());

if NULLABLE {
let mut ptr = self.element_offset as *const i32;
let null_words = self.null_bitset_ptr();
debug_assert!(
!null_words.is_null(),
"append_dates: null_bitset_ptr is null"
);
debug_assert!(
!ptr.is_null(),
"append_dates: element_offset pointer is null"
);
for idx in 0..num_elements {
// SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements
let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) };
debug_assert!(!null_words.is_null(), "null_bitset_ptr is null");
if aligned {
// Raw values
let values = unsafe { std::slice::from_raw_parts(ptr, num_elements) };

// Note: in Spark bitmap is padded to 8 byte word-boundaries
// In Arrow we just use the needed number of whole bytes without padding
let null_mask_len = num_elements.div_ceil(8);
let null_mask = unsafe {
std::slice::from_raw_parts::<u8>(null_words as *const u8, null_mask_len)
};

if is_null {
builder.append_null();
} else {
// SAFETY: ptr is within element data bounds
builder.append_value(unsafe { ptr.read_unaligned() });
// We need to perform this flip due to the null bitmap `Spark` vs `Arrow` incompatibility
// In `Spark` we have 1 set in bitmap meaning that element IS NULL
// In `Arrow` we have 1 set in bitmap meaning that element IS VALID (non-null)
let flipped: Vec<u8> = null_mask.iter().map(|n| !n).collect();
// Constructing null-buffer
let validity =
NullBuffer::new(BooleanBuffer::new(Buffer::from(flipped), 0, num_elements));

// Constructing Arrow array with timezone set
let arr = PrimitiveArray::<Date32Type>::new(
ScalarBuffer::from(Buffer::from_slice_ref(values)),
Some(validity),
);
builder.append_array(&arr);
} else {
let mut ptr = ptr;
for idx in 0..num_elements {
if unsafe { Self::is_null_in_bitset(null_words, idx) } {
builder.append_null();
} else {
builder.append_value(unsafe { ptr.read_unaligned() });
}
ptr = unsafe { ptr.add(1) };
}
// SAFETY: ptr stays within bounds, iterating num_elements times
ptr = unsafe { ptr.add(1) };
}
} else {
// SAFETY: element_offset points to contiguous i32 data of length num_elements
debug_assert!(
self.element_offset != 0,
"append_dates: element_offset is null"
);
let ptr = self.element_offset as *const i32;
if (ptr as usize).is_multiple_of(std::mem::align_of::<i32>()) {
let slice = unsafe { std::slice::from_raw_parts(ptr, num_elements) };
builder.append_slice(slice);
if aligned {
let values = unsafe { std::slice::from_raw_parts(ptr, num_elements) };
builder.append_slice(values);
} else {
let mut ptr = ptr;
for _ in 0..num_elements {
Expand Down Expand Up @@ -385,9 +448,9 @@ pub fn append_to_builder<const NULLABLE: bool>(
let builder = downcast_builder_ref!(Float64Builder, builder);
array.append_doubles_to_builder::<NULLABLE>(builder);
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
let builder = downcast_builder_ref!(TimestampMicrosecondBuilder, builder);
array.append_timestamps_to_builder::<NULLABLE>(builder);
array.append_timestamps_to_builder::<NULLABLE>(builder, tz.clone());
}
DataType::Date32 => {
let builder = downcast_builder_ref!(Date32Builder, builder);
Expand Down
Loading