Skip to content

Commit 24f94ed

Browse files
committed
FSST contains: NEON SSA fusion in Teddy pair/triple
Drops the TODO on the NEON Teddy passes. Mirrors the AVX2 / AVX-512 implementations: at setup, load the SSA nibble tables into NEON registers (or splat zero when SSA is absent); in the inner loop, compute `ssa_bits = neon_nibble_lookup(ssa_lo, ssa_hi, v1, nibble_mask)` and `vorrq_u8` it into the Teddy candidate vector before the movemask. Scalar tail picks up SSA via the same one-line nibble table check used in the AVX2 tail, and the pair NEON path adds the last-byte SSA-only check. The NEON code is `#[cfg(target_arch = "aarch64")]`-gated; no runtime change on x86_64 (which already does fused SSA via AVX2 / AVX-512). Cross-compile-checked locally is not available; logic is byte-for-byte parallel to the AVX2 path. Signed-off-by: Claude <noreply@anthropic.com>
1 parent 5e5334c commit 24f94ed

1 file changed

Lines changed: 59 additions & 10 deletions

File tree

encodings/fsst/src/dfa/anchor_scan.rs

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,11 +1708,7 @@ unsafe fn teddy_pair_pass_avx2<T, V>(
17081708
#[expect(unsafe_op_in_unsafe_fn)]
17091709
unsafe fn teddy_pair_pass_neon<T, V>(
17101710
tables: &BucketTables,
1711-
// TODO: NEON SSA fusion. For now the NEON path doesn't handle the
1712-
// SSA set inline; on AArch64 the caller falls back to the
1713-
// non-fused 1-byte path when SSA codes exist, so correctness is
1714-
// preserved.
1715-
_ssa_tables: Option<&NibbleTables>,
1711+
ssa_tables: Option<&NibbleTables>,
17161712
n: usize,
17171713
offsets: &[T],
17181714
all_bytes: &[u8],
@@ -1740,6 +1736,14 @@ unsafe fn teddy_pair_pass_neon<T, V>(
17401736
let zero = vdupq_n_u8(0);
17411737
let nibble_mask = vdupq_n_u8(0x0F);
17421738
let lane_bits = vld1q_u8(NEON_MOVEMASK_BITS.as_ptr());
1739+
// Fused SSA tables (mirrors the AVX2 path). When `ssa_tables` is
1740+
// None we splat zero — `vqtbl1q_u8` of zero is zero and the
1741+
// subsequent `vorrq_u8` is a no-op, keeping the inner loop
1742+
// branch-free.
1743+
let (ssa_lo, ssa_hi, has_ssa) = match ssa_tables {
1744+
Some(t) => (vld1q_u8(t.lo.as_ptr()), vld1q_u8(t.hi.as_ptr()), true),
1745+
None => (zero, zero, false),
1746+
};
17431747
let setup_us = setup_t
17441748
.map(|t| t.elapsed().as_secs_f64() * 1e6)
17451749
.unwrap_or_default();
@@ -1765,7 +1769,15 @@ unsafe fn teddy_pair_pass_neon<T, V>(
17651769
let c1_bits = neon_nibble_lookup(c1_lo, c1_hi, v1, nibble_mask);
17661770
let c2_bits = neon_nibble_lookup(c2_lo, c2_hi, v2, nibble_mask);
17671771
let pair = vandq_u8(c1_bits, c2_bits);
1768-
let mut mask = u32::from(neon_nonzero_mask(pair, zero, lane_bits));
1772+
// Fused SSA lookup on the same v1: bit `b` set in `ssa_bits`
1773+
// iff `v1[lane]` matches an SSA code value.
1774+
let combined = if has_ssa {
1775+
let ssa_bits = neon_nibble_lookup(ssa_lo, ssa_hi, v1, nibble_mask);
1776+
vorrq_u8(pair, ssa_bits)
1777+
} else {
1778+
pair
1779+
};
1780+
let mut mask = u32::from(neon_nonzero_mask(combined, zero, lane_bits));
17691781
if mask != 0 {
17701782
nonzero_masks += usize::from(trace);
17711783
let candidate_t = trace.then(std::time::Instant::now);
@@ -1845,7 +1857,10 @@ unsafe fn teddy_pair_pass_neon<T, V>(
18451857
tables.c1.lo[usize::from(b1 & 0x0F)] & tables.c1.hi[usize::from(b1 >> 4)];
18461858
let c2_bits_b =
18471859
tables.c2.lo[usize::from(b2 & 0x0F)] & tables.c2.hi[usize::from(b2 >> 4)];
1848-
if (c1_bits_b & c2_bits_b) == 0 {
1860+
let pair_hit = (c1_bits_b & c2_bits_b) != 0;
1861+
let ssa_hit = ssa_tables
1862+
.is_some_and(|t| (t.lo[usize::from(b1 & 0x0F)] & t.hi[usize::from(b1 >> 4)]) != 0);
1863+
if !pair_hit && !ssa_hit {
18491864
continue;
18501865
}
18511866
tail_candidates += usize::from(trace);
@@ -1885,6 +1900,31 @@ unsafe fn teddy_pair_pass_neon<T, V>(
18851900
}
18861901
}
18871902
}
1903+
// Last-position SSA-only candidate (no successor for the pair check).
1904+
if let Some(t) = ssa_tables {
1905+
let j = len - 1;
1906+
let b = *all_bytes.get_unchecked(j);
1907+
let ssa_hit = (t.lo[usize::from(b & 0x0F)] & t.hi[usize::from(b >> 4)]) != 0;
1908+
if ssa_hit && j >= scan_start {
1909+
while j >= string_end && string_idx < n {
1910+
string_idx += 1;
1911+
if string_idx < n {
1912+
string_end = (*offsets.get_unchecked(string_idx + 1)).as_();
1913+
}
1914+
}
1915+
if string_idx < n {
1916+
let already = bits.value(string_idx);
1917+
let already_match = if negated { !already } else { already };
1918+
if !already_match && verify_at(j, string_end) {
1919+
if negated {
1920+
bits.unset_unchecked(string_idx);
1921+
} else {
1922+
bits.set_unchecked(string_idx);
1923+
}
1924+
}
1925+
}
1926+
}
1927+
}
18881928
}
18891929
let tail_us = tail_t
18901930
.map(|t| t.elapsed().as_secs_f64() * 1e6)
@@ -2752,8 +2792,7 @@ unsafe fn teddy_triple_pass_avx2<T, V>(
27522792
#[expect(unsafe_op_in_unsafe_fn)]
27532793
unsafe fn teddy_triple_pass_neon<T, V>(
27542794
tables: &TripleTables,
2755-
// TODO: NEON SSA fusion. See `teddy_pair_pass_neon`.
2756-
_ssa_tables: Option<&NibbleTables>,
2795+
ssa_tables: Option<&NibbleTables>,
27572796
n: usize,
27582797
offsets: &[T],
27592798
all_bytes: &[u8],
@@ -2783,6 +2822,10 @@ unsafe fn teddy_triple_pass_neon<T, V>(
27832822
let zero = vdupq_n_u8(0);
27842823
let nibble_mask = vdupq_n_u8(0x0F);
27852824
let lane_bits = vld1q_u8(NEON_MOVEMASK_BITS.as_ptr());
2825+
let (ssa_lo, ssa_hi, has_ssa) = match ssa_tables {
2826+
Some(t) => (vld1q_u8(t.lo.as_ptr()), vld1q_u8(t.hi.as_ptr()), true),
2827+
None => (zero, zero, false),
2828+
};
27862829
let setup_us = setup_t
27872830
.map(|t| t.elapsed().as_secs_f64() * 1e6)
27882831
.unwrap_or_default();
@@ -2812,8 +2855,14 @@ unsafe fn teddy_triple_pass_neon<T, V>(
28122855
let c2_bits = neon_nibble_lookup(c2_lo, c2_hi, v2, nibble_mask);
28132856
let c3_bits = neon_nibble_lookup(c3_lo, c3_hi, v3, nibble_mask);
28142857
let triple = vandq_u8(vandq_u8(c1_bits, c2_bits), c3_bits);
2858+
let combined = if has_ssa {
2859+
let ssa_bits = neon_nibble_lookup(ssa_lo, ssa_hi, v1, nibble_mask);
2860+
vorrq_u8(triple, ssa_bits)
2861+
} else {
2862+
triple
2863+
};
28152864

2816-
let mut mask = u32::from(neon_nonzero_mask(triple, zero, lane_bits));
2865+
let mut mask = u32::from(neon_nonzero_mask(combined, zero, lane_bits));
28172866
if mask != 0 {
28182867
nonzero_masks += usize::from(trace);
28192868
let candidate_t = trace.then(std::time::Instant::now);

0 commit comments

Comments
 (0)