Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions native/shuffle/src/spark_unsafe/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1509,4 +1509,31 @@ mod test {
assert_eq!(struct_array.len(), 1);
assert!(struct_array.is_null(0));
}

// Spark's `UnsafeRow.getUTF8String` performs no UTF-8 validation, and
// `cast(BinaryType -> StringType)` is a zero-copy reinterpret -- so a StringType field can
// hold arbitrary non-UTF-8 bytes. `get_string` must not panic on those; it should decode
// lossily, matching Spark treating the bytes as opaque.
#[test]
fn get_string_tolerates_non_utf8_bytes() {
// One string field. Row layout: 8-byte null bitset + an 8-byte (offset<<32 | len) slot,
// then the variable-length region. 8-byte aligned to match a real Spark UnsafeRow.
#[repr(align(8))]
struct Aligned([u8; 24]);
let mut data = Aligned([0u8; 24]);
// Invalid UTF-8 bytes at offset 16: 0xFF, 0xFE, then ASCII 'A'.
data.0[16] = 0xFF;
data.0[17] = 0xFE;
data.0[18] = b'A';
// Field 0 slot: offset = 16, len = 3.
let offset_and_len: i64 = (16i64 << 32) | 3;
data.0[8..16].copy_from_slice(&offset_and_len.to_ne_bytes());

let mut row = SparkUnsafeRow::new_with_num_fields(1);
row.point_to_slice(&data.0);

// Strict `from_utf8(..).unwrap()` panics here; lossy decode replaces each invalid byte
// with U+FFFD. `&*` works whether get_string returns `&str` or `Cow<str>`.
assert_eq!(&*row.get_string(0), "\u{FFFD}\u{FFFD}A");
}
}
19 changes: 13 additions & 6 deletions native/shuffle/src/spark_unsafe/unsafe_object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use super::list::SparkUnsafeArray;
use super::map::SparkUnsafeMap;
use super::row::SparkUnsafeRow;
use datafusion_comet_common::bytes_to_i128;
use std::str::from_utf8;
use std::borrow::Cow;

const MAX_LONG_DIGITS: u8 = 18;

Expand Down Expand Up @@ -75,19 +75,26 @@ pub trait SparkUnsafeObject {
}

/// Returns string value at the given index of the object.
fn get_string(&self, index: usize) -> &str {
///
/// Spark's `UnsafeRow.getUTF8String` wraps the bytes via `UTF8String.fromAddress` with no
/// UTF-8 validation, and Spark's `cast(BinaryType -> StringType)` is a zero-copy reinterpret
/// that can leave arbitrary bytes in a `StringType` column. Strict `from_utf8(..).unwrap()`
/// here panics on those rows even though Spark itself treats them as opaque. We use
/// `from_utf8_lossy`: it returns the original `&str` borrow for valid UTF-8 (zero-cost) and a
/// `String` with `U+FFFD` replacements for invalid bytes (defined behavior, no UB). This
/// avoids `from_utf8_unchecked`, which would construct a `&str` from arbitrary bytes -- UB per
/// the Rust reference, and would propagate into downstream Arrow ops that internally call
/// `str::from_utf8_unchecked` on the buffer.
fn get_string(&self, index: usize) -> Cow<'_, str> {
let (offset, len) = self.get_offset_and_len(index);
let addr = self.get_row_addr() + offset as i64;
// SAFETY: addr points to valid UTF-8 string data within the variable-length region.
// Offset and length are read from the fixed-length portion of the row/array.
debug_assert!(addr != 0, "get_string: null address at index {index}");
debug_assert!(
len >= 0,
"get_string: negative length {len} at index {index}"
);
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) };

from_utf8(slice).unwrap()
String::from_utf8_lossy(slice)
}

/// Returns binary value at the given index of the object.
Expand Down