Skip to main content

rustc_mir_transform/
simplify_comparison_integral.rs

1use std::iter;
2
3use rustc_middle::bug;
4use rustc_middle::mir::interpret::Scalar;
5use rustc_middle::mir::{
6    BasicBlock, BinOp, Body, Operand, Place, Rvalue, Statement, StatementKind, SwitchTargets,
7    TerminatorKind,
8};
9use rustc_middle::ty::{Ty, TyCtxt};
10use tracing::trace;
11
12use crate::ssa::SsaLocals;
13
14/// Pass to convert `if` conditions on integrals into switches on the integral.
15/// For an example, it turns something like
16///
17/// ```ignore (MIR)
18/// _3 = Eq(move _4, const 43i32);
19/// StorageDead(_4);
20/// switchInt(_3) -> [false: bb2, otherwise: bb3];
21/// ```
22///
23/// into:
24///
25/// ```ignore (MIR)
26/// switchInt(_4) -> [43i32: bb3, otherwise: bb2];
27/// ```
28pub(super) struct SimplifyComparisonIntegral;
29
30impl<'tcx> crate::MirPass<'tcx> for SimplifyComparisonIntegral {
31    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
32        sess.mir_opt_level() > 1
33    }
34
35    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
36        trace!("Running SimplifyComparisonIntegral on {:?}", body.source);
37
38        let typing_env = body.typing_env(tcx);
39        let ssa = SsaLocals::new(tcx, body, typing_env);
40        let helper = OptimizationFinder { body };
41        let opts = helper.find_optimizations(&ssa);
42        let mut storage_deads_to_insert = vec![];
43        let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![];
44        for opt in opts {
45            trace!("SUCCESS: Applying {:?}", opt);
46            // replace terminator with a switchInt that switches on the integer directly
47            let bbs = &mut body.basic_blocks_mut();
48            let bb = &mut bbs[opt.bb_idx];
49            let new_value = match opt.branch_value_scalar {
50                Scalar::Int(int) => {
51                    let layout = tcx
52                        .layout_of(typing_env.as_query_input(opt.branch_value_ty))
53                        .expect("if we have an evaluated constant we must know the layout");
54                    int.to_bits(layout.size)
55                }
56                Scalar::Ptr(..) => continue,
57            };
58            const FALSE: u128 = 0;
59
60            let mut new_targets = opt.targets;
61            let first_value = new_targets.iter().next().unwrap().0;
62            let first_is_false_target = first_value == FALSE;
63            match opt.op {
64                BinOp::Eq => {
65                    // if the assignment was Eq we want the true case to be first
66                    if first_is_false_target {
67                        new_targets.all_targets_mut().swap(0, 1);
68                    }
69                }
70                BinOp::Ne => {
71                    // if the assignment was Ne we want the false case to be first
72                    if !first_is_false_target {
73                        new_targets.all_targets_mut().swap(0, 1);
74                    }
75                }
76                _ => unreachable!(),
77            }
78
79            // delete comparison statement if it the value being switched on was moved, which means
80            // it can not be used later on
81            if opt.can_remove_bin_op_stmt {
82                bb.statements[opt.bin_op_stmt_idx].make_nop(true);
83            } else {
84                // if the integer being compared to a const integral is being moved into the
85                // comparison, e.g `_2 = Eq(move _3, const 'x');`
86                // we want to avoid making a double move later on in the switchInt on _3.
87                // So to avoid `switchInt(move _3) -> ['x': bb2, otherwise: bb1];`,
88                // we convert the move in the comparison statement to a copy.
89
90                // unwrap is safe as we know this statement is an assign
91                let (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();
92
93                use Operand::*;
94                match rhs {
95                    Rvalue::BinaryOp(_, (left @ Move(_), Constant(_))) => {
96                        *left = Copy(opt.to_switch_on);
97                    }
98                    Rvalue::BinaryOp(_, (Constant(_), right @ Move(_))) => {
99                        *right = Copy(opt.to_switch_on);
100                    }
101                    _ => (),
102                }
103            }
104
105            let terminator = bb.terminator();
106
107            // remove StorageDead (if it exists) being used in the assign of the comparison
108            for (stmt_idx, stmt) in bb.statements.iter().enumerate() {
109                if !matches!(
110                    stmt.kind,
111                    StatementKind::StorageDead(local) if local == opt.to_switch_on.local
112                ) {
113                    continue;
114                }
115                storage_deads_to_remove.push((stmt_idx, opt.bb_idx));
116                // if we have StorageDeads to remove then make sure to insert them at the top of
117                // each target
118                for bb_idx in new_targets.all_targets() {
119                    storage_deads_to_insert.push((
120                        *bb_idx,
121                        Statement::new(
122                            terminator.source_info,
123                            StatementKind::StorageDead(opt.to_switch_on.local),
124                        ),
125                    ));
126                }
127            }
128
129            let [bb_cond, bb_otherwise] = match new_targets.all_targets() {
130                [a, b] => [*a, *b],
131                e => bug!("expected 2 switch targets, got: {:?}", e),
132            };
133
134            let targets = SwitchTargets::new(iter::once((new_value, bb_cond)), bb_otherwise);
135
136            let terminator = bb.terminator_mut();
137            terminator.kind =
138                TerminatorKind::SwitchInt { discr: Operand::Copy(opt.to_switch_on), targets };
139        }
140
141        for (idx, bb_idx) in storage_deads_to_remove {
142            body.basic_blocks_mut()[bb_idx].statements[idx].make_nop(true);
143        }
144
145        for (idx, stmt) in storage_deads_to_insert {
146            body.basic_blocks_mut()[idx].statements.insert(0, stmt);
147        }
148    }
149
150    fn is_required(&self) -> bool {
151        false
152    }
153}
154
155struct OptimizationFinder<'a, 'tcx> {
156    body: &'a Body<'tcx>,
157}
158
159impl<'tcx> OptimizationFinder<'_, 'tcx> {
160    fn find_optimizations(&self, ssa: &SsaLocals) -> Vec<OptimizationInfo<'tcx>> {
161        self.body
162            .basic_blocks
163            .iter_enumerated()
164            .filter_map(|(bb_idx, bb)| {
165                // find switch
166                let (discr, targets) = bb.terminator().kind.as_switch()?;
167                let place_switched_on = discr.place()?;
168                // Make sure that the place is not modified.
169                if !ssa.is_ssa(place_switched_on.local) || !place_switched_on.is_stable_offset() {
170                    return None;
171                }
172
173                // find the statement that assigns the place being switched on
174                bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| {
175                    match &stmt.kind {
176                        rustc_middle::mir::StatementKind::Assign((lhs, rhs))
177                            if *lhs == place_switched_on =>
178                        {
179                            match rhs {
180                                Rvalue::BinaryOp(op @ (BinOp::Eq | BinOp::Ne), (left, right)) => {
181                                    let (branch_value_scalar, branch_value_ty, to_switch_on) =
182                                        find_branch_value_info(left, right, ssa)?;
183
184                                    Some(OptimizationInfo {
185                                        bin_op_stmt_idx: stmt_idx,
186                                        bb_idx,
187                                        can_remove_bin_op_stmt: discr.is_move(),
188                                        to_switch_on,
189                                        branch_value_scalar,
190                                        branch_value_ty,
191                                        op: *op,
192                                        targets: targets.clone(),
193                                    })
194                                }
195                                _ => None,
196                            }
197                        }
198                        _ => None,
199                    }
200                })
201            })
202            .collect()
203    }
204}
205
206fn find_branch_value_info<'tcx>(
207    left: &Operand<'tcx>,
208    right: &Operand<'tcx>,
209    ssa: &SsaLocals,
210) -> Option<(Scalar, Ty<'tcx>, Place<'tcx>)> {
211    // check that either left or right is a constant.
212    // if any are, we can use the other to switch on, and the constant as a value in a switch
213    use Operand::*;
214    match (left, right) {
215        (Constant(branch_value), Copy(to_switch_on) | Move(to_switch_on))
216        | (Copy(to_switch_on) | Move(to_switch_on), Constant(branch_value)) => {
217            // Make sure that the place is not modified.
218            if !ssa.is_ssa(to_switch_on.local) || !to_switch_on.is_stable_offset() {
219                return None;
220            }
221            let branch_value_ty = branch_value.const_.ty();
222            // we only want to apply this optimization if we are matching on integrals (and chars),
223            // as it is not possible to switch on floats
224            if !branch_value_ty.is_integral() && !branch_value_ty.is_char() {
225                return None;
226            };
227            let branch_value_scalar = branch_value.const_.try_to_scalar()?;
228            Some((branch_value_scalar, branch_value_ty, *to_switch_on))
229        }
230        _ => None,
231    }
232}
233
234#[derive(Debug)]
235struct OptimizationInfo<'tcx> {
236    /// Basic block to apply the optimization
237    bb_idx: BasicBlock,
238    /// Statement index of Eq/Ne assignment that can be removed. None if the assignment can not be
239    /// removed - i.e the statement is used later on
240    bin_op_stmt_idx: usize,
241    /// Can remove Eq/Ne assignment
242    can_remove_bin_op_stmt: bool,
243    /// Place that needs to be switched on. This place is of type integral
244    to_switch_on: Place<'tcx>,
245    /// Constant to use in switch target value
246    branch_value_scalar: Scalar,
247    /// Type of the constant value
248    branch_value_ty: Ty<'tcx>,
249    /// Either Eq or Ne
250    op: BinOp,
251    /// Current targets used in the switch
252    targets: SwitchTargets,
253}