Skip to main content

rustc_mir_transform/
simplify_comparison_integral.rs

1use std::iter;
2
3use rustc_middle::bug;
4use rustc_middle::mir::interpret::Scalar;
5use rustc_middle::mir::{
6    BasicBlock, BinOp, Body, Operand, Place, Rvalue, Statement, StatementKind, SwitchTargets,
7    TerminatorKind,
8};
9use rustc_middle::ty::{Ty, TyCtxt};
10use tracing::trace;
11
12use crate::ssa::SsaLocals;
13
14/// Pass to convert `if` conditions on integrals into switches on the integral.
15/// For an example, it turns something like
16///
17/// ```ignore (MIR)
18/// _3 = Eq(move _4, const 43i32);
19/// StorageDead(_4);
20/// switchInt(_3) -> [false: bb2, otherwise: bb3];
21/// ```
22///
23/// into:
24///
25/// ```ignore (MIR)
26/// switchInt(_4) -> [43i32: bb3, otherwise: bb2];
27/// ```
28pub(super) struct SimplifyComparisonIntegral;
29
30impl<'tcx> crate::MirPass<'tcx> for SimplifyComparisonIntegral {
31    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
32        sess.mir_opt_level() > 1
33    }
34
35    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
36        trace!("Running SimplifyComparisonIntegral on {:?}", body.source);
37
38        let typing_env = body.typing_env(tcx);
39        let ssa = SsaLocals::new(tcx, body, typing_env);
40        let helper = OptimizationFinder { body };
41        let opts = helper.find_optimizations(&ssa);
42        let mut storage_deads_to_insert = vec![];
43        let mut storage_deads_to_remove: Vec<(usize, BasicBlock)> = vec![];
44        for opt in opts {
45            trace!("SUCCESS: Applying {:?}", opt);
46            // replace terminator with a switchInt that switches on the integer directly
47            let bbs = &mut body.basic_blocks_mut();
48            let bb = &mut bbs[opt.bb_idx];
49            let new_value = match opt.branch_value_scalar {
50                Scalar::Int(int) => {
51                    let layout = tcx
52                        .layout_of(typing_env.as_query_input(opt.branch_value_ty))
53                        .expect("if we have an evaluated constant we must know the layout");
54                    int.to_bits(layout.size)
55                }
56                Scalar::Ptr(..) => continue,
57            };
58            const FALSE: u128 = 0;
59
60            let mut new_targets = opt.targets;
61            let first_value = new_targets.iter().next().unwrap().0;
62            let first_is_false_target = first_value == FALSE;
63            match opt.op {
64                BinOp::Eq => {
65                    // if the assignment was Eq we want the true case to be first
66                    if first_is_false_target {
67                        new_targets.all_targets_mut().swap(0, 1);
68                    }
69                }
70                BinOp::Ne => {
71                    // if the assignment was Ne we want the false case to be first
72                    if !first_is_false_target {
73                        new_targets.all_targets_mut().swap(0, 1);
74                    }
75                }
76                _ => unreachable!(),
77            }
78
79            // delete comparison statement if it the value being switched on was moved, which means
80            // it can not be used later on
81            if opt.can_remove_bin_op_stmt {
82                bb.statements[opt.bin_op_stmt_idx].make_nop(true);
83            } else {
84                // if the integer being compared to a const integral is being moved into the
85                // comparison, e.g `_2 = Eq(move _3, const 'x');`
86                // we want to avoid making a double move later on in the switchInt on _3.
87                // So to avoid `switchInt(move _3) -> ['x': bb2, otherwise: bb1];`,
88                // we convert the move in the comparison statement to a copy.
89
90                // unwrap is safe as we know this statement is an assign
91                let (_, rhs) = bb.statements[opt.bin_op_stmt_idx].kind.as_assign_mut().unwrap();
92
93                use Operand::*;
94                match rhs {
95                    Rvalue::BinaryOp(_, box (left @ Move(_), Constant(_))) => {
96                        *left = Copy(opt.to_switch_on);
97                    }
98                    Rvalue::BinaryOp(_, box (Constant(_), right @ Move(_))) => {
99                        *right = Copy(opt.to_switch_on);
100                    }
101                    _ => (),
102                }
103            }
104
105            let terminator = bb.terminator();
106
107            // remove StorageDead (if it exists) being used in the assign of the comparison
108            for (stmt_idx, stmt) in bb.statements.iter().enumerate() {
109                if !matches!(
110                    stmt.kind,
111                    StatementKind::StorageDead(local) if local == opt.to_switch_on.local
112                ) {
113                    continue;
114                }
115                storage_deads_to_remove.push((stmt_idx, opt.bb_idx));
116                // if we have StorageDeads to remove then make sure to insert them at the top of
117                // each target
118                for bb_idx in new_targets.all_targets() {
119                    storage_deads_to_insert.push((
120                        *bb_idx,
121                        Statement::new(
122                            terminator.source_info,
123                            StatementKind::StorageDead(opt.to_switch_on.local),
124                        ),
125                    ));
126                }
127            }
128
129            let [bb_cond, bb_otherwise] = match new_targets.all_targets() {
130                [a, b] => [*a, *b],
131                e => bug!("expected 2 switch targets, got: {:?}", e),
132            };
133
134            let targets = SwitchTargets::new(iter::once((new_value, bb_cond)), bb_otherwise);
135
136            let terminator = bb.terminator_mut();
137            terminator.kind =
138                TerminatorKind::SwitchInt { discr: Operand::Copy(opt.to_switch_on), targets };
139        }
140
141        for (idx, bb_idx) in storage_deads_to_remove {
142            body.basic_blocks_mut()[bb_idx].statements[idx].make_nop(true);
143        }
144
145        for (idx, stmt) in storage_deads_to_insert {
146            body.basic_blocks_mut()[idx].statements.insert(0, stmt);
147        }
148    }
149
150    fn is_required(&self) -> bool {
151        false
152    }
153}
154
155struct OptimizationFinder<'a, 'tcx> {
156    body: &'a Body<'tcx>,
157}
158
159impl<'tcx> OptimizationFinder<'_, 'tcx> {
160    fn find_optimizations(&self, ssa: &SsaLocals) -> Vec<OptimizationInfo<'tcx>> {
161        self.body
162            .basic_blocks
163            .iter_enumerated()
164            .filter_map(|(bb_idx, bb)| {
165                // find switch
166                let (discr, targets) = bb.terminator().kind.as_switch()?;
167                let place_switched_on = discr.place()?;
168                // Make sure that the place is not modified.
169                if !ssa.is_ssa(place_switched_on.local) || !place_switched_on.is_stable_offset() {
170                    return None;
171                }
172
173                // find the statement that assigns the place being switched on
174                bb.statements.iter().enumerate().rev().find_map(|(stmt_idx, stmt)| {
175                    match &stmt.kind {
176                        rustc_middle::mir::StatementKind::Assign(box (lhs, rhs))
177                            if *lhs == place_switched_on =>
178                        {
179                            match rhs {
180                                Rvalue::BinaryOp(
181                                    op @ (BinOp::Eq | BinOp::Ne),
182                                    box (left, right),
183                                ) => {
184                                    let (branch_value_scalar, branch_value_ty, to_switch_on) =
185                                        find_branch_value_info(left, right, ssa)?;
186
187                                    Some(OptimizationInfo {
188                                        bin_op_stmt_idx: stmt_idx,
189                                        bb_idx,
190                                        can_remove_bin_op_stmt: discr.is_move(),
191                                        to_switch_on,
192                                        branch_value_scalar,
193                                        branch_value_ty,
194                                        op: *op,
195                                        targets: targets.clone(),
196                                    })
197                                }
198                                _ => None,
199                            }
200                        }
201                        _ => None,
202                    }
203                })
204            })
205            .collect()
206    }
207}
208
209fn find_branch_value_info<'tcx>(
210    left: &Operand<'tcx>,
211    right: &Operand<'tcx>,
212    ssa: &SsaLocals,
213) -> Option<(Scalar, Ty<'tcx>, Place<'tcx>)> {
214    // check that either left or right is a constant.
215    // if any are, we can use the other to switch on, and the constant as a value in a switch
216    use Operand::*;
217    match (left, right) {
218        (Constant(branch_value), Copy(to_switch_on) | Move(to_switch_on))
219        | (Copy(to_switch_on) | Move(to_switch_on), Constant(branch_value)) => {
220            // Make sure that the place is not modified.
221            if !ssa.is_ssa(to_switch_on.local) || !to_switch_on.is_stable_offset() {
222                return None;
223            }
224            let branch_value_ty = branch_value.const_.ty();
225            // we only want to apply this optimization if we are matching on integrals (and chars),
226            // as it is not possible to switch on floats
227            if !branch_value_ty.is_integral() && !branch_value_ty.is_char() {
228                return None;
229            };
230            let branch_value_scalar = branch_value.const_.try_to_scalar()?;
231            Some((branch_value_scalar, branch_value_ty, *to_switch_on))
232        }
233        _ => None,
234    }
235}
236
237#[derive(Debug)]
238struct OptimizationInfo<'tcx> {
239    /// Basic block to apply the optimization
240    bb_idx: BasicBlock,
241    /// Statement index of Eq/Ne assignment that can be removed. None if the assignment can not be
242    /// removed - i.e the statement is used later on
243    bin_op_stmt_idx: usize,
244    /// Can remove Eq/Ne assignment
245    can_remove_bin_op_stmt: bool,
246    /// Place that needs to be switched on. This place is of type integral
247    to_switch_on: Place<'tcx>,
248    /// Constant to use in switch target value
249    branch_value_scalar: Scalar,
250    /// Type of the constant value
251    branch_value_ty: Ty<'tcx>,
252    /// Either Eq or Ne
253    op: BinOp,
254    /// Current targets used in the switch
255    targets: SwitchTargets,
256}