Skip to content

Commit 38fc512

Browse files
committed
ZJIT: Minimize FixnumMult overflow reproducer
1 parent 0d09244 commit 38fc512

3 files changed

Lines changed: 52 additions & 90 deletions

File tree

zjit/src/backend/arm64/mod.rs

Lines changed: 31 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -799,31 +799,18 @@ impl Assembler {
799799
let mem_out = split_memory_write(out, SCRATCH0_OPND);
800800
let reg_out = out.clone();
801801

802-
let has_jo_mul = idx + 1 < linearized_insns.len() && matches!(linearized_insns[idx + 1], Insn::JoMul(_));
803-
804802
asm.push_insn(insn);
805803

806-
// When JoMul follows, the emit pass needs Mul → RShift → JoMul
807-
// to be contiguous so it can pair smulh+mul+asr+cmp. The spill
808-
// Store must NOT be between Mul and RShift. Instead, we record
809-
// the spill destination in the RShift and have the emit pass
810-
// emit the store between mul and asr (before asr clobbers the
811-
// mul output register).
812-
if has_jo_mul {
813-
// Emit RShift immediately after Mul (before any Store)
804+
if let Some(mem_out) = mem_out {
805+
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
806+
asm.store(mem_out, SCRATCH0_OPND);
807+
};
808+
809+
// If the next instruction is JoMul
810+
if idx + 1 < linearized_insns.len() && matches!(linearized_insns[idx + 1], Insn::JoMul(_)) {
811+
// Produce a register that is all zeros or all ones
812+
// Based on the sign bit of the 64-bit mul result
814813
asm.push_insn(Insn::RShift { out: SCRATCH0_OPND, opnd: reg_out, shift: Opnd::UImm(63) });
815-
// Emit spill Store after RShift. The emit pass will
816-
// skip it along with the RShift, and emit the spill
817-
// at the right point (between mul and asr).
818-
if let Some(mem_out) = mem_out {
819-
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
820-
asm.store(mem_out, reg_out);
821-
}
822-
} else {
823-
if let Some(mem_out) = mem_out {
824-
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
825-
asm.store(mem_out, SCRATCH0_OPND);
826-
}
827814
}
828815
}
829816
Insn::LShift { opnd, out, .. } |
@@ -1252,48 +1239,30 @@ impl Assembler {
12521239
}
12531240
},
12541241
Insn::Mul { left, right, out } => {
1255-
// Look for the RShift+JoMul overflow check sequence inserted
1256-
// by arm64_scratch_split. When the Mul output is spilled,
1257-
// scratch_split emits [Mul, RShift, Store, JoMul] with the
1258-
// Store after the RShift. Without a spill, it's just
1259-
// [Mul, RShift, JoMul].
1260-
let rshift_insn = match (insns.get(insn_idx + 1), insns.get(insn_idx + 2), insns.get(insn_idx + 3)) {
1261-
(Some(&Insn::RShift { out: out_sign, opnd: out_opnd, shift: out_shift }), Some(&Insn::Store { dest: spill_dest, src: spill_src }), Some(Insn::JoMul(_))) => {
1262-
Some((out_sign, out_opnd, out_shift, Some((spill_dest, spill_src))))
1263-
}
1264-
(Some(&Insn::RShift { out: out_sign, opnd: out_opnd, shift: out_shift }), Some(Insn::JoMul(_)), _) => {
1265-
Some((out_sign, out_opnd, out_shift, None))
1242+
// If the next instruction is JoMul with RShift created by arm64_scratch_split
1243+
match (insns.get(insn_idx + 1), insns.get(insn_idx + 2)) {
1244+
(Some(Insn::RShift { out: out_sign, opnd: out_opnd, shift: out_shift }), Some(Insn::JoMul(_))) => {
1245+
// Compute the high 64 bits
1246+
smulh(cb, Self::EMIT_OPND, left.into(), right.into());
1247+
1248+
// Compute the low 64 bits
1249+
// This may clobber one of the input registers,
1250+
// so we do it after smulh
1251+
mul(cb, out.into(), left.into(), right.into());
1252+
1253+
// Insert the shift instruction created by arm64_scratch_split
1254+
// to prepare the register that has the sign bit of the high 64 bits after mul.
1255+
asr(cb, out_sign.into(), out_opnd.into(), out_shift.into());
1256+
insn_idx += 1; // skip the next Insn::RShift
1257+
1258+
// If the high 64-bits are not all zeros or all ones,
1259+
// matching the sign bit, then we have an overflow
1260+
cmp(cb, Self::EMIT_OPND, out_sign.into());
1261+
// Insn::JoMul will emit_conditional_jump::<{Condition::NE}>
12661262
}
1267-
_ => None,
1268-
};
1269-
1270-
if let Some((out_sign, out_opnd, out_shift, spill)) = rshift_insn {
1271-
// Compute the high 64 bits into EMIT_OPND (X16)
1272-
smulh(cb, Self::EMIT_OPND, left.into(), right.into());
1273-
1274-
// Compute the low 64 bits into `out` (may clobber inputs,
1275-
// so this must come after smulh)
1276-
mul(cb, out.into(), left.into(), right.into());
1277-
1278-
// If the mul result was spilled, emit the store now
1279-
// BEFORE asr clobbers the output register with the sign
1280-
// bit. The spill source is always a register (SCRATCH0),
1281-
// not EMIT_OPND (X16), so the smulh result is preserved.
1282-
if let Some((spill_dest, spill_src)) = spill {
1283-
stur(cb, spill_src.into(), spill_dest.into());
1284-
insn_idx += 1; // will skip the Store insn
1263+
_ => {
1264+
mul(cb, out.into(), left.into(), right.into());
12851265
}
1286-
1287-
// Shift to extract the sign bit of the 64-bit mul result
1288-
asr(cb, out_sign.into(), out_opnd.into(), out_shift.into());
1289-
insn_idx += 1; // skip the RShift
1290-
1291-
// If the high 64-bits are not all zeros or all ones,
1292-
// matching the sign bit, then we have an overflow
1293-
cmp(cb, Self::EMIT_OPND, out_sign.into());
1294-
// JoMul will emit_conditional_jump::<{Condition::NE}>
1295-
} else {
1296-
mul(cb, out.into(), left.into(), right.into());
12971266
}
12981267
},
12991268
Insn::And { left, right, out } => {

zjit/tmp/REPORT-fixnum-mult-overflow.md

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,21 @@ this in practice but the possibility exists.
2222

2323
```ruby
2424
# frozen_string_literal: true
25-
# tmp/muloverflow.rb — two FixnumMult + getbyte + >>32 in a loop
26-
def repro(str)
27-
lo = 5381
28-
hi = 0
29-
i = 0
30-
len = str.bytesize
31-
while i < len
32-
prod_lo = lo * 33 + str.getbyte(i)
33-
carry = prod_lo >> 32
34-
lo = prod_lo & 0xFFFFFFFF
35-
hi = (hi * 5 + carry) & 0xFFFFFFFF
25+
def f(s)
26+
a = 0; b = 0; i = 0
27+
while i < s.bytesize
28+
a = a * 3 + s.getbyte(i)
29+
b = b * 3 + (a >> 32)
3630
i += 1
3731
end
38-
lo
32+
a
3933
end
40-
41-
100.times { repro("hello world") }
34+
100.times { f("x") }
4235
```
4336

37+
The trigger requires: two `FixnumMult` in one block, a cfunc call (`getbyte`)
38+
for register pressure, and enough live values (`>> 32`) to force a spill.
39+
4440
**Before fix:**
4541
```
4642
$ ruby --zjit-stats tmp/muloverflow.rb 2>&1 | grep -E "mult_overflow|ratio_in_zjit"

zjit/tmp/muloverflow.rb

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
# frozen_string_literal: true
2-
# Minimal: two mults + getbyte + >> 32
3-
def repro(str)
4-
lo = 5381
5-
hi = 0
6-
i = 0
7-
len = str.bytesize
8-
while i < len
9-
prod_lo = lo * 33 + str.getbyte(i)
10-
carry = prod_lo >> 32
11-
lo = prod_lo & 0xFFFFFFFF
12-
hi = (hi * 5 + carry) & 0xFFFFFFFF
2+
# Minimal reproducer: ARM64 FixnumMult spurious overflow side-exits.
3+
# Needs: two multiplies, getbyte (for register pressure), and >>32.
4+
# Before fix: fixnum_mult_overflow: 71, ratio_in_zjit: 3.7%
5+
# After fix: side_exit_count: 0, ratio_in_zjit: 69%
6+
def f(s)
7+
a = 0; b = 0; i = 0
8+
while i < s.bytesize
9+
a = a * 3 + s.getbyte(i)
10+
b = b * 3 + (a >> 32)
1311
i += 1
1412
end
15-
lo
13+
a
1614
end
17-
18-
100.times { repro("hello world") }
15+
100.times { f("x") }

0 commit comments

Comments
 (0)