1use rustc_abi::WrappingRange;
13use rustc_const_eval::interpret::Scalar;
14use rustc_data_structures::fx::FxHashMap;
15use rustc_data_structures::graph::dominators::Dominators;
16use rustc_index::bit_set::DenseBitSet;
17use rustc_middle::mir::visit::MutVisitor;
18use rustc_middle::mir::{BasicBlock, Body, Location, Operand, Place, TerminatorKind, *};
19use rustc_middle::ty::{TyCtxt, TypingEnv};
20use rustc_span::DUMMY_SP;
21
22use crate::ssa::SsaLocals;
23
24pub(super) struct SsaRangePropagation;
25
26impl<'tcx> crate::MirPass<'tcx> for SsaRangePropagation {
27 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
28 sess.mir_opt_level() > 1
29 }
30
31 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
32 let typing_env = body.typing_env(tcx);
33 let ssa = SsaLocals::new(tcx, body, typing_env);
34 let dominators = body.basic_blocks.dominators().clone();
36 let mut range_set =
37 RangeSet::new(tcx, typing_env, body, &ssa, &body.local_decls, dominators);
38
39 let reverse_postorder = body.basic_blocks.reverse_postorder().to_vec();
40 for bb in reverse_postorder {
41 let data = &mut body.basic_blocks.as_mut_preserves_cfg()[bb];
42 range_set.visit_basic_block_data(bb, data);
43 }
44 }
45
46 fn is_required(&self) -> bool {
47 false
48 }
49}
50
51struct RangeSet<'tcx, 'body, 'a> {
52 tcx: TyCtxt<'tcx>,
53 typing_env: TypingEnv<'tcx>,
54 ssa: &'a SsaLocals,
55 local_decls: &'body LocalDecls<'tcx>,
56 dominators: Dominators<BasicBlock>,
57 ranges: FxHashMap<Place<'tcx>, Vec<(Location, WrappingRange)>>,
59 unique_predecessors: DenseBitSet<BasicBlock>,
61}
62
63impl<'tcx, 'body, 'a> RangeSet<'tcx, 'body, 'a> {
64 fn new(
65 tcx: TyCtxt<'tcx>,
66 typing_env: TypingEnv<'tcx>,
67 body: &Body<'tcx>,
68 ssa: &'a SsaLocals,
69 local_decls: &'body LocalDecls<'tcx>,
70 dominators: Dominators<BasicBlock>,
71 ) -> Self {
72 let predecessors = body.basic_blocks.predecessors();
73 let mut unique_predecessors = DenseBitSet::new_empty(body.basic_blocks.len());
74 for bb in body.basic_blocks.indices() {
75 if predecessors[bb].len() == 1 {
76 unique_predecessors.insert(bb);
77 }
78 }
79 RangeSet {
80 tcx,
81 typing_env,
82 ssa,
83 local_decls,
84 dominators,
85 ranges: FxHashMap::default(),
86 unique_predecessors,
87 }
88 }
89
90 fn insert_range(&mut self, place: Place<'tcx>, location: Location, range: WrappingRange) {
92 assert!(self.is_ssa(place));
93 self.ranges.entry(place).or_default().push((location, range));
94 }
95
96 fn get_range(&self, place: &Place<'tcx>, location: Location) -> Option<WrappingRange> {
98 let Some(ranges) = self.ranges.get(place) else {
99 return None;
100 };
101 let (_, range) =
103 ranges.iter().find(|(range_loc, _)| range_loc.dominates(location, &self.dominators))?;
104 Some(*range)
105 }
106
107 fn try_as_constant(
108 &mut self,
109 place: Place<'tcx>,
110 location: Location,
111 ) -> Option<ConstOperand<'tcx>> {
112 if let Some(range) = self.get_range(&place, location)
113 && range.start == range.end
114 {
115 let ty = place.ty(self.local_decls, self.tcx).ty;
116 let layout = self.tcx.layout_of(self.typing_env.as_query_input(ty)).ok()?;
117 let value = ConstValue::Scalar(Scalar::from_uint(range.start, layout.size));
118 let const_ = Const::Val(value, ty);
119 return Some(ConstOperand { span: DUMMY_SP, user_ty: None, const_ });
120 }
121 None
122 }
123
124 fn is_ssa(&self, place: Place<'tcx>) -> bool {
125 self.ssa.is_ssa(place.local) && place.is_stable_offset()
126 }
127}
128
129impl<'tcx> MutVisitor<'tcx> for RangeSet<'tcx, '_, '_> {
130 fn tcx(&self) -> TyCtxt<'tcx> {
131 self.tcx
132 }
133
134 fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) {
135 if let Some(place) = operand.place()
137 && let Some(const_) = self.try_as_constant(place, location)
138 {
139 *operand = Operand::Constant(Box::new(const_));
140 };
141 }
142
143 fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
144 self.super_statement(statement, location);
145 match &statement.kind {
146 StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(operand)) => {
147 if let Some(place) = operand.place()
148 && self.is_ssa(place)
149 {
150 let successor = location.successor_within_block();
151 let range = WrappingRange { start: 1, end: 1 };
152 self.insert_range(place, successor, range);
153 }
154 }
155 _ => {}
156 }
157 }
158
159 fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
160 self.super_terminator(terminator, location);
161 match &terminator.kind {
162 TerminatorKind::Assert { cond, expected, target, .. } => {
163 if let Some(place) = cond.place()
164 && self.is_ssa(place)
165 {
166 let successor = Location { block: *target, statement_index: 0 };
167 if location.dominates(successor, &self.dominators) {
168 assert_ne!(location.block, successor.block);
169 let val = *expected as u128;
170 let range = WrappingRange { start: val, end: val };
171 self.insert_range(place, successor, range);
172 }
173 }
174 }
175 TerminatorKind::SwitchInt { discr, targets } => {
176 if let Some(place) = discr.place()
177 && self.is_ssa(place)
178 && targets.all_targets().len() < 16
180 {
181 let mut distinct_targets: FxHashMap<BasicBlock, u64> = FxHashMap::default();
182 for (_, target) in targets.iter() {
183 let targets = distinct_targets.entry(target).or_default();
184 *targets += 1;
185 }
186 for (val, target) in targets.iter() {
187 if distinct_targets[&target] != 1 {
188 continue;
190 }
191 let successor = Location { block: target, statement_index: 0 };
192 if self.unique_predecessors.contains(successor.block) {
193 assert_ne!(location.block, successor.block);
194 let range = WrappingRange { start: val, end: val };
195 self.insert_range(place, successor, range);
196 }
197 }
198
199 let otherwise = Location { block: targets.otherwise(), statement_index: 0 };
202 if place.ty(self.local_decls, self.tcx).ty.is_bool()
203 && let [val] = targets.all_values()
204 && self.unique_predecessors.contains(otherwise.block)
205 {
206 assert_ne!(location.block, otherwise.block);
207 let range = if val.get() == 0 {
208 WrappingRange { start: 1, end: 1 }
209 } else {
210 WrappingRange { start: 0, end: 0 }
211 };
212 self.insert_range(place, otherwise, range);
213 }
214 }
215 }
216 _ => {}
217 }
218 }
219}