Skip to main content

rustc_mir_transform/
early_otherwise_branch.rs

1use std::fmt::Debug;
2
3use rustc_data_structures::thin_vec::ThinVec;
4use rustc_middle::mir::*;
5use rustc_middle::ty::{Ty, TyCtxt};
6use tracing::trace;
7
8use super::simplify::simplify_cfg;
9use crate::patch::MirPatch;
10
11/// This pass optimizes something like
12/// ```ignore (syntax-highlighting-only)
13/// let x: Option<()>;
14/// let y: Option<()>;
15/// match (x,y) {
16///     (Some(_), Some(_)) => {0},
17///     (None, None) => {2},
18///     _ => {1}
19/// }
20/// ```
21/// into something like
22/// ```ignore (syntax-highlighting-only)
23/// let x: Option<()>;
24/// let y: Option<()>;
25/// let discriminant_x = std::mem::discriminant(x);
26/// let discriminant_y = std::mem::discriminant(y);
27/// if discriminant_x == discriminant_y {
28///     match x {
29///         Some(_) => 0,
30///         None => 2,
31///     }
32/// } else {
33///     1
34/// }
35/// ```
36///
37/// Specifically, it looks for instances of control flow like this:
38/// ```text
39///
40///     =================
41///     |      BB1      |
42///     |---------------|                  ============================
43///     |     ...       |         /------> |            BBC           |
44///     |---------------|         |        |--------------------------|
45///     |  switchInt(Q) |         |        |   _cl = discriminant(P)  |
46///     |       c       | --------/        |--------------------------|
47///     |       d       | -------\         |       switchInt(_cl)     |
48///     |      ...      |        |         |            c             | ---> BBC.2
49///     |    otherwise  | --\    |    /--- |         otherwise        |
50///     =================   |    |    |    ============================
51///                         |    |    |
52///     =================   |    |    |
53///     |      BBU      | <-|    |    |    ============================
54///     |---------------|        \-------> |            BBD           |
55///     |---------------|             |    |--------------------------|
56///     |  unreachable  |             |    |   _dl = discriminant(P)  |
57///     =================             |    |--------------------------|
58///                                   |    |       switchInt(_dl)     |
59///     =================             |    |            d             | ---> BBD.2
60///     |      BB9      | <--------------- |         otherwise        |
61///     |---------------|                  ============================
62///     |      ...      |
63///     =================
64/// ```
65/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
66/// code:
67///  - `BB1` is `parent` and `BBC, BBD` are children
68///  - `P` is `child_place`
69///  - `child_ty` is the type of `_cl`.
70///  - `Q` is `parent_op`.
71///  - `parent_ty` is the type of `Q`.
72///  - `BB9` is `destination`
73/// All this is then transformed into:
74/// ```text
75///
76///     =======================
77///     |          BB1        |
78///     |---------------------|                  ============================
79///     |          ...        |         /------> |           BBEq           |
80///     | _s = discriminant(P)|         |        |--------------------------|
81///     | _t = Ne(Q, _s)      |         |        |--------------------------|
82///     |---------------------|         |        |       switchInt(Q)       |
83///     |     switchInt(_t)   |         |        |            c             | ---> BBC.2
84///     |        false        | --------/        |            d             | ---> BBD.2
85///     |       otherwise     |       /--------- |         otherwise        |
86///     =======================       |          ============================
87///                                   |
88///     =================             |
89///     |      BB9      | <-----------/
90///     |---------------|
91///     |      ...      |
92///     =================
93/// ```
94pub(super) struct EarlyOtherwiseBranch;
95
96impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
97    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
98        sess.mir_opt_level() >= 2
99    }
100
101    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
102        trace!("running EarlyOtherwiseBranch on {:?}", body.source);
103
104        let mut should_cleanup = false;
105
106        // Also consider newly generated bbs in the same pass
107        for parent in body.basic_blocks.indices() {
108            let bbs = &*body.basic_blocks;
109            let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
110
111            trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
112
113            should_cleanup = true;
114
115            let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
116                &bbs[parent].terminator().kind
117            else {
118                unreachable!()
119            };
120            // Always correct since we can only switch on `Copy` types
121            let parent_op = parent_op.to_copy();
122            let parent_ty = parent_op.ty(body.local_decls(), tcx);
123            let statements_before = bbs[parent].statements.len();
124            let parent_end = Location { block: parent, statement_index: statements_before };
125
126            let mut patch = MirPatch::new(body);
127
128            let second_operand = if opt_data.need_hoist_discriminant {
129                // create temp to store second discriminant in, `_s` in example above
130                let second_discriminant_temp =
131                    patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
132
133                // create assignment of discriminant
134                patch.add_assign(
135                    parent_end,
136                    Place::from(second_discriminant_temp),
137                    Rvalue::Discriminant(opt_data.child_place),
138                );
139                Operand::Move(Place::from(second_discriminant_temp))
140            } else {
141                Operand::Copy(opt_data.child_place)
142            };
143
144            // create temp to store inequality comparison between the two discriminants, `_t` in
145            // example above
146            let nequal = BinOp::Ne;
147            let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
148            let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
149
150            // create inequality comparison
151            let comp_rvalue =
152                Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
153            patch.add_statement(
154                parent_end,
155                StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
156            );
157
158            let eq_new_targets = parent_targets.iter().map(|(value, child)| {
159                let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
160                else {
161                    unreachable!()
162                };
163                (value, targets.target_for_value(value))
164            });
165            // The otherwise either is the same target branch or an unreachable.
166            let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
167
168            // Create `bbEq` in example above
169            let eq_switch = BasicBlockData::new(
170                Some(Terminator {
171                    source_info: bbs[parent].terminator().source_info,
172                    kind: TerminatorKind::SwitchInt {
173                        // switch on the first discriminant, so we can mark the second one as dead
174                        discr: parent_op,
175                        targets: eq_targets,
176                    },
177                    attributes: ThinVec::new(),
178                }),
179                bbs[parent].is_cleanup,
180            );
181
182            let eq_bb = patch.new_block(eq_switch);
183
184            // Jump to it on the basis of the inequality comparison
185            let true_case = opt_data.destination;
186            let false_case = eq_bb;
187            patch.patch_terminator(
188                parent,
189                TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
190            );
191
192            patch.apply(body);
193        }
194
195        // Since this optimization adds new basic blocks and invalidates others,
196        // clean up the cfg to make it nicer for other passes
197        if should_cleanup {
198            simplify_cfg(tcx, body);
199        }
200    }
201
202    fn is_required(&self) -> bool {
203        false
204    }
205}
206
207#[derive(Debug)]
208struct OptimizationData<'tcx> {
209    destination: BasicBlock,
210    child_place: Place<'tcx>,
211    child_ty: Ty<'tcx>,
212    child_source: SourceInfo,
213    need_hoist_discriminant: bool,
214}
215
216fn evaluate_candidate<'tcx>(
217    tcx: TyCtxt<'tcx>,
218    body: &Body<'tcx>,
219    parent: BasicBlock,
220) -> Option<OptimizationData<'tcx>> {
221    let bbs = &body.basic_blocks;
222    // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled.
223    if bbs[parent].is_cleanup {
224        return None;
225    }
226    let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind
227    else {
228        return None;
229    };
230    let parent_ty = parent_discr.ty(body.local_decls(), tcx);
231    let (_, child) = targets.iter().next()?;
232
233    let Terminator {
234        kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
235        source_info,
236        attributes: _,
237    } = bbs[child].terminator()
238    else {
239        return None;
240    };
241    let child_ty = child_discr.ty(body.local_decls(), tcx);
242    if child_ty != parent_ty {
243        return None;
244    }
245
246    // We only handle:
247    // ```
248    // bb4: {
249    //     _8 = discriminant((_3.1: Enum1));
250    //    switchInt(move _8) -> [2: bb7, otherwise: bb1];
251    // }
252    // ```
253    // and
254    // ```
255    // bb2: {
256    //     switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
257    // }
258    // ```
259    if bbs[child].statements.len() > 1 {
260        return None;
261    }
262
263    // When thie BB has exactly one statement, this statement should be discriminant.
264    let need_hoist_discriminant = bbs[child].statements.len() == 1;
265    let child_place = if need_hoist_discriminant {
266        if !bbs[targets.otherwise()].is_empty_unreachable() {
267            // Someone could write code like this:
268            // ```rust
269            // let Q = val;
270            // if discriminant(P) == otherwise {
271            //     let ptr = &mut Q as *mut _ as *mut u8;
272            //     // It may be difficult for us to effectively determine whether values are valid.
273            //     // Invalid values can come from all sorts of corners.
274            //     unsafe { *ptr = 10; }
275            // }
276            //
277            // match P {
278            //    A => match Q {
279            //        A => {
280            //            // code
281            //        }
282            //        _ => {
283            //            // don't use Q
284            //        }
285            //    }
286            //    _ => {
287            //        // don't use Q
288            //    }
289            // };
290            // ```
291            //
292            // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
293            // invalid value, which is UB.
294            // In order to fix this, **we would either need to show that the discriminant computation of
295            // `place` is computed in all branches**.
296            // FIXME(#95162) For the moment, we adopt a conservative approach and
297            // consider only the `otherwise` branch has no statements and an unreachable terminator.
298            return None;
299        }
300        // Handle:
301        // ```
302        // bb4: {
303        //     _8 = discriminant((_3.1: Enum1));
304        //    switchInt(move _8) -> [2: bb7, otherwise: bb1];
305        // }
306        // ```
307        let [
308            Statement {
309                kind: StatementKind::Assign((_, Rvalue::Discriminant(child_place))), ..
310            },
311        ] = bbs[child].statements.as_slice()
312        else {
313            return None;
314        };
315        *child_place
316    } else {
317        // Handle:
318        // ```
319        // bb2: {
320        //     switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
321        // }
322        // ```
323        let Operand::Copy(child_place) = child_discr else {
324            return None;
325        };
326        *child_place
327    };
328    let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
329    {
330        child_targets.otherwise()
331    } else {
332        targets.otherwise()
333    };
334
335    // Verify that the optimization is legal for each branch
336    for (value, child) in targets.iter() {
337        if !verify_candidate_branch(
338            &bbs[child],
339            value,
340            child_place,
341            destination,
342            need_hoist_discriminant,
343        ) {
344            return None;
345        }
346    }
347    Some(OptimizationData {
348        destination,
349        child_place,
350        child_ty,
351        child_source: *source_info,
352        need_hoist_discriminant,
353    })
354}
355
356fn verify_candidate_branch<'tcx>(
357    branch: &BasicBlockData<'tcx>,
358    value: u128,
359    place: Place<'tcx>,
360    destination: BasicBlock,
361    need_hoist_discriminant: bool,
362) -> bool {
363    // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
364    let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
365        return false;
366    };
367    if need_hoist_discriminant {
368        // If we need hoist discriminant, the branch must have exactly one statement.
369        let [statement] = branch.statements.as_slice() else {
370            return false;
371        };
372        // The statement must assign the discriminant of `place`.
373        let StatementKind::Assign((discr_place, Rvalue::Discriminant(from_place))) = statement.kind
374        else {
375            return false;
376        };
377        if from_place != place {
378            return false;
379        }
380        // The assignment must invalidate a local that terminate on a `SwitchInt`.
381        if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
382            return false;
383        }
384    } else {
385        // If we don't need hoist discriminant, the branch must not have any statements.
386        if !branch.statements.is_empty() {
387            return false;
388        }
389        // The place on `SwitchInt` must be the same.
390        if *switch_op != Operand::Copy(place) {
391            return false;
392        }
393    }
394    // It must fall through to `destination` if the switch misses.
395    if destination != targets.otherwise() {
396        return false;
397    }
398    // It must have exactly one branch for value `value` and have no more branches.
399    let mut iter = targets.iter();
400    let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
401        return false;
402    };
403    target_value == value
404}