Skip to content

Commit 83d1d30

Browse files
andygroveclaude
andcommitted
refactor: add safety comments to remaining unsafe blocks
Add SAFETY comments to: - SparkUnsafeRow::is_null_at and set_not_null_at - SparkUnsafeArray::new and is_null_at - Batch processing functions (append_list_column_batch, append_map_column_batch, append_struct_fields_field_major) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent ac8c182 commit 83d1d30

2 files changed

Lines changed: 34 additions & 1 deletion

File tree

native/core/src/execution/shuffle/list.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ impl SparkUnsafeObject for SparkUnsafeArray {
9999
impl SparkUnsafeArray {
100100
/// Creates a `SparkUnsafeArray` which points to the given address and size in bytes.
101101
pub fn new(addr: i64) -> Self {
102-
// Read the number of elements from the first 8 bytes.
102+
// SAFETY: addr points to valid Spark UnsafeArray data from the JVM.
103+
// The first 8 bytes contain the element count as a little-endian i64.
103104
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
104105
let num_elements = i64::from_le_bytes(slice.try_into().unwrap());
105106

@@ -131,6 +132,9 @@ impl SparkUnsafeArray {
131132
/// Returns true if the null bit at the given index of the array is set.
132133
#[inline]
133134
pub(crate) fn is_null_at(&self, index: usize) -> bool {
135+
// SAFETY: row_addr points to valid Spark UnsafeArray data. The null bitset starts
136+
// at offset 8 and contains ceil(num_elements/64) * 8 bytes. The caller ensures
137+
// index < num_elements, so word_offset is within the bitset region.
134138
unsafe {
135139
let mask: i64 = 1i64 << (index & 0x3f);
136140
let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64;

native/core/src/execution/shuffle/row.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ impl SparkUnsafeRow {
277277
/// Returns true if the null bit at the given index of the row is set.
278278
#[inline]
279279
pub(crate) fn is_null_at(&self, index: usize) -> bool {
280+
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
281+
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
282+
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
280283
unsafe {
281284
let mask: i64 = 1i64 << (index & 0x3f);
282285
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64;
@@ -287,6 +290,10 @@ impl SparkUnsafeRow {
287290

288291
/// Unsets the null bit at the given index of the row, i.e., set the bit to 0 (not null).
289292
pub fn set_not_null_at(&mut self, index: usize) {
293+
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
294+
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
295+
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
296+
// Writing is safe because we have mutable access and the memory is owned by the JVM.
290297
unsafe {
291298
let mask: i64 = 1i64 << (index & 0x3f);
292299
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64;
@@ -716,6 +723,8 @@ fn append_list_column_batch(
716723
macro_rules! process_primitive_lists {
717724
($builder_type:ty, $append_fn:ident) => {{
718725
for i in row_start..row_end {
726+
// SAFETY: Caller (append_columns) guarantees row_addresses_ptr and row_sizes_ptr
727+
// are valid for indices [row_start, row_end) as provided by the JVM
719728
let row_addr = unsafe { *row_addresses_ptr.add(i) };
720729
let row_size = unsafe { *row_sizes_ptr.add(i) };
721730
row.point_to(row_addr, row_size);
@@ -768,6 +777,8 @@ fn append_list_column_batch(
768777
// For complex element types, fall back to per-row dispatch
769778
_ => {
770779
for i in row_start..row_end {
780+
// SAFETY: Caller (append_columns) guarantees row_addresses_ptr and row_sizes_ptr
781+
// are valid for indices [row_start, row_end) as provided by the JVM
771782
let row_addr = unsafe { *row_addresses_ptr.add(i) };
772783
let row_size = unsafe { *row_sizes_ptr.add(i) };
773784
row.point_to(row_addr, row_size);
@@ -807,6 +818,8 @@ fn append_map_column_batch(
807818
macro_rules! process_primitive_maps {
808819
($key_builder:ty, $key_append:ident, $val_builder:ty, $val_append:ident) => {{
809820
for i in row_start..row_end {
821+
// SAFETY: Caller (append_columns) guarantees row_addresses_ptr and row_sizes_ptr
822+
// are valid for indices [row_start, row_end) as provided by the JVM
810823
let row_addr = unsafe { *row_addresses_ptr.add(i) };
811824
let row_size = unsafe { *row_sizes_ptr.add(i) };
812825
row.point_to(row_addr, row_size);
@@ -880,6 +893,8 @@ fn append_map_column_batch(
880893
// For other types, fall back to per-row dispatch
881894
_ => {
882895
for i in row_start..row_end {
896+
// SAFETY: Caller (append_columns) guarantees row_addresses_ptr and row_sizes_ptr
897+
// are valid for indices [row_start, row_end) as provided by the JVM
883898
let row_addr = unsafe { *row_addresses_ptr.add(i) };
884899
let row_size = unsafe { *row_sizes_ptr.add(i) };
885900
row.point_to(row_addr, row_size);
@@ -918,6 +933,8 @@ fn append_struct_fields_field_major(
918933
let mut struct_is_null = Vec::with_capacity(num_rows);
919934

920935
for i in row_start..row_end {
936+
// SAFETY: Caller (append_columns) guarantees row_addresses_ptr and row_sizes_ptr
937+
// are valid for indices [row_start, row_end) as provided by the JVM
921938
let row_addr = unsafe { *row_addresses_ptr.add(i) };
922939
let row_size = unsafe { *row_sizes_ptr.add(i) };
923940
parent_row.point_to(row_addr, row_size);
@@ -942,6 +959,8 @@ fn append_struct_fields_field_major(
942959
// Struct is null, field is also null
943960
field_builder.append_null();
944961
} else {
962+
// SAFETY: Caller (append_columns) guarantees row_addresses_ptr and row_sizes_ptr
963+
// are valid for indices [row_start, row_end) as provided by the JVM
945964
let row_addr = unsafe { *row_addresses_ptr.add(i) };
946965
let row_size = unsafe { *row_sizes_ptr.add(i) };
947966
parent_row.point_to(row_addr, row_size);
@@ -1006,6 +1025,8 @@ fn append_struct_fields_field_major(
10061025
if struct_is_null[row_idx] {
10071026
field_builder.append_null();
10081027
} else {
1028+
// SAFETY: Caller (append_columns) guarantees row_addresses_ptr and row_sizes_ptr
1029+
// are valid for indices [row_start, row_end) as provided by the JVM
10091030
let row_addr = unsafe { *row_addresses_ptr.add(i) };
10101031
let row_size = unsafe { *row_sizes_ptr.add(i) };
10111032
parent_row.point_to(row_addr, row_size);
@@ -1026,6 +1047,8 @@ fn append_struct_fields_field_major(
10261047
if struct_is_null[row_idx] {
10271048
field_builder.append_null();
10281049
} else {
1050+
// SAFETY: Caller (append_columns) guarantees row_addresses_ptr and row_sizes_ptr
1051+
// are valid for indices [row_start, row_end) as provided by the JVM
10291052
let row_addr = unsafe { *row_addresses_ptr.add(i) };
10301053
let row_size = unsafe { *row_sizes_ptr.add(i) };
10311054
parent_row.point_to(row_addr, row_size);
@@ -1048,6 +1071,8 @@ fn append_struct_fields_field_major(
10481071
if struct_is_null[row_idx] {
10491072
field_builder.append_null();
10501073
} else {
1074+
// SAFETY: Caller (append_columns) guarantees row_addresses_ptr and row_sizes_ptr
1075+
// are valid for indices [row_start, row_end) as provided by the JVM
10511076
let row_addr = unsafe { *row_addresses_ptr.add(i) };
10521077
let row_size = unsafe { *row_sizes_ptr.add(i) };
10531078
parent_row.point_to(row_addr, row_size);
@@ -1078,6 +1103,8 @@ fn append_struct_fields_field_major(
10781103
nested_addresses.push(0);
10791104
nested_sizes.push(0);
10801105
} else {
1106+
// SAFETY: Caller (append_columns) guarantees row_addresses_ptr and row_sizes_ptr
1107+
// are valid for indices [row_start, row_end) as provided by the JVM
10811108
let row_addr = unsafe { *row_addresses_ptr.add(i) };
10821109
let row_size = unsafe { *row_sizes_ptr.add(i) };
10831110
parent_row.point_to(row_addr, row_size);
@@ -1116,6 +1143,8 @@ fn append_struct_fields_field_major(
11161143
let null_row = SparkUnsafeRow::default();
11171144
append_field(dt, struct_builder, &null_row, field_idx)?;
11181145
} else {
1146+
// SAFETY: Caller (append_columns) guarantees row_addresses_ptr and row_sizes_ptr
1147+
// are valid for indices [row_start, row_end) as provided by the JVM
11191148
let row_addr = unsafe { *row_addresses_ptr.add(i) };
11201149
let row_size = unsafe { *row_sizes_ptr.add(i) };
11211150
parent_row.point_to(row_addr, row_size);

0 commit comments

Comments
 (0)