Skip to content

Commit c916849

Browse files
authored
docs: add SAFETY comments to all unsafe blocks in shuffle spark_unsafe module (#3603)
1 parent 68b2c4d commit c916849

3 files changed

Lines changed: 48 additions & 2 deletions

File tree

native/core/src/execution/shuffle/spark_unsafe/list.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ impl SparkUnsafeObject for SparkUnsafeArray {
5151
impl SparkUnsafeArray {
5252
/// Creates a `SparkUnsafeArray` which points to the given address and size in bytes.
5353
pub fn new(addr: i64) -> Self {
54-
// Read the number of elements from the first 8 bytes.
54+
// SAFETY: addr points to valid Spark UnsafeArray data from the JVM.
55+
// The first 8 bytes contain the element count as a little-endian i64.
5556
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
5657
let num_elements = i64::from_le_bytes(slice.try_into().unwrap());
5758

@@ -83,6 +84,9 @@ impl SparkUnsafeArray {
8384
/// Returns true if the null bit at the given index of the array is set.
8485
#[inline]
8586
pub(crate) fn is_null_at(&self, index: usize) -> bool {
87+
// SAFETY: row_addr points to valid Spark UnsafeArray data. The null bitset starts
88+
// at offset 8 and contains ceil(num_elements/64) * 8 bytes. The caller ensures
89+
// index < num_elements, so word_offset is within the bitset region.
8690
unsafe {
8791
let mask: i64 = 1i64 << (index & 0x3f);
8892
let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64;

native/core/src/execution/shuffle/spark_unsafe/map.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ pub struct SparkUnsafeMap {
3030
impl SparkUnsafeMap {
3131
/// Creates a `SparkUnsafeMap` which points to the given address and size in bytes.
3232
pub(crate) fn new(addr: i64, size: i32) -> Self {
33-
// Read the number of bytes of key array from the first 8 bytes.
33+
// SAFETY: addr points to valid Spark UnsafeMap data from the JVM.
34+
// The first 8 bytes contain the key array size as a little-endian i64.
3435
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
3536
let key_array_size = i64::from_le_bytes(slice.try_into().unwrap());
3637

native/core/src/execution/shuffle/spark_unsafe/row.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@ const NESTED_TYPE_BUILDER_CAPACITY: usize = 100;
5858
/// A common trait for Spark Unsafe classes that can be used to access the underlying data,
5959
/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that can be used to
6060
/// access the underlying data with index.
61+
///
62+
/// # Safety
63+
///
64+
/// Implementations must ensure that:
65+
/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory
66+
/// - `get_element_offset()` returns a valid pointer within the row/array data region
67+
/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format
68+
/// - The memory remains valid for the lifetime of the object (guaranteed by JVM ownership)
69+
///
70+
/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer operations but are
71+
/// safe to call as long as:
72+
/// - The index is within bounds (caller's responsibility)
73+
/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data
6174
pub trait SparkUnsafeObject {
6275
/// Returns the address of the row.
6376
fn get_row_addr(&self) -> i64;
@@ -77,47 +90,55 @@ pub trait SparkUnsafeObject {
7790
/// Returns boolean value at the given index of the object.
7891
fn get_boolean(&self, index: usize) -> bool {
7992
let addr = self.get_element_offset(index, 1);
93+
// SAFETY: addr points to valid element data within the UnsafeRow/UnsafeArray region.
94+
// The caller ensures index is within bounds.
8095
unsafe { *addr != 0 }
8196
}
8297

8398
/// Returns byte value at the given index of the object.
8499
fn get_byte(&self, index: usize) -> i8 {
85100
let addr = self.get_element_offset(index, 1);
101+
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
86102
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) };
87103
i8::from_le_bytes(slice.try_into().unwrap())
88104
}
89105

90106
/// Returns short value at the given index of the object.
91107
fn get_short(&self, index: usize) -> i16 {
92108
let addr = self.get_element_offset(index, 2);
109+
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
93110
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) };
94111
i16::from_le_bytes(slice.try_into().unwrap())
95112
}
96113

97114
/// Returns integer value at the given index of the object.
98115
fn get_int(&self, index: usize) -> i32 {
99116
let addr = self.get_element_offset(index, 4);
117+
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
100118
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
101119
i32::from_le_bytes(slice.try_into().unwrap())
102120
}
103121

104122
/// Returns long value at the given index of the object.
105123
fn get_long(&self, index: usize) -> i64 {
106124
let addr = self.get_element_offset(index, 8);
125+
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
107126
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
108127
i64::from_le_bytes(slice.try_into().unwrap())
109128
}
110129

111130
/// Returns float value at the given index of the object.
112131
fn get_float(&self, index: usize) -> f32 {
113132
let addr = self.get_element_offset(index, 4);
133+
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
114134
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
115135
f32::from_le_bytes(slice.try_into().unwrap())
116136
}
117137

118138
/// Returns double value at the given index of the object.
119139
fn get_double(&self, index: usize) -> f64 {
120140
let addr = self.get_element_offset(index, 8);
141+
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
121142
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
122143
f64::from_le_bytes(slice.try_into().unwrap())
123144
}
@@ -126,6 +147,8 @@ pub trait SparkUnsafeObject {
126147
fn get_string(&self, index: usize) -> &str {
127148
let (offset, len) = self.get_offset_and_len(index);
128149
let addr = self.get_row_addr() + offset as i64;
150+
// SAFETY: addr points to valid UTF-8 string data within the variable-length region.
151+
// Offset and length are read from the fixed-length portion of the row/array.
129152
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) };
130153

131154
from_utf8(slice).unwrap()
@@ -135,19 +158,23 @@ pub trait SparkUnsafeObject {
135158
fn get_binary(&self, index: usize) -> &[u8] {
136159
let (offset, len) = self.get_offset_and_len(index);
137160
let addr = self.get_row_addr() + offset as i64;
161+
// SAFETY: addr points to valid binary data within the variable-length region.
162+
// Offset and length are read from the fixed-length portion of the row/array.
138163
unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
139164
}
140165

141166
/// Returns date value at the given index of the object.
142167
fn get_date(&self, index: usize) -> i32 {
143168
let addr = self.get_element_offset(index, 4);
169+
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
144170
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
145171
i32::from_le_bytes(slice.try_into().unwrap())
146172
}
147173

148174
/// Returns timestamp value at the given index of the object.
149175
fn get_timestamp(&self, index: usize) -> i64 {
150176
let addr = self.get_element_offset(index, 8);
177+
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
151178
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
152179
i64::from_le_bytes(slice.try_into().unwrap())
153180
}
@@ -257,6 +284,9 @@ impl SparkUnsafeRow {
257284
/// Returns true if the null bit at the given index of the row is set.
258285
#[inline]
259286
pub(crate) fn is_null_at(&self, index: usize) -> bool {
287+
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
288+
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
289+
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
260290
unsafe {
261291
let mask: i64 = 1i64 << (index & 0x3f);
262292
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64;
@@ -267,6 +297,9 @@ impl SparkUnsafeRow {
267297

268298
/// Unsets the null bit at the given index of the row, i.e., set the bit to 0 (not null).
269299
pub fn set_not_null_at(&mut self, index: usize) {
300+
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
301+
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
302+
// Writing is safe because we have mutable access and the memory is owned by the JVM.
270303
unsafe {
271304
let mask: i64 = 1i64 << (index & 0x3f);
272305
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64;
@@ -463,6 +496,8 @@ fn append_columns(
463496
let mut row = SparkUnsafeRow::new(schema);
464497

465498
for i in row_start..row_end {
499+
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
500+
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
466501
let row_addr = unsafe { *row_addresses_ptr.add(i) };
467502
let row_size = unsafe { *row_sizes_ptr.add(i) };
468503
row.point_to(row_addr, row_size);
@@ -593,6 +628,8 @@ fn append_columns(
593628
let mut row = SparkUnsafeRow::new(schema);
594629

595630
for i in row_start..row_end {
631+
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
632+
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
596633
let row_addr = unsafe { *row_addresses_ptr.add(i) };
597634
let row_size = unsafe { *row_sizes_ptr.add(i) };
598635
row.point_to(row_addr, row_size);
@@ -613,6 +650,8 @@ fn append_columns(
613650
let mut row = SparkUnsafeRow::new(schema);
614651

615652
for i in row_start..row_end {
653+
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
654+
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
616655
let row_addr = unsafe { *row_addresses_ptr.add(i) };
617656
let row_size = unsafe { *row_sizes_ptr.add(i) };
618657
row.point_to(row_addr, row_size);
@@ -640,6 +679,8 @@ fn append_columns(
640679
let mut row = SparkUnsafeRow::new(schema);
641680

642681
for i in row_start..row_end {
682+
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
683+
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
643684
let row_addr = unsafe { *row_addresses_ptr.add(i) };
644685
let row_size = unsafe { *row_sizes_ptr.add(i) };
645686
row.point_to(row_addr, row_size);

0 commit comments

Comments
 (0)