Skip to content

Commit 126b263

Browse files
committed
executable to use cost model
1 parent 203a491 commit 126b263

2 files changed

Lines changed: 162 additions & 0 deletions

File tree

cranelift/codegen/Cargo.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@ features = ["all-arch"]
2424
anyhow = { workspace = true, optional = true, features = ['std'] }
2525
bumpalo = "3"
2626
capstone = { workspace = true, optional = true }
27+
env_logger = { workspace = true, optional = true }
2728
cranelift-assembler-x64 = { workspace = true }
2829
cranelift-codegen-shared = { path = "./shared", version = "0.128.0" }
2930
cranelift-entity = { workspace = true }
3031
cranelift-bforest = { workspace = true }
3132
cranelift-bitset = { workspace = true }
3233
cranelift-control = { workspace = true }
34+
cranelift-isle = { path = "../isle/isle", version = "=0.128.0", optional = true }
3335
hashbrown = { workspace = true }
3436
target-lexicon = { workspace = true }
3537
log = { workspace = true }
@@ -124,7 +126,15 @@ souper-harvest = ["souper-ir", "souper-ir/stringify"]
124126
# Report any ISLE errors in pretty-printed style.
125127
isle-errors = ["cranelift-isle/fancy-errors"]
126128

129+
# Enable the `isle_cost_check` binary in `src/bin/`.
130+
isle-cost-check = ["dep:cranelift-isle", "dep:env_logger"]
131+
127132
# Enable tracking how long passes take in Cranelift.
128133
#
129134
# Enabled by default.
130135
timing = []
136+
137+
[[bin]]
138+
name = "isle_cost_check"
139+
path = "src/bin/isle_cost_check.rs"
140+
required-features = ["isle-cost-check"]
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
//! ISLE rule cost checker using Cranelift's egraph cost model.
2+
//!
3+
//! This binary uses Cranelift's existing ISLE parser
4+
//! (`cranelift_isle`) to parse a rule and compute a cost for the RHS
5+
//! expression using the same internal egraph cost model as Cranelift itself.
6+
//!
7+
//! Exit status:
8+
//! - 0: acceptable cost
9+
//! - 1: unacceptable cost or parse error
10+
//!
11+
//! Examples (from `wasmtime/` workspace root):
12+
//!
13+
//! echo '(rule (lower (iadd x y)) (iadd y x))' | \
14+
//! cargo +stable run -p cranelift-codegen --features isle-cost-check --bin isle_cost_check
15+
//!
16+
//! # Debug:
17+
//! echo '(rule (simplify x) (imul x x))' | \
18+
//! env RUST_LOG=trace cargo +stable run -p cranelift-codegen --features isle-cost-check --bin isle_cost_check
19+
20+
use cranelift_codegen::ir::Opcode;
21+
use std::io::Read;
22+
23+
// The included file expects `crate::ir::Opcode`. In this binary-crate, define a
24+
// tiny shim module so we can include it verbatim.
25+
mod ir {
26+
pub use cranelift_codegen::ir::Opcode;
27+
}
28+
29+
// Keep everything local to this binary: directly include the internal cost model
30+
// implementation without re-exporting it from the library.
31+
#[allow(dead_code)]
32+
mod cost_model {
33+
include!("../egraph/cost.rs");
34+
35+
impl Cost {
36+
pub fn is_infinite(self) -> bool {
37+
self == Cost::infinity()
38+
}
39+
40+
pub fn op_cost_part(self) -> u32 {
41+
(self.0 & Self::OP_COST_MASK) >> Self::DEPTH_BITS
42+
}
43+
44+
pub fn depth_part(self) -> u8 {
45+
let depth = self.0 & Self::DEPTH_MASK;
46+
depth as u8
47+
}
48+
49+
pub fn bump_depth(self) -> Self {
50+
Cost::new(self.op_cost(), self.depth().saturating_add(1))
51+
}
52+
}
53+
}
54+
55+
fn main() {
56+
// Match the logging style used across Cranelift tools: use `log` macros
57+
// gated by `RUST_LOG`, with `env_logger` initialization.
58+
let _ = env_logger::try_init();
59+
60+
// Read stdin unconditionally.
61+
let mut input = String::new();
62+
if let Err(e) = std::io::stdin().read_to_string(&mut input) {
63+
eprintln!("failed to read stdin: {e}");
64+
std::process::exit(1);
65+
}
66+
67+
// Parse the first `(rule ...)` in the input.
68+
let lexer = cranelift_isle::lexer::Lexer::new(0, &input).expect("ISLE lex error");
69+
let defs = cranelift_isle::parser::parse(lexer).expect("ISLE parse error");
70+
let rule = defs
71+
.into_iter()
72+
.find_map(|d| match d {
73+
cranelift_isle::ast::Def::Rule(r) => Some(r),
74+
_ => None,
75+
})
76+
.expect("no `(rule ...)` found in input");
77+
78+
let lhs = pattern_cost(&rule.pattern);
79+
let rhs = expr_cost(&rule.expr);
80+
81+
cranelift_codegen::trace!(
82+
"rule_cost: lhs op_cost={} depth={} infinite={} | rhs op_cost={} depth={} infinite={}",
83+
lhs.op_cost_part(),
84+
lhs.depth_part(),
85+
lhs.is_infinite(),
86+
rhs.op_cost_part(),
87+
rhs.depth_part(),
88+
rhs.is_infinite(),
89+
);
90+
91+
// Deoptimization check: RHS should not be more expensive than LHS.
92+
if rhs > lhs {
93+
std::process::exit(1);
94+
}
95+
}
96+
97+
fn pattern_cost(pat: &cranelift_isle::ast::Pattern) -> cost_model::Cost {
98+
use cost_model::Cost;
99+
use cranelift_isle::ast::Pattern::*;
100+
match pat {
101+
Var { .. } | Wildcard { .. } | MacroArg { .. } => Cost::zero(),
102+
103+
ConstBool { .. } | ConstInt { .. } => {
104+
Cost::of_pure_op(ir::Opcode::Iconst, std::iter::empty())
105+
}
106+
ConstPrim { .. } => Cost::zero(),
107+
108+
BindPattern { subpat, .. } => pattern_cost(subpat),
109+
110+
And { subpats, .. } => subpats.iter().map(pattern_cost).sum::<Cost>().bump_depth(),
111+
112+
Term { sym, args, .. } => {
113+
let head = sym.0.as_str();
114+
if let Ok(op) = head.parse::<Opcode>() {
115+
return Cost::of_pure_op(op, args.iter().map(pattern_cost));
116+
}
117+
args.iter().map(pattern_cost).sum::<Cost>().bump_depth()
118+
}
119+
}
120+
}
121+
122+
fn expr_cost(expr: &cranelift_isle::ast::Expr) -> cost_model::Cost {
123+
use cost_model::Cost;
124+
use cranelift_isle::ast::Expr::*;
125+
match expr {
126+
Var { .. } | ConstPrim { .. } => Cost::zero(),
127+
ConstBool { .. } | ConstInt { .. } => {
128+
// Approximate constants with the cheapest Cranelift constant op.
129+
Cost::of_pure_op(ir::Opcode::Iconst, std::iter::empty())
130+
}
131+
Let { defs, body, .. } => {
132+
let c = defs
133+
.iter()
134+
.map(|d| expr_cost(&d.val))
135+
.chain(std::iter::once(expr_cost(body)))
136+
.sum::<Cost>();
137+
c.bump_depth()
138+
}
139+
Term { sym, args, .. } => {
140+
let head = sym.0.as_str();
141+
142+
// Try to interpret the head symbol as a Cranelift opcode.
143+
if let Ok(op) = head.parse::<Opcode>() {
144+
return Cost::of_pure_op(op, args.iter().map(expr_cost));
145+
}
146+
147+
// Otherwise, treat it as an ISLE constructor/helper and just propagate
148+
// operand costs (still bumping depth to penalize nesting).
149+
args.iter().map(expr_cost).sum::<Cost>().bump_depth()
150+
}
151+
}
152+
}

0 commit comments

Comments
 (0)