Skip to content

Commit 8fdb99a

Browse files
authored
Add register repacking (#413)
This PR sorts bytecode registers by frequency of use. The goal is to optimize GPU evaluators, which use a select chain to dynamically pick a register from an array.
1 parent 1735e79 commit 8fdb99a

5 files changed

Lines changed: 207 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
code using `anyhow` (or similar) may not need to change
3434
- Fix a bug in bulk evaluator argument checks where mismatched slices could be
3535
allowed under some circumstances
36+
- Add `VmData::asm` to get an immutable reference to the inner `RegTape`
37+
- Add `RegTape::repack_map` and `RegTape::repack` to repack registers by
38+
frequency (making register 0 the most frequently used, etc)
39+
- Add `RegOp::visit_regs` and `RegOp::visit_regs_mut` to visit registers in an
40+
operation
3641

3742
# 0.4.3
3843
- Fixed bug in x86 interval `OR` function ([#395](https://github.com/mkeeter/fidget/pull/395)),

fidget-bytecode/src/lib.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,16 @@ impl Bytecode {
189189

190190
/// Builds a new bytecode object from VM data
191191
///
192-
/// Returns an error if the reserved register (255) is in use
192+
/// Registers are reordered by frequency of use, e.g. the most frequently
193+
/// used register becomes register 0.
194+
///
195+
/// Returns an error if the reserved register (255) is in use, which should
196+
/// only happen if the incoming tape has 256 active registers.
193197
pub fn new<const N: usize>(
194198
t: &VmData<N>,
195199
) -> Result<Self, ReservedRegister> {
200+
// Build a map for repacking registers by frequency
201+
let map = t.asm().repack_map();
196202
// The initial opcode is `OP_JUMP 0x0000_0000`
197203
let mut data = vec![u32::MAX, 0u32];
198204
let mut reg_count = 0u8;
@@ -202,6 +208,7 @@ impl Bytecode {
202208
let mut word = [0xFF; 4];
203209
let mut imm = None;
204210
let mut store_reg = |i, r| {
211+
let r = map[&r];
205212
if r == u8::MAX {
206213
Err(ReservedRegister)
207214
} else {

fidget-core/src/compiler/op.rs

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,159 @@ opcodes!(
288288
Store(u8, u32),
289289
}
290290
);
291+
292+
impl RegOp {
293+
/// Apply a mutating function to every register in the op
294+
///
295+
/// Both inputs and outputs are visited
296+
pub fn visit_regs_mut<F: FnMut(&mut u8)>(&mut self, mut f: F) {
297+
match self {
298+
RegOp::CopyImm(out, imm) => {
299+
let _: f32 = *imm;
300+
f(out)
301+
}
302+
RegOp::NegReg(out, arg)
303+
| RegOp::AbsReg(out, arg)
304+
| RegOp::RecipReg(out, arg)
305+
| RegOp::SqrtReg(out, arg)
306+
| RegOp::SquareReg(out, arg)
307+
| RegOp::FloorReg(out, arg)
308+
| RegOp::CeilReg(out, arg)
309+
| RegOp::RoundReg(out, arg)
310+
| RegOp::CopyReg(out, arg)
311+
| RegOp::SinReg(out, arg)
312+
| RegOp::CosReg(out, arg)
313+
| RegOp::TanReg(out, arg)
314+
| RegOp::AsinReg(out, arg)
315+
| RegOp::AcosReg(out, arg)
316+
| RegOp::AtanReg(out, arg)
317+
| RegOp::ExpReg(out, arg)
318+
| RegOp::LnReg(out, arg)
319+
| RegOp::NotReg(out, arg) => {
320+
f(out);
321+
f(arg);
322+
}
323+
RegOp::AddRegImm(out, arg, imm)
324+
| RegOp::MulRegImm(out, arg, imm)
325+
| RegOp::DivRegImm(out, arg, imm)
326+
| RegOp::DivImmReg(out, arg, imm)
327+
| RegOp::SubImmReg(out, arg, imm)
328+
| RegOp::SubRegImm(out, arg, imm)
329+
| RegOp::AtanRegImm(out, arg, imm)
330+
| RegOp::AtanImmReg(out, arg, imm)
331+
| RegOp::MinRegImm(out, arg, imm)
332+
| RegOp::MaxRegImm(out, arg, imm)
333+
| RegOp::CompareRegImm(out, arg, imm)
334+
| RegOp::CompareImmReg(out, arg, imm)
335+
| RegOp::ModRegImm(out, arg, imm)
336+
| RegOp::ModImmReg(out, arg, imm)
337+
| RegOp::AndRegImm(out, arg, imm)
338+
| RegOp::OrRegImm(out, arg, imm) => {
339+
let _: f32 = *imm; // type-checking pattern
340+
f(out);
341+
f(arg);
342+
}
343+
344+
RegOp::AddRegReg(out, lhs, rhs)
345+
| RegOp::MulRegReg(out, lhs, rhs)
346+
| RegOp::DivRegReg(out, lhs, rhs)
347+
| RegOp::SubRegReg(out, lhs, rhs)
348+
| RegOp::AtanRegReg(out, lhs, rhs)
349+
| RegOp::MinRegReg(out, lhs, rhs)
350+
| RegOp::MaxRegReg(out, lhs, rhs)
351+
| RegOp::CompareRegReg(out, lhs, rhs)
352+
| RegOp::ModRegReg(out, lhs, rhs)
353+
| RegOp::AndRegReg(out, lhs, rhs)
354+
| RegOp::OrRegReg(out, lhs, rhs) => {
355+
f(out);
356+
f(lhs);
357+
f(rhs);
358+
}
359+
360+
RegOp::Output(reg, imm)
361+
| RegOp::Input(reg, imm)
362+
| RegOp::Store(reg, imm)
363+
| RegOp::Load(reg, imm) => {
364+
let _: u32 = *imm; // type-checking pattern
365+
f(reg)
366+
}
367+
}
368+
}
369+
370+
/// Apply a function to every register in the op
371+
///
372+
/// Both inputs and outputs are visited
373+
pub fn visit_regs<F: FnMut(u8)>(&self, mut f: F) {
374+
match self {
375+
RegOp::CopyImm(out, imm) => {
376+
let _: f32 = *imm;
377+
f(*out)
378+
}
379+
RegOp::NegReg(out, arg)
380+
| RegOp::AbsReg(out, arg)
381+
| RegOp::RecipReg(out, arg)
382+
| RegOp::SqrtReg(out, arg)
383+
| RegOp::SquareReg(out, arg)
384+
| RegOp::FloorReg(out, arg)
385+
| RegOp::CeilReg(out, arg)
386+
| RegOp::RoundReg(out, arg)
387+
| RegOp::CopyReg(out, arg)
388+
| RegOp::SinReg(out, arg)
389+
| RegOp::CosReg(out, arg)
390+
| RegOp::TanReg(out, arg)
391+
| RegOp::AsinReg(out, arg)
392+
| RegOp::AcosReg(out, arg)
393+
| RegOp::AtanReg(out, arg)
394+
| RegOp::ExpReg(out, arg)
395+
| RegOp::LnReg(out, arg)
396+
| RegOp::NotReg(out, arg) => {
397+
f(*out);
398+
f(*arg);
399+
}
400+
RegOp::AddRegImm(out, arg, imm)
401+
| RegOp::MulRegImm(out, arg, imm)
402+
| RegOp::DivRegImm(out, arg, imm)
403+
| RegOp::DivImmReg(out, arg, imm)
404+
| RegOp::SubImmReg(out, arg, imm)
405+
| RegOp::SubRegImm(out, arg, imm)
406+
| RegOp::AtanRegImm(out, arg, imm)
407+
| RegOp::AtanImmReg(out, arg, imm)
408+
| RegOp::MinRegImm(out, arg, imm)
409+
| RegOp::MaxRegImm(out, arg, imm)
410+
| RegOp::CompareRegImm(out, arg, imm)
411+
| RegOp::CompareImmReg(out, arg, imm)
412+
| RegOp::ModRegImm(out, arg, imm)
413+
| RegOp::ModImmReg(out, arg, imm)
414+
| RegOp::AndRegImm(out, arg, imm)
415+
| RegOp::OrRegImm(out, arg, imm) => {
416+
let _: f32 = *imm; // type-checking pattern
417+
f(*out);
418+
f(*arg);
419+
}
420+
421+
RegOp::AddRegReg(out, lhs, rhs)
422+
| RegOp::MulRegReg(out, lhs, rhs)
423+
| RegOp::DivRegReg(out, lhs, rhs)
424+
| RegOp::SubRegReg(out, lhs, rhs)
425+
| RegOp::AtanRegReg(out, lhs, rhs)
426+
| RegOp::MinRegReg(out, lhs, rhs)
427+
| RegOp::MaxRegReg(out, lhs, rhs)
428+
| RegOp::CompareRegReg(out, lhs, rhs)
429+
| RegOp::ModRegReg(out, lhs, rhs)
430+
| RegOp::AndRegReg(out, lhs, rhs)
431+
| RegOp::OrRegReg(out, lhs, rhs) => {
432+
f(*out);
433+
f(*lhs);
434+
f(*rhs);
435+
}
436+
437+
RegOp::Output(reg, imm)
438+
| RegOp::Input(reg, imm)
439+
| RegOp::Store(reg, imm)
440+
| RegOp::Load(reg, imm) => {
441+
let _: u32 = *imm; // type-checking pattern
442+
f(*reg)
443+
}
444+
}
445+
}
446+
}

fidget-core/src/compiler/reg_tape.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Tape used for evaluation
22
use crate::compiler::{RegOp, RegisterAllocator, SsaTape};
33
use serde::{Deserialize, Serialize};
4+
use std::collections::HashMap;
45

56
/// Low-level tape for use with the Fidget virtual machine (or to be lowered
67
/// further into machine instructions).
@@ -9,6 +10,9 @@ pub struct RegTape {
910
tape: Vec<RegOp>,
1011

1112
/// Total allocated slots
13+
///
14+
/// This is a continuous space of registers (`0..N`) and memory (`N..`),
15+
/// where `N` is the parameter in [`RegTape::new`].
1216
pub(super) slot_count: u32,
1317
}
1418

@@ -27,6 +31,35 @@ impl RegTape {
2731
alloc.finalize()
2832
}
2933

34+
/// Repacks registers by frequency (so that register 0 is the most frequent)
35+
pub fn repack(&mut self) {
36+
let map = self.repack_map();
37+
for op in &mut self.tape {
38+
op.visit_regs_mut(|reg| *reg = map[reg]);
39+
}
40+
}
41+
42+
/// Returns a map for register repacking
43+
///
44+
/// The map repacks registers in the tape by frequency, so that register 0
45+
/// is the most frequent.
46+
pub fn repack_map(&self) -> HashMap<u8, u8> {
47+
let mut reg_counts: HashMap<u8, usize> = HashMap::new();
48+
for op in &self.tape {
49+
op.visit_regs(|reg| *reg_counts.entry(reg).or_default() += 1);
50+
}
51+
let mut sorted = reg_counts
52+
.into_iter()
53+
.map(|(reg, count)| (std::cmp::Reverse(count), reg))
54+
.collect::<Vec<_>>();
55+
sorted.sort_unstable();
56+
sorted
57+
.into_iter()
58+
.enumerate()
59+
.map(|(i, (_count, reg))| (reg, u8::try_from(i).unwrap()))
60+
.collect()
61+
}
62+
3063
/// Builds a new empty tape
3164
pub(crate) fn empty() -> Self {
3265
Self {

fidget-core/src/vm/data.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,11 @@ impl<const N: usize> VmData<N> {
318318
self.asm.iter().cloned().rev()
319319
}
320320

321+
/// Returns a reference to the inner [`RegTape`]
322+
pub fn asm(&self) -> &RegTape {
323+
&self.asm
324+
}
325+
321326
/// Pretty-prints the inner SSA tape
322327
pub fn pretty_print(&self) {
323328
self.ssa.pretty_print();

0 commit comments

Comments
 (0)