1use rustc_index::{Idx, IndexSlice, IndexVec};
31use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
32use rustc_middle::mir::*;
33use rustc_middle::ty::TyCtxt;
34use rustc_span::DUMMY_SP;
35use smallvec::SmallVec;
36use tracing::{debug, trace};
37
38pub(super) enum SimplifyCfg {
39 Initial,
40 PromoteConsts,
41 RemoveFalseEdges,
42 PostAnalysis,
44 PreOptimizations,
47 Final,
48 MakeShim,
49 AfterUnreachableEnumBranching,
50}
51
52impl SimplifyCfg {
53 fn name(&self) -> &'static str {
54 match self {
55 SimplifyCfg::Initial => "SimplifyCfg-initial",
56 SimplifyCfg::PromoteConsts => "SimplifyCfg-promote-consts",
57 SimplifyCfg::RemoveFalseEdges => "SimplifyCfg-remove-false-edges",
58 SimplifyCfg::PostAnalysis => "SimplifyCfg-post-analysis",
59 SimplifyCfg::PreOptimizations => "SimplifyCfg-pre-optimizations",
60 SimplifyCfg::Final => "SimplifyCfg-final",
61 SimplifyCfg::MakeShim => "SimplifyCfg-make_shim",
62 SimplifyCfg::AfterUnreachableEnumBranching => {
63 "SimplifyCfg-after-unreachable-enum-branching"
64 }
65 }
66 }
67}
68
69pub(super) fn simplify_cfg(body: &mut Body<'_>) {
70 CfgSimplifier::new(body).simplify();
71 remove_dead_blocks(body);
72
73 body.basic_blocks_mut().raw.shrink_to_fit();
75}
76
77impl<'tcx> crate::MirPass<'tcx> for SimplifyCfg {
78 fn name(&self) -> &'static str {
79 self.name()
80 }
81
82 fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
83 debug!("SimplifyCfg({:?}) - simplifying {:?}", self.name(), body.source);
84 simplify_cfg(body);
85 }
86
87 fn is_required(&self) -> bool {
88 false
89 }
90}
91
92struct CfgSimplifier<'a, 'tcx> {
93 basic_blocks: &'a mut IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
94 pred_count: IndexVec<BasicBlock, u32>,
95}
96
97impl<'a, 'tcx> CfgSimplifier<'a, 'tcx> {
98 fn new(body: &'a mut Body<'tcx>) -> Self {
99 let mut pred_count = IndexVec::from_elem(0u32, &body.basic_blocks);
100
101 pred_count[START_BLOCK] = 1;
104
105 for (_, data) in traversal::preorder(body) {
106 if let Some(ref term) = data.terminator {
107 for tgt in term.successors() {
108 pred_count[tgt] += 1;
109 }
110 }
111 }
112
113 let basic_blocks = body.basic_blocks_mut();
114
115 CfgSimplifier { basic_blocks, pred_count }
116 }
117
118 fn simplify(mut self) {
119 self.strip_nops();
120
121 let mut merged_blocks = Vec::new();
126 loop {
127 let mut changed = false;
128
129 for bb in self.basic_blocks.indices() {
130 if self.pred_count[bb] == 0 {
131 continue;
132 }
133
134 debug!("simplifying {:?}", bb);
135
136 let mut terminator =
137 self.basic_blocks[bb].terminator.take().expect("invalid terminator state");
138
139 for successor in terminator.successors_mut() {
140 self.collapse_goto_chain(successor, &mut changed);
141 }
142
143 let mut inner_changed = true;
144 merged_blocks.clear();
145 while inner_changed {
146 inner_changed = false;
147 inner_changed |= self.simplify_branch(&mut terminator);
148 inner_changed |= self.merge_successor(&mut merged_blocks, &mut terminator);
149 changed |= inner_changed;
150 }
151
152 let statements_to_merge =
153 merged_blocks.iter().map(|&i| self.basic_blocks[i].statements.len()).sum();
154
155 if statements_to_merge > 0 {
156 let mut statements = std::mem::take(&mut self.basic_blocks[bb].statements);
157 statements.reserve(statements_to_merge);
158 for &from in &merged_blocks {
159 statements.append(&mut self.basic_blocks[from].statements);
160 }
161 self.basic_blocks[bb].statements = statements;
162 }
163
164 self.basic_blocks[bb].terminator = Some(terminator);
165 }
166
167 if !changed {
168 break;
169 }
170 }
171 }
172
173 fn take_terminator_if_simple_goto(&mut self, bb: BasicBlock) -> Option<Terminator<'tcx>> {
178 match self.basic_blocks[bb] {
179 BasicBlockData {
180 ref statements,
181 terminator:
182 ref mut terminator @ Some(Terminator { kind: TerminatorKind::Goto { .. }, .. }),
183 ..
184 } if statements.is_empty() => terminator.take(),
185 _ => None,
188 }
189 }
190
191 fn collapse_goto_chain(&mut self, start: &mut BasicBlock, changed: &mut bool) {
193 let mut terminators: SmallVec<[_; 1]> = Default::default();
196 let mut current = *start;
197 while let Some(terminator) = self.take_terminator_if_simple_goto(current) {
198 let Terminator { kind: TerminatorKind::Goto { target }, .. } = terminator else {
199 unreachable!();
200 };
201 terminators.push((current, terminator));
202 current = target;
203 }
204 let last = current;
205 *start = last;
206 while let Some((current, mut terminator)) = terminators.pop() {
207 let Terminator { kind: TerminatorKind::Goto { ref mut target }, .. } = terminator
208 else {
209 unreachable!();
210 };
211 *changed |= *target != last;
212 *target = last;
213 debug!("collapsing goto chain from {:?} to {:?}", current, target);
214
215 if self.pred_count[current] == 1 {
216 self.pred_count[current] = 0;
219 } else {
220 self.pred_count[*target] += 1;
221 self.pred_count[current] -= 1;
222 }
223 self.basic_blocks[current].terminator = Some(terminator);
224 }
225 }
226
227 fn merge_successor(
229 &mut self,
230 merged_blocks: &mut Vec<BasicBlock>,
231 terminator: &mut Terminator<'tcx>,
232 ) -> bool {
233 let target = match terminator.kind {
234 TerminatorKind::Goto { target } if self.pred_count[target] == 1 => target,
235 _ => return false,
236 };
237
238 debug!("merging block {:?} into {:?}", target, terminator);
239 *terminator = match self.basic_blocks[target].terminator.take() {
240 Some(terminator) => terminator,
241 None => {
242 return false;
245 }
246 };
247
248 merged_blocks.push(target);
249 self.pred_count[target] = 0;
250
251 true
252 }
253
254 fn simplify_branch(&mut self, terminator: &mut Terminator<'tcx>) -> bool {
256 match terminator.kind {
257 TerminatorKind::SwitchInt { .. } => {}
258 _ => return false,
259 };
260
261 let first_succ = {
262 if let Some(first_succ) = terminator.successors().next() {
263 if terminator.successors().all(|s| s == first_succ) {
264 let count = terminator.successors().count();
265 self.pred_count[first_succ] -= (count - 1) as u32;
266 first_succ
267 } else {
268 return false;
269 }
270 } else {
271 return false;
272 }
273 };
274
275 debug!("simplifying branch {:?}", terminator);
276 terminator.kind = TerminatorKind::Goto { target: first_succ };
277 true
278 }
279
280 fn strip_nops(&mut self) {
281 for blk in self.basic_blocks.iter_mut() {
282 blk.statements.retain(|stmt| !matches!(stmt.kind, StatementKind::Nop))
283 }
284 }
285}
286
287pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>) {
288 if let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind {
289 let otherwise = targets.otherwise();
290 if targets.iter().any(|t| t.1 == otherwise) {
291 *targets = SwitchTargets::new(
292 targets.iter().filter(|t| t.1 != otherwise),
293 targets.otherwise(),
294 );
295 }
296 }
297}
298
299pub(super) fn remove_dead_blocks(body: &mut Body<'_>) {
300 let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| {
301 bbdata.terminator.is_some() && bbdata.is_empty_unreachable() && !bbdata.is_cleanup
307 };
308
309 let reachable = traversal::reachable_as_bitset(body);
310 let empty_unreachable_blocks = body
311 .basic_blocks
312 .iter_enumerated()
313 .filter(|(bb, bbdata)| should_deduplicate_unreachable(bbdata) && reachable.contains(*bb))
314 .count();
315
316 let num_blocks = body.basic_blocks.len();
317 if num_blocks == reachable.count() && empty_unreachable_blocks <= 1 {
318 return;
319 }
320
321 let basic_blocks = body.basic_blocks.as_mut();
322
323 let mut replacements: Vec<_> = (0..num_blocks).map(BasicBlock::new).collect();
324 let mut orig_index = 0;
325 let mut used_index = 0;
326 let mut kept_unreachable = None;
327 let mut deduplicated_unreachable = false;
328 basic_blocks.raw.retain(|bbdata| {
329 let orig_bb = BasicBlock::new(orig_index);
330 if !reachable.contains(orig_bb) {
331 orig_index += 1;
332 return false;
333 }
334
335 let used_bb = BasicBlock::new(used_index);
336 if should_deduplicate_unreachable(bbdata) {
337 let kept_unreachable = *kept_unreachable.get_or_insert(used_bb);
338 if kept_unreachable != used_bb {
339 replacements[orig_index] = kept_unreachable;
340 deduplicated_unreachable = true;
341 orig_index += 1;
342 return false;
343 }
344 }
345
346 replacements[orig_index] = used_bb;
347 used_index += 1;
348 orig_index += 1;
349 true
350 });
351
352 if deduplicated_unreachable {
356 basic_blocks[kept_unreachable.unwrap()].terminator_mut().source_info =
357 SourceInfo { span: DUMMY_SP, scope: OUTERMOST_SOURCE_SCOPE };
358 }
359
360 for block in basic_blocks {
361 for target in block.terminator_mut().successors_mut() {
362 *target = replacements[target.index()];
363 }
364 }
365}
366
367pub(super) enum SimplifyLocals {
368 BeforeConstProp,
369 AfterGVN,
370 Final,
371}
372
373impl<'tcx> crate::MirPass<'tcx> for SimplifyLocals {
374 fn name(&self) -> &'static str {
375 match &self {
376 SimplifyLocals::BeforeConstProp => "SimplifyLocals-before-const-prop",
377 SimplifyLocals::AfterGVN => "SimplifyLocals-after-value-numbering",
378 SimplifyLocals::Final => "SimplifyLocals-final",
379 }
380 }
381
382 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
383 sess.mir_opt_level() > 0
384 }
385
386 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
387 trace!("running SimplifyLocals on {:?}", body.source);
388
389 let mut used_locals = UsedLocals::new(body);
391
392 remove_unused_definitions_helper(&mut used_locals, body);
398
399 let map = make_local_map(&mut body.local_decls, &used_locals);
402
403 if map.iter().any(Option::is_none) {
405 let mut updater = LocalUpdater { map, tcx };
407 updater.visit_body_preserves_cfg(body);
408
409 body.local_decls.shrink_to_fit();
410 }
411 }
412
413 fn is_required(&self) -> bool {
414 false
415 }
416}
417
418pub(super) fn remove_unused_definitions<'tcx>(body: &mut Body<'tcx>) {
419 let mut used_locals = UsedLocals::new(body);
421
422 remove_unused_definitions_helper(&mut used_locals, body);
428}
429
430fn make_local_map<V>(
432 local_decls: &mut IndexVec<Local, V>,
433 used_locals: &UsedLocals,
434) -> IndexVec<Local, Option<Local>> {
435 let mut map: IndexVec<Local, Option<Local>> = IndexVec::from_elem(None, local_decls);
436 let mut used = Local::ZERO;
437
438 for alive_index in local_decls.indices() {
439 if !used_locals.is_used(alive_index) {
441 continue;
442 }
443
444 map[alive_index] = Some(used);
445 if alive_index != used {
446 local_decls.swap(alive_index, used);
447 }
448 used.increment_by(1);
449 }
450 local_decls.truncate(used.index());
451 map
452}
453
454struct UsedLocals {
456 increment: bool,
457 arg_count: u32,
458 use_count: IndexVec<Local, u32>,
459}
460
461impl UsedLocals {
462 fn new(body: &Body<'_>) -> Self {
464 let mut this = Self {
465 increment: true,
466 arg_count: body.arg_count.try_into().unwrap(),
467 use_count: IndexVec::from_elem(0, &body.local_decls),
468 };
469 this.visit_body(body);
470 this
471 }
472
473 fn is_used(&self, local: Local) -> bool {
477 trace!("is_used({:?}): use_count: {:?}", local, self.use_count[local]);
478 local.as_u32() <= self.arg_count || self.use_count[local] != 0
479 }
480
481 fn statement_removed(&mut self, statement: &Statement<'_>) {
483 self.increment = false;
484
485 let location = Location::START;
487 self.visit_statement(statement, location);
488 }
489
490 fn visit_lhs(&mut self, place: &Place<'_>, location: Location) {
492 if place.is_indirect() {
493 self.visit_place(place, PlaceContext::MutatingUse(MutatingUseContext::Store), location);
495 } else {
496 self.super_projection(
500 place.as_ref(),
501 PlaceContext::MutatingUse(MutatingUseContext::Projection),
502 location,
503 );
504 }
505 }
506}
507
508impl<'tcx> Visitor<'tcx> for UsedLocals {
509 fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
510 match statement.kind {
511 StatementKind::Intrinsic(..)
512 | StatementKind::Retag(..)
513 | StatementKind::Coverage(..)
514 | StatementKind::FakeRead(..)
515 | StatementKind::PlaceMention(..)
516 | StatementKind::AscribeUserType(..) => {
517 self.super_statement(statement, location);
518 }
519
520 StatementKind::ConstEvalCounter | StatementKind::Nop => {}
521
522 StatementKind::StorageLive(_local) | StatementKind::StorageDead(_local) => {}
523
524 StatementKind::Assign(box (ref place, ref rvalue)) => {
525 if rvalue.is_safe_to_remove() {
526 self.visit_lhs(place, location);
527 self.visit_rvalue(rvalue, location);
528 } else {
529 self.super_statement(statement, location);
530 }
531 }
532
533 StatementKind::SetDiscriminant { ref place, variant_index: _ }
534 | StatementKind::Deinit(ref place)
535 | StatementKind::BackwardIncompatibleDropHint { ref place, reason: _ } => {
536 self.visit_lhs(place, location);
537 }
538 }
539 }
540
541 fn visit_local(&mut self, local: Local, _ctx: PlaceContext, _location: Location) {
542 if self.increment {
543 self.use_count[local] += 1;
544 } else {
545 assert_ne!(self.use_count[local], 0);
546 self.use_count[local] -= 1;
547 }
548 }
549}
550
551fn remove_unused_definitions_helper(used_locals: &mut UsedLocals, body: &mut Body<'_>) {
553 let mut modified = true;
559 while modified {
560 modified = false;
561
562 for data in body.basic_blocks.as_mut_preserves_cfg() {
563 data.statements.retain(|statement| {
565 let keep = match &statement.kind {
566 StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
567 used_locals.is_used(*local)
568 }
569 StatementKind::Assign(box (place, _)) => used_locals.is_used(place.local),
570
571 StatementKind::SetDiscriminant { ref place, .. }
572 | StatementKind::BackwardIncompatibleDropHint { ref place, reason: _ }
573 | StatementKind::Deinit(ref place) => used_locals.is_used(place.local),
574 StatementKind::Nop => false,
575 _ => true,
576 };
577
578 if !keep {
579 trace!("removing statement {:?}", statement);
580 modified = true;
581 used_locals.statement_removed(statement);
582 }
583
584 keep
585 });
586 }
587 }
588}
589
590struct LocalUpdater<'tcx> {
591 map: IndexVec<Local, Option<Local>>,
592 tcx: TyCtxt<'tcx>,
593}
594
595impl<'tcx> MutVisitor<'tcx> for LocalUpdater<'tcx> {
596 fn tcx(&self) -> TyCtxt<'tcx> {
597 self.tcx
598 }
599
600 fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
601 if let StatementKind::BackwardIncompatibleDropHint { place, reason: _ } =
602 &mut statement.kind
603 {
604 self.visit_local(
605 &mut place.local,
606 PlaceContext::MutatingUse(MutatingUseContext::Store),
607 location,
608 );
609 } else {
610 self.super_statement(statement, location);
611 }
612 }
613
614 fn visit_local(&mut self, l: &mut Local, _: PlaceContext, _: Location) {
615 *l = self.map[*l].unwrap();
616 }
617}