Skip to content

Commit a3926e2

Browse files
committed
Fix: OOB read in sz_order_skylake tail for length-mismatched strings
The tail used masked loads with different masks for `a` and `b`, then an unmasked `_mm512_cmpneq_epi8_mask`. When one string was shorter, its masked-off lanes were zero, so the longer string's bytes in those lanes compared as "not equal" against zero, producing spurious mismatch bits beyond `min(a_length, b_length)`. The follow-up `a[first_diff]` / `b[first_diff]` then read past the shorter buffer and could return a wrong ordering whenever the random byte at the OOB index outranked the real one — manifesting in `sz_sequence_argsort_with_insertion` as the empty string being placed after non-empty strings starting with `\0`, which made `test_sorting_algorithms` abort on the `"ab\0"` alphabet. Restrict the compare to `a_mask & b_mask` (the bytes valid in both) using the masked `_mm512_mask_cmpneq_epi8_mask`, mirroring how the head section and `sz_equal_skylake` already handle masked tails.
1 parent 267c757 commit a3926e2

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

include/stringzilla/compare.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,12 @@ SZ_PUBLIC sz_ordering_t sz_order_skylake(sz_cptr_t a, sz_size_t a_length, sz_cpt
374374
b_mask = sz_u64_clamp_mask_until_(b_length);
375375
a_vec.zmm = _mm512_maskz_loadu_epi8(a_mask, a);
376376
b_vec.zmm = _mm512_maskz_loadu_epi8(b_mask, b);
377-
// The AVX-512 `_mm512_mask_cmpneq_epi8_mask` intrinsics are generally handy in such environments.
378-
// They, however, have latency 3 on most modern CPUs. Using AVX2: `_mm256_cmpeq_epi8` would have
379-
// been cheaper, if we didn't have to apply `_mm256_movemask_epi8` afterwards.
380-
mask_not_equal = _mm512_cmpneq_epi8_mask(a_vec.zmm, b_vec.zmm);
377+
// Restrict the comparison to bytes valid in both strings. The masked loads zero out lanes
378+
// past each string's end, so an unmasked compare would see spurious mismatches against
379+
// those zeros in the longer string's tail and read out of bounds for the shorter one
380+
// (e.g. `sz_order("\0baa", 4, "", 0)` would dereference `b[1]`).
381+
__mmask64 const common_mask = a_mask & b_mask;
382+
mask_not_equal = _mm512_mask_cmpneq_epi8_mask(common_mask, a_vec.zmm, b_vec.zmm);
381383
if (mask_not_equal != 0) {
382384
// Reload from original memory (L1 cached) to avoid ZMM-to-stack spill.
383385
sz_u64_t first_diff = _tzcnt_u64(mask_not_equal);

0 commit comments

Comments
 (0)