rustc_mir_transform/
simplify.rs

1//! A number of passes which remove various redundancies in the CFG.
2//!
3//! The `SimplifyCfg` pass gets rid of unnecessary blocks in the CFG, whereas the `SimplifyLocals`
4//! gets rid of all the unnecessary local variable declarations.
5//!
6//! The `SimplifyLocals` pass is kinda expensive and therefore not very suitable to be run often.
7//! Most of the passes should not care or be impacted in meaningful ways due to extra locals
8//! either, so running the pass once, right before codegen, should suffice.
9//!
10//! On the other side of the spectrum, the `SimplifyCfg` pass is considerably cheap to run, thus
11//! one should run it after every pass which may modify CFG in significant ways. This pass must
12//! also be run before any analysis passes because it removes dead blocks, and some of these can be
13//! ill-typed.
14//!
15//! The cause of this typing issue is typeck allowing most blocks whose end is not reachable have
16//! an arbitrary return type, rather than having the usual () return type (as a note, typeck's
17//! notion of reachability is in fact slightly weaker than MIR CFG reachability - see #31617). A
18//! standard example of the situation is:
19//!
20//! ```rust
21//!   fn example() {
22//!       let _a: char = { return; };
23//!   }
24//! ```
25//!
26//! Here the block (`{ return; }`) has the return type `char`, rather than `()`, but the MIR we
27//! naively generate still contains the `_a = ()` write in the unreachable block "after" the
28//! return.
29
30use rustc_index::{Idx, IndexSlice, IndexVec};
31use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
32use rustc_middle::mir::*;
33use rustc_middle::ty::TyCtxt;
34use rustc_span::DUMMY_SP;
35use smallvec::SmallVec;
36use tracing::{debug, trace};
37
38pub(super) enum SimplifyCfg {
39    Initial,
40    PromoteConsts,
41    RemoveFalseEdges,
42    /// Runs at the beginning of "analysis to runtime" lowering, *before* drop elaboration.
43    PostAnalysis,
44    /// Runs at the end of "analysis to runtime" lowering, *after* drop elaboration.
45    /// This is before the main optimization passes on runtime MIR kick in.
46    PreOptimizations,
47    Final,
48    MakeShim,
49    AfterUnreachableEnumBranching,
50}
51
52impl SimplifyCfg {
53    fn name(&self) -> &'static str {
54        match self {
55            SimplifyCfg::Initial => "SimplifyCfg-initial",
56            SimplifyCfg::PromoteConsts => "SimplifyCfg-promote-consts",
57            SimplifyCfg::RemoveFalseEdges => "SimplifyCfg-remove-false-edges",
58            SimplifyCfg::PostAnalysis => "SimplifyCfg-post-analysis",
59            SimplifyCfg::PreOptimizations => "SimplifyCfg-pre-optimizations",
60            SimplifyCfg::Final => "SimplifyCfg-final",
61            SimplifyCfg::MakeShim => "SimplifyCfg-make_shim",
62            SimplifyCfg::AfterUnreachableEnumBranching => {
63                "SimplifyCfg-after-unreachable-enum-branching"
64            }
65        }
66    }
67}
68
69pub(super) fn simplify_cfg(body: &mut Body<'_>) {
70    CfgSimplifier::new(body).simplify();
71    remove_dead_blocks(body);
72
73    // FIXME: Should probably be moved into some kind of pass manager
74    body.basic_blocks_mut().raw.shrink_to_fit();
75}
76
77impl<'tcx> crate::MirPass<'tcx> for SimplifyCfg {
78    fn name(&self) -> &'static str {
79        self.name()
80    }
81
82    fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
83        debug!("SimplifyCfg({:?}) - simplifying {:?}", self.name(), body.source);
84        simplify_cfg(body);
85    }
86
87    fn is_required(&self) -> bool {
88        false
89    }
90}
91
92struct CfgSimplifier<'a, 'tcx> {
93    basic_blocks: &'a mut IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
94    pred_count: IndexVec<BasicBlock, u32>,
95}
96
97impl<'a, 'tcx> CfgSimplifier<'a, 'tcx> {
98    fn new(body: &'a mut Body<'tcx>) -> Self {
99        let mut pred_count = IndexVec::from_elem(0u32, &body.basic_blocks);
100
101        // we can't use mir.predecessors() here because that counts
102        // dead blocks, which we don't want to.
103        pred_count[START_BLOCK] = 1;
104
105        for (_, data) in traversal::preorder(body) {
106            if let Some(ref term) = data.terminator {
107                for tgt in term.successors() {
108                    pred_count[tgt] += 1;
109                }
110            }
111        }
112
113        let basic_blocks = body.basic_blocks_mut();
114
115        CfgSimplifier { basic_blocks, pred_count }
116    }
117
118    fn simplify(mut self) {
119        self.strip_nops();
120
121        // Vec of the blocks that should be merged. We store the indices here, instead of the
122        // statements itself to avoid moving the (relatively) large statements twice.
123        // We do not push the statements directly into the target block (`bb`) as that is slower
124        // due to additional reallocations
125        let mut merged_blocks = Vec::new();
126        loop {
127            let mut changed = false;
128
129            for bb in self.basic_blocks.indices() {
130                if self.pred_count[bb] == 0 {
131                    continue;
132                }
133
134                debug!("simplifying {:?}", bb);
135
136                let mut terminator =
137                    self.basic_blocks[bb].terminator.take().expect("invalid terminator state");
138
139                for successor in terminator.successors_mut() {
140                    self.collapse_goto_chain(successor, &mut changed);
141                }
142
143                let mut inner_changed = true;
144                merged_blocks.clear();
145                while inner_changed {
146                    inner_changed = false;
147                    inner_changed |= self.simplify_branch(&mut terminator);
148                    inner_changed |= self.merge_successor(&mut merged_blocks, &mut terminator);
149                    changed |= inner_changed;
150                }
151
152                let statements_to_merge =
153                    merged_blocks.iter().map(|&i| self.basic_blocks[i].statements.len()).sum();
154
155                if statements_to_merge > 0 {
156                    let mut statements = std::mem::take(&mut self.basic_blocks[bb].statements);
157                    statements.reserve(statements_to_merge);
158                    for &from in &merged_blocks {
159                        statements.append(&mut self.basic_blocks[from].statements);
160                    }
161                    self.basic_blocks[bb].statements = statements;
162                }
163
164                self.basic_blocks[bb].terminator = Some(terminator);
165            }
166
167            if !changed {
168                break;
169            }
170        }
171    }
172
173    /// This function will return `None` if
174    /// * the block has statements
175    /// * the block has a terminator other than `goto`
176    /// * the block has no terminator (meaning some other part of the current optimization stole it)
177    fn take_terminator_if_simple_goto(&mut self, bb: BasicBlock) -> Option<Terminator<'tcx>> {
178        match self.basic_blocks[bb] {
179            BasicBlockData {
180                ref statements,
181                terminator:
182                    ref mut terminator @ Some(Terminator { kind: TerminatorKind::Goto { .. }, .. }),
183                ..
184            } if statements.is_empty() => terminator.take(),
185            // if `terminator` is None, this means we are in a loop. In that
186            // case, let all the loop collapse to its entry.
187            _ => None,
188        }
189    }
190
191    /// Collapse a goto chain starting from `start`
192    fn collapse_goto_chain(&mut self, start: &mut BasicBlock, changed: &mut bool) {
193        // Using `SmallVec` here, because in some logs on libcore oli-obk saw many single-element
194        // goto chains. We should probably benchmark different sizes.
195        let mut terminators: SmallVec<[_; 1]> = Default::default();
196        let mut current = *start;
197        while let Some(terminator) = self.take_terminator_if_simple_goto(current) {
198            let Terminator { kind: TerminatorKind::Goto { target }, .. } = terminator else {
199                unreachable!();
200            };
201            terminators.push((current, terminator));
202            current = target;
203        }
204        let last = current;
205        *start = last;
206        while let Some((current, mut terminator)) = terminators.pop() {
207            let Terminator { kind: TerminatorKind::Goto { ref mut target }, .. } = terminator
208            else {
209                unreachable!();
210            };
211            *changed |= *target != last;
212            *target = last;
213            debug!("collapsing goto chain from {:?} to {:?}", current, target);
214
215            if self.pred_count[current] == 1 {
216                // This is the last reference to current, so the pred-count to
217                // to target is moved into the current block.
218                self.pred_count[current] = 0;
219            } else {
220                self.pred_count[*target] += 1;
221                self.pred_count[current] -= 1;
222            }
223            self.basic_blocks[current].terminator = Some(terminator);
224        }
225    }
226
227    // merge a block with 1 `goto` predecessor to its parent
228    fn merge_successor(
229        &mut self,
230        merged_blocks: &mut Vec<BasicBlock>,
231        terminator: &mut Terminator<'tcx>,
232    ) -> bool {
233        let target = match terminator.kind {
234            TerminatorKind::Goto { target } if self.pred_count[target] == 1 => target,
235            _ => return false,
236        };
237
238        debug!("merging block {:?} into {:?}", target, terminator);
239        *terminator = match self.basic_blocks[target].terminator.take() {
240            Some(terminator) => terminator,
241            None => {
242                // unreachable loop - this should not be possible, as we
243                // don't strand blocks, but handle it correctly.
244                return false;
245            }
246        };
247
248        merged_blocks.push(target);
249        self.pred_count[target] = 0;
250
251        true
252    }
253
254    // turn a branch with all successors identical to a goto
255    fn simplify_branch(&mut self, terminator: &mut Terminator<'tcx>) -> bool {
256        match terminator.kind {
257            TerminatorKind::SwitchInt { .. } => {}
258            _ => return false,
259        };
260
261        let first_succ = {
262            if let Some(first_succ) = terminator.successors().next() {
263                if terminator.successors().all(|s| s == first_succ) {
264                    let count = terminator.successors().count();
265                    self.pred_count[first_succ] -= (count - 1) as u32;
266                    first_succ
267                } else {
268                    return false;
269                }
270            } else {
271                return false;
272            }
273        };
274
275        debug!("simplifying branch {:?}", terminator);
276        terminator.kind = TerminatorKind::Goto { target: first_succ };
277        true
278    }
279
280    fn strip_nops(&mut self) {
281        for blk in self.basic_blocks.iter_mut() {
282            blk.statements.retain(|stmt| !matches!(stmt.kind, StatementKind::Nop))
283        }
284    }
285}
286
287pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>) {
288    if let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind {
289        let otherwise = targets.otherwise();
290        if targets.iter().any(|t| t.1 == otherwise) {
291            *targets = SwitchTargets::new(
292                targets.iter().filter(|t| t.1 != otherwise),
293                targets.otherwise(),
294            );
295        }
296    }
297}
298
299pub(super) fn remove_dead_blocks(body: &mut Body<'_>) {
300    let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| {
301        // CfgSimplifier::simplify leaves behind some unreachable basic blocks without a
302        // terminator. Those blocks will be deleted by remove_dead_blocks, but we run just
303        // before then so we need to handle missing terminators.
304        // We also need to prevent confusing cleanup and non-cleanup blocks. In practice we
305        // don't emit empty unreachable cleanup blocks, so this simple check suffices.
306        bbdata.terminator.is_some() && bbdata.is_empty_unreachable() && !bbdata.is_cleanup
307    };
308
309    let reachable = traversal::reachable_as_bitset(body);
310    let empty_unreachable_blocks = body
311        .basic_blocks
312        .iter_enumerated()
313        .filter(|(bb, bbdata)| should_deduplicate_unreachable(bbdata) && reachable.contains(*bb))
314        .count();
315
316    let num_blocks = body.basic_blocks.len();
317    if num_blocks == reachable.count() && empty_unreachable_blocks <= 1 {
318        return;
319    }
320
321    let basic_blocks = body.basic_blocks.as_mut();
322
323    let mut replacements: Vec<_> = (0..num_blocks).map(BasicBlock::new).collect();
324    let mut orig_index = 0;
325    let mut used_index = 0;
326    let mut kept_unreachable = None;
327    let mut deduplicated_unreachable = false;
328    basic_blocks.raw.retain(|bbdata| {
329        let orig_bb = BasicBlock::new(orig_index);
330        if !reachable.contains(orig_bb) {
331            orig_index += 1;
332            return false;
333        }
334
335        let used_bb = BasicBlock::new(used_index);
336        if should_deduplicate_unreachable(bbdata) {
337            let kept_unreachable = *kept_unreachable.get_or_insert(used_bb);
338            if kept_unreachable != used_bb {
339                replacements[orig_index] = kept_unreachable;
340                deduplicated_unreachable = true;
341                orig_index += 1;
342                return false;
343            }
344        }
345
346        replacements[orig_index] = used_bb;
347        used_index += 1;
348        orig_index += 1;
349        true
350    });
351
352    // If we deduplicated unreachable blocks we erase their source_info as we
353    // can no longer attribute their code to a particular location in the
354    // source.
355    if deduplicated_unreachable {
356        basic_blocks[kept_unreachable.unwrap()].terminator_mut().source_info =
357            SourceInfo { span: DUMMY_SP, scope: OUTERMOST_SOURCE_SCOPE };
358    }
359
360    for block in basic_blocks {
361        for target in block.terminator_mut().successors_mut() {
362            *target = replacements[target.index()];
363        }
364    }
365}
366
367pub(super) enum SimplifyLocals {
368    BeforeConstProp,
369    AfterGVN,
370    Final,
371}
372
373impl<'tcx> crate::MirPass<'tcx> for SimplifyLocals {
374    fn name(&self) -> &'static str {
375        match &self {
376            SimplifyLocals::BeforeConstProp => "SimplifyLocals-before-const-prop",
377            SimplifyLocals::AfterGVN => "SimplifyLocals-after-value-numbering",
378            SimplifyLocals::Final => "SimplifyLocals-final",
379        }
380    }
381
382    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
383        sess.mir_opt_level() > 0
384    }
385
386    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
387        trace!("running SimplifyLocals on {:?}", body.source);
388
389        // First, we're going to get a count of *actual* uses for every `Local`.
390        let mut used_locals = UsedLocals::new(body);
391
392        // Next, we're going to remove any `Local` with zero actual uses. When we remove those
393        // `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
394        // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
395        // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
396        // fixedpoint where there are no more unused locals.
397        remove_unused_definitions_helper(&mut used_locals, body);
398
399        // Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the
400        // `Local`s.
401        let map = make_local_map(&mut body.local_decls, &used_locals);
402
403        // Only bother running the `LocalUpdater` if we actually found locals to remove.
404        if map.iter().any(Option::is_none) {
405            // Update references to all vars and tmps now
406            let mut updater = LocalUpdater { map, tcx };
407            updater.visit_body_preserves_cfg(body);
408
409            body.local_decls.shrink_to_fit();
410        }
411    }
412
413    fn is_required(&self) -> bool {
414        false
415    }
416}
417
418pub(super) fn remove_unused_definitions<'tcx>(body: &mut Body<'tcx>) {
419    // First, we're going to get a count of *actual* uses for every `Local`.
420    let mut used_locals = UsedLocals::new(body);
421
422    // Next, we're going to remove any `Local` with zero actual uses. When we remove those
423    // `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
424    // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
425    // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
426    // fixedpoint where there are no more unused locals.
427    remove_unused_definitions_helper(&mut used_locals, body);
428}
429
430/// Construct the mapping while swapping out unused stuff out from the `vec`.
431fn make_local_map<V>(
432    local_decls: &mut IndexVec<Local, V>,
433    used_locals: &UsedLocals,
434) -> IndexVec<Local, Option<Local>> {
435    let mut map: IndexVec<Local, Option<Local>> = IndexVec::from_elem(None, local_decls);
436    let mut used = Local::ZERO;
437
438    for alive_index in local_decls.indices() {
439        // `is_used` treats the `RETURN_PLACE` and arguments as used.
440        if !used_locals.is_used(alive_index) {
441            continue;
442        }
443
444        map[alive_index] = Some(used);
445        if alive_index != used {
446            local_decls.swap(alive_index, used);
447        }
448        used.increment_by(1);
449    }
450    local_decls.truncate(used.index());
451    map
452}
453
454/// Keeps track of used & unused locals.
455struct UsedLocals {
456    increment: bool,
457    arg_count: u32,
458    use_count: IndexVec<Local, u32>,
459}
460
461impl UsedLocals {
462    /// Determines which locals are used & unused in the given body.
463    fn new(body: &Body<'_>) -> Self {
464        let mut this = Self {
465            increment: true,
466            arg_count: body.arg_count.try_into().unwrap(),
467            use_count: IndexVec::from_elem(0, &body.local_decls),
468        };
469        this.visit_body(body);
470        this
471    }
472
473    /// Checks if local is used.
474    ///
475    /// Return place and arguments are always considered used.
476    fn is_used(&self, local: Local) -> bool {
477        trace!("is_used({:?}): use_count: {:?}", local, self.use_count[local]);
478        local.as_u32() <= self.arg_count || self.use_count[local] != 0
479    }
480
481    /// Updates the use counts to reflect the removal of given statement.
482    fn statement_removed(&mut self, statement: &Statement<'_>) {
483        self.increment = false;
484
485        // The location of the statement is irrelevant.
486        let location = Location::START;
487        self.visit_statement(statement, location);
488    }
489
490    /// Visits a left-hand side of an assignment.
491    fn visit_lhs(&mut self, place: &Place<'_>, location: Location) {
492        if place.is_indirect() {
493            // A use, not a definition.
494            self.visit_place(place, PlaceContext::MutatingUse(MutatingUseContext::Store), location);
495        } else {
496            // A definition. The base local itself is not visited, so this occurrence is not counted
497            // toward its use count. There might be other locals still, used in an indexing
498            // projection.
499            self.super_projection(
500                place.as_ref(),
501                PlaceContext::MutatingUse(MutatingUseContext::Projection),
502                location,
503            );
504        }
505    }
506}
507
508impl<'tcx> Visitor<'tcx> for UsedLocals {
509    fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
510        match statement.kind {
511            StatementKind::Intrinsic(..)
512            | StatementKind::Retag(..)
513            | StatementKind::Coverage(..)
514            | StatementKind::FakeRead(..)
515            | StatementKind::PlaceMention(..)
516            | StatementKind::AscribeUserType(..) => {
517                self.super_statement(statement, location);
518            }
519
520            StatementKind::ConstEvalCounter | StatementKind::Nop => {}
521
522            StatementKind::StorageLive(_local) | StatementKind::StorageDead(_local) => {}
523
524            StatementKind::Assign(box (ref place, ref rvalue)) => {
525                if rvalue.is_safe_to_remove() {
526                    self.visit_lhs(place, location);
527                    self.visit_rvalue(rvalue, location);
528                } else {
529                    self.super_statement(statement, location);
530                }
531            }
532
533            StatementKind::SetDiscriminant { ref place, variant_index: _ }
534            | StatementKind::Deinit(ref place)
535            | StatementKind::BackwardIncompatibleDropHint { ref place, reason: _ } => {
536                self.visit_lhs(place, location);
537            }
538        }
539    }
540
541    fn visit_local(&mut self, local: Local, _ctx: PlaceContext, _location: Location) {
542        if self.increment {
543            self.use_count[local] += 1;
544        } else {
545            assert_ne!(self.use_count[local], 0);
546            self.use_count[local] -= 1;
547        }
548    }
549}
550
551/// Removes unused definitions. Updates the used locals to reflect the changes made.
552fn remove_unused_definitions_helper(used_locals: &mut UsedLocals, body: &mut Body<'_>) {
553    // The use counts are updated as we remove the statements. A local might become unused
554    // during the retain operation, leading to a temporary inconsistency (storage statements or
555    // definitions referencing the local might remain). For correctness it is crucial that this
556    // computation reaches a fixed point.
557
558    let mut modified = true;
559    while modified {
560        modified = false;
561
562        for data in body.basic_blocks.as_mut_preserves_cfg() {
563            // Remove unnecessary StorageLive and StorageDead annotations.
564            data.statements.retain(|statement| {
565                let keep = match &statement.kind {
566                    StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
567                        used_locals.is_used(*local)
568                    }
569                    StatementKind::Assign(box (place, _)) => used_locals.is_used(place.local),
570
571                    StatementKind::SetDiscriminant { ref place, .. }
572                    | StatementKind::BackwardIncompatibleDropHint { ref place, reason: _ }
573                    | StatementKind::Deinit(ref place) => used_locals.is_used(place.local),
574                    StatementKind::Nop => false,
575                    _ => true,
576                };
577
578                if !keep {
579                    trace!("removing statement {:?}", statement);
580                    modified = true;
581                    used_locals.statement_removed(statement);
582                }
583
584                keep
585            });
586        }
587    }
588}
589
590struct LocalUpdater<'tcx> {
591    map: IndexVec<Local, Option<Local>>,
592    tcx: TyCtxt<'tcx>,
593}
594
595impl<'tcx> MutVisitor<'tcx> for LocalUpdater<'tcx> {
596    fn tcx(&self) -> TyCtxt<'tcx> {
597        self.tcx
598    }
599
600    fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
601        if let StatementKind::BackwardIncompatibleDropHint { place, reason: _ } =
602            &mut statement.kind
603        {
604            self.visit_local(
605                &mut place.local,
606                PlaceContext::MutatingUse(MutatingUseContext::Store),
607                location,
608            );
609        } else {
610            self.super_statement(statement, location);
611        }
612    }
613
614    fn visit_local(&mut self, l: &mut Local, _: PlaceContext, _: Location) {
615        *l = self.map[*l].unwrap();
616    }
617}