rustc_mir_transform/
match_branches.rs

1use std::iter;
2
3use rustc_abi::Integer;
4use rustc_index::IndexSlice;
5use rustc_middle::mir::*;
6use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
7use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
8use rustc_type_ir::TyKind::*;
9use tracing::instrument;
10
11use super::simplify::simplify_cfg;
12use crate::patch::MirPatch;
13
14pub(super) struct MatchBranchSimplification;
15
16impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
17    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
18        sess.mir_opt_level() >= 1
19    }
20
21    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
22        let typing_env = body.typing_env(tcx);
23        let mut should_cleanup = false;
24        for i in 0..body.basic_blocks.len() {
25            let bbs = &*body.basic_blocks;
26            let bb_idx = BasicBlock::from_usize(i);
27            match bbs[bb_idx].terminator().kind {
28                TerminatorKind::SwitchInt {
29                    discr: ref _discr @ (Operand::Copy(_) | Operand::Move(_)),
30                    ref targets,
31                    ..
32                    // We require that the possible target blocks don't contain this block.
33                } if !targets.all_targets().contains(&bb_idx) => {}
34                // Only optimize switch int statements
35                _ => continue,
36            };
37
38            if SimplifyToIf.simplify(tcx, body, bb_idx, typing_env).is_some() {
39                should_cleanup = true;
40                continue;
41            }
42            if SimplifyToExp::default().simplify(tcx, body, bb_idx, typing_env).is_some() {
43                should_cleanup = true;
44                continue;
45            }
46        }
47
48        if should_cleanup {
49            simplify_cfg(body);
50        }
51    }
52
53    fn is_required(&self) -> bool {
54        false
55    }
56}
57
58trait SimplifyMatch<'tcx> {
59    /// Simplifies a match statement, returning `Some` if the simplification succeeds, `None`
60    /// otherwise. Generic code is written here, and we generally don't need a custom
61    /// implementation.
62    fn simplify(
63        &mut self,
64        tcx: TyCtxt<'tcx>,
65        body: &mut Body<'tcx>,
66        switch_bb_idx: BasicBlock,
67        typing_env: ty::TypingEnv<'tcx>,
68    ) -> Option<()> {
69        let bbs = &body.basic_blocks;
70        let (discr, targets) = match bbs[switch_bb_idx].terminator().kind {
71            TerminatorKind::SwitchInt { ref discr, ref targets, .. } => (discr, targets),
72            _ => unreachable!(),
73        };
74
75        let discr_ty = discr.ty(body.local_decls(), tcx);
76        self.can_simplify(tcx, targets, typing_env, bbs, discr_ty)?;
77
78        let mut patch = MirPatch::new(body);
79
80        // Take ownership of items now that we know we can optimize.
81        let discr = discr.clone();
82
83        // Introduce a temporary for the discriminant value.
84        let source_info = bbs[switch_bb_idx].terminator().source_info;
85        let discr_local = patch.new_temp(discr_ty, source_info.span);
86
87        let (_, first) = targets.iter().next().unwrap();
88        let statement_index = bbs[switch_bb_idx].statements.len();
89        let parent_end = Location { block: switch_bb_idx, statement_index };
90        patch.add_statement(parent_end, StatementKind::StorageLive(discr_local));
91        patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr));
92        self.new_stmts(
93            tcx,
94            targets,
95            typing_env,
96            &mut patch,
97            parent_end,
98            bbs,
99            discr_local,
100            discr_ty,
101        );
102        patch.add_statement(parent_end, StatementKind::StorageDead(discr_local));
103        patch.patch_terminator(switch_bb_idx, bbs[first].terminator().kind.clone());
104        patch.apply(body);
105        Some(())
106    }
107
108    /// Check that the BBs to be simplified satisfies all distinct and
109    /// that the terminator are the same.
110    /// There are also conditions for different ways of simplification.
111    fn can_simplify(
112        &mut self,
113        tcx: TyCtxt<'tcx>,
114        targets: &SwitchTargets,
115        typing_env: ty::TypingEnv<'tcx>,
116        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
117        discr_ty: Ty<'tcx>,
118    ) -> Option<()>;
119
120    fn new_stmts(
121        &self,
122        tcx: TyCtxt<'tcx>,
123        targets: &SwitchTargets,
124        typing_env: ty::TypingEnv<'tcx>,
125        patch: &mut MirPatch<'tcx>,
126        parent_end: Location,
127        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
128        discr_local: Local,
129        discr_ty: Ty<'tcx>,
130    );
131}
132
133struct SimplifyToIf;
134
135/// If a source block is found that switches between two blocks that are exactly
136/// the same modulo const bool assignments (e.g., one assigns true another false
137/// to the same place), merge a target block statements into the source block,
138/// using Eq / Ne comparison with switch value where const bools value differ.
139///
140/// For example:
141///
142/// ```ignore (MIR)
143/// bb0: {
144///     switchInt(move _3) -> [42_isize: bb1, otherwise: bb2];
145/// }
146///
147/// bb1: {
148///     _2 = const true;
149///     goto -> bb3;
150/// }
151///
152/// bb2: {
153///     _2 = const false;
154///     goto -> bb3;
155/// }
156/// ```
157///
158/// into:
159///
160/// ```ignore (MIR)
161/// bb0: {
162///    _2 = Eq(move _3, const 42_isize);
163///    goto -> bb3;
164/// }
165/// ```
166impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
167    #[instrument(level = "debug", skip(self, tcx), ret)]
168    fn can_simplify(
169        &mut self,
170        tcx: TyCtxt<'tcx>,
171        targets: &SwitchTargets,
172        typing_env: ty::TypingEnv<'tcx>,
173        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
174        _discr_ty: Ty<'tcx>,
175    ) -> Option<()> {
176        let (first, second) = match targets.all_targets() {
177            &[first, otherwise] => (first, otherwise),
178            &[first, second, otherwise] if bbs[otherwise].is_empty_unreachable() => (first, second),
179            _ => {
180                return None;
181            }
182        };
183
184        // We require that the possible target blocks all be distinct.
185        if first == second {
186            return None;
187        }
188        // Check that destinations are identical, and if not, then don't optimize this block
189        if bbs[first].terminator().kind != bbs[second].terminator().kind {
190            return None;
191        }
192
193        // Check that blocks are assignments of consts to the same place or same statement,
194        // and match up 1-1, if not don't optimize this block.
195        let first_stmts = &bbs[first].statements;
196        let second_stmts = &bbs[second].statements;
197        if first_stmts.len() != second_stmts.len() {
198            return None;
199        }
200        for (f, s) in iter::zip(first_stmts, second_stmts) {
201            match (&f.kind, &s.kind) {
202                // If two statements are exactly the same, we can optimize.
203                (f_s, s_s) if f_s == s_s => {}
204
205                // If two statements are const bool assignments to the same place, we can optimize.
206                (
207                    StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
208                    StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
209                ) if lhs_f == lhs_s
210                    && f_c.const_.ty().is_bool()
211                    && s_c.const_.ty().is_bool()
212                    && f_c.const_.try_eval_bool(tcx, typing_env).is_some()
213                    && s_c.const_.try_eval_bool(tcx, typing_env).is_some() => {}
214
215                // Otherwise we cannot optimize. Try another block.
216                _ => return None,
217            }
218        }
219        Some(())
220    }
221
222    fn new_stmts(
223        &self,
224        tcx: TyCtxt<'tcx>,
225        targets: &SwitchTargets,
226        typing_env: ty::TypingEnv<'tcx>,
227        patch: &mut MirPatch<'tcx>,
228        parent_end: Location,
229        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
230        discr_local: Local,
231        discr_ty: Ty<'tcx>,
232    ) {
233        let ((val, first), second) = match (targets.all_targets(), targets.all_values()) {
234            (&[first, otherwise], &[val]) => ((val, first), otherwise),
235            (&[first, second, otherwise], &[val, _]) if bbs[otherwise].is_empty_unreachable() => {
236                ((val, first), second)
237            }
238            _ => unreachable!(),
239        };
240
241        // We already checked that first and second are different blocks,
242        // and bb_idx has a different terminator from both of them.
243        let first = &bbs[first];
244        let second = &bbs[second];
245        for (f, s) in iter::zip(&first.statements, &second.statements) {
246            match (&f.kind, &s.kind) {
247                (f_s, s_s) if f_s == s_s => {
248                    patch.add_statement(parent_end, f.kind.clone());
249                }
250
251                (
252                    StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
253                    StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))),
254                ) => {
255                    // From earlier loop we know that we are dealing with bool constants only:
256                    let f_b = f_c.const_.try_eval_bool(tcx, typing_env).unwrap();
257                    let s_b = s_c.const_.try_eval_bool(tcx, typing_env).unwrap();
258                    if f_b == s_b {
259                        // Same value in both blocks. Use statement as is.
260                        patch.add_statement(parent_end, f.kind.clone());
261                    } else {
262                        // Different value between blocks. Make value conditional on switch
263                        // condition.
264                        let size = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap().size;
265                        let const_cmp = Operand::const_from_scalar(
266                            tcx,
267                            discr_ty,
268                            rustc_const_eval::interpret::Scalar::from_uint(val, size),
269                            rustc_span::DUMMY_SP,
270                        );
271                        let op = if f_b { BinOp::Eq } else { BinOp::Ne };
272                        let rhs = Rvalue::BinaryOp(
273                            op,
274                            Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)),
275                        );
276                        patch.add_assign(parent_end, *lhs, rhs);
277                    }
278                }
279
280                _ => unreachable!(),
281            }
282        }
283    }
284}
285
286/// Check if the cast constant using `IntToInt` is equal to the target constant.
287fn can_cast(
288    tcx: TyCtxt<'_>,
289    src_val: impl Into<u128>,
290    src_layout: TyAndLayout<'_>,
291    cast_ty: Ty<'_>,
292    target_scalar: ScalarInt,
293) -> bool {
294    let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap();
295    let v = match src_layout.ty.kind() {
296        Uint(_) => from_scalar.to_uint(src_layout.size),
297        Int(_) => from_scalar.to_int(src_layout.size) as u128,
298        _ => unreachable!("invalid int"),
299    };
300    let size = match *cast_ty.kind() {
301        Int(t) => Integer::from_int_ty(&tcx, t).size(),
302        Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
303        _ => unreachable!("invalid int"),
304    };
305    let v = size.truncate(v);
306    let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap();
307    cast_scalar == target_scalar
308}
309
310#[derive(Default)]
311struct SimplifyToExp {
312    transform_kinds: Vec<TransformKind>,
313}
314
315#[derive(Clone, Copy, Debug)]
316enum ExpectedTransformKind<'a, 'tcx> {
317    /// Identical statements.
318    Same(&'a StatementKind<'tcx>),
319    /// Assignment statements have the same value.
320    SameByEq { place: &'a Place<'tcx>, ty: Ty<'tcx>, scalar: ScalarInt },
321    /// Enum variant comparison type.
322    Cast { place: &'a Place<'tcx>, ty: Ty<'tcx> },
323}
324
325enum TransformKind {
326    Same,
327    Cast,
328}
329
330impl From<ExpectedTransformKind<'_, '_>> for TransformKind {
331    fn from(compare_type: ExpectedTransformKind<'_, '_>) -> Self {
332        match compare_type {
333            ExpectedTransformKind::Same(_) => TransformKind::Same,
334            ExpectedTransformKind::SameByEq { .. } => TransformKind::Same,
335            ExpectedTransformKind::Cast { .. } => TransformKind::Cast,
336        }
337    }
338}
339
340/// If we find that the value of match is the same as the assignment,
341/// merge a target block statements into the source block,
342/// using cast to transform different integer types.
343///
344/// For example:
345///
346/// ```ignore (MIR)
347/// bb0: {
348///     switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1];
349/// }
350///
351/// bb1: {
352///     unreachable;
353/// }
354///
355/// bb2: {
356///     _0 = const 1_i16;
357///     goto -> bb5;
358/// }
359///
360/// bb3: {
361///     _0 = const 2_i16;
362///     goto -> bb5;
363/// }
364///
365/// bb4: {
366///     _0 = const 3_i16;
367///     goto -> bb5;
368/// }
369/// ```
370///
371/// into:
372///
373/// ```ignore (MIR)
374/// bb0: {
375///    _0 = _3 as i16 (IntToInt);
376///    goto -> bb5;
377/// }
378/// ```
379impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
380    #[instrument(level = "debug", skip(self, tcx), ret)]
381    fn can_simplify(
382        &mut self,
383        tcx: TyCtxt<'tcx>,
384        targets: &SwitchTargets,
385        typing_env: ty::TypingEnv<'tcx>,
386        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
387        discr_ty: Ty<'tcx>,
388    ) -> Option<()> {
389        if targets.iter().len() < 2 || targets.iter().len() > 64 {
390            return None;
391        }
392        // We require that the possible target blocks all be distinct.
393        if !targets.is_distinct() {
394            return None;
395        }
396        if !bbs[targets.otherwise()].is_empty_unreachable() {
397            return None;
398        }
399        let mut target_iter = targets.iter();
400        let (first_case_val, first_target) = target_iter.next().unwrap();
401        let first_terminator_kind = &bbs[first_target].terminator().kind;
402        // Check that destinations are identical, and if not, then don't optimize this block
403        if !targets
404            .iter()
405            .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind)
406        {
407            return None;
408        }
409
410        let discr_layout = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap();
411        let first_stmts = &bbs[first_target].statements;
412        let (second_case_val, second_target) = target_iter.next().unwrap();
413        let second_stmts = &bbs[second_target].statements;
414        if first_stmts.len() != second_stmts.len() {
415            return None;
416        }
417
418        // We first compare the two branches, and then the other branches need to fulfill the same
419        // conditions.
420        let mut expected_transform_kinds = Vec::new();
421        for (f, s) in iter::zip(first_stmts, second_stmts) {
422            let compare_type = match (&f.kind, &s.kind) {
423                // If two statements are exactly the same, we can optimize.
424                (f_s, s_s) if f_s == s_s => ExpectedTransformKind::Same(f_s),
425
426                // If two statements are assignments with the match values to the same place, we
427                // can optimize.
428                (
429                    StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))),
430                    StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
431                ) if lhs_f == lhs_s
432                    && f_c.const_.ty() == s_c.const_.ty()
433                    && f_c.const_.ty().is_integral() =>
434                {
435                    match (
436                        f_c.const_.try_eval_scalar_int(tcx, typing_env),
437                        s_c.const_.try_eval_scalar_int(tcx, typing_env),
438                    ) {
439                        (Some(f), Some(s)) if f == s => ExpectedTransformKind::SameByEq {
440                            place: lhs_f,
441                            ty: f_c.const_.ty(),
442                            scalar: f,
443                        },
444                        // Enum variants can also be simplified to an assignment statement,
445                        // if we can use `IntToInt` cast to get an equal value.
446                        (Some(f), Some(s))
447                            if (can_cast(
448                                tcx,
449                                first_case_val,
450                                discr_layout,
451                                f_c.const_.ty(),
452                                f,
453                            ) && can_cast(
454                                tcx,
455                                second_case_val,
456                                discr_layout,
457                                f_c.const_.ty(),
458                                s,
459                            )) =>
460                        {
461                            ExpectedTransformKind::Cast { place: lhs_f, ty: f_c.const_.ty() }
462                        }
463                        _ => {
464                            return None;
465                        }
466                    }
467                }
468
469                // Otherwise we cannot optimize. Try another block.
470                _ => return None,
471            };
472            expected_transform_kinds.push(compare_type);
473        }
474
475        // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
476        for (other_val, other_target) in target_iter {
477            let other_stmts = &bbs[other_target].statements;
478            if expected_transform_kinds.len() != other_stmts.len() {
479                return None;
480            }
481            for (f, s) in iter::zip(&expected_transform_kinds, other_stmts) {
482                match (*f, &s.kind) {
483                    (ExpectedTransformKind::Same(f_s), s_s) if f_s == s_s => {}
484                    (
485                        ExpectedTransformKind::SameByEq { place: lhs_f, ty: f_ty, scalar },
486                        StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
487                    ) if lhs_f == lhs_s
488                        && s_c.const_.ty() == f_ty
489                        && s_c.const_.try_eval_scalar_int(tcx, typing_env) == Some(scalar) => {}
490                    (
491                        ExpectedTransformKind::Cast { place: lhs_f, ty: f_ty },
492                        StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))),
493                    ) if let Some(f) = s_c.const_.try_eval_scalar_int(tcx, typing_env)
494                        && lhs_f == lhs_s
495                        && s_c.const_.ty() == f_ty
496                        && can_cast(tcx, other_val, discr_layout, f_ty, f) => {}
497                    _ => return None,
498                }
499            }
500        }
501        self.transform_kinds = expected_transform_kinds.into_iter().map(|c| c.into()).collect();
502        Some(())
503    }
504
505    fn new_stmts(
506        &self,
507        _tcx: TyCtxt<'tcx>,
508        targets: &SwitchTargets,
509        _typing_env: ty::TypingEnv<'tcx>,
510        patch: &mut MirPatch<'tcx>,
511        parent_end: Location,
512        bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
513        discr_local: Local,
514        discr_ty: Ty<'tcx>,
515    ) {
516        let (_, first) = targets.iter().next().unwrap();
517        let first = &bbs[first];
518
519        for (t, s) in iter::zip(&self.transform_kinds, &first.statements) {
520            match (t, &s.kind) {
521                (TransformKind::Same, _) => {
522                    patch.add_statement(parent_end, s.kind.clone());
523                }
524                (
525                    TransformKind::Cast,
526                    StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))),
527                ) => {
528                    let operand = Operand::Copy(Place::from(discr_local));
529                    let r_val = if f_c.const_.ty() == discr_ty {
530                        Rvalue::Use(operand)
531                    } else {
532                        Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty())
533                    };
534                    patch.add_assign(parent_end, *lhs, r_val);
535                }
536                _ => unreachable!(),
537            }
538        }
539    }
540}