Skip to content

Commit 0d09244

Browse files
committed
ZJIT: Fix spurious FixnumMult overflow side-exits on ARM64
When a FixnumMult output was spilled to the stack, arm64_scratch_split inserted a Store between Mul and the RShift it creates for the overflow check. This broke the emit-time pattern match for [Mul, RShift, JoMul], causing smulh+cmp to never be emitted. JoMul then branched on stale condition flags, producing spurious overflow exits on every call. Fix by reordering scratch_split to emit RShift immediately after Mul (before the spill Store), and teaching the emit pass to handle the [Mul, RShift, Store, JoMul] pattern by emitting the spill via stur between mul and asr.
1 parent 231bd61 commit 0d09244

3 files changed

Lines changed: 336 additions & 31 deletions

File tree

zjit/src/backend/arm64/mod.rs

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -799,18 +799,31 @@ impl Assembler {
799799
let mem_out = split_memory_write(out, SCRATCH0_OPND);
800800
let reg_out = out.clone();
801801

802-
asm.push_insn(insn);
802+
let has_jo_mul = idx + 1 < linearized_insns.len() && matches!(linearized_insns[idx + 1], Insn::JoMul(_));
803803

804-
if let Some(mem_out) = mem_out {
805-
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
806-
asm.store(mem_out, SCRATCH0_OPND);
807-
};
804+
asm.push_insn(insn);
808805

809-
// If the next instruction is JoMul
810-
if idx + 1 < linearized_insns.len() && matches!(linearized_insns[idx + 1], Insn::JoMul(_)) {
811-
// Produce a register that is all zeros or all ones
812-
// Based on the sign bit of the 64-bit mul result
806+
// When JoMul follows, the emit pass needs Mul → RShift → JoMul
807+
// to be contiguous so it can pair smulh+mul+asr+cmp. The spill
808+
// Store must NOT be between Mul and RShift. Instead, we record
809+
// the spill destination in the RShift and have the emit pass
810+
// emit the store between mul and asr (before asr clobbers the
811+
// mul output register).
812+
if has_jo_mul {
813+
// Emit RShift immediately after Mul (before any Store)
813814
asm.push_insn(Insn::RShift { out: SCRATCH0_OPND, opnd: reg_out, shift: Opnd::UImm(63) });
815+
// Emit spill Store after RShift. The emit pass will
816+
// skip it along with the RShift, and emit the spill
817+
// at the right point (between mul and asr).
818+
if let Some(mem_out) = mem_out {
819+
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
820+
asm.store(mem_out, reg_out);
821+
}
822+
} else {
823+
if let Some(mem_out) = mem_out {
824+
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
825+
asm.store(mem_out, SCRATCH0_OPND);
826+
}
814827
}
815828
}
816829
Insn::LShift { opnd, out, .. } |
@@ -1239,30 +1252,48 @@ impl Assembler {
12391252
}
12401253
},
12411254
Insn::Mul { left, right, out } => {
1242-
// If the next instruction is JoMul with RShift created by arm64_scratch_split
1243-
match (insns.get(insn_idx + 1), insns.get(insn_idx + 2)) {
1244-
(Some(Insn::RShift { out: out_sign, opnd: out_opnd, shift: out_shift }), Some(Insn::JoMul(_))) => {
1245-
// Compute the high 64 bits
1246-
smulh(cb, Self::EMIT_OPND, left.into(), right.into());
1247-
1248-
// Compute the low 64 bits
1249-
// This may clobber one of the input registers,
1250-
// so we do it after smulh
1251-
mul(cb, out.into(), left.into(), right.into());
1252-
1253-
// Insert the shift instruction created by arm64_scratch_split
1254-
// to prepare the register that has the sign bit of the high 64 bits after mul.
1255-
asr(cb, out_sign.into(), out_opnd.into(), out_shift.into());
1256-
insn_idx += 1; // skip the next Insn::RShift
1257-
1258-
// If the high 64-bits are not all zeros or all ones,
1259-
// matching the sign bit, then we have an overflow
1260-
cmp(cb, Self::EMIT_OPND, out_sign.into());
1261-
// Insn::JoMul will emit_conditional_jump::<{Condition::NE}>
1255+
// Look for the RShift+JoMul overflow check sequence inserted
1256+
// by arm64_scratch_split. When the Mul output is spilled,
1257+
// scratch_split emits [Mul, RShift, Store, JoMul] with the
1258+
// Store after the RShift. Without a spill, it's just
1259+
// [Mul, RShift, JoMul].
1260+
let rshift_insn = match (insns.get(insn_idx + 1), insns.get(insn_idx + 2), insns.get(insn_idx + 3)) {
1261+
(Some(&Insn::RShift { out: out_sign, opnd: out_opnd, shift: out_shift }), Some(&Insn::Store { dest: spill_dest, src: spill_src }), Some(Insn::JoMul(_))) => {
1262+
Some((out_sign, out_opnd, out_shift, Some((spill_dest, spill_src))))
12621263
}
1263-
_ => {
1264-
mul(cb, out.into(), left.into(), right.into());
1264+
(Some(&Insn::RShift { out: out_sign, opnd: out_opnd, shift: out_shift }), Some(Insn::JoMul(_)), _) => {
1265+
Some((out_sign, out_opnd, out_shift, None))
1266+
}
1267+
_ => None,
1268+
};
1269+
1270+
if let Some((out_sign, out_opnd, out_shift, spill)) = rshift_insn {
1271+
// Compute the high 64 bits into EMIT_OPND (X16)
1272+
smulh(cb, Self::EMIT_OPND, left.into(), right.into());
1273+
1274+
// Compute the low 64 bits into `out` (may clobber inputs,
1275+
// so this must come after smulh)
1276+
mul(cb, out.into(), left.into(), right.into());
1277+
1278+
// If the mul result was spilled, emit the store now
1279+
// BEFORE asr clobbers the output register with the sign
1280+
// bit. The spill source is always a register (SCRATCH0),
1281+
// not EMIT_OPND (X16), so the smulh result is preserved.
1282+
if let Some((spill_dest, spill_src)) = spill {
1283+
stur(cb, spill_src.into(), spill_dest.into());
1284+
insn_idx += 1; // will skip the Store insn
12651285
}
1286+
1287+
// Shift to extract the sign bit of the 64-bit mul result
1288+
asr(cb, out_sign.into(), out_opnd.into(), out_shift.into());
1289+
insn_idx += 1; // skip the RShift
1290+
1291+
// If the high 64-bits are not all zeros or all ones,
1292+
// matching the sign bit, then we have an overflow
1293+
cmp(cb, Self::EMIT_OPND, out_sign.into());
1294+
// JoMul will emit_conditional_jump::<{Condition::NE}>
1295+
} else {
1296+
mul(cb, out.into(), left.into(), right.into());
12661297
}
12671298
},
12681299
Insn::And { left, right, out } => {
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# ARM64: Spurious FixnumMult overflow side-exits due to broken smulh/cmp pattern match
2+
3+
## Summary
4+
5+
A bug in the ARM64 backend causes `FixnumMult` to emit spurious overflow
6+
side-exits. The JIT bails to the interpreter on nearly every function call,
7+
reducing `ratio_in_zjit` from ~70% to ~4%. The root cause is a fragile
8+
instruction pattern match in the emit pass that silently fails when a spill
9+
Store is inserted between `Mul` and `RShift` by `arm64_scratch_split`.
10+
11+
The failure mode is **silent**: no crash, no assertion, no error. The `JoMul`
12+
conditional branch simply reads stale CPU condition flags from a prior
13+
instruction, and since those flags happen to say "not equal" in most cases, the
14+
JIT side-exits to the interpreter. This makes it a **performance cliff** that
15+
is invisible unless you check `--zjit-stats`.
16+
17+
In theory, if stale flags happened to indicate "equal" when a real overflow
18+
occurred, the JIT would silently produce wrong results. We have not observed
19+
this in practice but the possibility exists.
20+
21+
## Minimal reproducer
22+
23+
```ruby
24+
# frozen_string_literal: true
25+
# tmp/muloverflow.rb — two FixnumMult + getbyte + >>32 in a loop
26+
def repro(str)
27+
lo = 5381
28+
hi = 0
29+
i = 0
30+
len = str.bytesize
31+
while i < len
32+
prod_lo = lo * 33 + str.getbyte(i)
33+
carry = prod_lo >> 32
34+
lo = prod_lo & 0xFFFFFFFF
35+
hi = (hi * 5 + carry) & 0xFFFFFFFF
36+
i += 1
37+
end
38+
lo
39+
end
40+
41+
100.times { repro("hello world") }
42+
```
43+
44+
**Before fix:**
45+
```
46+
$ ruby --zjit-stats tmp/muloverflow.rb 2>&1 | grep -E "mult_overflow|ratio_in_zjit"
47+
fixnum_mult_overflow: 71 (100.0%)
48+
ratio_in_zjit: 3.7%
49+
```
50+
51+
**After fix:**
52+
```
53+
$ ruby --zjit-stats tmp/muloverflow.rb 2>&1 | grep -E "mult_overflow|ratio_in_zjit"
54+
side_exit_count: 0
55+
ratio_in_zjit: 69.1%
56+
```
57+
58+
## Conditions required to trigger
59+
60+
All of these must be present simultaneously:
61+
62+
1. **Two `FixnumMult` instructions** in the same basic block
63+
2. **A cfunc call** (like `String#getbyte`) in between — this creates enough
64+
register pressure to cause the Mul output to be spilled to a stack slot
65+
3. **A right-shift by 32** (`>> 32`) — adds more live values, increasing spill
66+
pressure
67+
68+
Removing any one of these conditions makes the bug disappear.
69+
70+
## Root cause
71+
72+
### ARM64 overflow detection for multiply
73+
74+
ARM64 has no overflow flag for multiplication. The standard technique is:
75+
76+
```asm
77+
smulh x16, x0, x1 ; signed multiply high: upper 64 bits
78+
mul x0, x0, x1 ; multiply: lower 64 bits
79+
asr x15, x0, #63 ; sign-extend the low result
80+
cmp x16, x15 ; if high bits != sign-extended low, overflow
81+
b.ne overflow_exit
82+
```
83+
84+
ZJIT implements this as a three-pass pipeline:
85+
86+
1. **HIR → LIR lowering** (`codegen.rs`): Emits `Mul` + `JoMul` instructions
87+
2. **arm64_scratch_split** (`arm64/mod.rs`): Inserts `RShift` between `Mul` and
88+
`JoMul` to prepare the sign bit for the comparison
89+
3. **arm64_emit** (`arm64/mod.rs`): Pattern-matches `[Mul, RShift, JoMul]` to
90+
fuse them into the `smulh`+`mul`+`asr`+`cmp` sequence
91+
92+
### The bug
93+
94+
When the Mul output register is spilled (allocated to a stack slot by
95+
`alloc_regs`), `arm64_scratch_split` also inserts a `Store` instruction to
96+
write the result to the stack. The code inserts the Store **before** the
97+
RShift, producing:
98+
99+
```
100+
Mul x15, x15, x17
101+
Store [x29 - 8], x15 ← spill
102+
RShift x15, x15, 63 ← sign bit extraction
103+
JoMul side_exit
104+
```
105+
106+
The emit pass checks `insns[idx+1]` and `insns[idx+2]` for `RShift` and
107+
`JoMul`:
108+
109+
```rust
110+
match (insns.get(insn_idx + 1), insns.get(insn_idx + 2)) {
111+
(Some(Insn::RShift { .. }), Some(Insn::JoMul(_))) => {
112+
// emit smulh + mul + asr + cmp
113+
}
114+
_ => {
115+
mul(cb, out, left, right); // ← NO smulh, NO cmp
116+
}
117+
}
118+
```
119+
120+
With the Store in between, `insns[idx+1]` is `Store`, not `RShift`. The
121+
pattern match **falls through to the else branch**, which emits only `mul`
122+
without `smulh` or `cmp`. The `RShift` is then emitted as a standalone `asr`,
123+
and `JoMul` emits `b.ne` — but `cmp` was never executed, so `b.ne` reads
124+
**stale condition flags** from whatever instruction last set them.
125+
126+
### Consequence
127+
128+
- **If stale flags = NE (common):** Spurious side-exit. The function drops to
129+
interpreter speed. ~20x perf regression on affected code.
130+
- **If stale flags = EQ during real overflow (rare):** Missed overflow. The
131+
mul result wraps without promoting to Bignum. **Silent wrong result.**
132+
133+
## Fix
134+
135+
The fix has two parts:
136+
137+
### 1. `arm64_scratch_split`: Reorder Store to after RShift when JoMul follows
138+
139+
When the next instruction is `JoMul`, emit the `RShift` immediately after
140+
`Mul`, and move the spill `Store` to after the `RShift`. The Store now writes
141+
from the Mul output register directly (rather than SCRATCH0, which was
142+
clobbered by the RShift):
143+
144+
```
145+
Mul x15, x15, x17 ← mul result in x15
146+
RShift x15, x15, 63 ← sign bit (clobbers x15)
147+
Store [x29 - 8], x15 ← stores sign bit, NOT mul result
148+
JoMul side_exit
149+
```
150+
151+
Wait — this stores the sign bit, not the mul result! So the emit pass must
152+
handle this.
153+
154+
### 2. `arm64_emit`: Handle `[Mul, RShift, Store, JoMul]` pattern
155+
156+
The emit pass now recognizes both patterns:
157+
158+
- `[Mul, RShift, JoMul]` — original (no spill)
159+
- `[Mul, RShift, Store, JoMul]` — with spill
160+
161+
For the spill case, the emitted ARM64 code is:
162+
163+
```asm
164+
smulh x16, x0, x1 ; high 64 bits
165+
mul x15, x0, x1 ; low 64 bits
166+
stur x15, [x29, #-8] ; spill mul result BEFORE asr clobbers x15
167+
asr x15, x15, #63 ; sign bit of low result
168+
cmp x16, x15 ; overflow check
169+
b.ne overflow_exit
170+
```
171+
172+
The `stur` is emitted between `mul` and `asr`, preserving the mul result in
173+
the spill slot before `asr` clobbers the register. The `smulh` result in X16
174+
is safe because `stur` with a register source doesn't touch X16.
175+
176+
## Bisection table
177+
178+
These tests were run on the pre-fix build to isolate the trigger conditions:
179+
180+
| Variant | Overflows? | ratio_in_zjit |
181+
|---------|-----------|---------------|
182+
| `lo * 33`, no hi, no getbyte | No | 69% |
183+
| `lo * 33 + getbyte`, no hi | No | 68% |
184+
| `lo * 33 + getbyte`, `>> 32`, no hi multiply | No | 69% |
185+
| `lo * 33 + getbyte`, `>> 32`, `hi * 33 + carry` (full) | **Yes** | 3.7% |
186+
| Same but `carry` computed but unused | **Yes** | 3.7% |
187+
| Same but `(lo << 5) + lo` instead of `lo * 33` | No | 68% |
188+
| `lo * 33 + (i & 0xFF)` instead of getbyte, with hi | No | 69% |
189+
| `hi * 33 + constant` in isolation | No | 69% |
190+
191+
## Lessons and recommendations
192+
193+
### 1. Avoid fragile peephole pattern matching across passes
194+
195+
The core issue is that `arm64_scratch_split` and `arm64_emit` communicate via
196+
an **implicit contract**: "RShift will be at idx+1 after Mul". When
197+
`scratch_split` inserts a Store, this contract is silently violated. Neither
198+
pass checks or asserts the invariant.
199+
200+
**Recommendation**: Use an explicit fused instruction (`MulWithOverflowCheck`)
201+
in the LIR instead of relying on Mul+RShift+JoMul being contiguous. This is
202+
what other compilers do:
203+
204+
- **LLVM**: Uses `llvm.smul.with.overflow` intrinsic — a single instruction
205+
that produces both the result and an overflow bit.
206+
- **Cranelift**: Uses `imul` + `trapif` where the overflow flag is an explicit
207+
operand, not a side effect.
208+
- **V8 TurboFan**: Uses `Int32MulWithOverflow` as a single node that lowers to
209+
the platform-specific sequence atomically.
210+
- **GCC**: The `-ftrapv` multiply overflow check is emitted as a single
211+
inseparable sequence during final code emission, never as separate
212+
matchable instructions.
213+
214+
### 2. Reject unknown instruction sequences
215+
216+
When the emit pass encounters `Mul` without the expected `RShift + JoMul`
217+
pattern, it falls through to a plain `mul`. But if there IS a `JoMul` later
218+
(just not at idx+2), the `JoMul` will execute with wrong flags. The else
219+
branch should either:
220+
221+
- **Panic**: "Mul followed by JoMul but RShift not at expected position"
222+
- **Emit a safe fallback**: `smulh` + `mul` + `asr` + `cmp` unconditionally
223+
224+
The current silent fallthrough is the worst option.
225+
226+
### 3. Add end-to-end overflow tests with register pressure
227+
228+
The existing tests only exercise `FixnumMult` in simple functions with low
229+
register pressure (where the output isn't spilled). A test with two multiplies
230+
and a cfunc call would have caught this immediately.
231+
232+
### 4. Consider separating "instruction lowering" from "instruction selection"
233+
234+
The ARM64 backend conflates two concerns:
235+
236+
- **Lowering**: Converting abstract LIR to concrete register operations,
237+
handling spills (scratch_split)
238+
- **Selection**: Fusing multiple LIR instructions into a single ARM64 sequence
239+
(emit)
240+
241+
Other compilers (LLVM, Cranelift) keep these clearly separated, with selection
242+
happening before register allocation. Post-regalloc passes only handle
243+
mechanical concerns (encoding, addressing modes) and never need to pattern
244+
match across multiple instructions.
245+
246+
## Files changed
247+
248+
- `zjit/src/backend/arm64/mod.rs` — scratch_split Mul handling + emit Mul
249+
pattern match
250+
251+
## Test results
252+
253+
- 1454 unit tests: all pass
254+
- 27 integration tests, 7605 assertions: all pass
255+
- Reproducer: 0 side exits, 69.1% ratio_in_zjit (was 71 exits, 3.7%)
256+
- DJB2 64-bit hash: correct results match interpreter

zjit/tmp/muloverflow.rb

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# frozen_string_literal: true
2+
# Minimal: two mults + getbyte + >> 32
3+
def repro(str)
4+
lo = 5381
5+
hi = 0
6+
i = 0
7+
len = str.bytesize
8+
while i < len
9+
prod_lo = lo * 33 + str.getbyte(i)
10+
carry = prod_lo >> 32
11+
lo = prod_lo & 0xFFFFFFFF
12+
hi = (hi * 5 + carry) & 0xFFFFFFFF
13+
i += 1
14+
end
15+
lo
16+
end
17+
18+
100.times { repro("hello world") }

0 commit comments

Comments
 (0)