rustc_mir_transform/early_otherwise_branch.rs
1use std::fmt::Debug;
2
3use rustc_middle::mir::*;
4use rustc_middle::ty::{Ty, TyCtxt};
5use tracing::trace;
6
7use super::simplify::simplify_cfg;
8use crate::patch::MirPatch;
9
10/// This pass optimizes something like
11/// ```ignore (syntax-highlighting-only)
12/// let x: Option<()>;
13/// let y: Option<()>;
14/// match (x,y) {
15/// (Some(_), Some(_)) => {0},
16/// (None, None) => {2},
17/// _ => {1}
18/// }
19/// ```
20/// into something like
21/// ```ignore (syntax-highlighting-only)
22/// let x: Option<()>;
23/// let y: Option<()>;
24/// let discriminant_x = std::mem::discriminant(x);
25/// let discriminant_y = std::mem::discriminant(y);
26/// if discriminant_x == discriminant_y {
27/// match x {
28/// Some(_) => 0,
29/// None => 2,
30/// }
31/// } else {
32/// 1
33/// }
34/// ```
35///
36/// Specifically, it looks for instances of control flow like this:
37/// ```text
38///
39/// =================
40/// | BB1 |
41/// |---------------| ============================
42/// | ... | /------> | BBC |
43/// |---------------| | |--------------------------|
44/// | switchInt(Q) | | | _cl = discriminant(P) |
45/// | c | --------/ |--------------------------|
46/// | d | -------\ | switchInt(_cl) |
47/// | ... | | | c | ---> BBC.2
48/// | otherwise | --\ | /--- | otherwise |
49/// ================= | | | ============================
50/// | | |
51/// ================= | | |
52/// | BBU | <-| | | ============================
53/// |---------------| \-------> | BBD |
54/// |---------------| | |--------------------------|
55/// | unreachable | | | _dl = discriminant(P) |
56/// ================= | |--------------------------|
57/// | | switchInt(_dl) |
58/// ================= | | d | ---> BBD.2
59/// | BB9 | <--------------- | otherwise |
60/// |---------------| ============================
61/// | ... |
62/// =================
63/// ```
64/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
65/// code:
66/// - `BB1` is `parent` and `BBC, BBD` are children
67/// - `P` is `child_place`
68/// - `child_ty` is the type of `_cl`.
69/// - `Q` is `parent_op`.
70/// - `parent_ty` is the type of `Q`.
71/// - `BB9` is `destination`
72/// All this is then transformed into:
73/// ```text
74///
75/// =======================
76/// | BB1 |
77/// |---------------------| ============================
78/// | ... | /------> | BBEq |
79/// | _s = discriminant(P)| | |--------------------------|
80/// | _t = Ne(Q, _s) | | |--------------------------|
81/// |---------------------| | | switchInt(Q) |
82/// | switchInt(_t) | | | c | ---> BBC.2
83/// | false | --------/ | d | ---> BBD.2
84/// | otherwise | /--------- | otherwise |
85/// ======================= | ============================
86/// |
87/// ================= |
88/// | BB9 | <-----------/
89/// |---------------|
90/// | ... |
91/// =================
92/// ```
93pub(super) struct EarlyOtherwiseBranch;
94
95impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
96 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
97 sess.mir_opt_level() >= 2
98 }
99
100 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
101 trace!("running EarlyOtherwiseBranch on {:?}", body.source);
102
103 let mut should_cleanup = false;
104
105 // Also consider newly generated bbs in the same pass
106 for parent in body.basic_blocks.indices() {
107 let bbs = &*body.basic_blocks;
108 let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
109
110 trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
111
112 should_cleanup = true;
113
114 let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
115 &bbs[parent].terminator().kind
116 else {
117 unreachable!()
118 };
119 // Always correct since we can only switch on `Copy` types
120 let parent_op = parent_op.to_copy();
121 let parent_ty = parent_op.ty(body.local_decls(), tcx);
122 let statements_before = bbs[parent].statements.len();
123 let parent_end = Location { block: parent, statement_index: statements_before };
124
125 let mut patch = MirPatch::new(body);
126
127 let second_operand = if opt_data.need_hoist_discriminant {
128 // create temp to store second discriminant in, `_s` in example above
129 let second_discriminant_temp =
130 patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
131
132 // create assignment of discriminant
133 patch.add_assign(
134 parent_end,
135 Place::from(second_discriminant_temp),
136 Rvalue::Discriminant(opt_data.child_place),
137 );
138 Operand::Move(Place::from(second_discriminant_temp))
139 } else {
140 Operand::Copy(opt_data.child_place)
141 };
142
143 // create temp to store inequality comparison between the two discriminants, `_t` in
144 // example above
145 let nequal = BinOp::Ne;
146 let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
147 let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
148
149 // create inequality comparison
150 let comp_rvalue =
151 Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
152 patch.add_statement(
153 parent_end,
154 StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
155 );
156
157 let eq_new_targets = parent_targets.iter().map(|(value, child)| {
158 let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
159 else {
160 unreachable!()
161 };
162 (value, targets.target_for_value(value))
163 });
164 // The otherwise either is the same target branch or an unreachable.
165 let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
166
167 // Create `bbEq` in example above
168 let eq_switch = BasicBlockData::new(
169 Some(Terminator {
170 source_info: bbs[parent].terminator().source_info,
171 kind: TerminatorKind::SwitchInt {
172 // switch on the first discriminant, so we can mark the second one as dead
173 discr: parent_op,
174 targets: eq_targets,
175 },
176 }),
177 bbs[parent].is_cleanup,
178 );
179
180 let eq_bb = patch.new_block(eq_switch);
181
182 // Jump to it on the basis of the inequality comparison
183 let true_case = opt_data.destination;
184 let false_case = eq_bb;
185 patch.patch_terminator(
186 parent,
187 TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
188 );
189
190 patch.apply(body);
191 }
192
193 // Since this optimization adds new basic blocks and invalidates others,
194 // clean up the cfg to make it nicer for other passes
195 if should_cleanup {
196 simplify_cfg(tcx, body);
197 }
198 }
199
200 fn is_required(&self) -> bool {
201 false
202 }
203}
204
205#[derive(Debug)]
206struct OptimizationData<'tcx> {
207 destination: BasicBlock,
208 child_place: Place<'tcx>,
209 child_ty: Ty<'tcx>,
210 child_source: SourceInfo,
211 need_hoist_discriminant: bool,
212}
213
214fn evaluate_candidate<'tcx>(
215 tcx: TyCtxt<'tcx>,
216 body: &Body<'tcx>,
217 parent: BasicBlock,
218) -> Option<OptimizationData<'tcx>> {
219 let bbs = &body.basic_blocks;
220 // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled.
221 if bbs[parent].is_cleanup {
222 return None;
223 }
224 let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind
225 else {
226 return None;
227 };
228 let parent_ty = parent_discr.ty(body.local_decls(), tcx);
229 let (_, child) = targets.iter().next()?;
230
231 let Terminator {
232 kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
233 source_info,
234 } = bbs[child].terminator()
235 else {
236 return None;
237 };
238 let child_ty = child_discr.ty(body.local_decls(), tcx);
239 if child_ty != parent_ty {
240 return None;
241 }
242
243 // We only handle:
244 // ```
245 // bb4: {
246 // _8 = discriminant((_3.1: Enum1));
247 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
248 // }
249 // ```
250 // and
251 // ```
252 // bb2: {
253 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
254 // }
255 // ```
256 if bbs[child].statements.len() > 1 {
257 return None;
258 }
259
260 // When thie BB has exactly one statement, this statement should be discriminant.
261 let need_hoist_discriminant = bbs[child].statements.len() == 1;
262 let child_place = if need_hoist_discriminant {
263 if !bbs[targets.otherwise()].is_empty_unreachable() {
264 // Someone could write code like this:
265 // ```rust
266 // let Q = val;
267 // if discriminant(P) == otherwise {
268 // let ptr = &mut Q as *mut _ as *mut u8;
269 // // It may be difficult for us to effectively determine whether values are valid.
270 // // Invalid values can come from all sorts of corners.
271 // unsafe { *ptr = 10; }
272 // }
273 //
274 // match P {
275 // A => match Q {
276 // A => {
277 // // code
278 // }
279 // _ => {
280 // // don't use Q
281 // }
282 // }
283 // _ => {
284 // // don't use Q
285 // }
286 // };
287 // ```
288 //
289 // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
290 // invalid value, which is UB.
291 // In order to fix this, **we would either need to show that the discriminant computation of
292 // `place` is computed in all branches**.
293 // FIXME(#95162) For the moment, we adopt a conservative approach and
294 // consider only the `otherwise` branch has no statements and an unreachable terminator.
295 return None;
296 }
297 // Handle:
298 // ```
299 // bb4: {
300 // _8 = discriminant((_3.1: Enum1));
301 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
302 // }
303 // ```
304 let [
305 Statement {
306 kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
307 ..
308 },
309 ] = bbs[child].statements.as_slice()
310 else {
311 return None;
312 };
313 *child_place
314 } else {
315 // Handle:
316 // ```
317 // bb2: {
318 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
319 // }
320 // ```
321 let Operand::Copy(child_place) = child_discr else {
322 return None;
323 };
324 *child_place
325 };
326 let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
327 {
328 child_targets.otherwise()
329 } else {
330 targets.otherwise()
331 };
332
333 // Verify that the optimization is legal for each branch
334 for (value, child) in targets.iter() {
335 if !verify_candidate_branch(
336 &bbs[child],
337 value,
338 child_place,
339 destination,
340 need_hoist_discriminant,
341 ) {
342 return None;
343 }
344 }
345 Some(OptimizationData {
346 destination,
347 child_place,
348 child_ty,
349 child_source: *source_info,
350 need_hoist_discriminant,
351 })
352}
353
354fn verify_candidate_branch<'tcx>(
355 branch: &BasicBlockData<'tcx>,
356 value: u128,
357 place: Place<'tcx>,
358 destination: BasicBlock,
359 need_hoist_discriminant: bool,
360) -> bool {
361 // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
362 let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
363 return false;
364 };
365 if need_hoist_discriminant {
366 // If we need hoist discriminant, the branch must have exactly one statement.
367 let [statement] = branch.statements.as_slice() else {
368 return false;
369 };
370 // The statement must assign the discriminant of `place`.
371 let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
372 statement.kind
373 else {
374 return false;
375 };
376 if from_place != place {
377 return false;
378 }
379 // The assignment must invalidate a local that terminate on a `SwitchInt`.
380 if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
381 return false;
382 }
383 } else {
384 // If we don't need hoist discriminant, the branch must not have any statements.
385 if !branch.statements.is_empty() {
386 return false;
387 }
388 // The place on `SwitchInt` must be the same.
389 if *switch_op != Operand::Copy(place) {
390 return false;
391 }
392 }
393 // It must fall through to `destination` if the switch misses.
394 if destination != targets.otherwise() {
395 return false;
396 }
397 // It must have exactly one branch for value `value` and have no more branches.
398 let mut iter = targets.iter();
399 let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
400 return false;
401 };
402 target_value == value
403}