Skip to content

Commit edacb2e

Browse files
Auto merge of #150309 - dianqk:ssa-range, r=<try>
[EXPERIMENT] New MIR Pass: SsaRangePropagation
2 parents 99ff3fb + 230527f commit edacb2e

10 files changed

Lines changed: 572 additions & 0 deletions

compiler/rustc_middle/src/mir/statement.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,12 @@ impl<'tcx> Place<'tcx> {
374374
self.projection.iter().any(|elem| elem.is_indirect())
375375
}
376376

377+
/// Returns `true` if the `Place` always refers to the same memory region
378+
/// whatever the state of the program.
379+
pub fn is_stable_offset(&self) -> bool {
380+
self.projection.iter().all(|elem| elem.is_stable_offset())
381+
}
382+
377383
/// Returns `true` if this `Place`'s first projection is `Deref`.
378384
///
379385
/// This is useful because for MIR phases `AnalysisPhase::PostCleanup` and later,

compiler/rustc_mir_transform/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ declare_passes! {
198198
mod single_use_consts : SingleUseConsts;
199199
mod sroa : ScalarReplacementOfAggregates;
200200
mod strip_debuginfo : StripDebugInfo;
201+
mod ssa_range_prop: SsaRangePropagation;
201202
mod unreachable_enum_branching : UnreachableEnumBranching;
202203
mod unreachable_prop : UnreachablePropagation;
203204
mod validate : Validator;
@@ -743,6 +744,7 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'
743744
&simplify::SimplifyLocals::AfterGVN,
744745
&match_branches::MatchBranchSimplification,
745746
&dataflow_const_prop::DataflowConstProp,
747+
&ssa_range_prop::SsaRangePropagation,
746748
&single_use_consts::SingleUseConsts,
747749
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),
748750
&jump_threading::JumpThreading,
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
use rustc_abi::WrappingRange;
2+
use rustc_const_eval::interpret::Scalar;
3+
use rustc_data_structures::fx::FxHashMap;
4+
use rustc_data_structures::graph::dominators::Dominators;
5+
use rustc_index::bit_set::DenseBitSet;
6+
use rustc_middle::mir::visit::MutVisitor;
7+
use rustc_middle::mir::{BasicBlock, Body, Location, Operand, Place, TerminatorKind, *};
8+
use rustc_middle::ty::{TyCtxt, TypingEnv};
9+
use rustc_span::DUMMY_SP;
10+
11+
use crate::ssa::SsaLocals;
12+
13+
pub(super) struct SsaRangePropagation;
14+
15+
impl<'tcx> crate::MirPass<'tcx> for SsaRangePropagation {
16+
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
17+
sess.mir_opt_level() > 1
18+
}
19+
20+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
21+
let typing_env = body.typing_env(tcx);
22+
let ssa = SsaLocals::new(tcx, body, typing_env);
23+
// Clone dominators because we need them while mutating the body.
24+
let dominators = body.basic_blocks.dominators().clone();
25+
let mut range_set =
26+
RangeSet::new(tcx, typing_env, body, &ssa, &body.local_decls, dominators);
27+
28+
let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec();
29+
for bb in reverse_postorder {
30+
let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb];
31+
range_set.visit_basic_block_data(bb, data);
32+
}
33+
}
34+
35+
fn is_required(&self) -> bool {
36+
false
37+
}
38+
}
39+
40+
struct RangeSet<'tcx, 'body, 'a> {
41+
tcx: TyCtxt<'tcx>,
42+
typing_env: TypingEnv<'tcx>,
43+
ssa: &'a SsaLocals,
44+
local_decls: &'body LocalDecls<'tcx>,
45+
dominators: Dominators<BasicBlock>,
46+
/// Known ranges at each locations.
47+
ranges: FxHashMap<Place<'tcx>, Vec<(Location, WrappingRange)>>,
48+
/// Determines if the basic block has a single unique predecessor.
49+
unique_predecessors: DenseBitSet<BasicBlock>,
50+
}
51+
52+
impl<'tcx, 'body, 'a> RangeSet<'tcx, 'body, 'a> {
53+
fn new(
54+
tcx: TyCtxt<'tcx>,
55+
typing_env: TypingEnv<'tcx>,
56+
body: &Body<'tcx>,
57+
ssa: &'a SsaLocals,
58+
local_decls: &'body LocalDecls<'tcx>,
59+
dominators: Dominators<BasicBlock>,
60+
) -> Self {
61+
let predecessors = body.basic_blocks.predecessors();
62+
let mut unique_predecessors = DenseBitSet::new_empty(body.basic_blocks.len());
63+
for (bb, _) in body.basic_blocks.iter_enumerated() {
64+
if predecessors[bb].len() == 1 {
65+
unique_predecessors.insert(bb);
66+
}
67+
}
68+
RangeSet {
69+
tcx,
70+
typing_env,
71+
ssa,
72+
local_decls,
73+
dominators,
74+
ranges: FxHashMap::default(),
75+
unique_predecessors,
76+
}
77+
}
78+
79+
/// Create a new known range at the location.
80+
fn insert_range(&mut self, place: Place<'tcx>, location: Location, range: WrappingRange) {
81+
self.ranges.entry(place).or_default().push((location, range));
82+
}
83+
84+
/// Get the known range at the location.
85+
fn get_range(&self, place: &Place<'tcx>, location: Location) -> Option<WrappingRange> {
86+
let Some(ranges) = self.ranges.get(place) else {
87+
return None;
88+
};
89+
// FIXME: This should use the intersection of all valid ranges.
90+
let (_, range) =
91+
ranges.iter().find(|(range_loc, _)| range_loc.dominates(location, &self.dominators))?;
92+
Some(*range)
93+
}
94+
95+
fn try_as_constant(
96+
&mut self,
97+
place: Place<'tcx>,
98+
location: Location,
99+
) -> Option<ConstOperand<'tcx>> {
100+
if let Some(range) = self.get_range(&place, location)
101+
&& range.start == range.end
102+
{
103+
let ty = place.ty(self.local_decls, self.tcx).ty;
104+
let layout = self.tcx.layout_of(self.typing_env.as_query_input(ty)).ok()?;
105+
let value = ConstValue::Scalar(Scalar::from_uint(range.start, layout.size));
106+
let const_ = Const::Val(value, ty);
107+
return Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_ });
108+
}
109+
None
110+
}
111+
112+
fn simplify_operand(
113+
&mut self,
114+
operand: &mut Operand<'tcx>,
115+
location: Location,
116+
) -> Result<(), Option<Place<'tcx>>> {
117+
let Some(place) = operand.place() else {
118+
return Ok(());
119+
};
120+
let Some(const_) = self.try_as_constant(place, location) else {
121+
if self.is_ssa(place) {
122+
return Err(Some(place));
123+
} else {
124+
return Err(None);
125+
}
126+
};
127+
*operand = Operand::Constant(Box::new(const_));
128+
Ok(())
129+
}
130+
131+
fn is_ssa(&self, place: Place<'tcx>) -> bool {
132+
self.ssa.is_ssa(place.local) && place.is_stable_offset()
133+
}
134+
}
135+
136+
impl<'tcx> MutVisitor<'tcx> for RangeSet<'tcx, '_, '_> {
137+
fn tcx(&self) -> TyCtxt<'tcx> {
138+
self.tcx
139+
}
140+
141+
fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) {
142+
let _ = self.simplify_operand(operand, location);
143+
}
144+
145+
fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
146+
match &mut statement.kind {
147+
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(operand)) => {
148+
if let Err(Some(place)) = self.simplify_operand(operand, location) {
149+
let successor = location.successor_within_block();
150+
let range = WrappingRange { start: 1, end: 1 };
151+
self.insert_range(place, successor, range);
152+
}
153+
}
154+
_ => {}
155+
}
156+
self.super_statement(statement, location);
157+
}
158+
159+
fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
160+
match &mut terminator.kind {
161+
TerminatorKind::Assert { cond, expected, target, .. } => {
162+
if let Err(Some(place)) = self.simplify_operand(cond, location) {
163+
let successor = Location { block: *target, statement_index: 0 };
164+
if location.block != successor.block
165+
&& self.unique_predecessors.contains(successor.block)
166+
{
167+
let val = *expected as u128;
168+
let range = WrappingRange { start: val, end: val };
169+
self.insert_range(place, successor, range);
170+
}
171+
}
172+
}
173+
TerminatorKind::SwitchInt { discr, targets } => {
174+
if let Err(Some(place)) = self.simplify_operand(discr, location)
175+
&& targets.all_targets().len() < 8
176+
{
177+
let mut distinct_targets: FxHashMap<BasicBlock, u8> = FxHashMap::default();
178+
for (_, target) in targets.iter() {
179+
let targets = distinct_targets.entry(target).or_default();
180+
if *targets == 0 {
181+
*targets = 1;
182+
} else {
183+
*targets = 2;
184+
}
185+
}
186+
for (val, target) in targets.iter() {
187+
if distinct_targets[&target] != 1 {
188+
continue;
189+
}
190+
let successor = Location { block: target, statement_index: 0 };
191+
if location.block != successor.block
192+
&& self.unique_predecessors.contains(successor.block)
193+
{
194+
let range = WrappingRange { start: val, end: val };
195+
self.insert_range(place, successor, range);
196+
}
197+
}
198+
199+
let otherwise = Location { block: targets.otherwise(), statement_index: 0 };
200+
if place.ty(self.local_decls, self.tcx).ty.is_bool()
201+
&& let [val] = targets.all_values()
202+
&& location.block != otherwise.block
203+
&& self.unique_predecessors.contains(otherwise.block)
204+
{
205+
let range = if val.get() == 0 {
206+
WrappingRange { start: 1, end: 1 }
207+
} else {
208+
WrappingRange { start: 0, end: 0 }
209+
};
210+
self.insert_range(place, otherwise, range);
211+
}
212+
}
213+
}
214+
_ => {}
215+
}
216+
}
217+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
- // MIR for `on_assert` before SsaRangePropagation
2+
+ // MIR for `on_assert` after SsaRangePropagation
3+
4+
fn on_assert(_1: usize, _2: &[u8]) -> u8 {
5+
debug i => _1;
6+
debug v => _2;
7+
let mut _0: u8;
8+
let _3: ();
9+
let mut _4: bool;
10+
let mut _5: usize;
11+
let mut _6: usize;
12+
let mut _7: &[u8];
13+
let mut _8: !;
14+
let _9: usize;
15+
let mut _10: usize;
16+
let mut _11: bool;
17+
scope 1 (inlined core::slice::<impl [u8]>::len) {
18+
scope 2 (inlined std::ptr::metadata::<[u8]>) {
19+
}
20+
}
21+
22+
bb0: {
23+
StorageLive(_3);
24+
nop;
25+
StorageLive(_5);
26+
_5 = copy _1;
27+
nop;
28+
StorageLive(_7);
29+
_7 = &(*_2);
30+
_6 = PtrMetadata(copy _2);
31+
StorageDead(_7);
32+
_4 = Lt(copy _1, copy _6);
33+
switchInt(copy _4) -> [0: bb2, otherwise: bb1];
34+
}
35+
36+
bb1: {
37+
nop;
38+
StorageDead(_5);
39+
_3 = const ();
40+
nop;
41+
StorageDead(_3);
42+
StorageLive(_9);
43+
_9 = copy _1;
44+
_10 = copy _6;
45+
- _11 = copy _4;
46+
- assert(copy _4, "index out of bounds: the length is {} but the index is {}", copy _6, copy _1) -> [success: bb3, unwind unreachable];
47+
+ _11 = const true;
48+
+ assert(const true, "index out of bounds: the length is {} but the index is {}", copy _6, copy _1) -> [success: bb3, unwind unreachable];
49+
}
50+
51+
bb2: {
52+
nop;
53+
StorageDead(_5);
54+
StorageLive(_8);
55+
_8 = panic(const "assertion failed: i < v.len()") -> unwind unreachable;
56+
}
57+
58+
bb3: {
59+
_0 = copy (*_2)[_1];
60+
StorageDead(_9);
61+
return;
62+
}
63+
}
64+
65+
ALLOC0 (size: 29, align: 1) {
66+
0x00 │ 61 73 73 65 72 74 69 6f 6e 20 66 61 69 6c 65 64 │ assertion failed
67+
0x10 │ 3a 20 69 20 3c 20 76 2e 6c 65 6e 28 29 │ : i < v.len()
68+
}
69+
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
- // MIR for `on_assume` before SsaRangePropagation
2+
+ // MIR for `on_assume` after SsaRangePropagation
3+
4+
fn on_assume(_1: usize, _2: &[u8]) -> u8 {
5+
debug i => _1;
6+
debug v => _2;
7+
let mut _0: u8;
8+
let _3: ();
9+
let _4: ();
10+
let mut _5: bool;
11+
let mut _6: usize;
12+
let mut _7: usize;
13+
let mut _8: &[u8];
14+
let _9: usize;
15+
let mut _10: usize;
16+
let mut _11: bool;
17+
scope 1 (inlined core::slice::<impl [u8]>::len) {
18+
scope 2 (inlined std::ptr::metadata::<[u8]>) {
19+
}
20+
}
21+
22+
bb0: {
23+
StorageLive(_3);
24+
StorageLive(_4);
25+
nop;
26+
StorageLive(_6);
27+
_6 = copy _1;
28+
nop;
29+
StorageLive(_8);
30+
_8 = &(*_2);
31+
_7 = PtrMetadata(copy _2);
32+
StorageDead(_8);
33+
_5 = Lt(copy _1, copy _7);
34+
nop;
35+
StorageDead(_6);
36+
assume(copy _5);
37+
nop;
38+
StorageDead(_4);
39+
_3 = const ();
40+
StorageDead(_3);
41+
StorageLive(_9);
42+
_9 = copy _1;
43+
_10 = copy _7;
44+
- _11 = copy _5;
45+
- assert(copy _5, "index out of bounds: the length is {} but the index is {}", copy _7, copy _1) -> [success: bb1, unwind unreachable];
46+
+ _11 = const true;
47+
+ assert(const true, "index out of bounds: the length is {} but the index is {}", copy _7, copy _1) -> [success: bb1, unwind unreachable];
48+
}
49+
50+
bb1: {
51+
_0 = copy (*_2)[_1];
52+
StorageDead(_9);
53+
return;
54+
}
55+
}
56+

0 commit comments

Comments
 (0)