Skip to main content

rustc_mir_transform/
match_branches.rs

1use rustc_abi::Integer;
2use rustc_const_eval::const_eval::mk_eval_cx_for_const_val;
3use rustc_middle::mir::*;
4use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
5use rustc_middle::ty::util::Discr;
6use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
7
8use super::simplify::simplify_cfg;
9use crate::patch::MirPatch;
10use crate::unreachable_prop::remove_successors_from_switch;
11
12/// Unifies all targets into one basic block if each statement can have the same statement.
13pub(super) struct MatchBranchSimplification;
14
15impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
16    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
17        // Enable only under -Zmir-opt-level=2 as this can make programs less debuggable.
18        sess.mir_opt_level() >= 2
19    }
20
21    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
22        let typing_env = body.typing_env(tcx);
23        let mut changed = false;
24        for bb in body.basic_blocks.indices() {
25            if !candidate_match(body, bb) {
26                continue;
27            };
28            changed |= simplify_match(tcx, typing_env, body, bb)
29        }
30
31        if changed {
32            simplify_cfg(tcx, body);
33        }
34    }
35
36    fn is_required(&self) -> bool {
37        false
38    }
39}
40
41struct SimplifyMatch<'tcx, 'a> {
42    tcx: TyCtxt<'tcx>,
43    typing_env: ty::TypingEnv<'tcx>,
44    patch: MirPatch<'tcx>,
45    body: &'a Body<'tcx>,
46    switch_bb: BasicBlock,
47    discr: &'a Operand<'tcx>,
48    discr_local: Option<Local>,
49    discr_ty: Ty<'tcx>,
50}
51
52impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> {
53    fn discr_local(&mut self) -> Local {
54        *self.discr_local.get_or_insert_with(|| {
55            // Introduce a temporary for the discriminant value.
56            let source_info = self.body.basic_blocks[self.switch_bb].terminator().source_info;
57            self.patch.new_temp(self.discr_ty, source_info.span)
58        })
59    }
60
61    /// Unifies the assignments if all rvalues are constants and equal.
62    fn unify_if_equal_const(
63        &self,
64        dest: Place<'tcx>,
65        consts: &[(u128, &ConstOperand<'tcx>)],
66        otherwise: Option<&ConstOperand<'tcx>>,
67    ) -> Option<StatementKind<'tcx>> {
68        let (_, first_const, mut others) = split_first_case(consts, otherwise);
69        let first_scalar_int = first_const.const_.try_eval_scalar_int(self.tcx, self.typing_env)?;
70        if others.all(|const_| {
71            const_.const_.try_eval_scalar_int(self.tcx, self.typing_env) == Some(first_scalar_int)
72        }) {
73            Some(StatementKind::Assign(Box::new((
74                dest,
75                Rvalue::Use(Operand::Constant(Box::new(first_const.clone()))),
76            ))))
77        } else {
78            None
79        }
80    }
81
82    /// If a source block is found that switches between two blocks that are exactly
83    /// the same modulo const bool assignments (e.g., one assigns true another false
84    /// to the same place), unify a target block statements into the source block,
85    /// using Eq / Ne comparison with switch value where const bools value differ.
86    ///
87    /// For example:
88    ///
89    /// ```ignore (MIR)
90    /// bb0: {
91    ///     switchInt(move _3) -> [42_isize: bb1, otherwise: bb2];
92    /// }
93    ///
94    /// bb1: {
95    ///     _2 = const true;
96    ///     goto -> bb3;
97    /// }
98    ///
99    /// bb2: {
100    ///     _2 = const false;
101    ///     goto -> bb3;
102    /// }
103    /// ```
104    ///
105    /// into:
106    ///
107    /// ```ignore (MIR)
108    /// bb0: {
109    ///    _2 = Eq(move _3, const 42_isize);
110    ///    goto -> bb3;
111    /// }
112    /// ```
113    fn unify_by_eq_op(
114        &mut self,
115        dest: Place<'tcx>,
116        consts: &[(u128, &ConstOperand<'tcx>)],
117        otherwise: Option<&ConstOperand<'tcx>>,
118    ) -> Option<StatementKind<'tcx>> {
119        // FIXME: extend to any case.
120        let (first_case, first_const, mut others) = split_first_case(consts, otherwise);
121        if !first_const.ty().is_bool() {
122            return None;
123        }
124        let first_bool = first_const.const_.try_eval_bool(self.tcx, self.typing_env)?;
125        if others.all(|const_| {
126            const_.const_.try_eval_bool(self.tcx, self.typing_env) == Some(!first_bool)
127        }) {
128            // Make value conditional on switch condition.
129            let size =
130                self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap().size;
131            let const_cmp = Operand::const_from_scalar(
132                self.tcx,
133                self.discr_ty,
134                rustc_const_eval::interpret::Scalar::from_uint(first_case, size),
135                rustc_span::DUMMY_SP,
136            );
137            let op = if first_bool { BinOp::Eq } else { BinOp::Ne };
138            let rval = Rvalue::BinaryOp(
139                op,
140                Box::new((Operand::Copy(Place::from(self.discr_local())), const_cmp)),
141            );
142            Some(StatementKind::Assign(Box::new((dest, rval))))
143        } else {
144            None
145        }
146    }
147
148    /// Unifies the assignments if all rvalues can be cast from the discriminant value by IntToInt.
149    ///
150    /// For example:
151    ///
152    /// ```ignore (MIR)
153    /// bb0: {
154    ///     switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
155    /// }
156    ///
157    /// bb1: {
158    ///     unreachable;
159    /// }
160    ///
161    /// bb2: {
162    ///     _0 = const 1_i16;
163    ///     goto -> bb5;
164    /// }
165    ///
166    /// bb3: {
167    ///     _0 = const 2_i16;
168    ///     goto -> bb5;
169    /// }
170    ///
171    /// bb4: {
172    ///     _0 = const 3_i16;
173    ///     goto -> bb5;
174    /// }
175    /// ```
176    ///
177    /// into:
178    ///
179    /// ```ignore (MIR)
180    /// bb0: {
181    ///    _0 = _1 as i16 (IntToInt);
182    ///    goto -> bb5;
183    /// }
184    /// ```
185    fn unify_by_int_to_int(
186        &mut self,
187        dest: Place<'tcx>,
188        consts: &[(u128, &ConstOperand<'tcx>)],
189    ) -> Option<StatementKind<'tcx>> {
190        let (_, first_const) = consts[0];
191        if !first_const.ty().is_integral() {
192            return None;
193        }
194        let discr_layout =
195            self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap();
196        if consts.iter().all(|&(case, const_)| {
197            let Some(scalar_int) = const_.const_.try_eval_scalar_int(self.tcx, self.typing_env)
198            else {
199                return false;
200            };
201            can_cast(self.tcx, case, discr_layout, const_.ty(), scalar_int)
202        }) {
203            let operand = Operand::Copy(Place::from(self.discr_local()));
204            let rval = if first_const.ty() == self.discr_ty {
205                Rvalue::Use(operand)
206            } else {
207                Rvalue::Cast(CastKind::IntToInt, operand, first_const.ty())
208            };
209            Some(StatementKind::Assign(Box::new((dest, rval))))
210        } else {
211            None
212        }
213    }
214
215    /// This is primarily used to unify these copy statements that simplified the canonical enum clone method by GVN.
216    /// The GVN simplified
217    /// ```ignore (syntax-highlighting-only)
218    /// match a {
219    ///     Foo::A(x) => Foo::A(*x),
220    ///     Foo::B => Foo::B
221    /// }
222    /// ```
223    /// to
224    /// ```ignore (syntax-highlighting-only)
225    /// match a {
226    ///     Foo::A(_x) => a, // copy a
227    ///     Foo::B => Foo::B
228    /// }
229    /// ```
230    /// This will simplify into a copy statement.
231    fn unify_by_copy(
232        &self,
233        dest: Place<'tcx>,
234        rvals: &[(u128, &Rvalue<'tcx>)],
235    ) -> Option<StatementKind<'tcx>> {
236        let bbs = &self.body.basic_blocks;
237        // Check if the copy source matches the following pattern.
238        // _2 = discriminant(*_1); // "*_1" is the expected the copy source.
239        // switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
240        let &Statement {
241            kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(copy_src_place))),
242            ..
243        } = bbs[self.switch_bb].statements.last()?
244        else {
245            return None;
246        };
247        if self.discr.place() != Some(discr_place) {
248            return None;
249        }
250        let src_ty = copy_src_place.ty(self.body.local_decls(), self.tcx);
251        if !src_ty.ty.is_enum() || src_ty.variant_index.is_some() {
252            return None;
253        }
254        let dest_ty = dest.ty(self.body.local_decls(), self.tcx);
255        if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() {
256            return None;
257        }
258        let ty::Adt(def, _) = dest_ty.ty.kind() else {
259            return None;
260        };
261
262        for &(case, rvalue) in rvals.iter() {
263            match rvalue {
264                // Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`.
265                Rvalue::Use(Operand::Constant(box constant))
266                    if let Const::Val(const_, ty) = constant.const_ =>
267                {
268                    let (ecx, op) = mk_eval_cx_for_const_val(
269                        self.tcx.at(constant.span),
270                        self.typing_env,
271                        const_,
272                        ty,
273                    )?;
274                    let variant = ecx.read_discriminant(&op).discard_err()?;
275                    if !def.variants()[variant].fields.is_empty() {
276                        return None;
277                    }
278                    let Discr { val, .. } = ty.discriminant_for_variant(self.tcx, variant)?;
279                    if val != case {
280                        return None;
281                    }
282                }
283                Rvalue::Use(Operand::Copy(src_place)) if *src_place == copy_src_place => {}
284                // Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`.
285                Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
286                    if fields.is_empty()
287                        && let Some(Discr { val, .. }) =
288                            src_ty.ty.discriminant_for_variant(self.tcx, *variant_index)
289                        && val == case => {}
290                _ => return None,
291            }
292        }
293        Some(StatementKind::Assign(Box::new((dest, Rvalue::Use(Operand::Copy(copy_src_place))))))
294    }
295
296    /// Returns a new statement if we can use the statement replace all statements.
297    fn try_unify_stmts(
298        &mut self,
299        index: usize,
300        stmts: &[(u128, &StatementKind<'tcx>)],
301        otherwise: Option<&StatementKind<'tcx>>,
302    ) -> Option<StatementKind<'tcx>> {
303        if let Some(new_stmt) = identical_stmts(stmts, otherwise) {
304            return Some(new_stmt);
305        }
306
307        let (dest, rvals, otherwise) = candidate_assign(stmts, otherwise)?;
308        if let Some((consts, otherwise)) = candidate_const(&rvals, otherwise) {
309            if let Some(new_stmt) = self.unify_if_equal_const(dest, &consts, otherwise) {
310                return Some(new_stmt);
311            }
312            if let Some(new_stmt) = self.unify_by_eq_op(dest, &consts, otherwise) {
313                return Some(new_stmt);
314            }
315            // Requires the otherwise is unreachable.
316            if otherwise.is_none()
317                && let Some(new_stmt) = self.unify_by_int_to_int(dest, &consts)
318            {
319                return Some(new_stmt);
320            }
321        }
322
323        // We only know the first statement is safe to introduce new dereferences.
324        if index == 0
325            // We cannot create overlapping assignments.
326            && dest.is_stable_offset()
327            // Requires the otherwise is unreachable.
328            && otherwise.is_none()
329            && let Some(new_stmt) = self.unify_by_copy(dest, &rvals)
330        {
331            return Some(new_stmt);
332        }
333        None
334    }
335}
336
337/// Returns the first case target if all targets have an equal number of statements and identical destination.
338fn candidate_match<'tcx>(body: &Body<'tcx>, switch_bb: BasicBlock) -> bool {
339    use itertools::Itertools;
340    let targets = match &body.basic_blocks[switch_bb].terminator().kind {
341        TerminatorKind::SwitchInt {
342            discr: Operand::Copy(_) | Operand::Move(_), targets, ..
343        } => targets,
344        // Only optimize switch int statements
345        _ => return false,
346    };
347    // We require that the possible target blocks don't contain this block.
348    if targets.all_targets().contains(&switch_bb) {
349        return false;
350    }
351    // We require that the possible target blocks all be distinct.
352    if !targets.is_distinct() {
353        return false;
354    }
355    // Check that destinations are identical, and if not, then don't optimize this block
356    targets
357        .all_targets()
358        .iter()
359        .map(|&bb| &body.basic_blocks[bb])
360        .filter(|bb| !bb.is_empty_unreachable())
361        .map(|bb| (bb.statements.len(), &bb.terminator().kind))
362        .all_equal()
363}
364
365fn simplify_match<'tcx>(
366    tcx: TyCtxt<'tcx>,
367    typing_env: ty::TypingEnv<'tcx>,
368    body: &mut Body<'tcx>,
369    switch_bb: BasicBlock,
370) -> bool {
371    let (discr, targets) = match &body.basic_blocks[switch_bb].terminator().kind {
372        TerminatorKind::SwitchInt { discr, targets, .. } => (discr, targets),
373        _ => unreachable!(),
374    };
375    let mut simplify_match = SimplifyMatch {
376        tcx,
377        typing_env,
378        patch: MirPatch::new(body),
379        body,
380        switch_bb,
381        discr,
382        discr_local: None,
383        discr_ty: discr.ty(body.local_decls(), tcx),
384    };
385    let reachable_cases: Vec<_> =
386        targets.iter().filter(|&(_, bb)| !body.basic_blocks[bb].is_empty_unreachable()).collect();
387    let mut new_stmts = Vec::new();
388    let otherwise = if body.basic_blocks[targets.otherwise()].is_empty_unreachable() {
389        None
390    } else {
391        Some(targets.otherwise())
392    };
393    // We can patch the terminator to goto because there is a single target.
394    match (reachable_cases.len(), otherwise.is_none()) {
395        (1, true) | (0, false) => {
396            let mut patch = simplify_match.patch;
397            remove_successors_from_switch(tcx, switch_bb, body, &mut patch, |bb| {
398                body.basic_blocks[bb].is_empty_unreachable()
399            });
400            patch.apply(body);
401            return true;
402        }
403        _ => {}
404    }
405    let Some(&(_, first_case_bb)) = reachable_cases.first() else {
406        return false;
407    };
408    let stmt_len = body.basic_blocks[first_case_bb].statements.len();
409    let mut cases = Vec::with_capacity(stmt_len);
410    // Check at each position in the basic blocks whether these statements can be unified.
411    for index in 0..stmt_len {
412        cases.clear();
413        let otherwise = otherwise.map(|bb| &body.basic_blocks[bb].statements[index].kind);
414        for &(case, bb) in &reachable_cases {
415            cases.push((case, &body.basic_blocks[bb].statements[index].kind));
416        }
417        let Some(new_stmt) = simplify_match.try_unify_stmts(index, &cases, otherwise) else {
418            return false;
419        };
420        new_stmts.push(new_stmt);
421    }
422    // Take ownership of items now that we know we can optimize.
423    let discr = discr.clone();
424
425    let statement_index = body.basic_blocks[switch_bb].statements.len();
426    let parent_end = Location { block: switch_bb, statement_index };
427    let mut patch = simplify_match.patch;
428    if let Some(discr_local) = simplify_match.discr_local {
429        patch.add_statement(parent_end, StatementKind::StorageLive(discr_local));
430        patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr));
431    }
432    for new_stmt in new_stmts {
433        patch.add_statement(parent_end, new_stmt);
434    }
435    if let Some(discr_local) = simplify_match.discr_local {
436        patch.add_statement(parent_end, StatementKind::StorageDead(discr_local));
437    }
438    patch.patch_terminator(switch_bb, body.basic_blocks[first_case_bb].terminator().kind.clone());
439    patch.apply(body);
440    true
441}
442
443/// Check if the cast constant using `IntToInt` is equal to the target constant.
444fn can_cast(
445    tcx: TyCtxt<'_>,
446    src_val: impl Into<u128>,
447    src_layout: TyAndLayout<'_>,
448    cast_ty: Ty<'_>,
449    target_scalar: ScalarInt,
450) -> bool {
451    let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap();
452    let v = match src_layout.ty.kind() {
453        ty::Uint(_) => from_scalar.to_uint(src_layout.size),
454        ty::Int(_) => from_scalar.to_int(src_layout.size) as u128,
455        // We can also transform the values of other integer representations (such as char),
456        // although this may not be practical in real-world scenarios.
457        _ => return false,
458    };
459    let size = match *cast_ty.kind() {
460        ty::Int(t) => Integer::from_int_ty(&tcx, t).size(),
461        ty::Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
462        _ => return false,
463    };
464    let v = size.truncate(v);
465    let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap();
466    cast_scalar == target_scalar
467}
468
469fn candidate_assign<'tcx, 'a>(
470    stmts: &'a [(u128, &'a StatementKind<'tcx>)],
471    otherwise: Option<&'a StatementKind<'tcx>>,
472) -> Option<(Place<'tcx>, Vec<(u128, &'a Rvalue<'tcx>)>, Option<&'a Rvalue<'tcx>>)> {
473    let (_, first_stmt) = stmts[0];
474    let (dest, _) = first_stmt.as_assign()?;
475    let otherwise = if let Some(otherwise) = otherwise {
476        let Some((otherwise_dest, rval)) = otherwise.as_assign() else {
477            return None;
478        };
479        if otherwise_dest != dest {
480            return None;
481        }
482        Some(rval)
483    } else {
484        None
485    };
486    let rvals = stmts
487        .into_iter()
488        .map(|&(case, stmt)| {
489            let (other_dest, rval) = stmt.as_assign()?;
490            if other_dest != dest {
491                return None;
492            }
493            Some((case, rval))
494        })
495        .try_collect()?;
496    Some((*dest, rvals, otherwise))
497}
498
499// Returns all ConstOperands if all Rvalues are ConstOperands.
500fn candidate_const<'tcx, 'a>(
501    rvals: &'a [(u128, &'a Rvalue<'tcx>)],
502    otherwise: Option<&'a Rvalue<'tcx>>,
503) -> Option<(Vec<(u128, &'a ConstOperand<'tcx>)>, Option<&'a ConstOperand<'tcx>>)> {
504    let otherwise = if let Some(otherwise) = otherwise {
505        let Rvalue::Use(Operand::Constant(box const_)) = otherwise else {
506            return None;
507        };
508        Some(const_)
509    } else {
510        None
511    };
512    let consts = rvals
513        .into_iter()
514        .map(|&(case, rval)| {
515            let Rvalue::Use(Operand::Constant(box const_)) = rval else { return None };
516            Some((case, const_))
517        })
518        .try_collect()?;
519    Some((consts, otherwise))
520}
521
522// Returns the first case and others (including otherwise if present).
523fn split_first_case<'a, T>(
524    stmts: &'a [(u128, &'a T)],
525    otherwise: Option<&'a T>,
526) -> (u128, &'a T, impl Iterator<Item = &'a T>) {
527    let (first_case, first) = stmts[0];
528    (first_case, first, stmts[1..].into_iter().map(|&(_, val)| val).chain(otherwise))
529}
530
531// If all statements are identical, we can optimize.
532fn identical_stmts<'tcx>(
533    stmts: &[(u128, &StatementKind<'tcx>)],
534    otherwise: Option<&StatementKind<'tcx>>,
535) -> Option<StatementKind<'tcx>> {
536    use itertools::Itertools;
537    let (_, first_stmt, others) = split_first_case(stmts, otherwise);
538    if std::iter::once(first_stmt).chain(others).all_equal() {
539        return Some(first_stmt.clone());
540    }
541    None
542}