rustc_mir_transform/
early_otherwise_branch.rs

1use std::fmt::Debug;
2
3use rustc_middle::mir::*;
4use rustc_middle::ty::{Ty, TyCtxt};
5use tracing::trace;
6
7use super::simplify::simplify_cfg;
8use crate::patch::MirPatch;
9
10/// This pass optimizes something like
11/// ```ignore (syntax-highlighting-only)
12/// let x: Option<()>;
13/// let y: Option<()>;
14/// match (x,y) {
15///     (Some(_), Some(_)) => {0},
16///     (None, None) => {2},
17///     _ => {1}
18/// }
19/// ```
20/// into something like
21/// ```ignore (syntax-highlighting-only)
22/// let x: Option<()>;
23/// let y: Option<()>;
24/// let discriminant_x = std::mem::discriminant(x);
25/// let discriminant_y = std::mem::discriminant(y);
26/// if discriminant_x == discriminant_y {
27///     match x {
28///         Some(_) => 0,
29///         None => 2,
30///     }
31/// } else {
32///     1
33/// }
34/// ```
35///
36/// Specifically, it looks for instances of control flow like this:
37/// ```text
38///
39///     =================
40///     |      BB1      |
41///     |---------------|                  ============================
42///     |     ...       |         /------> |            BBC           |
43///     |---------------|         |        |--------------------------|
44///     |  switchInt(Q) |         |        |   _cl = discriminant(P)  |
45///     |       c       | --------/        |--------------------------|
46///     |       d       | -------\         |       switchInt(_cl)     |
47///     |      ...      |        |         |            c             | ---> BBC.2
48///     |    otherwise  | --\    |    /--- |         otherwise        |
49///     =================   |    |    |    ============================
50///                         |    |    |
51///     =================   |    |    |
52///     |      BBU      | <-|    |    |    ============================
53///     |---------------|        \-------> |            BBD           |
54///     |---------------|             |    |--------------------------|
55///     |  unreachable  |             |    |   _dl = discriminant(P)  |
56///     =================             |    |--------------------------|
57///                                   |    |       switchInt(_dl)     |
58///     =================             |    |            d             | ---> BBD.2
59///     |      BB9      | <--------------- |         otherwise        |
60///     |---------------|                  ============================
61///     |      ...      |
62///     =================
63/// ```
64/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
65/// code:
66///  - `BB1` is `parent` and `BBC, BBD` are children
67///  - `P` is `child_place`
68///  - `child_ty` is the type of `_cl`.
69///  - `Q` is `parent_op`.
70///  - `parent_ty` is the type of `Q`.
71///  - `BB9` is `destination`
72/// All this is then transformed into:
73/// ```text
74///
75///     =======================
76///     |          BB1        |
77///     |---------------------|                  ============================
78///     |          ...        |         /------> |           BBEq           |
79///     | _s = discriminant(P)|         |        |--------------------------|
80///     | _t = Ne(Q, _s)      |         |        |--------------------------|
81///     |---------------------|         |        |       switchInt(Q)       |
82///     |     switchInt(_t)   |         |        |            c             | ---> BBC.2
83///     |        false        | --------/        |            d             | ---> BBD.2
84///     |       otherwise     |       /--------- |         otherwise        |
85///     =======================       |          ============================
86///                                   |
87///     =================             |
88///     |      BB9      | <-----------/
89///     |---------------|
90///     |      ...      |
91///     =================
92/// ```
93pub(super) struct EarlyOtherwiseBranch;
94
95impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
96    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
97        sess.mir_opt_level() >= 2
98    }
99
100    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
101        trace!("running EarlyOtherwiseBranch on {:?}", body.source);
102
103        let mut should_cleanup = false;
104
105        // Also consider newly generated bbs in the same pass
106        for i in 0..body.basic_blocks.len() {
107            let bbs = &*body.basic_blocks;
108            let parent = BasicBlock::from_usize(i);
109            let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
110
111            trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
112
113            should_cleanup = true;
114
115            let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
116                &bbs[parent].terminator().kind
117            else {
118                unreachable!()
119            };
120            // Always correct since we can only switch on `Copy` types
121            let parent_op = match parent_op {
122                Operand::Move(x) => Operand::Copy(*x),
123                Operand::Copy(x) => Operand::Copy(*x),
124                Operand::Constant(x) => Operand::Constant(x.clone()),
125            };
126            let parent_ty = parent_op.ty(body.local_decls(), tcx);
127            let statements_before = bbs[parent].statements.len();
128            let parent_end = Location { block: parent, statement_index: statements_before };
129
130            let mut patch = MirPatch::new(body);
131
132            let (second_discriminant_temp, second_operand) = if opt_data.need_hoist_discriminant {
133                // create temp to store second discriminant in, `_s` in example above
134                let second_discriminant_temp =
135                    patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
136
137                patch.add_statement(
138                    parent_end,
139                    StatementKind::StorageLive(second_discriminant_temp),
140                );
141
142                // create assignment of discriminant
143                patch.add_assign(
144                    parent_end,
145                    Place::from(second_discriminant_temp),
146                    Rvalue::Discriminant(opt_data.child_place),
147                );
148                (
149                    Some(second_discriminant_temp),
150                    Operand::Move(Place::from(second_discriminant_temp)),
151                )
152            } else {
153                (None, Operand::Copy(opt_data.child_place))
154            };
155
156            // create temp to store inequality comparison between the two discriminants, `_t` in
157            // example above
158            let nequal = BinOp::Ne;
159            let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
160            let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
161            patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
162
163            // create inequality comparison
164            let comp_rvalue =
165                Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
166            patch.add_statement(
167                parent_end,
168                StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
169            );
170
171            let eq_new_targets = parent_targets.iter().map(|(value, child)| {
172                let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
173                else {
174                    unreachable!()
175                };
176                (value, targets.target_for_value(value))
177            });
178            // The otherwise either is the same target branch or an unreachable.
179            let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
180
181            // Create `bbEq` in example above
182            let eq_switch = BasicBlockData::new(
183                Some(Terminator {
184                    source_info: bbs[parent].terminator().source_info,
185                    kind: TerminatorKind::SwitchInt {
186                        // switch on the first discriminant, so we can mark the second one as dead
187                        discr: parent_op,
188                        targets: eq_targets,
189                    },
190                }),
191                bbs[parent].is_cleanup,
192            );
193
194            let eq_bb = patch.new_block(eq_switch);
195
196            // Jump to it on the basis of the inequality comparison
197            let true_case = opt_data.destination;
198            let false_case = eq_bb;
199            patch.patch_terminator(
200                parent,
201                TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
202            );
203
204            if let Some(second_discriminant_temp) = second_discriminant_temp {
205                // generate StorageDead for the second_discriminant_temp not in use anymore
206                patch.add_statement(
207                    parent_end,
208                    StatementKind::StorageDead(second_discriminant_temp),
209                );
210            }
211
212            // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
213            // the switch
214            for bb in [false_case, true_case].iter() {
215                patch.add_statement(
216                    Location { block: *bb, statement_index: 0 },
217                    StatementKind::StorageDead(comp_temp),
218                );
219            }
220
221            patch.apply(body);
222        }
223
224        // Since this optimization adds new basic blocks and invalidates others,
225        // clean up the cfg to make it nicer for other passes
226        if should_cleanup {
227            simplify_cfg(body);
228        }
229    }
230
231    fn is_required(&self) -> bool {
232        false
233    }
234}
235
236#[derive(Debug)]
237struct OptimizationData<'tcx> {
238    destination: BasicBlock,
239    child_place: Place<'tcx>,
240    child_ty: Ty<'tcx>,
241    child_source: SourceInfo,
242    need_hoist_discriminant: bool,
243}
244
245fn evaluate_candidate<'tcx>(
246    tcx: TyCtxt<'tcx>,
247    body: &Body<'tcx>,
248    parent: BasicBlock,
249) -> Option<OptimizationData<'tcx>> {
250    let bbs = &body.basic_blocks;
251    // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled.
252    if bbs[parent].is_cleanup {
253        return None;
254    }
255    let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind
256    else {
257        return None;
258    };
259    let parent_ty = parent_discr.ty(body.local_decls(), tcx);
260    let (_, child) = targets.iter().next()?;
261
262    let Terminator {
263        kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
264        source_info,
265    } = bbs[child].terminator()
266    else {
267        return None;
268    };
269    let child_ty = child_discr.ty(body.local_decls(), tcx);
270    if child_ty != parent_ty {
271        return None;
272    }
273
274    // We only handle:
275    // ```
276    // bb4: {
277    //     _8 = discriminant((_3.1: Enum1));
278    //    switchInt(move _8) -> [2: bb7, otherwise: bb1];
279    // }
280    // ```
281    // and
282    // ```
283    // bb2: {
284    //     switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
285    // }
286    // ```
287    if bbs[child].statements.len() > 1 {
288        return None;
289    }
290
291    // When thie BB has exactly one statement, this statement should be discriminant.
292    let need_hoist_discriminant = bbs[child].statements.len() == 1;
293    let child_place = if need_hoist_discriminant {
294        if !bbs[targets.otherwise()].is_empty_unreachable() {
295            // Someone could write code like this:
296            // ```rust
297            // let Q = val;
298            // if discriminant(P) == otherwise {
299            //     let ptr = &mut Q as *mut _ as *mut u8;
300            //     // It may be difficult for us to effectively determine whether values are valid.
301            //     // Invalid values can come from all sorts of corners.
302            //     unsafe { *ptr = 10; }
303            // }
304            //
305            // match P {
306            //    A => match Q {
307            //        A => {
308            //            // code
309            //        }
310            //        _ => {
311            //            // don't use Q
312            //        }
313            //    }
314            //    _ => {
315            //        // don't use Q
316            //    }
317            // };
318            // ```
319            //
320            // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
321            // invalid value, which is UB.
322            // In order to fix this, **we would either need to show that the discriminant computation of
323            // `place` is computed in all branches**.
324            // FIXME(#95162) For the moment, we adopt a conservative approach and
325            // consider only the `otherwise` branch has no statements and an unreachable terminator.
326            return None;
327        }
328        // Handle:
329        // ```
330        // bb4: {
331        //     _8 = discriminant((_3.1: Enum1));
332        //    switchInt(move _8) -> [2: bb7, otherwise: bb1];
333        // }
334        // ```
335        let [
336            Statement {
337                kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
338                ..
339            },
340        ] = bbs[child].statements.as_slice()
341        else {
342            return None;
343        };
344        *child_place
345    } else {
346        // Handle:
347        // ```
348        // bb2: {
349        //     switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
350        // }
351        // ```
352        let Operand::Copy(child_place) = child_discr else {
353            return None;
354        };
355        *child_place
356    };
357    let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
358    {
359        child_targets.otherwise()
360    } else {
361        targets.otherwise()
362    };
363
364    // Verify that the optimization is legal for each branch
365    for (value, child) in targets.iter() {
366        if !verify_candidate_branch(
367            &bbs[child],
368            value,
369            child_place,
370            destination,
371            need_hoist_discriminant,
372        ) {
373            return None;
374        }
375    }
376    Some(OptimizationData {
377        destination,
378        child_place,
379        child_ty,
380        child_source: *source_info,
381        need_hoist_discriminant,
382    })
383}
384
385fn verify_candidate_branch<'tcx>(
386    branch: &BasicBlockData<'tcx>,
387    value: u128,
388    place: Place<'tcx>,
389    destination: BasicBlock,
390    need_hoist_discriminant: bool,
391) -> bool {
392    // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
393    let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
394        return false;
395    };
396    if need_hoist_discriminant {
397        // If we need hoist discriminant, the branch must have exactly one statement.
398        let [statement] = branch.statements.as_slice() else {
399            return false;
400        };
401        // The statement must assign the discriminant of `place`.
402        let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
403            statement.kind
404        else {
405            return false;
406        };
407        if from_place != place {
408            return false;
409        }
410        // The assignment must invalidate a local that terminate on a `SwitchInt`.
411        if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
412            return false;
413        }
414    } else {
415        // If we don't need hoist discriminant, the branch must not have any statements.
416        if !branch.statements.is_empty() {
417            return false;
418        }
419        // The place on `SwitchInt` must be the same.
420        if *switch_op != Operand::Copy(place) {
421            return false;
422        }
423    }
424    // It must fall through to `destination` if the switch misses.
425    if destination != targets.otherwise() {
426        return false;
427    }
428    // It must have exactly one branch for value `value` and have no more branches.
429    let mut iter = targets.iter();
430    let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
431        return false;
432    };
433    target_value == value
434}