Skip to content

Commit b4a2de3

Browse files
committed
ZJIT: Replace fragile Mul+RShift+JoMul pattern with speculative smulh
Instead of scratch_split inserting an RShift between Mul and JoMul (which broke when a spill Store disrupted the pattern), make each instruction emit independently: - Mul always speculatively emits smulh into X16 before mul - JoMul carries the Mul output operand and emits a barrel-shifted cmp (CMP X16, val, ASR Shopify#62) to check overflow in one instruction - No cross-pass coordination, no pattern matching, no synthetic RShift Also adds cmp_shifted (CMP with ASR) to the ARM64 assembler.
1 parent f52a1f2 commit b4a2de3

6 files changed

Lines changed: 77 additions & 86 deletions

File tree

zjit/src/asm/arm64/inst/data_reg.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,26 @@ impl DataReg {
9292
Self::subs(31, rn, rm, num_bits)
9393
}
9494

95+
/// CMP (shifted register) with explicit shift
96+
/// Encodes: CMP <Xn>, <Xm>, <shift> #<amount>
97+
pub fn cmp_shifted(rn: u8, rm: u8, shift: u8, amount: u8, num_bits: u8) -> Self {
98+
Self {
99+
rd: 31,
100+
rn,
101+
imm6: amount,
102+
rm,
103+
shift: match shift {
104+
0b00 => Shift::LSL,
105+
0b01 => Shift::LSR,
106+
0b10 => Shift::ASR,
107+
_ => panic!("Invalid shift type"),
108+
},
109+
s: S::UpdateFlags,
110+
op: Op::Sub,
111+
sf: num_bits.into()
112+
}
113+
}
114+
95115
/// SUB (shifted register)
96116
/// <https://developer.arm.com/documentation/ddi0596/2021-12/Base-Instructions/SUB--shifted-register---Subtract--shifted-register--?lang=en>
97117
pub fn sub(rd: u8, rn: u8, rm: u8, num_bits: u8) -> Self {

zjit/src/asm/arm64/mod.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,20 @@ pub fn cmp(cb: &mut CodeBlock, rn: A64Opnd, rm: A64Opnd) {
344344
cb.write_bytes(&bytes);
345345
}
346346

347+
/// CMP (shifted register) - compare with shifted second operand
348+
/// shift: 0b00=LSL, 0b01=LSR, 0b10=ASR
349+
pub fn cmp_shifted(cb: &mut CodeBlock, rn: A64Opnd, rm: A64Opnd, shift: u8, amount: u8) {
350+
let bytes: [u8; 4] = match (rn, rm) {
351+
(A64Opnd::Reg(rn), A64Opnd::Reg(rm)) => {
352+
assert!(rn.num_bits == rm.num_bits, "All operands must be of the same size.");
353+
DataReg::cmp_shifted(rn.reg_no, rm.reg_no, shift, amount, rn.num_bits).into()
354+
},
355+
_ => panic!("Invalid operand combination to cmp_shifted instruction."),
356+
};
357+
358+
cb.write_bytes(&bytes);
359+
}
360+
347361
/// CSEL - conditionally select between two registers
348362
pub fn csel(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, rm: A64Opnd, cond: u8) {
349363
let bytes: [u8; 4] = match (rd, rn, rm) {

zjit/src/backend/arm64/mod.rs

Lines changed: 26 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -797,33 +797,12 @@ impl Assembler {
797797
*left = split_memory_read(asm, *left, SCRATCH0_OPND);
798798
*right = split_memory_read(asm, *right, SCRATCH1_OPND);
799799
let mem_out = split_memory_write(out, SCRATCH0_OPND);
800-
let reg_out = out.clone();
801-
802-
let has_jo_mul = idx + 1 < linearized_insns.len() && matches!(linearized_insns[idx + 1], Insn::JoMul(_));
803800

804801
asm.push_insn(insn);
805802

806-
// When JoMul follows, the emit pass needs Mul → RShift → JoMul
807-
// to be contiguous so it can pair smulh+mul+asr+cmp. The spill
808-
// Store must NOT be between Mul and RShift. Instead, we record
809-
// the spill destination in the RShift and have the emit pass
810-
// emit the store between mul and asr (before asr clobbers the
811-
// mul output register).
812-
if has_jo_mul {
813-
// Emit RShift immediately after Mul (before any Store)
814-
asm.push_insn(Insn::RShift { out: SCRATCH0_OPND, opnd: reg_out, shift: Opnd::UImm(63) });
815-
// Emit spill Store after RShift. The emit pass will
816-
// skip it along with the RShift, and emit the spill
817-
// at the right point (between mul and asr).
818-
if let Some(mem_out) = mem_out {
819-
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
820-
asm.store(mem_out, reg_out);
821-
}
822-
} else {
823-
if let Some(mem_out) = mem_out {
824-
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
825-
asm.store(mem_out, SCRATCH0_OPND);
826-
}
803+
if let Some(mem_out) = mem_out {
804+
let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND);
805+
asm.store(mem_out, SCRATCH0_OPND);
827806
}
828807
}
829808
Insn::LShift { opnd, out, .. } |
@@ -928,6 +907,10 @@ impl Assembler {
928907
}
929908
}
930909
}
910+
Insn::JoMul(opnd, _) => {
911+
*opnd = split_memory_read(asm, *opnd, SCRATCH0_OPND);
912+
asm.push_insn(insn);
913+
}
931914
&mut Insn::PatchPoint { ref target, invariant, version } => {
932915
split_patch_point(asm, target, invariant, version);
933916
}
@@ -1252,49 +1235,12 @@ impl Assembler {
12521235
}
12531236
},
12541237
Insn::Mul { left, right, out } => {
1255-
// Look for the RShift+JoMul overflow check sequence inserted
1256-
// by arm64_scratch_split. When the Mul output is spilled,
1257-
// scratch_split emits [Mul, RShift, Store, JoMul] with the
1258-
// Store after the RShift. Without a spill, it's just
1259-
// [Mul, RShift, JoMul].
1260-
let rshift_insn = match (insns.get(insn_idx + 1), insns.get(insn_idx + 2), insns.get(insn_idx + 3)) {
1261-
(Some(&Insn::RShift { out: out_sign, opnd: out_opnd, shift: out_shift }), Some(&Insn::Store { dest: spill_dest, src: spill_src }), Some(Insn::JoMul(_))) => {
1262-
Some((out_sign, out_opnd, out_shift, Some((spill_dest, spill_src))))
1263-
}
1264-
(Some(&Insn::RShift { out: out_sign, opnd: out_opnd, shift: out_shift }), Some(Insn::JoMul(_)), _) => {
1265-
Some((out_sign, out_opnd, out_shift, None))
1266-
}
1267-
_ => None,
1268-
};
1269-
1270-
if let Some((out_sign, out_opnd, out_shift, spill)) = rshift_insn {
1271-
// Compute the high 64 bits into EMIT_OPND (X16)
1272-
smulh(cb, Self::EMIT_OPND, left.into(), right.into());
1273-
1274-
// Compute the low 64 bits into `out` (may clobber inputs,
1275-
// so this must come after smulh)
1276-
mul(cb, out.into(), left.into(), right.into());
1277-
1278-
// If the mul result was spilled, emit the store now
1279-
// BEFORE asr clobbers the output register with the sign
1280-
// bit. The spill source is always a register (SCRATCH0),
1281-
// not EMIT_OPND (X16), so the smulh result is preserved.
1282-
if let Some((spill_dest, spill_src)) = spill {
1283-
stur(cb, spill_src.into(), spill_dest.into());
1284-
insn_idx += 1; // will skip the Store insn
1285-
}
1286-
1287-
// Shift to extract the sign bit of the 64-bit mul result
1288-
asr(cb, out_sign.into(), out_opnd.into(), out_shift.into());
1289-
insn_idx += 1; // skip the RShift
1290-
1291-
// If the high 64-bits are not all zeros or all ones,
1292-
// matching the sign bit, then we have an overflow
1293-
cmp(cb, Self::EMIT_OPND, out_sign.into());
1294-
// JoMul will emit_conditional_jump::<{Condition::NE}>
1295-
} else {
1296-
mul(cb, out.into(), left.into(), right.into());
1297-
}
1238+
// Speculatively emit smulh into EMIT_OPND (X16) for a
1239+
// potential following JoMul. If no JoMul follows, X16 is
1240+
// simply overwritten later. Must come before mul since mul
1241+
// may clobber an input register.
1242+
smulh(cb, Self::EMIT_OPND, left.into(), right.into());
1243+
mul(cb, out.into(), left.into(), right.into());
12981244
},
12991245
Insn::And { left, right, out } => {
13001246
and(cb, out.into(), left.into(), right.into());
@@ -1558,7 +1504,14 @@ impl Assembler {
15581504
Insn::Je(target) | Insn::Jz(target) => {
15591505
emit_conditional_jump::<{Condition::EQ}>(self, cb, target.clone());
15601506
},
1561-
Insn::Jne(target) | Insn::Jnz(target) | Insn::JoMul(target) => {
1507+
Insn::Jne(target) | Insn::Jnz(target) => {
1508+
emit_conditional_jump::<{Condition::NE}>(self, cb, target.clone());
1509+
},
1510+
Insn::JoMul(val, target) => {
1511+
// Compare smulh result (in EMIT_OPND/X16 from preceding Mul)
1512+
// with the mul output sign-extended from bit 62. Uses the
1513+
// barrel shifter built into CMP for a single instruction.
1514+
cmp_shifted(cb, Self::EMIT_OPND, val.into(), 0b10, 62); // ASR #62
15621515
emit_conditional_jump::<{Condition::NE}>(self, cb, target.clone());
15631516
},
15641517
Insn::Jl(target) => {
@@ -1809,11 +1762,12 @@ mod tests {
18091762
asm.compile_with_num_regs(&mut cb, 2);
18101763

18111764
assert_disasm_snapshot!(cb.disasm(), @"
1812-
0x0: mov x0, #3
1813-
0x4: mul x0, x9, x0
1814-
0x8: mov x1, x0
1765+
0x0: mov x0, #3
1766+
0x4: smulh x16, x9, x0
1767+
0x8: mul x0, x9, x0
1768+
0xc: mov x1, x0
18151769
");
1816-
assert_snapshot!(cb.hexdump(), @"600080d2207d009be10300aa");
1770+
assert_snapshot!(cb.hexdump(), @"600080d2307d409b207d009be10300aa");
18171771
}
18181772

18191773
#[test]

zjit/src/backend/lir.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -621,8 +621,9 @@ pub enum Insn {
621621
/// Jump if overflow
622622
Jo(Target),
623623

624-
/// Jump if overflow in multiplication
625-
JoMul(Target),
624+
/// Jump if overflow in multiplication.
625+
/// The operand is the Mul output, used on ARM64 for the barrel-shifted compare.
626+
JoMul(Opnd, Target),
626627

627628
/// Jump if zero
628629
Jz(Target),
@@ -734,7 +735,7 @@ impl Insn {
734735
Insn::Jne(target) |
735736
Insn::Jnz(target) |
736737
Insn::Jo(target) |
737-
Insn::JoMul(target) |
738+
Insn::JoMul(_, target) |
738739
Insn::Jz(target) |
739740
Insn::Joz(_, target) |
740741
Insn::Jonz(_, target) |
@@ -786,7 +787,7 @@ impl Insn {
786787
Insn::Jne(_) => "Jne",
787788
Insn::Jnz(_) => "Jnz",
788789
Insn::Jo(_) => "Jo",
789-
Insn::JoMul(_) => "JoMul",
790+
Insn::JoMul(..) => "JoMul",
790791
Insn::Jz(_) => "Jz",
791792
Insn::Joz(..) => "Joz",
792793
Insn::Jonz(..) => "Jonz",
@@ -894,7 +895,7 @@ impl Insn {
894895
Insn::Jne(target) |
895896
Insn::Jnz(target) |
896897
Insn::Jo(target) |
897-
Insn::JoMul(target) |
898+
Insn::JoMul(_, target) |
898899
Insn::Jz(target) |
899900
Insn::Joz(_, target) |
900901
Insn::Jonz(_, target) |
@@ -928,7 +929,7 @@ impl Insn {
928929
Insn::Jne(_) |
929930
Insn::Jnz(_) |
930931
Insn::Jo(_) |
931-
Insn::JoMul(_) |
932+
Insn::JoMul(..) |
932933
Insn::Jz(_) |
933934
Insn::Joz(..) |
934935
Insn::Jonz(..) |
@@ -966,7 +967,7 @@ impl<'a> Iterator for InsnOpndIterator<'a> {
966967
Insn::Jne(target) |
967968
Insn::Jnz(target) |
968969
Insn::Jo(target) |
969-
Insn::JoMul(target) |
970+
Insn::JoMul(_, target) |
970971
Insn::Jz(target) |
971972
Insn::Label(target) |
972973
Insn::LeaJumpTarget { target, .. } |
@@ -1158,7 +1159,6 @@ impl<'a> InsnOpndMutIterator<'a> {
11581159
Insn::Jne(target) |
11591160
Insn::Jnz(target) |
11601161
Insn::Jo(target) |
1161-
Insn::JoMul(target) |
11621162
Insn::Jz(target) |
11631163
Insn::Label(target) |
11641164
Insn::LeaJumpTarget { target, .. } |
@@ -1192,6 +1192,7 @@ impl<'a> InsnOpndMutIterator<'a> {
11921192
}
11931193
}
11941194

1195+
Insn::JoMul(opnd, target) |
11951196
Insn::Joz(opnd, target) |
11961197
Insn::Jonz(opnd, target) => {
11971198
if self.idx == 0 {
@@ -1798,7 +1799,7 @@ impl Assembler
17981799
Insn::Jbe(Target::Block(edge)) => Insn::Jbe(Target::Label(process_edge(edge))),
17991800
Insn::Jb(Target::Block(edge)) => Insn::Jb(Target::Label(process_edge(edge))),
18001801
Insn::Jo(Target::Block(edge)) => Insn::Jo(Target::Label(process_edge(edge))),
1801-
Insn::JoMul(Target::Block(edge)) => Insn::JoMul(Target::Label(process_edge(edge))),
1802+
Insn::JoMul(opnd, Target::Block(edge)) => Insn::JoMul(*opnd, Target::Label(process_edge(edge))),
18021803
Insn::Joz(opnd, Target::Block(edge)) => Insn::Joz(*opnd, Target::Label(process_edge(edge))),
18031804
Insn::Jonz(opnd, Target::Block(edge)) => Insn::Jonz(*opnd, Target::Label(process_edge(edge))),
18041805
_ => insn.clone()
@@ -2452,6 +2453,7 @@ impl fmt::Display for Assembler {
24522453
// If the instruction has a SideExit, avoid using opnd_iter(), which has stack/locals.
24532454
// Here, only handle instructions that have both Opnd and Target.
24542455
match insn {
2456+
Insn::JoMul(opnd, _) |
24552457
Insn::Joz(opnd, _) |
24562458
Insn::Jonz(opnd, _) |
24572459
Insn::LeaJumpTarget { out: opnd, target: _ } => {
@@ -2463,6 +2465,7 @@ impl fmt::Display for Assembler {
24632465
// If the instruction has a Block target, avoid using opnd_iter() for branch args
24642466
// since they're already printed inline with the target. Only print non-target operands.
24652467
match insn {
2468+
Insn::JoMul(opnd, _) |
24662469
Insn::Joz(opnd, _) |
24672470
Insn::Jonz(opnd, _) |
24682471
Insn::LeaJumpTarget { out: opnd, target: _ } => {
@@ -2787,8 +2790,8 @@ impl Assembler {
27872790
self.push_insn(Insn::Jo(target));
27882791
}
27892792

2790-
pub fn jo_mul(&mut self, target: Target) {
2791-
self.push_insn(Insn::JoMul(target));
2793+
pub fn jo_mul(&mut self, val: Opnd, target: Target) {
2794+
self.push_insn(Insn::JoMul(val, target));
27922795
}
27932796

27942797
pub fn jz(&mut self, target: Target) {

zjit/src/backend/x86_64/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,7 @@ impl Assembler {
10071007
}
10081008

10091009
Insn::Jo(target) |
1010-
Insn::JoMul(target) => {
1010+
Insn::JoMul(_, target) => {
10111011
match *target {
10121012
Target::CodePtr(code_ptr) => jo_ptr(cb, code_ptr),
10131013
Target::Label(label) => jo_label(cb, label),

zjit/src/codegen.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1984,8 +1984,8 @@ fn gen_fixnum_mult(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, rig
19841984
let right_untag = asm.sub(right, Opnd::UImm(1));
19851985
let out_val = asm.mul(left_untag, right_untag);
19861986

1987-
// Test for overflow
1988-
asm.jo_mul(side_exit(jit, state, FixnumMultOverflow));
1987+
// Test for overflow (on ARM64, JoMul uses out_val for barrel-shifted cmp)
1988+
asm.jo_mul(out_val, side_exit(jit, state, FixnumMultOverflow));
19891989
asm.add(out_val, Opnd::UImm(1))
19901990
}
19911991

0 commit comments

Comments
 (0)