Skip to content

Commit e7b5202

Browse files
committed
ZJIT: Add MulHighBits LIR instruction, eliminate implicit X16
Replace the speculative smulh emission (Mul always writing X16) with an explicit MulHighBits instruction that the register allocator handles like any other value. JoMul now takes both the high and low operands, and on ARM64 emits CMP high, low, ASR Shopify#62 using the barrel shifter. No implicit register contracts between instructions. The full multiply overflow sequence is: MulHighBits → Mul → JoMul, each independently emitted, each with explicit register-allocated operands. On x86, MulHighBits emits nothing (imul sets OF directly).
1 parent b4a2de3 commit e7b5202

4 files changed

Lines changed: 94 additions & 32 deletions

File tree

zjit/src/backend/arm64/mod.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,8 @@ impl Assembler {
664664
*opnd = split_load_operand(asm, *opnd);
665665
asm.push_insn(insn);
666666
},
667-
Insn::Mul { left, right, .. } => {
667+
Insn::Mul { left, right, .. } |
668+
Insn::MulHighBits { left, right, .. } => {
668669
*left = split_load_operand(asm, *left);
669670
*right = split_load_operand(asm, *right);
670671
asm.push_insn(insn);
@@ -793,7 +794,8 @@ impl Assembler {
793794
asm.store(mem_out, SCRATCH0_OPND);
794795
}
795796
}
796-
Insn::Mul { left, right, out } => {
797+
Insn::Mul { left, right, out } |
798+
Insn::MulHighBits { left, right, out } => {
797799
*left = split_memory_read(asm, *left, SCRATCH0_OPND);
798800
*right = split_memory_read(asm, *right, SCRATCH1_OPND);
799801
let mem_out = split_memory_write(out, SCRATCH0_OPND);
@@ -907,8 +909,9 @@ impl Assembler {
907909
}
908910
}
909911
}
910-
Insn::JoMul(opnd, _) => {
911-
*opnd = split_memory_read(asm, *opnd, SCRATCH0_OPND);
912+
Insn::JoMul(high, low, _) => {
913+
*high = split_memory_read(asm, *high, SCRATCH0_OPND);
914+
*low = split_memory_read(asm, *low, SCRATCH1_OPND);
912915
asm.push_insn(insn);
913916
}
914917
&mut Insn::PatchPoint { ref target, invariant, version } => {
@@ -1235,13 +1238,11 @@ impl Assembler {
12351238
}
12361239
},
12371240
Insn::Mul { left, right, out } => {
1238-
// Speculatively emit smulh into EMIT_OPND (X16) for a
1239-
// potential following JoMul. If no JoMul follows, X16 is
1240-
// simply overwritten later. Must come before mul since mul
1241-
// may clobber an input register.
1242-
smulh(cb, Self::EMIT_OPND, left.into(), right.into());
12431241
mul(cb, out.into(), left.into(), right.into());
12441242
},
1243+
Insn::MulHighBits { left, right, out } => {
1244+
smulh(cb, out.into(), left.into(), right.into());
1245+
},
12451246
Insn::And { left, right, out } => {
12461247
and(cb, out.into(), left.into(), right.into());
12471248
},
@@ -1507,11 +1508,11 @@ impl Assembler {
15071508
Insn::Jne(target) | Insn::Jnz(target) => {
15081509
emit_conditional_jump::<{Condition::NE}>(self, cb, target.clone());
15091510
},
1510-
Insn::JoMul(val, target) => {
1511-
// Compare smulh result (in EMIT_OPND/X16 from preceding Mul)
1512-
// with the mul output sign-extended from bit 62. Uses the
1513-
// barrel shifter built into CMP for a single instruction.
1514-
cmp_shifted(cb, Self::EMIT_OPND, val.into(), 0b10, 62); // ASR #62
1511+
Insn::JoMul(high, low, target) => {
1512+
// Compare MulHighBits result with the Mul output
1513+
// sign-extended from bit 62. Uses the barrel shifter
1514+
// built into CMP for a single instruction.
1515+
cmp_shifted(cb, high.into(), low.into(), 0b10, 62); // ASR #62
15151516
emit_conditional_jump::<{Condition::NE}>(self, cb, target.clone());
15161517
},
15171518
Insn::Jl(target) => {
@@ -1763,11 +1764,10 @@ mod tests {
17631764

17641765
assert_disasm_snapshot!(cb.disasm(), @"
17651766
0x0: mov x0, #3
1766-
0x4: smulh x16, x9, x0
1767-
0x8: mul x0, x9, x0
1768-
0xc: mov x1, x0
1767+
0x4: mul x0, x9, x0
1768+
0x8: mov x1, x0
17691769
");
1770-
assert_snapshot!(cb.hexdump(), @"600080d2307d409b207d009be10300aa");
1770+
assert_snapshot!(cb.hexdump(), @"600080d2207d009be10300aa");
17711771
}
17721772

17731773
#[test]

zjit/src/backend/lir.rs

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -622,8 +622,9 @@ pub enum Insn {
622622
Jo(Target),
623623

624624
/// Jump if overflow in multiplication.
625-
/// The operand is the Mul output, used on ARM64 for the barrel-shifted compare.
626-
JoMul(Opnd, Target),
625+
/// First operand is MulHighBits output, second is Mul output.
626+
/// On ARM64, emits CMP high, low, ASR #62. On x86, emits jo (ignores both).
627+
JoMul(Opnd, Opnd, Target),
627628

628629
/// Jump if zero
629630
Jz(Target),
@@ -698,6 +699,9 @@ pub enum Insn {
698699
// Integer multiplication
699700
Mul { left: Opnd, right: Opnd, out: Opnd },
700701

702+
// High 64 bits of a signed 128-bit multiply (smulh on ARM64, no-op on x86)
703+
MulHighBits { left: Opnd, right: Opnd, out: Opnd },
704+
701705
// Bitwise AND test instruction
702706
Test { left: Opnd, right: Opnd },
703707

@@ -735,7 +739,7 @@ impl Insn {
735739
Insn::Jne(target) |
736740
Insn::Jnz(target) |
737741
Insn::Jo(target) |
738-
Insn::JoMul(_, target) |
742+
Insn::JoMul(_, _, target) |
739743
Insn::Jz(target) |
740744
Insn::Joz(_, target) |
741745
Insn::Jonz(_, target) |
@@ -810,6 +814,7 @@ impl Insn {
810814
Insn::Store { .. } => "Store",
811815
Insn::Sub { .. } => "Sub",
812816
Insn::Mul { .. } => "Mul",
817+
Insn::MulHighBits { .. } => "MulHighBits",
813818
Insn::Test { .. } => "Test",
814819
Insn::URShift { .. } => "URShift",
815820
Insn::Xor { .. } => "Xor"
@@ -843,6 +848,7 @@ impl Insn {
843848
Insn::RShift { out, .. } |
844849
Insn::Sub { out, .. } |
845850
Insn::Mul { out, .. } |
851+
Insn::MulHighBits { out, .. } |
846852
Insn::URShift { out, .. } |
847853
Insn::Xor { out, .. } => Some(out),
848854
_ => None
@@ -876,6 +882,7 @@ impl Insn {
876882
Insn::RShift { out, .. } |
877883
Insn::Sub { out, .. } |
878884
Insn::Mul { out, .. } |
885+
Insn::MulHighBits { out, .. } |
879886
Insn::URShift { out, .. } |
880887
Insn::Xor { out, .. } => Some(out),
881888
_ => None
@@ -895,7 +902,7 @@ impl Insn {
895902
Insn::Jne(target) |
896903
Insn::Jnz(target) |
897904
Insn::Jo(target) |
898-
Insn::JoMul(_, target) |
905+
Insn::JoMul(_, _, target) |
899906
Insn::Jz(target) |
900907
Insn::Joz(_, target) |
901908
Insn::Jonz(_, target) |
@@ -967,7 +974,7 @@ impl<'a> Iterator for InsnOpndIterator<'a> {
967974
Insn::Jne(target) |
968975
Insn::Jnz(target) |
969976
Insn::Jo(target) |
970-
Insn::JoMul(_, target) |
977+
Insn::JoMul(_, _, target) |
971978
Insn::Jz(target) |
972979
Insn::Label(target) |
973980
Insn::LeaJumpTarget { target, .. } |
@@ -1084,6 +1091,7 @@ impl<'a> Iterator for InsnOpndIterator<'a> {
10841091
Insn::Store { dest: opnd0, src: opnd1 } |
10851092
Insn::Sub { left: opnd0, right: opnd1, .. } |
10861093
Insn::Mul { left: opnd0, right: opnd1, .. } |
1094+
Insn::MulHighBits { left: opnd0, right: opnd1, .. } |
10871095
Insn::Test { left: opnd0, right: opnd1 } |
10881096
Insn::URShift { opnd: opnd0, shift: opnd1, .. } |
10891097
Insn::Xor { left: opnd0, right: opnd1, .. } => {
@@ -1192,7 +1200,43 @@ impl<'a> InsnOpndMutIterator<'a> {
11921200
}
11931201
}
11941202

1195-
Insn::JoMul(opnd, target) |
1203+
Insn::JoMul(high, low, target) => {
1204+
match self.idx {
1205+
0 => { self.idx += 1; return Some(high); }
1206+
1 => { self.idx += 1; return Some(low); }
1207+
_ => {}
1208+
}
1209+
1210+
match target {
1211+
Target::SideExit { exit: SideExit { stack, locals, .. }, .. } => {
1212+
let stack_idx = self.idx - 2;
1213+
if stack_idx < stack.len() {
1214+
let opnd = &mut stack[stack_idx];
1215+
self.idx += 1;
1216+
return Some(opnd);
1217+
}
1218+
1219+
let local_idx = stack_idx - stack.len();
1220+
if local_idx < locals.len() {
1221+
let opnd = &mut locals[local_idx];
1222+
self.idx += 1;
1223+
return Some(opnd);
1224+
}
1225+
None
1226+
}
1227+
Target::Block(edge) => {
1228+
let arg_idx = self.idx - 2;
1229+
if arg_idx < edge.args.len() {
1230+
let opnd = &mut edge.args[arg_idx];
1231+
self.idx += 1;
1232+
return Some(opnd);
1233+
}
1234+
None
1235+
}
1236+
_ => None
1237+
}
1238+
}
1239+
11961240
Insn::Joz(opnd, target) |
11971241
Insn::Jonz(opnd, target) => {
11981242
if self.idx == 0 {
@@ -1278,6 +1322,7 @@ impl<'a> InsnOpndMutIterator<'a> {
12781322
Insn::Store { dest: opnd0, src: opnd1 } |
12791323
Insn::Sub { left: opnd0, right: opnd1, .. } |
12801324
Insn::Mul { left: opnd0, right: opnd1, .. } |
1325+
Insn::MulHighBits { left: opnd0, right: opnd1, .. } |
12811326
Insn::Test { left: opnd0, right: opnd1 } |
12821327
Insn::URShift { opnd: opnd0, shift: opnd1, .. } |
12831328
Insn::Xor { left: opnd0, right: opnd1, .. } => {
@@ -1799,7 +1844,7 @@ impl Assembler
17991844
Insn::Jbe(Target::Block(edge)) => Insn::Jbe(Target::Label(process_edge(edge))),
18001845
Insn::Jb(Target::Block(edge)) => Insn::Jb(Target::Label(process_edge(edge))),
18011846
Insn::Jo(Target::Block(edge)) => Insn::Jo(Target::Label(process_edge(edge))),
1802-
Insn::JoMul(opnd, Target::Block(edge)) => Insn::JoMul(*opnd, Target::Label(process_edge(edge))),
1847+
Insn::JoMul(high, low, Target::Block(edge)) => Insn::JoMul(*high, *low, Target::Label(process_edge(edge))),
18031848
Insn::Joz(opnd, Target::Block(edge)) => Insn::Joz(*opnd, Target::Label(process_edge(edge))),
18041849
Insn::Jonz(opnd, Target::Block(edge)) => Insn::Jonz(*opnd, Target::Label(process_edge(edge))),
18051850
_ => insn.clone()
@@ -2453,7 +2498,9 @@ impl fmt::Display for Assembler {
24532498
// If the instruction has a SideExit, avoid using opnd_iter(), which has stack/locals.
24542499
// Here, only handle instructions that have both Opnd and Target.
24552500
match insn {
2456-
Insn::JoMul(opnd, _) |
2501+
Insn::JoMul(high, low, _) => {
2502+
write!(f, ", {high}, {low}")?;
2503+
}
24572504
Insn::Joz(opnd, _) |
24582505
Insn::Jonz(opnd, _) |
24592506
Insn::LeaJumpTarget { out: opnd, target: _ } => {
@@ -2465,7 +2512,9 @@ impl fmt::Display for Assembler {
24652512
// If the instruction has a Block target, avoid using opnd_iter() for branch args
24662513
// since they're already printed inline with the target. Only print non-target operands.
24672514
match insn {
2468-
Insn::JoMul(opnd, _) |
2515+
Insn::JoMul(high, low, _) => {
2516+
write!(f, ", {high}, {low}")?;
2517+
}
24692518
Insn::Joz(opnd, _) |
24702519
Insn::Jonz(opnd, _) |
24712520
Insn::LeaJumpTarget { out: opnd, target: _ } => {
@@ -2790,8 +2839,8 @@ impl Assembler {
27902839
self.push_insn(Insn::Jo(target));
27912840
}
27922841

2793-
pub fn jo_mul(&mut self, val: Opnd, target: Target) {
2794-
self.push_insn(Insn::JoMul(val, target));
2842+
pub fn jo_mul(&mut self, high: Opnd, low: Opnd, target: Target) {
2843+
self.push_insn(Insn::JoMul(high, low, target));
27952844
}
27962845

27972846
pub fn jz(&mut self, target: Target) {
@@ -2919,6 +2968,15 @@ impl Assembler {
29192968
out
29202969
}
29212970

2971+
/// High 64 bits of signed 128-bit multiply. On ARM64 this emits smulh;
2972+
/// on x86 it's a no-op since imul sets the overflow flag directly.
2973+
#[must_use]
2974+
pub fn mul_high_bits(&mut self, left: Opnd, right: Opnd) -> Opnd {
2975+
let out = self.new_vreg(Opnd::match_num_bits(&[left, right]));
2976+
self.push_insn(Insn::MulHighBits { left, right, out });
2977+
out
2978+
}
2979+
29222980
pub fn test(&mut self, left: Opnd, right: Opnd) {
29232981
self.push_insn(Insn::Test { left, right });
29242982
}

zjit/src/backend/x86_64/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,8 @@ impl Assembler {
785785
Insn::Mul { left, right, .. } => {
786786
imul(cb, left.into(), right.into());
787787
},
788+
// x86 imul sets OF directly, no high bits needed
789+
Insn::MulHighBits { .. } => {},
788790

789791
Insn::And { left, right, .. } => {
790792
and(cb, left.into(), right.into());
@@ -1007,7 +1009,7 @@ impl Assembler {
10071009
}
10081010

10091011
Insn::Jo(target) |
1010-
Insn::JoMul(_, target) => {
1012+
Insn::JoMul(_, _, target) => {
10111013
match *target {
10121014
Target::CodePtr(code_ptr) => jo_ptr(cb, code_ptr),
10131015
Target::Label(label) => jo_label(cb, label),

zjit/src/codegen.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1984,8 +1984,10 @@ fn gen_fixnum_mult(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, rig
19841984
let right_untag = asm.sub(right, Opnd::UImm(1));
19851985
let out_val = asm.mul(left_untag, right_untag);
19861986

1987-
// Test for overflow (on ARM64, JoMul uses out_val for barrel-shifted cmp)
1988-
asm.jo_mul(out_val, side_exit(jit, state, FixnumMultOverflow));
1987+
// Test for overflow: on ARM64, compare smulh (high bits) with sign-extended
1988+
// low bits. On x86, mul_high_bits is a no-op and jo_mul just reads OF.
1989+
let high = asm.mul_high_bits(left_untag, right_untag);
1990+
asm.jo_mul(high, out_val, side_exit(jit, state, FixnumMultOverflow));
19891991
asm.add(out_val, Opnd::UImm(1))
19901992
}
19911993

0 commit comments

Comments
 (0)