rustc_mir_transform/
simplify_branches.rs1use rustc_middle::mir::*;
2use rustc_middle::ty::TyCtxt;
3use tracing::trace;
4
5use crate::patch::MirPatch;
6
7pub(super) enum SimplifyConstCondition {
8 AfterConstProp,
9 Final,
10}
11
12impl<'tcx> crate::MirPass<'tcx> for SimplifyConstCondition {
14 fn name(&self) -> &'static str {
15 match self {
16 SimplifyConstCondition::AfterConstProp => "SimplifyConstCondition-after-const-prop",
17 SimplifyConstCondition::Final => "SimplifyConstCondition-final",
18 }
19 }
20
21 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
22 trace!("Running SimplifyConstCondition on {:?}", body.source);
23 let typing_env = body.typing_env(tcx);
24 let mut patch = MirPatch::new(body);
25
26 'blocks: for (bb, block) in body.basic_blocks.iter_enumerated() {
27 for (statement_index, stmt) in block.statements.iter().enumerate() {
28 if let StatementKind::Intrinsic(box ref intrinsic) = stmt.kind
30 && let NonDivergingIntrinsic::Assume(discr) = intrinsic
31 && let Operand::Constant(c) = discr
32 && let Some(constant) = c.const_.try_eval_bool(tcx, typing_env)
33 {
34 if constant {
35 patch.nop_statement(Location { block: bb, statement_index });
36 } else {
37 patch.patch_terminator(bb, TerminatorKind::Unreachable);
38 continue 'blocks;
39 }
40 }
41 }
42
43 let terminator = block.terminator();
44 let terminator = match terminator.kind {
45 TerminatorKind::SwitchInt {
46 discr: Operand::Constant(ref c), ref targets, ..
47 } => {
48 let constant = c.const_.try_eval_bits(tcx, typing_env);
49 if let Some(constant) = constant {
50 let target = targets.target_for_value(constant);
51 TerminatorKind::Goto { target }
52 } else {
53 continue;
54 }
55 }
56 TerminatorKind::Assert {
57 target, cond: Operand::Constant(ref c), expected, ..
58 } => match c.const_.try_eval_bool(tcx, typing_env) {
59 Some(v) if v == expected => TerminatorKind::Goto { target },
60 _ => continue,
61 },
62 _ => continue,
63 };
64 patch.patch_terminator(bb, terminator);
65 }
66 patch.apply(body);
67 }
68
69 fn is_required(&self) -> bool {
70 false
71 }
72}