Skip to content

Commit 3d087e6

Browse files
committed
Auto merge of #150309 - dianqk:ssa-range, r=cjgillot
New MIR Pass: SsaRangePropagation As an alternative to #150192. Introduces a new pass that propagates the known ranges of SSA locals. We can know the ranges of SSA locals at some locations for the following code: ```rust fn foo(a: u32) { let b = a < 9; if b { let c = b; // c is true since b is whitin the range [1, 2) let d = a < 8; // d is true since b whitin the range [0, 9) } } ``` This PR only implements a trivial range: we know one value on switch, assert, and assume.
2 parents 9b37157 + e9a67c7 commit 3d087e6

9 files changed

Lines changed: 570 additions & 0 deletions

compiler/rustc_mir_transform/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ declare_passes! {
197197
mod single_use_consts : SingleUseConsts;
198198
mod sroa : ScalarReplacementOfAggregates;
199199
mod strip_debuginfo : StripDebugInfo;
200+
mod ssa_range_prop: SsaRangePropagation;
200201
mod unreachable_enum_branching : UnreachableEnumBranching;
201202
mod unreachable_prop : UnreachablePropagation;
202203
mod validate : Validator;
@@ -743,6 +744,9 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'
743744
&dead_store_elimination::DeadStoreElimination::Initial,
744745
&gvn::GVN,
745746
&simplify::SimplifyLocals::AfterGVN,
747+
// This pass does attempt to track assignments.
748+
// Keep it close to GVN which merges identical values into the same local.
749+
&ssa_range_prop::SsaRangePropagation,
746750
&match_branches::MatchBranchSimplification,
747751
&dataflow_const_prop::DataflowConstProp,
748752
&single_use_consts::SingleUseConsts,
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
//! A pass that propagates the known ranges of SSA locals.
2+
//! We can know the ranges of SSA locals in certain locations for the following code:
3+
//! ```
4+
//! fn foo(a: u32) {
5+
//! let b = a < 9; // the integer representation of b is within the full range [0, 2).
6+
//! if b {
7+
//! let c = b; // c is true since b is within the range [1, 2).
8+
//! let d = a < 8; // d is true since a is within the range [0, 9).
9+
//! }
10+
//! }
11+
//! ```
12+
use rustc_abi::WrappingRange;
13+
use rustc_const_eval::interpret::Scalar;
14+
use rustc_data_structures::fx::FxHashMap;
15+
use rustc_data_structures::graph::dominators::Dominators;
16+
use rustc_index::bit_set::DenseBitSet;
17+
use rustc_middle::mir::visit::MutVisitor;
18+
use rustc_middle::mir::{BasicBlock, Body, Location, Operand, Place, TerminatorKind, *};
19+
use rustc_middle::ty::{TyCtxt, TypingEnv};
20+
use rustc_span::DUMMY_SP;
21+
22+
use crate::ssa::SsaLocals;
23+
24+
pub(super) struct SsaRangePropagation;
25+
26+
impl<'tcx> crate::MirPass<'tcx> for SsaRangePropagation {
27+
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
28+
sess.mir_opt_level() > 1
29+
}
30+
31+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
32+
let typing_env = body.typing_env(tcx);
33+
let ssa = SsaLocals::new(tcx, body, typing_env);
34+
// Clone dominators because we need them while mutating the body.
35+
let dominators = body.basic_blocks.dominators().clone();
36+
let mut range_set =
37+
RangeSet::new(tcx, typing_env, body, &ssa, &body.local_decls, dominators);
38+
39+
let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec();
40+
for bb in reverse_postorder {
41+
let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb];
42+
range_set.visit_basic_block_data(bb, data);
43+
}
44+
}
45+
46+
fn is_required(&self) -> bool {
47+
false
48+
}
49+
}
50+
51+
struct RangeSet<'tcx, 'body, 'a> {
52+
tcx: TyCtxt<'tcx>,
53+
typing_env: TypingEnv<'tcx>,
54+
ssa: &'a SsaLocals,
55+
local_decls: &'body LocalDecls<'tcx>,
56+
dominators: Dominators<BasicBlock>,
57+
/// Known ranges at each locations.
58+
ranges: FxHashMap<Place<'tcx>, Vec<(Location, WrappingRange)>>,
59+
/// Determines if the basic block has a single unique predecessor.
60+
unique_predecessors: DenseBitSet<BasicBlock>,
61+
}
62+
63+
impl<'tcx, 'body, 'a> RangeSet<'tcx, 'body, 'a> {
64+
fn new(
65+
tcx: TyCtxt<'tcx>,
66+
typing_env: TypingEnv<'tcx>,
67+
body: &Body<'tcx>,
68+
ssa: &'a SsaLocals,
69+
local_decls: &'body LocalDecls<'tcx>,
70+
dominators: Dominators<BasicBlock>,
71+
) -> Self {
72+
let predecessors = body.basic_blocks.predecessors();
73+
let mut unique_predecessors = DenseBitSet::new_empty(body.basic_blocks.len());
74+
for bb in body.basic_blocks.indices() {
75+
if predecessors[bb].len() == 1 {
76+
unique_predecessors.insert(bb);
77+
}
78+
}
79+
RangeSet {
80+
tcx,
81+
typing_env,
82+
ssa,
83+
local_decls,
84+
dominators,
85+
ranges: FxHashMap::default(),
86+
unique_predecessors,
87+
}
88+
}
89+
90+
/// Create a new known range at the location.
91+
fn insert_range(&mut self, place: Place<'tcx>, location: Location, range: WrappingRange) {
92+
assert!(self.is_ssa(place));
93+
self.ranges.entry(place).or_default().push((location, range));
94+
}
95+
96+
/// Get the known range at the location.
97+
fn get_range(&self, place: &Place<'tcx>, location: Location) -> Option<WrappingRange> {
98+
let Some(ranges) = self.ranges.get(place) else {
99+
return None;
100+
};
101+
// FIXME: This should use the intersection of all valid ranges.
102+
let (_, range) =
103+
ranges.iter().find(|(range_loc, _)| range_loc.dominates(location, &self.dominators))?;
104+
Some(*range)
105+
}
106+
107+
fn try_as_constant(
108+
&mut self,
109+
place: Place<'tcx>,
110+
location: Location,
111+
) -> Option<ConstOperand<'tcx>> {
112+
if let Some(range) = self.get_range(&place, location)
113+
&& range.start == range.end
114+
{
115+
let ty = place.ty(self.local_decls, self.tcx).ty;
116+
let layout = self.tcx.layout_of(self.typing_env.as_query_input(ty)).ok()?;
117+
let value = ConstValue::Scalar(Scalar::from_uint(range.start, layout.size));
118+
let const_ = Const::Val(value, ty);
119+
return Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_ });
120+
}
121+
None
122+
}
123+
124+
fn is_ssa(&self, place: Place<'tcx>) -> bool {
125+
self.ssa.is_ssa(place.local) && place.is_stable_offset()
126+
}
127+
}
128+
129+
impl<'tcx> MutVisitor<'tcx> for RangeSet<'tcx, '_, '_> {
130+
fn tcx(&self) -> TyCtxt<'tcx> {
131+
self.tcx
132+
}
133+
134+
fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) {
135+
// Attempts to simplify an operand to a constant value.
136+
if let Some(place) = operand.place()
137+
&& let Some(const_) = self.try_as_constant(place, location)
138+
{
139+
*operand = Operand::Constant(Box::new(const_));
140+
};
141+
}
142+
143+
fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
144+
self.super_statement(statement, location);
145+
match &statement.kind {
146+
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(operand)) => {
147+
if let Some(place) = operand.place()
148+
&& self.is_ssa(place)
149+
{
150+
let successor = location.successor_within_block();
151+
let range = WrappingRange { start: 1, end: 1 };
152+
self.insert_range(place, successor, range);
153+
}
154+
}
155+
_ => {}
156+
}
157+
}
158+
159+
fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
160+
self.super_terminator(terminator, location);
161+
match &terminator.kind {
162+
TerminatorKind::Assert { cond, expected, target, .. } => {
163+
if let Some(place) = cond.place()
164+
&& self.is_ssa(place)
165+
{
166+
let successor = Location { block: *target, statement_index: 0 };
167+
if location.dominates(successor, &self.dominators) {
168+
assert_ne!(location.block, successor.block);
169+
let val = *expected as u128;
170+
let range = WrappingRange { start: val, end: val };
171+
self.insert_range(place, successor, range);
172+
}
173+
}
174+
}
175+
TerminatorKind::SwitchInt { discr, targets } => {
176+
if let Some(place) = discr.place()
177+
&& self.is_ssa(place)
178+
// Reduce the potential compile-time overhead.
179+
&& targets.all_targets().len() < 16
180+
{
181+
let mut distinct_targets: FxHashMap<BasicBlock, u64> = FxHashMap::default();
182+
for (_, target) in targets.iter() {
183+
let targets = distinct_targets.entry(target).or_default();
184+
*targets += 1;
185+
}
186+
for (val, target) in targets.iter() {
187+
if distinct_targets[&target] != 1 {
188+
// FIXME: For multiple targets, the range can be the union of their values.
189+
continue;
190+
}
191+
let successor = Location { block: target, statement_index: 0 };
192+
if self.unique_predecessors.contains(successor.block) {
193+
assert_ne!(location.block, successor.block);
194+
let range = WrappingRange { start: val, end: val };
195+
self.insert_range(place, successor, range);
196+
}
197+
}
198+
199+
// FIXME: The range for the otherwise target be extend to more types.
200+
// For instance, `val` is within the range [4, 1) at the otherwise target of `matches!(val, 1 | 2 | 3)`.
201+
let otherwise = Location { block: targets.otherwise(), statement_index: 0 };
202+
if place.ty(self.local_decls, self.tcx).ty.is_bool()
203+
&& let [val] = targets.all_values()
204+
&& self.unique_predecessors.contains(otherwise.block)
205+
{
206+
assert_ne!(location.block, otherwise.block);
207+
let range = if val.get() == 0 {
208+
WrappingRange { start: 1, end: 1 }
209+
} else {
210+
WrappingRange { start: 0, end: 0 }
211+
};
212+
self.insert_range(place, otherwise, range);
213+
}
214+
}
215+
}
216+
_ => {}
217+
}
218+
}
219+
}
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
- // MIR for `on_assert` before SsaRangePropagation
2+
+ // MIR for `on_assert` after SsaRangePropagation
3+
4+
fn on_assert(_1: usize, _2: &[u8]) -> u8 {
5+
debug i => _1;
6+
debug v => _2;
7+
let mut _0: u8;
8+
let _3: ();
9+
let mut _4: bool;
10+
let mut _5: usize;
11+
let mut _6: usize;
12+
let mut _7: &[u8];
13+
let mut _8: !;
14+
let _9: usize;
15+
let mut _10: usize;
16+
let mut _11: bool;
17+
scope 1 (inlined core::slice::<impl [u8]>::len) {
18+
scope 2 (inlined std::ptr::metadata::<[u8]>) {
19+
}
20+
}
21+
22+
bb0: {
23+
StorageLive(_3);
24+
nop;
25+
StorageLive(_5);
26+
_5 = copy _1;
27+
nop;
28+
StorageLive(_7);
29+
_7 = &(*_2);
30+
_6 = PtrMetadata(copy _2);
31+
StorageDead(_7);
32+
_4 = Lt(copy _1, copy _6);
33+
switchInt(copy _4) -> [0: bb2, otherwise: bb1];
34+
}
35+
36+
bb1: {
37+
nop;
38+
StorageDead(_5);
39+
_3 = const ();
40+
nop;
41+
StorageDead(_3);
42+
StorageLive(_9);
43+
_9 = copy _1;
44+
_10 = copy _6;
45+
- _11 = copy _4;
46+
- assert(copy _4, "index out of bounds: the length is {} but the index is {}", copy _6, copy _1) -> [success: bb3, unwind unreachable];
47+
+ _11 = const true;
48+
+ assert(const true, "index out of bounds: the length is {} but the index is {}", copy _6, copy _1) -> [success: bb3, unwind unreachable];
49+
}
50+
51+
bb2: {
52+
nop;
53+
StorageDead(_5);
54+
StorageLive(_8);
55+
_8 = panic(const "assertion failed: i < v.len()") -> unwind unreachable;
56+
}
57+
58+
bb3: {
59+
_0 = copy (*_2)[_1];
60+
StorageDead(_9);
61+
return;
62+
}
63+
}
64+
65+
ALLOC0 (size: 29, align: 1) {
66+
0x00 │ 61 73 73 65 72 74 69 6f 6e 20 66 61 69 6c 65 64 │ assertion failed
67+
0x10 │ 3a 20 69 20 3c 20 76 2e 6c 65 6e 28 29 │ : i < v.len()
68+
}
69+
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
- // MIR for `on_assume` before SsaRangePropagation
2+
+ // MIR for `on_assume` after SsaRangePropagation
3+
4+
fn on_assume(_1: usize, _2: &[u8]) -> u8 {
5+
debug i => _1;
6+
debug v => _2;
7+
let mut _0: u8;
8+
let _3: ();
9+
let _4: ();
10+
let mut _5: bool;
11+
let mut _6: usize;
12+
let mut _7: usize;
13+
let mut _8: &[u8];
14+
let _9: usize;
15+
let mut _10: usize;
16+
let mut _11: bool;
17+
scope 1 (inlined core::slice::<impl [u8]>::len) {
18+
scope 2 (inlined std::ptr::metadata::<[u8]>) {
19+
}
20+
}
21+
22+
bb0: {
23+
StorageLive(_3);
24+
StorageLive(_4);
25+
nop;
26+
StorageLive(_6);
27+
_6 = copy _1;
28+
nop;
29+
StorageLive(_8);
30+
_8 = &(*_2);
31+
_7 = PtrMetadata(copy _2);
32+
StorageDead(_8);
33+
_5 = Lt(copy _1, copy _7);
34+
nop;
35+
StorageDead(_6);
36+
assume(copy _5);
37+
nop;
38+
StorageDead(_4);
39+
_3 = const ();
40+
StorageDead(_3);
41+
StorageLive(_9);
42+
_9 = copy _1;
43+
_10 = copy _7;
44+
- _11 = copy _5;
45+
- assert(copy _5, "index out of bounds: the length is {} but the index is {}", copy _7, copy _1) -> [success: bb1, unwind unreachable];
46+
+ _11 = const true;
47+
+ assert(const true, "index out of bounds: the length is {} but the index is {}", copy _7, copy _1) -> [success: bb1, unwind unreachable];
48+
}
49+
50+
bb1: {
51+
_0 = copy (*_2)[_1];
52+
StorageDead(_9);
53+
return;
54+
}
55+
}
56+

0 commit comments

Comments
 (0)