rustc_mir_transform/
simplify_branches.rs

1use rustc_middle::mir::*;
2use rustc_middle::ty::TyCtxt;
3use tracing::trace;
4
5use crate::patch::MirPatch;
6
7pub(super) enum SimplifyConstCondition {
8    AfterConstProp,
9    Final,
10}
11
12/// A pass that replaces a branch with a goto when its condition is known.
13impl<'tcx> crate::MirPass<'tcx> for SimplifyConstCondition {
14    fn name(&self) -> &'static str {
15        match self {
16            SimplifyConstCondition::AfterConstProp => "SimplifyConstCondition-after-const-prop",
17            SimplifyConstCondition::Final => "SimplifyConstCondition-final",
18        }
19    }
20
21    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
22        trace!("Running SimplifyConstCondition on {:?}", body.source);
23        let typing_env = body.typing_env(tcx);
24        let mut patch = MirPatch::new(body);
25
26        'blocks: for (bb, block) in body.basic_blocks.iter_enumerated() {
27            for (statement_index, stmt) in block.statements.iter().enumerate() {
28                // Simplify `assume` of a known value: either a NOP or unreachable.
29                if let StatementKind::Intrinsic(box ref intrinsic) = stmt.kind
30                    && let NonDivergingIntrinsic::Assume(discr) = intrinsic
31                    && let Operand::Constant(c) = discr
32                    && let Some(constant) = c.const_.try_eval_bool(tcx, typing_env)
33                {
34                    if constant {
35                        patch.nop_statement(Location { block: bb, statement_index });
36                    } else {
37                        patch.patch_terminator(bb, TerminatorKind::Unreachable);
38                        continue 'blocks;
39                    }
40                }
41            }
42
43            let terminator = block.terminator();
44            let terminator = match terminator.kind {
45                TerminatorKind::SwitchInt {
46                    discr: Operand::Constant(ref c), ref targets, ..
47                } => {
48                    let constant = c.const_.try_eval_bits(tcx, typing_env);
49                    if let Some(constant) = constant {
50                        let target = targets.target_for_value(constant);
51                        TerminatorKind::Goto { target }
52                    } else {
53                        continue;
54                    }
55                }
56                TerminatorKind::Assert {
57                    target, cond: Operand::Constant(ref c), expected, ..
58                } => match c.const_.try_eval_bool(tcx, typing_env) {
59                    Some(v) if v == expected => TerminatorKind::Goto { target },
60                    _ => continue,
61                },
62                _ => continue,
63            };
64            patch.patch_terminator(bb, terminator);
65        }
66        patch.apply(body);
67    }
68
69    fn is_required(&self) -> bool {
70        false
71    }
72}