Skip to content

Commit f52a1f2

Browse files
committed
ZJIT: Restore fix and add no-loop reproducer for FixnumMult overflow
1 parent 38fc512 commit f52a1f2

2 files changed

Lines changed: 72 additions & 42 deletions

File tree

zjit/src/backend/arm64/mod.rs

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -799,18 +799,31 @@ impl Assembler {
799799
let mem_out = split_memory_write(out, SCRATCH0_OPND);
800800
let reg_out = out.clone();
801801

802-
asm.push_insn(insn);
802+
let has_jo_mul = idx + 1 < linearized_insns.len() && matches!(linearized_insns[idx + 1], Insn::JoMul(_));
803803

804-
if let Some(mem_out) = mem_out {
805-
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
806-
asm.store(mem_out, SCRATCH0_OPND);
807-
};
804+
asm.push_insn(insn);
808805

809-
// If the next instruction is JoMul
810-
if idx + 1 < linearized_insns.len() && matches!(linearized_insns[idx + 1], Insn::JoMul(_)) {
811-
// Produce a register that is all zeros or all ones
812-
// Based on the sign bit of the 64-bit mul result
806+
// When JoMul follows, the emit pass needs Mul → RShift → JoMul
807+
// to be contiguous so it can pair smulh+mul+asr+cmp. The spill
808+
// Store must NOT be between Mul and RShift. Instead, we record
809+
// the spill destination in the RShift and have the emit pass
810+
// emit the store between mul and asr (before asr clobbers the
811+
// mul output register).
812+
if has_jo_mul {
813+
// Emit RShift immediately after Mul (before any Store)
813814
asm.push_insn(Insn::RShift { out: SCRATCH0_OPND, opnd: reg_out, shift: Opnd::UImm(63) });
815+
// Emit spill Store after RShift. The emit pass will
816+
// skip it along with the RShift, and emit the spill
817+
// at the right point (between mul and asr).
818+
if let Some(mem_out) = mem_out {
819+
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
820+
asm.store(mem_out, reg_out);
821+
}
822+
} else {
823+
if let Some(mem_out) = mem_out {
824+
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
825+
asm.store(mem_out, SCRATCH0_OPND);
826+
}
814827
}
815828
}
816829
Insn::LShift { opnd, out, .. } |
@@ -1239,30 +1252,48 @@ impl Assembler {
12391252
}
12401253
},
12411254
Insn::Mul { left, right, out } => {
1242-
// If the next instruction is JoMul with RShift created by arm64_scratch_split
1243-
match (insns.get(insn_idx + 1), insns.get(insn_idx + 2)) {
1244-
(Some(Insn::RShift { out: out_sign, opnd: out_opnd, shift: out_shift }), Some(Insn::JoMul(_))) => {
1245-
// Compute the high 64 bits
1246-
smulh(cb, Self::EMIT_OPND, left.into(), right.into());
1247-
1248-
// Compute the low 64 bits
1249-
// This may clobber one of the input registers,
1250-
// so we do it after smulh
1251-
mul(cb, out.into(), left.into(), right.into());
1252-
1253-
// Insert the shift instruction created by arm64_scratch_split
1254-
// to prepare the register that has the sign bit of the high 64 bits after mul.
1255-
asr(cb, out_sign.into(), out_opnd.into(), out_shift.into());
1256-
insn_idx += 1; // skip the next Insn::RShift
1257-
1258-
// If the high 64-bits are not all zeros or all ones,
1259-
// matching the sign bit, then we have an overflow
1260-
cmp(cb, Self::EMIT_OPND, out_sign.into());
1261-
// Insn::JoMul will emit_conditional_jump::<{Condition::NE}>
1255+
// Look for the RShift+JoMul overflow check sequence inserted
1256+
// by arm64_scratch_split. When the Mul output is spilled,
1257+
// scratch_split emits [Mul, RShift, Store, JoMul] with the
1258+
// Store after the RShift. Without a spill, it's just
1259+
// [Mul, RShift, JoMul].
1260+
let rshift_insn = match (insns.get(insn_idx + 1), insns.get(insn_idx + 2), insns.get(insn_idx + 3)) {
1261+
(Some(&Insn::RShift { out: out_sign, opnd: out_opnd, shift: out_shift }), Some(&Insn::Store { dest: spill_dest, src: spill_src }), Some(Insn::JoMul(_))) => {
1262+
Some((out_sign, out_opnd, out_shift, Some((spill_dest, spill_src))))
12621263
}
1263-
_ => {
1264-
mul(cb, out.into(), left.into(), right.into());
1264+
(Some(&Insn::RShift { out: out_sign, opnd: out_opnd, shift: out_shift }), Some(Insn::JoMul(_)), _) => {
1265+
Some((out_sign, out_opnd, out_shift, None))
1266+
}
1267+
_ => None,
1268+
};
1269+
1270+
if let Some((out_sign, out_opnd, out_shift, spill)) = rshift_insn {
1271+
// Compute the high 64 bits into EMIT_OPND (X16)
1272+
smulh(cb, Self::EMIT_OPND, left.into(), right.into());
1273+
1274+
// Compute the low 64 bits into `out` (may clobber inputs,
1275+
// so this must come after smulh)
1276+
mul(cb, out.into(), left.into(), right.into());
1277+
1278+
// If the mul result was spilled, emit the store now
1279+
// BEFORE asr clobbers the output register with the sign
1280+
// bit. The spill source is always a register (SCRATCH0),
1281+
// not EMIT_OPND (X16), so the smulh result is preserved.
1282+
if let Some((spill_dest, spill_src)) = spill {
1283+
stur(cb, spill_src.into(), spill_dest.into());
1284+
insn_idx += 1; // will skip the Store insn
12651285
}
1286+
1287+
// Shift to extract the sign bit of the 64-bit mul result
1288+
asr(cb, out_sign.into(), out_opnd.into(), out_shift.into());
1289+
insn_idx += 1; // skip the RShift
1290+
1291+
// If the high 64-bits are not all zeros or all ones,
1292+
// matching the sign bit, then we have an overflow
1293+
cmp(cb, Self::EMIT_OPND, out_sign.into());
1294+
// JoMul will emit_conditional_jump::<{Condition::NE}>
1295+
} else {
1296+
mul(cb, out.into(), left.into(), right.into());
12661297
}
12671298
},
12681299
Insn::And { left, right, out } => {

zjit/tmp/muloverflow.rb

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
# frozen_string_literal: true
2-
# Minimal reproducer: ARM64 FixnumMult spurious overflow side-exits.
3-
# Needs: two multiplies, getbyte (for register pressure), and >>32.
4-
# Before fix: fixnum_mult_overflow: 71, ratio_in_zjit: 3.7%
2+
# Minimal no-loop reproducer for ARM64 FixnumMult spurious overflow.
3+
# 7 getbyte calls exhaust registers, forcing Mul output to spill.
4+
# Before fix: fixnum_mult_overflow: 71, ratio_in_zjit: 35%
55
# After fix: side_exit_count: 0, ratio_in_zjit: 69%
66
def f(s)
7-
a = 0; b = 0; i = 0
8-
while i < s.bytesize
9-
a = a * 3 + s.getbyte(i)
10-
b = b * 3 + (a >> 32)
11-
i += 1
12-
end
13-
a
7+
v0 = s.getbyte(0); v1 = s.getbyte(1); v2 = s.getbyte(2)
8+
v3 = s.getbyte(3); v4 = s.getbyte(4); v5 = s.getbyte(5)
9+
v6 = s.getbyte(6)
10+
a = v0 * 3 + v1
11+
b = a * 3 + (a >> 32)
12+
a + b + v2 + v3 + v4 + v5 + v6
1413
end
15-
100.times { f("x") }
14+
100.times { f("hello!!") }

0 commit comments

Comments
 (0)