Skip to content

Commit 6ea6b1d

Browse files
schenksjclaudeandygrove
authored
fix(shuffle): tolerate non-UTF-8 bytes in get_string (lossy decode) (#4524)
* fix(shuffle): get_string tolerates non-UTF-8 bytes (lossy decode) Spark's UnsafeRow.getUTF8String performs no UTF-8 validation, and cast(BinaryType -> StringType) is a zero-copy reinterpret, so a StringType column can legitimately hold arbitrary non-UTF-8 bytes. get_string decoded with from_utf8(..).unwrap(), which panics on such rows even though Spark treats them as opaque. Use from_utf8_lossy (returning Cow<str>): a zero-cost borrow for valid UTF-8 and a String with U+FFFD replacements otherwise -- defined behavior, no UB. Avoids from_utf8_unchecked, which would construct a &str from arbitrary bytes (UB) and propagate into downstream Arrow ops. Adds a standalone unit test that panics without the fix and passes with it. Closes #4521 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * test: add end-to-end shuffle test for non-UTF-8 StringType bytes (#4521) Address review feedback: add a Spark-level regression test demonstrating the bug. cast(binary -> string) is a zero-copy reinterpret in Spark, so a StringType column can hold arbitrary non-UTF-8 bytes. The test disables Comet's Cast so those raw bytes reach Comet's columnar (JVM) shuffle inside a JVM UnsafeRow, exercising the native row->Arrow get_string path that used to panic via from_utf8(..).unwrap() and now decodes lossily. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * fix(shuffle): match JVM U+FFFD granularity in get_string decode Replace `String::from_utf8_lossy` in `get_string` with `decode_utf8_spark_lossy`, which mirrors `sun.nio.cs.UTF_8.Decoder` (action REPLACE) byte-for-byte so a Comet columnar shuffle of arbitrary bytes renders identically to a Spark JVM shuffle. `from_utf8_lossy` follows the Unicode "maximal subpart" rule and can emit more than one U+FFFD per ill-formed multi-byte unit; the JDK collapses certain units (notably surrogate-range three-byte sequences `ED A0..BF ..`, e.g. CESU-8 / modified-UTF-8 supplementary chars) into a single U+FFFD. Valid UTF-8 still returns a zero-cost borrow via the fast path. Tests use JDK-17 `new String(bytes, UTF_8)` output as the oracle: a 7-case replacement-granularity table (incl. the `ED A0 80` -> single U+FFFD parity case), zero-copy borrow for valid UTF-8, and valid multibyte chars preserved around an invalid byte. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * docs(shuffle): frame JDK UTF-8 decoder parity as behavioral, note provenance Address ASF-provenance review feedback on decode_utf8_spark_lossy: reword the doc comment so it describes the *observable* replacement behavior of the JDK UTF-8 decoder rather than saying the per-class malformed lengths "mirror sun.nio.cs.UTF_8.Decoder" (which implies derivation from that class). State that they were determined from observed `new String(bytes, UTF_8)` output, not by reviewing the OpenJDK source. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * style: cargo fmt unsafe_object.rs utf8_lossy test tuple Resolves the rustfmt diff at unsafe_object.rs:370 that failed the Lint check. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com> Co-authored-by: Andy Grove <agrove@apache.org>
1 parent 642e360 commit 6ea6b1d

3 files changed

Lines changed: 245 additions & 6 deletions

File tree

native/shuffle/src/spark_unsafe/row.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,4 +1509,31 @@ mod test {
15091509
assert_eq!(struct_array.len(), 1);
15101510
assert!(struct_array.is_null(0));
15111511
}
1512+
1513+
// Spark's `UnsafeRow.getUTF8String` performs no UTF-8 validation, and
1514+
// `cast(BinaryType -> StringType)` is a zero-copy reinterpret -- so a StringType field can
1515+
// hold arbitrary non-UTF-8 bytes. `get_string` must not panic on those; it should decode
1516+
// lossily, matching Spark treating the bytes as opaque.
1517+
#[test]
1518+
fn get_string_tolerates_non_utf8_bytes() {
1519+
// One string field. Row layout: 8-byte null bitset + an 8-byte (offset<<32 | len) slot,
1520+
// then the variable-length region. 8-byte aligned to match a real Spark UnsafeRow.
1521+
#[repr(align(8))]
1522+
struct Aligned([u8; 24]);
1523+
let mut data = Aligned([0u8; 24]);
1524+
// Invalid UTF-8 bytes at offset 16: 0xFF, 0xFE, then ASCII 'A'.
1525+
data.0[16] = 0xFF;
1526+
data.0[17] = 0xFE;
1527+
data.0[18] = b'A';
1528+
// Field 0 slot: offset = 16, len = 3.
1529+
let offset_and_len: i64 = (16i64 << 32) | 3;
1530+
data.0[8..16].copy_from_slice(&offset_and_len.to_ne_bytes());
1531+
1532+
let mut row = SparkUnsafeRow::new_with_num_fields(1);
1533+
row.point_to_slice(&data.0);
1534+
1535+
// Strict `from_utf8(..).unwrap()` panics here; lossy decode replaces each invalid byte
1536+
// with U+FFFD. `&*` works whether get_string returns `&str` or `Cow<str>`.
1537+
assert_eq!(&*row.get_string(0), "\u{FFFD}\u{FFFD}A");
1538+
}
15121539
}

native/shuffle/src/spark_unsafe/unsafe_object.rs

Lines changed: 188 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,127 @@ use super::list::SparkUnsafeArray;
1919
use super::map::SparkUnsafeMap;
2020
use super::row::SparkUnsafeRow;
2121
use datafusion_comet_common::bytes_to_i128;
22-
use std::str::from_utf8;
22+
use std::borrow::Cow;
2323

2424
const MAX_LONG_DIGITS: u8 = 18;
2525

26+
/// Decode `bytes` as UTF-8 the way Spark renders `StringType` -- `new String(bytes, UTF_8)` on the
27+
/// JVM -- replacing each ill-formed sequence with a single `U+FFFD` and skipping the same number of
28+
/// bytes the JDK's UTF-8 `CharsetDecoder` (action REPLACE) would. Valid UTF-8 is returned as a
29+
/// zero-cost borrow.
30+
///
31+
/// This intentionally differs from `str::from_utf8_lossy` for surrogate-range three-byte sequences
32+
/// (`ED A0..BF ..`, e.g. CESU-8 / Java modified-UTF-8 supplementary chars) and for some other
33+
/// ill-formed multi-byte units: `from_utf8_lossy` follows the Unicode "maximal subpart" rule and
34+
/// can emit one `U+FFFD` per byte, whereas the JDK collapses certain ill-formed units into a single
35+
/// `U+FFFD`. Matching the JDK byte-for-byte means a Comet columnar shuffle of arbitrary bytes
36+
/// renders identically to a Spark JVM shuffle. The per-class malformed lengths below
37+
/// (E0/ED overlong & surrogate handling, F0/F4 range checks) match the observable replacement
38+
/// behavior of the JDK UTF-8 decoder; they were determined from observed
39+
/// `new String(bytes, UTF_8)` output, not by reviewing the OpenJDK source.
40+
pub(crate) fn decode_utf8_spark_lossy(bytes: &[u8]) -> Cow<'_, str> {
41+
// Fast path: well-formed UTF-8 borrows with zero copy (the overwhelmingly common case).
42+
if let Ok(s) = std::str::from_utf8(bytes) {
43+
return Cow::Borrowed(s);
44+
}
45+
46+
const RC: char = '\u{FFFD}';
47+
let n = bytes.len();
48+
let mut out = String::with_capacity(n);
49+
let mut i = 0;
50+
while i < n {
51+
let b1 = bytes[i];
52+
if b1 < 0x80 {
53+
out.push(b1 as char);
54+
i += 1;
55+
} else if (0xC2..=0xDF).contains(&b1) {
56+
// 2-byte lead. Bad/absent continuation -> single FFFD, skip 1.
57+
if i + 1 < n && (bytes[i + 1] & 0xC0) == 0x80 {
58+
let cp = (((b1 as u32) & 0x1F) << 6) | ((bytes[i + 1] as u32) & 0x3F);
59+
out.push(char::from_u32(cp).unwrap());
60+
i += 2;
61+
} else {
62+
out.push(RC);
63+
i += 1;
64+
}
65+
} else if (0xE0..=0xEF).contains(&b1) {
66+
// 3-byte lead.
67+
if i + 1 >= n {
68+
out.push(RC); // truncated lead at EOF
69+
i = n;
70+
} else {
71+
let b2 = bytes[i + 1];
72+
if (b1 == 0xE0 && (b2 & 0xE0) == 0x80) || (b2 & 0xC0) != 0x80 {
73+
// overlong (E0 80..9F) or b2 not a continuation -> skip 1
74+
out.push(RC);
75+
i += 1;
76+
} else if i + 2 >= n {
77+
out.push(RC); // truncated after a valid b2 at EOF
78+
i = n;
79+
} else {
80+
let b3 = bytes[i + 2];
81+
if (b3 & 0xC0) != 0x80 {
82+
out.push(RC); // b3 not a continuation -> skip 2
83+
i += 2;
84+
} else {
85+
let cp = (((b1 as u32) & 0x0F) << 12)
86+
| (((b2 as u32) & 0x3F) << 6)
87+
| ((b3 as u32) & 0x3F);
88+
if (0xD800..=0xDFFF).contains(&cp) {
89+
// surrogate (e.g. ED A0 80) -> JDK skips all 3, single FFFD
90+
out.push(RC);
91+
i += 3;
92+
} else {
93+
out.push(char::from_u32(cp).unwrap());
94+
i += 3;
95+
}
96+
}
97+
}
98+
}
99+
} else if (0xF0..=0xF4).contains(&b1) {
100+
// 4-byte lead.
101+
if i + 1 >= n {
102+
out.push(RC);
103+
i = n;
104+
} else {
105+
let b2 = bytes[i + 1];
106+
if (b1 == 0xF0 && !(0x90..=0xBF).contains(&b2))
107+
|| (b1 == 0xF4 && (b2 & 0xF0) != 0x80)
108+
|| (b2 & 0xC0) != 0x80
109+
{
110+
out.push(RC); // bad b2 -> skip 1
111+
i += 1;
112+
} else if i + 2 >= n {
113+
out.push(RC);
114+
i = n;
115+
} else if (bytes[i + 2] & 0xC0) != 0x80 {
116+
out.push(RC); // b3 not a continuation -> skip 2
117+
i += 2;
118+
} else if i + 3 >= n {
119+
out.push(RC);
120+
i = n;
121+
} else if (bytes[i + 3] & 0xC0) != 0x80 {
122+
out.push(RC); // b4 not a continuation -> skip 3
123+
i += 3;
124+
} else {
125+
let cp = (((b1 as u32) & 0x07) << 18)
126+
| (((b2 as u32) & 0x3F) << 12)
127+
| (((bytes[i + 2] as u32) & 0x3F) << 6)
128+
| ((bytes[i + 3] as u32) & 0x3F);
129+
out.push(char::from_u32(cp).unwrap());
130+
i += 4;
131+
}
132+
}
133+
} else {
134+
// Lone continuation (0x80..0xBF), overlong 2-byte leads (0xC0/0xC1), or out-of-range
135+
// 4-byte leads (0xF5..0xFF): each is a single ill-formed byte -> skip 1.
136+
out.push(RC);
137+
i += 1;
138+
}
139+
}
140+
Cow::Owned(out)
141+
}
142+
26143
/// A common trait for Spark Unsafe classes that can be used to access the underlying data,
27144
/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that can be used to
28145
/// access the underlying data with index.
@@ -75,19 +192,31 @@ pub trait SparkUnsafeObject {
75192
}
76193

77194
/// Returns string value at the given index of the object.
78-
fn get_string(&self, index: usize) -> &str {
195+
///
196+
/// Spark's `UnsafeRow.getUTF8String` wraps the bytes via `UTF8String.fromAddress` with no
197+
/// UTF-8 validation, and Spark's `cast(BinaryType -> StringType)` is a zero-copy reinterpret
198+
/// that can leave arbitrary bytes in a `StringType` column. Strict `from_utf8(..).unwrap()`
199+
/// here panics on those rows even though Spark itself treats them as opaque. We use
200+
/// `from_utf8_lossy`: it returns the original `&str` borrow for valid UTF-8 (zero-cost) and a
201+
/// `String` with `U+FFFD` replacements for invalid bytes (defined behavior, no UB). This
202+
/// avoids `from_utf8_unchecked`, which would construct a `&str` from arbitrary bytes -- UB per
203+
/// the Rust reference, and would propagate into downstream Arrow ops that internally call
204+
/// `str::from_utf8_unchecked` on the buffer.
205+
///
206+
/// We decode via [`decode_utf8_spark_lossy`] rather than `String::from_utf8_lossy` so the
207+
/// `U+FFFD` replacement granularity matches Spark's `new String(bytes, UTF_8)` EXACTLY,
208+
/// including surrogate-range three-byte sequences (`ED A0..BF ..`) where the two std libraries
209+
/// disagree -- so a Comet shuffle of arbitrary bytes renders identically to a Spark shuffle.
210+
fn get_string(&self, index: usize) -> Cow<'_, str> {
79211
let (offset, len) = self.get_offset_and_len(index);
80212
let addr = self.get_row_addr() + offset as i64;
81-
// SAFETY: addr points to valid UTF-8 string data within the variable-length region.
82-
// Offset and length are read from the fixed-length portion of the row/array.
83213
debug_assert!(addr != 0, "get_string: null address at index {index}");
84214
debug_assert!(
85215
len >= 0,
86216
"get_string: negative length {len} at index {index}"
87217
);
88218
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) };
89-
90-
from_utf8(slice).unwrap()
219+
decode_utf8_spark_lossy(slice)
91220
}
92221

93222
/// Returns binary value at the given index of the object.
@@ -222,3 +351,56 @@ macro_rules! impl_primitive_accessors {
222351
};
223352
}
224353
pub(crate) use impl_primitive_accessors;
354+
355+
#[cfg(test)]
356+
mod utf8_lossy_tests {
357+
use super::decode_utf8_spark_lossy;
358+
use std::borrow::Cow;
359+
360+
/// Oracle = JDK 17 `new String(bytes, StandardCharsets.UTF_8)` (the renderer Spark uses for
361+
/// StringType). Each row's expected output was verified against the JVM. The decoder must match
362+
/// it byte-for-byte -- including the surrogate-range case where `str::from_utf8_lossy` differs.
363+
#[test]
364+
fn matches_jvm_replacement_granularity() {
365+
let cases: &[(&[u8], &str)] = &[
366+
(&[0xFF, 0xFE, 0x41], "\u{FFFD}\u{FFFD}A"),
367+
(&[0x80, 0x42], "\u{FFFD}B"),
368+
(&[0xE0, 0x80], "\u{FFFD}\u{FFFD}"),
369+
(&[0xF0, 0x80, 0x80, 0x41], "\u{FFFD}\u{FFFD}\u{FFFD}A"),
370+
(&[0xC0, 0xAF], "\u{FFFD}\u{FFFD}"),
371+
// The parity case: Rust's from_utf8_lossy would give three U+FFFD here.
372+
(&[0xED, 0xA0, 0x80], "\u{FFFD}"),
373+
(
374+
&[0xF4, 0x90, 0x80, 0x80],
375+
"\u{FFFD}\u{FFFD}\u{FFFD}\u{FFFD}",
376+
),
377+
];
378+
for (bytes, expected) in cases {
379+
assert_eq!(
380+
decode_utf8_spark_lossy(bytes),
381+
*expected,
382+
"bytes {bytes:02x?} should render like the JVM"
383+
);
384+
}
385+
}
386+
387+
#[test]
388+
fn valid_utf8_is_borrowed_zero_copy() {
389+
let s = "café — 日本語 🦀";
390+
match decode_utf8_spark_lossy(s.as_bytes()) {
391+
Cow::Borrowed(b) => assert_eq!(b, s),
392+
Cow::Owned(_) => panic!("valid UTF-8 must borrow, not allocate"),
393+
}
394+
}
395+
396+
#[test]
397+
fn valid_multibyte_around_invalid_bytes_decodes() {
398+
// 'a' | é (C3 A9) | stray 0xFF | 'b' | 🦀 (F0 9F A6 80) -> valid chars preserved, one FFFD.
399+
let mut bytes = vec![b'a'];
400+
bytes.extend_from_slice("é".as_bytes());
401+
bytes.push(0xFF);
402+
bytes.push(b'b');
403+
bytes.extend_from_slice("🦀".as_bytes());
404+
assert_eq!(decode_utf8_spark_lossy(&bytes), "aé\u{FFFD}b🦀");
405+
}
406+
}

spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,36 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar
761761
}
762762
}
763763

764+
// Regression test for https://github.com/apache/datafusion-comet/issues/4521.
765+
//
766+
// Spark's `cast(BinaryType -> StringType)` is a zero-copy reinterpret (and `UnsafeRow`'s
767+
// string accessor performs no UTF-8 validation), so a `StringType` column can legitimately
768+
// hold arbitrary non-UTF-8 bytes that Spark treats as opaque. Comet's columnar (JVM) shuffle
769+
// converts those `UnsafeRow`s to Arrow natively (`process_sorted_row_partition` -> `get_string`),
770+
// which used to decode with `from_utf8(..).unwrap()` and panic on such rows. It now decodes
771+
// lossily (U+FFFD replacements), matching how Spark renders the same bytes.
772+
test("columnar shuffle tolerates non-UTF-8 bytes in a StringType column") {
773+
withParquetTable(
774+
Seq(
775+
// 0xFF and 0xFE are never valid UTF-8 lead bytes; each decodes to a single U+FFFD in
776+
// both Spark and Comet (so the lossy results match exactly).
777+
(1, Array[Byte](0xff.toByte, 0xfe.toByte, 'A'.toByte)),
778+
// 0x80 is a stray continuation byte -> one U+FFFD, followed by valid ASCII.
779+
(2, Array[Byte](0x80.toByte, 'B'.toByte)),
780+
// A fully valid UTF-8 row exercises the zero-cost borrow path.
781+
(3, "valid".getBytes("UTF-8"))),
782+
"tbl") {
783+
// Disable Comet's own Cast so the `cast(binary -> string)` runs in Spark and the raw bytes
784+
// reach the shuffle inside a JVM UnsafeRow. (If Comet performed the cast it would produce a
785+
// pre-sanitized Arrow string array and never exercise get_string.)
786+
withSQLConf(CometConf.getExprEnabledConfigKey("Cast") -> "false") {
787+
val df = sql("SELECT _1, CAST(_2 AS STRING) AS s FROM tbl")
788+
val shuffled = df.repartition(2, $"_1")
789+
checkShuffleAnswer(shuffled, 1)
790+
}
791+
}
792+
}
793+
764794
/**
765795
* Checks that `df` produces the same answer as Spark does, and has the `expectedNum` Comet
766796
* exchange operators.

0 commit comments

Comments
 (0)