Skip to main content

rustc_mir_transform/
ssa_range_prop.rs

1//! A pass that propagates the known ranges of SSA locals.
2//! We can know the ranges of SSA locals in certain locations for the following code:
3//! ```
4//! fn foo(a: u32) {
5//!   let b = a < 9; // the integer representation of b is within the full range [0, 2).
6//!   if b {
7//!     let c = b; // c is true since b is within the range [1, 2).
8//!     let d = a < 8; // d is true since a is within the range [0, 9).
9//!   }
10//! }
11//! ```
12use rustc_abi::WrappingRange;
13use rustc_const_eval::interpret::Scalar;
14use rustc_data_structures::fx::FxHashMap;
15use rustc_data_structures::graph::dominators::Dominators;
16use rustc_index::bit_set::DenseBitSet;
17use rustc_middle::mir::visit::MutVisitor;
18use rustc_middle::mir::{BasicBlock, Body, Location, Operand, Place, TerminatorKind, *};
19use rustc_middle::ty::{TyCtxt, TypingEnv};
20use rustc_span::DUMMY_SP;
21
22use crate::ssa::SsaLocals;
23
24pub(super) struct SsaRangePropagation;
25
26impl<'tcx> crate::MirPass<'tcx> for SsaRangePropagation {
27    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
28        sess.mir_opt_level() > 1
29    }
30
31    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
32        let typing_env = body.typing_env(tcx);
33        let ssa = SsaLocals::new(tcx, body, typing_env);
34        // Clone dominators because we need them while mutating the body.
35        let dominators = body.basic_blocks.dominators().clone();
36        let mut range_set =
37            RangeSet::new(tcx, typing_env, body, &ssa, &body.local_decls, dominators);
38
39        let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec();
40        for bb in reverse_postorder {
41            let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb];
42            range_set.visit_basic_block_data(bb, data);
43        }
44    }
45
46    fn is_required(&self) -> bool {
47        false
48    }
49}
50
51struct RangeSet<'tcx, 'body, 'a> {
52    tcx: TyCtxt<'tcx>,
53    typing_env: TypingEnv<'tcx>,
54    ssa: &'a SsaLocals,
55    local_decls: &'body LocalDecls<'tcx>,
56    dominators: Dominators<BasicBlock>,
57    /// Known ranges at each locations.
58    ranges: FxHashMap<Place<'tcx>, Vec<(Location, WrappingRange)>>,
59    /// Determines if the basic block has a single unique predecessor.
60    unique_predecessors: DenseBitSet<BasicBlock>,
61}
62
63impl<'tcx, 'body, 'a> RangeSet<'tcx, 'body, 'a> {
64    fn new(
65        tcx: TyCtxt<'tcx>,
66        typing_env: TypingEnv<'tcx>,
67        body: &Body<'tcx>,
68        ssa: &'a SsaLocals,
69        local_decls: &'body LocalDecls<'tcx>,
70        dominators: Dominators<BasicBlock>,
71    ) -> Self {
72        let predecessors = body.basic_blocks.predecessors();
73        let mut unique_predecessors = DenseBitSet::new_empty(body.basic_blocks.len());
74        for bb in body.basic_blocks.indices() {
75            if predecessors[bb].len() == 1 {
76                unique_predecessors.insert(bb);
77            }
78        }
79        RangeSet {
80            tcx,
81            typing_env,
82            ssa,
83            local_decls,
84            dominators,
85            ranges: FxHashMap::default(),
86            unique_predecessors,
87        }
88    }
89
90    /// Create a new known range at the location.
91    fn insert_range(&mut self, place: Place<'tcx>, location: Location, range: WrappingRange) {
92        assert!(self.is_ssa(place));
93        self.ranges.entry(place).or_default().push((location, range));
94    }
95
96    /// Get the known range at the location.
97    fn get_range(&self, place: &Place<'tcx>, location: Location) -> Option<WrappingRange> {
98        let Some(ranges) = self.ranges.get(place) else {
99            return None;
100        };
101        // FIXME: This should use the intersection of all valid ranges.
102        let (_, range) =
103            ranges.iter().find(|(range_loc, _)| range_loc.dominates(location, &self.dominators))?;
104        Some(*range)
105    }
106
107    fn try_as_constant(
108        &mut self,
109        place: Place<'tcx>,
110        location: Location,
111    ) -> Option<ConstOperand<'tcx>> {
112        if let Some(range) = self.get_range(&place, location)
113            && range.start == range.end
114        {
115            let ty = place.ty(self.local_decls, self.tcx).ty;
116            let layout = self.tcx.layout_of(self.typing_env.as_query_input(ty)).ok()?;
117            let value = ConstValue::Scalar(Scalar::from_uint(range.start, layout.size));
118            let const_ = Const::Val(value, ty);
119            return Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_ });
120        }
121        None
122    }
123
124    fn is_ssa(&self, place: Place<'tcx>) -> bool {
125        self.ssa.is_ssa(place.local) && place.is_stable_offset()
126    }
127}
128
129impl<'tcx> MutVisitor<'tcx> for RangeSet<'tcx, '_, '_> {
130    fn tcx(&self) -> TyCtxt<'tcx> {
131        self.tcx
132    }
133
134    fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) {
135        // Attempts to simplify an operand to a constant value.
136        if let Some(place) = operand.place()
137            && let Some(const_) = self.try_as_constant(place, location)
138        {
139            *operand = Operand::Constant(Box::new(const_));
140        };
141    }
142
143    fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
144        self.super_statement(statement, location);
145        match &statement.kind {
146            StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(operand)) => {
147                if let Some(place) = operand.place()
148                    && self.is_ssa(place)
149                {
150                    let successor = location.successor_within_block();
151                    let range = WrappingRange { start: 1, end: 1 };
152                    self.insert_range(place, successor, range);
153                }
154            }
155            _ => {}
156        }
157    }
158
159    fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
160        self.super_terminator(terminator, location);
161        match &terminator.kind {
162            TerminatorKind::Assert { cond, expected, target, .. } => {
163                if let Some(place) = cond.place()
164                    && self.is_ssa(place)
165                {
166                    let successor = Location { block: *target, statement_index: 0 };
167                    if location.dominates(successor, &self.dominators) {
168                        assert_ne!(location.block, successor.block);
169                        let val = *expected as u128;
170                        let range = WrappingRange { start: val, end: val };
171                        self.insert_range(place, successor, range);
172                    }
173                }
174            }
175            TerminatorKind::SwitchInt { discr, targets } => {
176                if let Some(place) = discr.place()
177                    && self.is_ssa(place)
178                    // Reduce the potential compile-time overhead.
179                    && targets.all_targets().len() < 16
180                {
181                    let mut distinct_targets: FxHashMap<BasicBlock, u64> = FxHashMap::default();
182                    for (_, target) in targets.iter() {
183                        let targets = distinct_targets.entry(target).or_default();
184                        *targets += 1;
185                    }
186                    for (val, target) in targets.iter() {
187                        if distinct_targets[&target] != 1 {
188                            // FIXME: For multiple targets, the range can be the union of their values.
189                            continue;
190                        }
191                        let successor = Location { block: target, statement_index: 0 };
192                        if self.unique_predecessors.contains(successor.block) {
193                            assert_ne!(location.block, successor.block);
194                            let range = WrappingRange { start: val, end: val };
195                            self.insert_range(place, successor, range);
196                        }
197                    }
198
199                    // FIXME: The range for the otherwise target be extend to more types.
200                    // For instance, `val` is within the range [4, 1) at the otherwise target of `matches!(val, 1 | 2 | 3)`.
201                    let otherwise = Location { block: targets.otherwise(), statement_index: 0 };
202                    if place.ty(self.local_decls, self.tcx).ty.is_bool()
203                        && let [val] = targets.all_values()
204                        && self.unique_predecessors.contains(otherwise.block)
205                    {
206                        assert_ne!(location.block, otherwise.block);
207                        let range = if val.get() == 0 {
208                            WrappingRange { start: 1, end: 1 }
209                        } else {
210                            WrappingRange { start: 0, end: 0 }
211                        };
212                        self.insert_range(place, otherwise, range);
213                    }
214                }
215            }
216            _ => {}
217        }
218    }
219}