From 37ee929c799b9b5ba4f03fc549171cab4360dde8 Mon Sep 17 00:00:00 2001 From: Scott Schenkein Date: Fri, 29 May 2026 20:14:21 -0400 Subject: [PATCH] fix(shuffle): get_string tolerates non-UTF-8 bytes (lossy decode) Spark's UnsafeRow.getUTF8String performs no UTF-8 validation, and cast(BinaryType -> StringType) is a zero-copy reinterpret, so a StringType column can legitimately hold arbitrary non-UTF-8 bytes. get_string decoded with from_utf8(..).unwrap(), which panics on such rows even though Spark treats them as opaque. Use from_utf8_lossy (returning Cow): a zero-cost borrow for valid UTF-8 and a String with U+FFFD replacements otherwise -- defined behavior, no UB. Avoids from_utf8_unchecked, which would construct a &str from arbitrary bytes (UB) and propagate into downstream Arrow ops. Adds a standalone unit test that panics without the fix and passes with it. Closes #4521 Co-Authored-By: Claude Opus 4.7 --- native/shuffle/src/spark_unsafe/row.rs | 27 +++++++++++++++++++ .../shuffle/src/spark_unsafe/unsafe_object.rs | 19 ++++++++----- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/native/shuffle/src/spark_unsafe/row.rs b/native/shuffle/src/spark_unsafe/row.rs index 6ffe9d0b6e..449371463d 100644 --- a/native/shuffle/src/spark_unsafe/row.rs +++ b/native/shuffle/src/spark_unsafe/row.rs @@ -1509,4 +1509,31 @@ mod test { assert_eq!(struct_array.len(), 1); assert!(struct_array.is_null(0)); } + + // Spark's `UnsafeRow.getUTF8String` performs no UTF-8 validation, and + // `cast(BinaryType -> StringType)` is a zero-copy reinterpret -- so a StringType field can + // hold arbitrary non-UTF-8 bytes. `get_string` must not panic on those; it should decode + // lossily, matching Spark treating the bytes as opaque. + #[test] + fn get_string_tolerates_non_utf8_bytes() { + // One string field. Row layout: 8-byte null bitset + an 8-byte (offset<<32 | len) slot, + // then the variable-length region. 8-byte aligned to match a real Spark UnsafeRow. + #[repr(align(8))] + struct Aligned([u8; 24]); + let mut data = Aligned([0u8; 24]); + // Invalid UTF-8 bytes at offset 16: 0xFF, 0xFE, then ASCII 'A'. + data.0[16] = 0xFF; + data.0[17] = 0xFE; + data.0[18] = b'A'; + // Field 0 slot: offset = 16, len = 3. + let offset_and_len: i64 = (16i64 << 32) | 3; + data.0[8..16].copy_from_slice(&offset_and_len.to_ne_bytes()); + + let mut row = SparkUnsafeRow::new_with_num_fields(1); + row.point_to_slice(&data.0); + + // Strict `from_utf8(..).unwrap()` panics here; lossy decode replaces each invalid byte + // with U+FFFD. `&*` works whether get_string returns `&str` or `Cow`. + assert_eq!(&*row.get_string(0), "\u{FFFD}\u{FFFD}A"); + } } diff --git a/native/shuffle/src/spark_unsafe/unsafe_object.rs b/native/shuffle/src/spark_unsafe/unsafe_object.rs index f32ea8c23b..5b4ce42e36 100644 --- a/native/shuffle/src/spark_unsafe/unsafe_object.rs +++ b/native/shuffle/src/spark_unsafe/unsafe_object.rs @@ -19,7 +19,7 @@ use super::list::SparkUnsafeArray; use super::map::SparkUnsafeMap; use super::row::SparkUnsafeRow; use datafusion_comet_common::bytes_to_i128; -use std::str::from_utf8; +use std::borrow::Cow; const MAX_LONG_DIGITS: u8 = 18; @@ -75,19 +75,26 @@ pub trait SparkUnsafeObject { } /// Returns string value at the given index of the object. - fn get_string(&self, index: usize) -> &str { + /// + /// Spark's `UnsafeRow.getUTF8String` wraps the bytes via `UTF8String.fromAddress` with no + /// UTF-8 validation, and Spark's `cast(BinaryType -> StringType)` is a zero-copy reinterpret + /// that can leave arbitrary bytes in a `StringType` column. Strict `from_utf8(..).unwrap()` + /// here panics on those rows even though Spark itself treats them as opaque. We use + /// `from_utf8_lossy`: it returns the original `&str` borrow for valid UTF-8 (zero-cost) and a + /// `String` with `U+FFFD` replacements for invalid bytes (defined behavior, no UB). This + /// avoids `from_utf8_unchecked`, which would construct a `&str` from arbitrary bytes -- UB per + /// the Rust reference, and would propagate into downstream Arrow ops that internally call + /// `str::from_utf8_unchecked` on the buffer. + fn get_string(&self, index: usize) -> Cow<'_, str> { let (offset, len) = self.get_offset_and_len(index); let addr = self.get_row_addr() + offset as i64; - // SAFETY: addr points to valid UTF-8 string data within the variable-length region. - // Offset and length are read from the fixed-length portion of the row/array. debug_assert!(addr != 0, "get_string: null address at index {index}"); debug_assert!( len >= 0, "get_string: negative length {len} at index {index}" ); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }; - - from_utf8(slice).unwrap() + String::from_utf8_lossy(slice) } /// Returns binary value at the given index of the object.