rustc_mir_transform/early_otherwise_branch.rs
1use std::fmt::Debug;
2
3use rustc_middle::mir::*;
4use rustc_middle::ty::{Ty, TyCtxt};
5use tracing::trace;
6
7use super::simplify::simplify_cfg;
8use crate::patch::MirPatch;
9
10/// This pass optimizes something like
11/// ```ignore (syntax-highlighting-only)
12/// let x: Option<()>;
13/// let y: Option<()>;
14/// match (x,y) {
15/// (Some(_), Some(_)) => {0},
16/// (None, None) => {2},
17/// _ => {1}
18/// }
19/// ```
20/// into something like
21/// ```ignore (syntax-highlighting-only)
22/// let x: Option<()>;
23/// let y: Option<()>;
24/// let discriminant_x = std::mem::discriminant(x);
25/// let discriminant_y = std::mem::discriminant(y);
26/// if discriminant_x == discriminant_y {
27/// match x {
28/// Some(_) => 0,
29/// None => 2,
30/// }
31/// } else {
32/// 1
33/// }
34/// ```
35///
36/// Specifically, it looks for instances of control flow like this:
37/// ```text
38///
39/// =================
40/// | BB1 |
41/// |---------------| ============================
42/// | ... | /------> | BBC |
43/// |---------------| | |--------------------------|
44/// | switchInt(Q) | | | _cl = discriminant(P) |
45/// | c | --------/ |--------------------------|
46/// | d | -------\ | switchInt(_cl) |
47/// | ... | | | c | ---> BBC.2
48/// | otherwise | --\ | /--- | otherwise |
49/// ================= | | | ============================
50/// | | |
51/// ================= | | |
52/// | BBU | <-| | | ============================
53/// |---------------| \-------> | BBD |
54/// |---------------| | |--------------------------|
55/// | unreachable | | | _dl = discriminant(P) |
56/// ================= | |--------------------------|
57/// | | switchInt(_dl) |
58/// ================= | | d | ---> BBD.2
59/// | BB9 | <--------------- | otherwise |
60/// |---------------| ============================
61/// | ... |
62/// =================
63/// ```
64/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
65/// code:
66/// - `BB1` is `parent` and `BBC, BBD` are children
67/// - `P` is `child_place`
68/// - `child_ty` is the type of `_cl`.
69/// - `Q` is `parent_op`.
70/// - `parent_ty` is the type of `Q`.
71/// - `BB9` is `destination`
72/// All this is then transformed into:
73/// ```text
74///
75/// =======================
76/// | BB1 |
77/// |---------------------| ============================
78/// | ... | /------> | BBEq |
79/// | _s = discriminant(P)| | |--------------------------|
80/// | _t = Ne(Q, _s) | | |--------------------------|
81/// |---------------------| | | switchInt(Q) |
82/// | switchInt(_t) | | | c | ---> BBC.2
83/// | false | --------/ | d | ---> BBD.2
84/// | otherwise | /--------- | otherwise |
85/// ======================= | ============================
86/// |
87/// ================= |
88/// | BB9 | <-----------/
89/// |---------------|
90/// | ... |
91/// =================
92/// ```
93pub(super) struct EarlyOtherwiseBranch;
94
95impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
96 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
97 sess.mir_opt_level() >= 2
98 }
99
100 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
101 trace!("running EarlyOtherwiseBranch on {:?}", body.source);
102
103 let mut should_cleanup = false;
104
105 // Also consider newly generated bbs in the same pass
106 for i in 0..body.basic_blocks.len() {
107 let bbs = &*body.basic_blocks;
108 let parent = BasicBlock::from_usize(i);
109 let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
110
111 trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
112
113 should_cleanup = true;
114
115 let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
116 &bbs[parent].terminator().kind
117 else {
118 unreachable!()
119 };
120 // Always correct since we can only switch on `Copy` types
121 let parent_op = match parent_op {
122 Operand::Move(x) => Operand::Copy(*x),
123 Operand::Copy(x) => Operand::Copy(*x),
124 Operand::Constant(x) => Operand::Constant(x.clone()),
125 };
126 let parent_ty = parent_op.ty(body.local_decls(), tcx);
127 let statements_before = bbs[parent].statements.len();
128 let parent_end = Location { block: parent, statement_index: statements_before };
129
130 let mut patch = MirPatch::new(body);
131
132 let (second_discriminant_temp, second_operand) = if opt_data.need_hoist_discriminant {
133 // create temp to store second discriminant in, `_s` in example above
134 let second_discriminant_temp =
135 patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
136
137 patch.add_statement(
138 parent_end,
139 StatementKind::StorageLive(second_discriminant_temp),
140 );
141
142 // create assignment of discriminant
143 patch.add_assign(
144 parent_end,
145 Place::from(second_discriminant_temp),
146 Rvalue::Discriminant(opt_data.child_place),
147 );
148 (
149 Some(second_discriminant_temp),
150 Operand::Move(Place::from(second_discriminant_temp)),
151 )
152 } else {
153 (None, Operand::Copy(opt_data.child_place))
154 };
155
156 // create temp to store inequality comparison between the two discriminants, `_t` in
157 // example above
158 let nequal = BinOp::Ne;
159 let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
160 let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
161 patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
162
163 // create inequality comparison
164 let comp_rvalue =
165 Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
166 patch.add_statement(
167 parent_end,
168 StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
169 );
170
171 let eq_new_targets = parent_targets.iter().map(|(value, child)| {
172 let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
173 else {
174 unreachable!()
175 };
176 (value, targets.target_for_value(value))
177 });
178 // The otherwise either is the same target branch or an unreachable.
179 let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
180
181 // Create `bbEq` in example above
182 let eq_switch = BasicBlockData::new(
183 Some(Terminator {
184 source_info: bbs[parent].terminator().source_info,
185 kind: TerminatorKind::SwitchInt {
186 // switch on the first discriminant, so we can mark the second one as dead
187 discr: parent_op,
188 targets: eq_targets,
189 },
190 }),
191 bbs[parent].is_cleanup,
192 );
193
194 let eq_bb = patch.new_block(eq_switch);
195
196 // Jump to it on the basis of the inequality comparison
197 let true_case = opt_data.destination;
198 let false_case = eq_bb;
199 patch.patch_terminator(
200 parent,
201 TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
202 );
203
204 if let Some(second_discriminant_temp) = second_discriminant_temp {
205 // generate StorageDead for the second_discriminant_temp not in use anymore
206 patch.add_statement(
207 parent_end,
208 StatementKind::StorageDead(second_discriminant_temp),
209 );
210 }
211
212 // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
213 // the switch
214 for bb in [false_case, true_case].iter() {
215 patch.add_statement(
216 Location { block: *bb, statement_index: 0 },
217 StatementKind::StorageDead(comp_temp),
218 );
219 }
220
221 patch.apply(body);
222 }
223
224 // Since this optimization adds new basic blocks and invalidates others,
225 // clean up the cfg to make it nicer for other passes
226 if should_cleanup {
227 simplify_cfg(body);
228 }
229 }
230
231 fn is_required(&self) -> bool {
232 false
233 }
234}
235
236#[derive(Debug)]
237struct OptimizationData<'tcx> {
238 destination: BasicBlock,
239 child_place: Place<'tcx>,
240 child_ty: Ty<'tcx>,
241 child_source: SourceInfo,
242 need_hoist_discriminant: bool,
243}
244
245fn evaluate_candidate<'tcx>(
246 tcx: TyCtxt<'tcx>,
247 body: &Body<'tcx>,
248 parent: BasicBlock,
249) -> Option<OptimizationData<'tcx>> {
250 let bbs = &body.basic_blocks;
251 // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled.
252 if bbs[parent].is_cleanup {
253 return None;
254 }
255 let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind
256 else {
257 return None;
258 };
259 let parent_ty = parent_discr.ty(body.local_decls(), tcx);
260 let (_, child) = targets.iter().next()?;
261
262 let Terminator {
263 kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
264 source_info,
265 } = bbs[child].terminator()
266 else {
267 return None;
268 };
269 let child_ty = child_discr.ty(body.local_decls(), tcx);
270 if child_ty != parent_ty {
271 return None;
272 }
273
274 // We only handle:
275 // ```
276 // bb4: {
277 // _8 = discriminant((_3.1: Enum1));
278 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
279 // }
280 // ```
281 // and
282 // ```
283 // bb2: {
284 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
285 // }
286 // ```
287 if bbs[child].statements.len() > 1 {
288 return None;
289 }
290
291 // When thie BB has exactly one statement, this statement should be discriminant.
292 let need_hoist_discriminant = bbs[child].statements.len() == 1;
293 let child_place = if need_hoist_discriminant {
294 if !bbs[targets.otherwise()].is_empty_unreachable() {
295 // Someone could write code like this:
296 // ```rust
297 // let Q = val;
298 // if discriminant(P) == otherwise {
299 // let ptr = &mut Q as *mut _ as *mut u8;
300 // // It may be difficult for us to effectively determine whether values are valid.
301 // // Invalid values can come from all sorts of corners.
302 // unsafe { *ptr = 10; }
303 // }
304 //
305 // match P {
306 // A => match Q {
307 // A => {
308 // // code
309 // }
310 // _ => {
311 // // don't use Q
312 // }
313 // }
314 // _ => {
315 // // don't use Q
316 // }
317 // };
318 // ```
319 //
320 // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
321 // invalid value, which is UB.
322 // In order to fix this, **we would either need to show that the discriminant computation of
323 // `place` is computed in all branches**.
324 // FIXME(#95162) For the moment, we adopt a conservative approach and
325 // consider only the `otherwise` branch has no statements and an unreachable terminator.
326 return None;
327 }
328 // Handle:
329 // ```
330 // bb4: {
331 // _8 = discriminant((_3.1: Enum1));
332 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
333 // }
334 // ```
335 let [
336 Statement {
337 kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
338 ..
339 },
340 ] = bbs[child].statements.as_slice()
341 else {
342 return None;
343 };
344 *child_place
345 } else {
346 // Handle:
347 // ```
348 // bb2: {
349 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
350 // }
351 // ```
352 let Operand::Copy(child_place) = child_discr else {
353 return None;
354 };
355 *child_place
356 };
357 let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
358 {
359 child_targets.otherwise()
360 } else {
361 targets.otherwise()
362 };
363
364 // Verify that the optimization is legal for each branch
365 for (value, child) in targets.iter() {
366 if !verify_candidate_branch(
367 &bbs[child],
368 value,
369 child_place,
370 destination,
371 need_hoist_discriminant,
372 ) {
373 return None;
374 }
375 }
376 Some(OptimizationData {
377 destination,
378 child_place,
379 child_ty,
380 child_source: *source_info,
381 need_hoist_discriminant,
382 })
383}
384
385fn verify_candidate_branch<'tcx>(
386 branch: &BasicBlockData<'tcx>,
387 value: u128,
388 place: Place<'tcx>,
389 destination: BasicBlock,
390 need_hoist_discriminant: bool,
391) -> bool {
392 // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
393 let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
394 return false;
395 };
396 if need_hoist_discriminant {
397 // If we need hoist discriminant, the branch must have exactly one statement.
398 let [statement] = branch.statements.as_slice() else {
399 return false;
400 };
401 // The statement must assign the discriminant of `place`.
402 let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
403 statement.kind
404 else {
405 return false;
406 };
407 if from_place != place {
408 return false;
409 }
410 // The assignment must invalidate a local that terminate on a `SwitchInt`.
411 if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
412 return false;
413 }
414 } else {
415 // If we don't need hoist discriminant, the branch must not have any statements.
416 if !branch.statements.is_empty() {
417 return false;
418 }
419 // The place on `SwitchInt` must be the same.
420 if *switch_op != Operand::Copy(place) {
421 return false;
422 }
423 }
424 // It must fall through to `destination` if the switch misses.
425 if destination != targets.otherwise() {
426 return false;
427 }
428 // It must have exactly one branch for value `value` and have no more branches.
429 let mut iter = targets.iter();
430 let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
431 return false;
432 };
433 target_value == value
434}