rustc_mir_transform/
early_otherwise_branch.rs

1use std::fmt::Debug;
2
3use rustc_middle::mir::*;
4use rustc_middle::ty::{Ty, TyCtxt};
5use tracing::trace;
6
7use super::simplify::simplify_cfg;
8use crate::patch::MirPatch;
9
10/// This pass optimizes something like
11/// ```ignore (syntax-highlighting-only)
12/// let x: Option<()>;
13/// let y: Option<()>;
14/// match (x,y) {
15///     (Some(_), Some(_)) => {0},
16///     (None, None) => {2},
17///     _ => {1}
18/// }
19/// ```
20/// into something like
21/// ```ignore (syntax-highlighting-only)
22/// let x: Option<()>;
23/// let y: Option<()>;
24/// let discriminant_x = std::mem::discriminant(x);
25/// let discriminant_y = std::mem::discriminant(y);
26/// if discriminant_x == discriminant_y {
27///     match x {
28///         Some(_) => 0,
29///         None => 2,
30///     }
31/// } else {
32///     1
33/// }
34/// ```
35///
36/// Specifically, it looks for instances of control flow like this:
37/// ```text
38///
39///     =================
40///     |      BB1      |
41///     |---------------|                  ============================
42///     |     ...       |         /------> |            BBC           |
43///     |---------------|         |        |--------------------------|
44///     |  switchInt(Q) |         |        |   _cl = discriminant(P)  |
45///     |       c       | --------/        |--------------------------|
46///     |       d       | -------\         |       switchInt(_cl)     |
47///     |      ...      |        |         |            c             | ---> BBC.2
48///     |    otherwise  | --\    |    /--- |         otherwise        |
49///     =================   |    |    |    ============================
50///                         |    |    |
51///     =================   |    |    |
52///     |      BBU      | <-|    |    |    ============================
53///     |---------------|        \-------> |            BBD           |
54///     |---------------|             |    |--------------------------|
55///     |  unreachable  |             |    |   _dl = discriminant(P)  |
56///     =================             |    |--------------------------|
57///                                   |    |       switchInt(_dl)     |
58///     =================             |    |            d             | ---> BBD.2
59///     |      BB9      | <--------------- |         otherwise        |
60///     |---------------|                  ============================
61///     |      ...      |
62///     =================
63/// ```
64/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
65/// code:
66///  - `BB1` is `parent` and `BBC, BBD` are children
67///  - `P` is `child_place`
68///  - `child_ty` is the type of `_cl`.
69///  - `Q` is `parent_op`.
70///  - `parent_ty` is the type of `Q`.
71///  - `BB9` is `destination`
72/// All this is then transformed into:
73/// ```text
74///
75///     =======================
76///     |          BB1        |
77///     |---------------------|                  ============================
78///     |          ...        |         /------> |           BBEq           |
79///     | _s = discriminant(P)|         |        |--------------------------|
80///     | _t = Ne(Q, _s)      |         |        |--------------------------|
81///     |---------------------|         |        |       switchInt(Q)       |
82///     |     switchInt(_t)   |         |        |            c             | ---> BBC.2
83///     |        false        | --------/        |            d             | ---> BBD.2
84///     |       otherwise     |       /--------- |         otherwise        |
85///     =======================       |          ============================
86///                                   |
87///     =================             |
88///     |      BB9      | <-----------/
89///     |---------------|
90///     |      ...      |
91///     =================
92/// ```
93pub(super) struct EarlyOtherwiseBranch;
94
95impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
96    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
97        sess.mir_opt_level() >= 2
98    }
99
100    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
101        trace!("running EarlyOtherwiseBranch on {:?}", body.source);
102
103        let mut should_cleanup = false;
104
105        // Also consider newly generated bbs in the same pass
106        for parent in body.basic_blocks.indices() {
107            let bbs = &*body.basic_blocks;
108            let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
109
110            trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
111
112            should_cleanup = true;
113
114            let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
115                &bbs[parent].terminator().kind
116            else {
117                unreachable!()
118            };
119            // Always correct since we can only switch on `Copy` types
120            let parent_op = parent_op.to_copy();
121            let parent_ty = parent_op.ty(body.local_decls(), tcx);
122            let statements_before = bbs[parent].statements.len();
123            let parent_end = Location { block: parent, statement_index: statements_before };
124
125            let mut patch = MirPatch::new(body);
126
127            let second_operand = if opt_data.need_hoist_discriminant {
128                // create temp to store second discriminant in, `_s` in example above
129                let second_discriminant_temp =
130                    patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
131
132                // create assignment of discriminant
133                patch.add_assign(
134                    parent_end,
135                    Place::from(second_discriminant_temp),
136                    Rvalue::Discriminant(opt_data.child_place),
137                );
138                Operand::Move(Place::from(second_discriminant_temp))
139            } else {
140                Operand::Copy(opt_data.child_place)
141            };
142
143            // create temp to store inequality comparison between the two discriminants, `_t` in
144            // example above
145            let nequal = BinOp::Ne;
146            let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
147            let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
148
149            // create inequality comparison
150            let comp_rvalue =
151                Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
152            patch.add_statement(
153                parent_end,
154                StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
155            );
156
157            let eq_new_targets = parent_targets.iter().map(|(value, child)| {
158                let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
159                else {
160                    unreachable!()
161                };
162                (value, targets.target_for_value(value))
163            });
164            // The otherwise either is the same target branch or an unreachable.
165            let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
166
167            // Create `bbEq` in example above
168            let eq_switch = BasicBlockData::new(
169                Some(Terminator {
170                    source_info: bbs[parent].terminator().source_info,
171                    kind: TerminatorKind::SwitchInt {
172                        // switch on the first discriminant, so we can mark the second one as dead
173                        discr: parent_op,
174                        targets: eq_targets,
175                    },
176                }),
177                bbs[parent].is_cleanup,
178            );
179
180            let eq_bb = patch.new_block(eq_switch);
181
182            // Jump to it on the basis of the inequality comparison
183            let true_case = opt_data.destination;
184            let false_case = eq_bb;
185            patch.patch_terminator(
186                parent,
187                TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
188            );
189
190            patch.apply(body);
191        }
192
193        // Since this optimization adds new basic blocks and invalidates others,
194        // clean up the cfg to make it nicer for other passes
195        if should_cleanup {
196            simplify_cfg(tcx, body);
197        }
198    }
199
200    fn is_required(&self) -> bool {
201        false
202    }
203}
204
205#[derive(Debug)]
206struct OptimizationData<'tcx> {
207    destination: BasicBlock,
208    child_place: Place<'tcx>,
209    child_ty: Ty<'tcx>,
210    child_source: SourceInfo,
211    need_hoist_discriminant: bool,
212}
213
214fn evaluate_candidate<'tcx>(
215    tcx: TyCtxt<'tcx>,
216    body: &Body<'tcx>,
217    parent: BasicBlock,
218) -> Option<OptimizationData<'tcx>> {
219    let bbs = &body.basic_blocks;
220    // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled.
221    if bbs[parent].is_cleanup {
222        return None;
223    }
224    let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind
225    else {
226        return None;
227    };
228    let parent_ty = parent_discr.ty(body.local_decls(), tcx);
229    let (_, child) = targets.iter().next()?;
230
231    let Terminator {
232        kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
233        source_info,
234    } = bbs[child].terminator()
235    else {
236        return None;
237    };
238    let child_ty = child_discr.ty(body.local_decls(), tcx);
239    if child_ty != parent_ty {
240        return None;
241    }
242
243    // We only handle:
244    // ```
245    // bb4: {
246    //     _8 = discriminant((_3.1: Enum1));
247    //    switchInt(move _8) -> [2: bb7, otherwise: bb1];
248    // }
249    // ```
250    // and
251    // ```
252    // bb2: {
253    //     switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
254    // }
255    // ```
256    if bbs[child].statements.len() > 1 {
257        return None;
258    }
259
260    // When thie BB has exactly one statement, this statement should be discriminant.
261    let need_hoist_discriminant = bbs[child].statements.len() == 1;
262    let child_place = if need_hoist_discriminant {
263        if !bbs[targets.otherwise()].is_empty_unreachable() {
264            // Someone could write code like this:
265            // ```rust
266            // let Q = val;
267            // if discriminant(P) == otherwise {
268            //     let ptr = &mut Q as *mut _ as *mut u8;
269            //     // It may be difficult for us to effectively determine whether values are valid.
270            //     // Invalid values can come from all sorts of corners.
271            //     unsafe { *ptr = 10; }
272            // }
273            //
274            // match P {
275            //    A => match Q {
276            //        A => {
277            //            // code
278            //        }
279            //        _ => {
280            //            // don't use Q
281            //        }
282            //    }
283            //    _ => {
284            //        // don't use Q
285            //    }
286            // };
287            // ```
288            //
289            // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
290            // invalid value, which is UB.
291            // In order to fix this, **we would either need to show that the discriminant computation of
292            // `place` is computed in all branches**.
293            // FIXME(#95162) For the moment, we adopt a conservative approach and
294            // consider only the `otherwise` branch has no statements and an unreachable terminator.
295            return None;
296        }
297        // Handle:
298        // ```
299        // bb4: {
300        //     _8 = discriminant((_3.1: Enum1));
301        //    switchInt(move _8) -> [2: bb7, otherwise: bb1];
302        // }
303        // ```
304        let [
305            Statement {
306                kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
307                ..
308            },
309        ] = bbs[child].statements.as_slice()
310        else {
311            return None;
312        };
313        *child_place
314    } else {
315        // Handle:
316        // ```
317        // bb2: {
318        //     switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
319        // }
320        // ```
321        let Operand::Copy(child_place) = child_discr else {
322            return None;
323        };
324        *child_place
325    };
326    let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
327    {
328        child_targets.otherwise()
329    } else {
330        targets.otherwise()
331    };
332
333    // Verify that the optimization is legal for each branch
334    for (value, child) in targets.iter() {
335        if !verify_candidate_branch(
336            &bbs[child],
337            value,
338            child_place,
339            destination,
340            need_hoist_discriminant,
341        ) {
342            return None;
343        }
344    }
345    Some(OptimizationData {
346        destination,
347        child_place,
348        child_ty,
349        child_source: *source_info,
350        need_hoist_discriminant,
351    })
352}
353
354fn verify_candidate_branch<'tcx>(
355    branch: &BasicBlockData<'tcx>,
356    value: u128,
357    place: Place<'tcx>,
358    destination: BasicBlock,
359    need_hoist_discriminant: bool,
360) -> bool {
361    // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
362    let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
363        return false;
364    };
365    if need_hoist_discriminant {
366        // If we need hoist discriminant, the branch must have exactly one statement.
367        let [statement] = branch.statements.as_slice() else {
368            return false;
369        };
370        // The statement must assign the discriminant of `place`.
371        let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
372            statement.kind
373        else {
374            return false;
375        };
376        if from_place != place {
377            return false;
378        }
379        // The assignment must invalidate a local that terminate on a `SwitchInt`.
380        if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
381            return false;
382        }
383    } else {
384        // If we don't need hoist discriminant, the branch must not have any statements.
385        if !branch.statements.is_empty() {
386            return false;
387        }
388        // The place on `SwitchInt` must be the same.
389        if *switch_op != Operand::Copy(place) {
390            return false;
391        }
392    }
393    // It must fall through to `destination` if the switch misses.
394    if destination != targets.otherwise() {
395        return false;
396    }
397    // It must have exactly one branch for value `value` and have no more branches.
398    let mut iter = targets.iter();
399    let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
400        return false;
401    };
402    target_value == value
403}