Skip to content

Commit f1d439a

Browse files
committed
New MIR Pass: SsaRangePropagation
1 parent 0ac9e59 commit f1d439a

9 files changed

Lines changed: 509 additions & 0 deletions

compiler/rustc_middle/src/mir/statement.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,12 @@ impl<'tcx> Place<'tcx> {
374374
self.projection.iter().any(|elem| elem.is_indirect())
375375
}
376376

377+
/// Returns `true` if the `Place` always refers to the same memory region
378+
/// whatever the state of the program.
379+
pub fn is_stable_offset(&self) -> bool {
380+
self.projection.iter().all(|elem| elem.is_stable_offset())
381+
}
382+
377383
/// Returns `true` if this `Place`'s first projection is `Deref`.
378384
///
379385
/// This is useful because for MIR phases `AnalysisPhase::PostCleanup` and later,

compiler/rustc_mir_transform/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ declare_passes! {
198198
mod single_use_consts : SingleUseConsts;
199199
mod sroa : ScalarReplacementOfAggregates;
200200
mod strip_debuginfo : StripDebugInfo;
201+
mod ssa_range_prop: SsaRangePropagation;
201202
mod unreachable_enum_branching : UnreachableEnumBranching;
202203
mod unreachable_prop : UnreachablePropagation;
203204
mod validate : Validator;
@@ -741,6 +742,7 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'
741742
&dead_store_elimination::DeadStoreElimination::Initial,
742743
&gvn::GVN,
743744
&simplify::SimplifyLocals::AfterGVN,
745+
&ssa_range_prop::SsaRangePropagation,
744746
&match_branches::MatchBranchSimplification,
745747
&dataflow_const_prop::DataflowConstProp,
746748
&single_use_consts::SingleUseConsts,
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
//! A pass that propagates the known ranges of SSA locals.
2+
//! We can know the ranges of SSA locals at some locations for the following code:
3+
//! ```
4+
//! fn foo(a: u32) {
5+
//! let b = a < 9;
6+
//! if b {
7+
//! let c = b; // c is true since b is whitin the range [1, 2)
8+
//! let d = a < 8; // d is true since b whitin the range [0, 9)
9+
//! }
10+
//! }
11+
//! ```
12+
use rustc_abi::WrappingRange;
13+
use rustc_const_eval::interpret::Scalar;
14+
use rustc_data_structures::fx::FxHashMap;
15+
use rustc_data_structures::graph::dominators::Dominators;
16+
use rustc_index::bit_set::DenseBitSet;
17+
use rustc_middle::mir::visit::MutVisitor;
18+
use rustc_middle::mir::{BasicBlock, Body, Location, Operand, Place, TerminatorKind, *};
19+
use rustc_middle::ty::{TyCtxt, TypingEnv};
20+
use rustc_span::DUMMY_SP;
21+
22+
use crate::ssa::SsaLocals;
23+
24+
pub(super) struct SsaRangePropagation;
25+
26+
impl<'tcx> crate::MirPass<'tcx> for SsaRangePropagation {
27+
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
28+
sess.mir_opt_level() > 1
29+
}
30+
31+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
32+
let typing_env = body.typing_env(tcx);
33+
let ssa = SsaLocals::new(tcx, body, typing_env);
34+
// Clone dominators because we need them while mutating the body.
35+
let dominators = body.basic_blocks.dominators().clone();
36+
let mut range_set =
37+
RangeSet::new(tcx, typing_env, body, &ssa, &body.local_decls, dominators);
38+
39+
let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec();
40+
for bb in reverse_postorder {
41+
let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb];
42+
range_set.visit_basic_block_data(bb, data);
43+
}
44+
}
45+
46+
fn is_required(&self) -> bool {
47+
false
48+
}
49+
}
50+
51+
struct RangeSet<'tcx, 'body, 'a> {
52+
tcx: TyCtxt<'tcx>,
53+
typing_env: TypingEnv<'tcx>,
54+
ssa: &'a SsaLocals,
55+
local_decls: &'body LocalDecls<'tcx>,
56+
dominators: Dominators<BasicBlock>,
57+
/// Known ranges at each locations.
58+
ranges: FxHashMap<Place<'tcx>, Vec<(Location, WrappingRange)>>,
59+
/// Determines if the basic block has a single unique predecessor.
60+
unique_predecessors: DenseBitSet<BasicBlock>,
61+
}
62+
63+
impl<'tcx, 'body, 'a> RangeSet<'tcx, 'body, 'a> {
64+
fn new(
65+
tcx: TyCtxt<'tcx>,
66+
typing_env: TypingEnv<'tcx>,
67+
body: &Body<'tcx>,
68+
ssa: &'a SsaLocals,
69+
local_decls: &'body LocalDecls<'tcx>,
70+
dominators: Dominators<BasicBlock>,
71+
) -> Self {
72+
let predecessors = body.basic_blocks.predecessors();
73+
let mut unique_predecessors = DenseBitSet::new_empty(body.basic_blocks.len());
74+
for bb in body.basic_blocks.indices() {
75+
if predecessors[bb].len() == 1 {
76+
unique_predecessors.insert(bb);
77+
}
78+
}
79+
RangeSet {
80+
tcx,
81+
typing_env,
82+
ssa,
83+
local_decls,
84+
dominators,
85+
ranges: FxHashMap::default(),
86+
unique_predecessors,
87+
}
88+
}
89+
90+
/// Create a new known range at the location.
91+
fn insert_range(&mut self, place: Place<'tcx>, location: Location, range: WrappingRange) {
92+
self.ranges.entry(place).or_default().push((location, range));
93+
}
94+
95+
/// Get the known range at the location.
96+
fn get_range(&self, place: &Place<'tcx>, location: Location) -> Option<WrappingRange> {
97+
let Some(ranges) = self.ranges.get(place) else {
98+
return None;
99+
};
100+
// FIXME: This should use the intersection of all valid ranges.
101+
let (_, range) =
102+
ranges.iter().find(|(range_loc, _)| range_loc.dominates(location, &self.dominators))?;
103+
Some(*range)
104+
}
105+
106+
fn try_as_constant(
107+
&mut self,
108+
place: Place<'tcx>,
109+
location: Location,
110+
) -> Option<ConstOperand<'tcx>> {
111+
if let Some(range) = self.get_range(&place, location)
112+
&& range.start == range.end
113+
{
114+
let ty = place.ty(self.local_decls, self.tcx).ty;
115+
let layout = self.tcx.layout_of(self.typing_env.as_query_input(ty)).ok()?;
116+
let value = ConstValue::Scalar(Scalar::from_uint(range.start, layout.size));
117+
let const_ = Const::Val(value, ty);
118+
return Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_ });
119+
}
120+
None
121+
}
122+
123+
/// Attempts to simplify an operand to a constant value.
124+
///
125+
/// Returns
126+
/// - `Ok(())` if the operand is or can be simplified to a constant.
127+
/// - `Err(Some(place))` if simplification fails for an SSA local.
128+
/// - `Err(None)` if simplification fails with no further optimization possible.
129+
fn simplify_operand(
130+
&mut self,
131+
operand: &mut Operand<'tcx>,
132+
location: Location,
133+
) -> Result<(), Option<Place<'tcx>>> {
134+
let Some(place) = operand.place() else {
135+
return Ok(());
136+
};
137+
let Some(const_) = self.try_as_constant(place, location) else {
138+
if self.is_ssa(place) {
139+
return Err(Some(place));
140+
} else {
141+
return Err(None);
142+
}
143+
};
144+
*operand = Operand::Constant(Box::new(const_));
145+
Ok(())
146+
}
147+
148+
fn is_ssa(&self, place: Place<'tcx>) -> bool {
149+
self.ssa.is_ssa(place.local) && place.is_stable_offset()
150+
}
151+
}
152+
153+
impl<'tcx> MutVisitor<'tcx> for RangeSet<'tcx, '_, '_> {
154+
fn tcx(&self) -> TyCtxt<'tcx> {
155+
self.tcx
156+
}
157+
158+
fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) {
159+
let _ = self.simplify_operand(operand, location);
160+
}
161+
162+
fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
163+
match &mut terminator.kind {
164+
TerminatorKind::Assert { cond, expected, target, .. } => {
165+
if let Err(Some(place)) = self.simplify_operand(cond, location) {
166+
let successor = Location { block: *target, statement_index: 0 };
167+
if location.block != successor.block
168+
&& self.unique_predecessors.contains(successor.block)
169+
{
170+
let val = *expected as u128;
171+
let range = WrappingRange { start: val, end: val };
172+
self.insert_range(place, successor, range);
173+
}
174+
}
175+
}
176+
TerminatorKind::SwitchInt { discr, targets } => {
177+
if let Err(Some(place)) = self.simplify_operand(discr, location)
178+
&& targets.all_targets().len() < 8
179+
{
180+
let mut distinct_targets: FxHashMap<BasicBlock, u8> = FxHashMap::default();
181+
for (_, target) in targets.iter() {
182+
let targets = distinct_targets.entry(target).or_default();
183+
if *targets == 0 {
184+
*targets = 1;
185+
} else {
186+
*targets = 2;
187+
}
188+
}
189+
for (val, target) in targets.iter() {
190+
if distinct_targets[&target] != 1 {
191+
continue;
192+
}
193+
let successor = Location { block: target, statement_index: 0 };
194+
if location.block != successor.block
195+
&& self.unique_predecessors.contains(successor.block)
196+
{
197+
let range = WrappingRange { start: val, end: val };
198+
self.insert_range(place, successor, range);
199+
}
200+
}
201+
202+
let otherwise = Location { block: targets.otherwise(), statement_index: 0 };
203+
if place.ty(self.local_decls, self.tcx).ty.is_bool()
204+
&& let [val] = targets.all_values()
205+
&& location.block != otherwise.block
206+
&& self.unique_predecessors.contains(otherwise.block)
207+
{
208+
let range = if val.get() == 0 {
209+
WrappingRange { start: 1, end: 1 }
210+
} else {
211+
WrappingRange { start: 0, end: 0 }
212+
};
213+
self.insert_range(place, otherwise, range);
214+
}
215+
}
216+
}
217+
_ => {}
218+
}
219+
}
220+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
- // MIR for `on_assert` before SsaRangePropagation
2+
+ // MIR for `on_assert` after SsaRangePropagation
3+
4+
fn on_assert(_1: usize, _2: &[u8]) -> u8 {
5+
debug i => _1;
6+
debug v => _2;
7+
let mut _0: u8;
8+
let _3: ();
9+
let mut _4: bool;
10+
let mut _5: usize;
11+
let mut _6: usize;
12+
let mut _7: &[u8];
13+
let mut _8: !;
14+
let _9: usize;
15+
let mut _10: usize;
16+
let mut _11: bool;
17+
scope 1 (inlined core::slice::<impl [u8]>::len) {
18+
scope 2 (inlined std::ptr::metadata::<[u8]>) {
19+
}
20+
}
21+
22+
bb0: {
23+
StorageLive(_3);
24+
nop;
25+
StorageLive(_5);
26+
_5 = copy _1;
27+
nop;
28+
StorageLive(_7);
29+
_7 = &(*_2);
30+
_6 = PtrMetadata(copy _2);
31+
StorageDead(_7);
32+
_4 = Lt(copy _1, copy _6);
33+
switchInt(copy _4) -> [0: bb2, otherwise: bb1];
34+
}
35+
36+
bb1: {
37+
nop;
38+
StorageDead(_5);
39+
_3 = const ();
40+
nop;
41+
StorageDead(_3);
42+
StorageLive(_9);
43+
_9 = copy _1;
44+
_10 = copy _6;
45+
- _11 = copy _4;
46+
- assert(copy _4, "index out of bounds: the length is {} but the index is {}", copy _6, copy _1) -> [success: bb3, unwind unreachable];
47+
+ _11 = const true;
48+
+ assert(const true, "index out of bounds: the length is {} but the index is {}", copy _6, copy _1) -> [success: bb3, unwind unreachable];
49+
}
50+
51+
bb2: {
52+
nop;
53+
StorageDead(_5);
54+
StorageLive(_8);
55+
_8 = panic(const "assertion failed: i < v.len()") -> unwind unreachable;
56+
}
57+
58+
bb3: {
59+
_0 = copy (*_2)[_1];
60+
StorageDead(_9);
61+
return;
62+
}
63+
}
64+
65+
ALLOC0 (size: 29, align: 1) {
66+
0x00 │ 61 73 73 65 72 74 69 6f 6e 20 66 61 69 6c 65 64 │ assertion failed
67+
0x10 │ 3a 20 69 20 3c 20 76 2e 6c 65 6e 28 29 │ : i < v.len()
68+
}
69+
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
- // MIR for `on_if` before SsaRangePropagation
2+
+ // MIR for `on_if` after SsaRangePropagation
3+
4+
fn on_if(_1: usize, _2: &[u8]) -> u8 {
5+
debug i => _1;
6+
debug v => _2;
7+
let mut _0: u8;
8+
let mut _3: bool;
9+
let mut _4: usize;
10+
let mut _5: usize;
11+
let mut _6: &[u8];
12+
let _7: usize;
13+
let mut _8: usize;
14+
let mut _9: bool;
15+
scope 1 (inlined core::slice::<impl [u8]>::len) {
16+
scope 2 (inlined std::ptr::metadata::<[u8]>) {
17+
}
18+
}
19+
20+
bb0: {
21+
nop;
22+
StorageLive(_4);
23+
_4 = copy _1;
24+
nop;
25+
StorageLive(_6);
26+
_6 = &(*_2);
27+
_5 = PtrMetadata(copy _2);
28+
StorageDead(_6);
29+
_3 = Lt(copy _1, copy _5);
30+
switchInt(copy _3) -> [0: bb3, otherwise: bb1];
31+
}
32+
33+
bb1: {
34+
nop;
35+
StorageDead(_4);
36+
StorageLive(_7);
37+
_7 = copy _1;
38+
_8 = copy _5;
39+
- _9 = copy _3;
40+
- assert(copy _3, "index out of bounds: the length is {} but the index is {}", copy _5, copy _1) -> [success: bb2, unwind unreachable];
41+
+ _9 = const true;
42+
+ assert(const true, "index out of bounds: the length is {} but the index is {}", copy _5, copy _1) -> [success: bb2, unwind unreachable];
43+
}
44+
45+
bb2: {
46+
_0 = copy (*_2)[_1];
47+
StorageDead(_7);
48+
goto -> bb4;
49+
}
50+
51+
bb3: {
52+
nop;
53+
StorageDead(_4);
54+
_0 = const 0_u8;
55+
goto -> bb4;
56+
}
57+
58+
bb4: {
59+
nop;
60+
return;
61+
}
62+
}
63+

0 commit comments

Comments
 (0)