rustc_mir_transform/early_otherwise_branch.rs
1use std::fmt::Debug;
2
3use rustc_middle::mir::*;
4use rustc_middle::ty::{Ty, TyCtxt};
5use tracing::trace;
6
7use super::simplify::simplify_cfg;
8use crate::patch::MirPatch;
9
10/// This pass optimizes something like
11/// ```ignore (syntax-highlighting-only)
12/// let x: Option<()>;
13/// let y: Option<()>;
14/// match (x,y) {
15/// (Some(_), Some(_)) => {0},
16/// (None, None) => {2},
17/// _ => {1}
18/// }
19/// ```
20/// into something like
21/// ```ignore (syntax-highlighting-only)
22/// let x: Option<()>;
23/// let y: Option<()>;
24/// let discriminant_x = std::mem::discriminant(x);
25/// let discriminant_y = std::mem::discriminant(y);
26/// if discriminant_x == discriminant_y {
27/// match x {
28/// Some(_) => 0,
29/// None => 2,
30/// }
31/// } else {
32/// 1
33/// }
34/// ```
35///
36/// Specifically, it looks for instances of control flow like this:
37/// ```text
38///
39/// =================
40/// | BB1 |
41/// |---------------| ============================
42/// | ... | /------> | BBC |
43/// |---------------| | |--------------------------|
44/// | switchInt(Q) | | | _cl = discriminant(P) |
45/// | c | --------/ |--------------------------|
46/// | d | -------\ | switchInt(_cl) |
47/// | ... | | | c | ---> BBC.2
48/// | otherwise | --\ | /--- | otherwise |
49/// ================= | | | ============================
50/// | | |
51/// ================= | | |
52/// | BBU | <-| | | ============================
53/// |---------------| \-------> | BBD |
54/// |---------------| | |--------------------------|
55/// | unreachable | | | _dl = discriminant(P) |
56/// ================= | |--------------------------|
57/// | | switchInt(_dl) |
58/// ================= | | d | ---> BBD.2
59/// | BB9 | <--------------- | otherwise |
60/// |---------------| ============================
61/// | ... |
62/// =================
63/// ```
64/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
65/// code:
66/// - `BB1` is `parent` and `BBC, BBD` are children
67/// - `P` is `child_place`
68/// - `child_ty` is the type of `_cl`.
69/// - `Q` is `parent_op`.
70/// - `parent_ty` is the type of `Q`.
71/// - `BB9` is `destination`
72/// All this is then transformed into:
73/// ```text
74///
75/// =======================
76/// | BB1 |
77/// |---------------------| ============================
78/// | ... | /------> | BBEq |
79/// | _s = discriminant(P)| | |--------------------------|
80/// | _t = Ne(Q, _s) | | |--------------------------|
81/// |---------------------| | | switchInt(Q) |
82/// | switchInt(_t) | | | c | ---> BBC.2
83/// | false | --------/ | d | ---> BBD.2
84/// | otherwise | /--------- | otherwise |
85/// ======================= | ============================
86/// |
87/// ================= |
88/// | BB9 | <-----------/
89/// |---------------|
90/// | ... |
91/// =================
92/// ```
93pub(super) struct EarlyOtherwiseBranch;
94
95impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
96 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
97 sess.mir_opt_level() >= 2
98 }
99
100 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
101 trace!("running EarlyOtherwiseBranch on {:?}", body.source);
102
103 let mut should_cleanup = false;
104
105 // Also consider newly generated bbs in the same pass
106 for parent in body.basic_blocks.indices() {
107 let bbs = &*body.basic_blocks;
108 let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
109
110 trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
111
112 should_cleanup = true;
113
114 let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
115 &bbs[parent].terminator().kind
116 else {
117 unreachable!()
118 };
119 // Always correct since we can only switch on `Copy` types
120 let parent_op = parent_op.to_copy();
121 let parent_ty = parent_op.ty(body.local_decls(), tcx);
122 let statements_before = bbs[parent].statements.len();
123 let parent_end = Location { block: parent, statement_index: statements_before };
124
125 let mut patch = MirPatch::new(body);
126
127 let second_operand = if opt_data.need_hoist_discriminant {
128 // create temp to store second discriminant in, `_s` in example above
129 let second_discriminant_temp =
130 patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
131
132 // create assignment of discriminant
133 patch.add_assign(
134 parent_end,
135 Place::from(second_discriminant_temp),
136 Rvalue::Discriminant(opt_data.child_place),
137 );
138 Operand::Move(Place::from(second_discriminant_temp))
139 } else {
140 Operand::Copy(opt_data.child_place)
141 };
142
143 // create temp to store inequality comparison between the two discriminants, `_t` in
144 // example above
145 let nequal = BinOp::Ne;
146 let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
147 let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
148
149 // create inequality comparison
150 let comp_rvalue =
151 Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
152 patch.add_statement(
153 parent_end,
154 StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
155 );
156
157 let eq_new_targets = parent_targets.iter().map(|(value, child)| {
158 let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
159 else {
160 unreachable!()
161 };
162 (value, targets.target_for_value(value))
163 });
164 // The otherwise either is the same target branch or an unreachable.
165 let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
166
167 // Create `bbEq` in example above
168 let eq_switch = BasicBlockData::new(
169 Some(Terminator {
170 source_info: bbs[parent].terminator().source_info,
171 kind: TerminatorKind::SwitchInt {
172 // switch on the first discriminant, so we can mark the second one as dead
173 discr: parent_op,
174 targets: eq_targets,
175 },
176 }),
177 bbs[parent].is_cleanup,
178 );
179
180 let eq_bb = patch.new_block(eq_switch);
181
182 // Jump to it on the basis of the inequality comparison
183 let true_case = opt_data.destination;
184 let false_case = eq_bb;
185 patch.patch_terminator(
186 parent,
187 TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
188 );
189
190 patch.apply(body);
191 }
192
193 // Since this optimization adds new basic blocks and invalidates others,
194 // clean up the cfg to make it nicer for other passes
195 if should_cleanup {
196 simplify_cfg(tcx, body);
197 }
198 }
199
200 fn is_required(&self) -> bool {
201 false
202 }
203}
204
205#[derive(Debug)]
206struct OptimizationData<'tcx> {
207 destination: BasicBlock,
208 child_place: Place<'tcx>,
209 child_ty: Ty<'tcx>,
210 child_source: SourceInfo,
211 need_hoist_discriminant: bool,
212}
213
214fn evaluate_candidate<'tcx>(
215 tcx: TyCtxt<'tcx>,
216 body: &Body<'tcx>,
217 parent: BasicBlock,
218) -> Option<OptimizationData<'tcx>> {
219 let bbs = &body.basic_blocks;
220 // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled.
221 if bbs[parent].is_cleanup {
222 return None;
223 }
224 let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind
225 else {
226 return None;
227 };
228 let parent_ty = parent_discr.ty(body.local_decls(), tcx);
229 let (_, child) = targets.iter().next()?;
230
231 let Terminator {
232 kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
233 source_info,
234 } = bbs[child].terminator()
235 else {
236 return None;
237 };
238 let child_ty = child_discr.ty(body.local_decls(), tcx);
239 if child_ty != parent_ty {
240 return None;
241 }
242
243 // We only handle:
244 // ```
245 // bb4: {
246 // _8 = discriminant((_3.1: Enum1));
247 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
248 // }
249 // ```
250 // and
251 // ```
252 // bb2: {
253 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
254 // }
255 // ```
256 if bbs[child].statements.len() > 1 {
257 return None;
258 }
259
260 // When thie BB has exactly one statement, this statement should be discriminant.
261 let need_hoist_discriminant = bbs[child].statements.len() == 1;
262 let child_place = if need_hoist_discriminant {
263 if !bbs[targets.otherwise()].is_empty_unreachable() {
264 // Someone could write code like this:
265 // ```rust
266 // let Q = val;
267 // if discriminant(P) == otherwise {
268 // let ptr = &mut Q as *mut _ as *mut u8;
269 // // It may be difficult for us to effectively determine whether values are valid.
270 // // Invalid values can come from all sorts of corners.
271 // unsafe { *ptr = 10; }
272 // }
273 //
274 // match P {
275 // A => match Q {
276 // A => {
277 // // code
278 // }
279 // _ => {
280 // // don't use Q
281 // }
282 // }
283 // _ => {
284 // // don't use Q
285 // }
286 // };
287 // ```
288 //
289 // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
290 // invalid value, which is UB.
291 // In order to fix this, **we would either need to show that the discriminant computation of
292 // `place` is computed in all branches**.
293 // FIXME(#95162) For the moment, we adopt a conservative approach and
294 // consider only the `otherwise` branch has no statements and an unreachable terminator.
295 return None;
296 }
297 // Handle:
298 // ```
299 // bb4: {
300 // _8 = discriminant((_3.1: Enum1));
301 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
302 // }
303 // ```
304 let [
305 Statement {
306 kind: StatementKind::Assign((_, Rvalue::Discriminant(child_place))), ..
307 },
308 ] = bbs[child].statements.as_slice()
309 else {
310 return None;
311 };
312 *child_place
313 } else {
314 // Handle:
315 // ```
316 // bb2: {
317 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
318 // }
319 // ```
320 let Operand::Copy(child_place) = child_discr else {
321 return None;
322 };
323 *child_place
324 };
325 let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
326 {
327 child_targets.otherwise()
328 } else {
329 targets.otherwise()
330 };
331
332 // Verify that the optimization is legal for each branch
333 for (value, child) in targets.iter() {
334 if !verify_candidate_branch(
335 &bbs[child],
336 value,
337 child_place,
338 destination,
339 need_hoist_discriminant,
340 ) {
341 return None;
342 }
343 }
344 Some(OptimizationData {
345 destination,
346 child_place,
347 child_ty,
348 child_source: *source_info,
349 need_hoist_discriminant,
350 })
351}
352
353fn verify_candidate_branch<'tcx>(
354 branch: &BasicBlockData<'tcx>,
355 value: u128,
356 place: Place<'tcx>,
357 destination: BasicBlock,
358 need_hoist_discriminant: bool,
359) -> bool {
360 // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
361 let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
362 return false;
363 };
364 if need_hoist_discriminant {
365 // If we need hoist discriminant, the branch must have exactly one statement.
366 let [statement] = branch.statements.as_slice() else {
367 return false;
368 };
369 // The statement must assign the discriminant of `place`.
370 let StatementKind::Assign((discr_place, Rvalue::Discriminant(from_place))) = statement.kind
371 else {
372 return false;
373 };
374 if from_place != place {
375 return false;
376 }
377 // The assignment must invalidate a local that terminate on a `SwitchInt`.
378 if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
379 return false;
380 }
381 } else {
382 // If we don't need hoist discriminant, the branch must not have any statements.
383 if !branch.statements.is_empty() {
384 return false;
385 }
386 // The place on `SwitchInt` must be the same.
387 if *switch_op != Operand::Copy(place) {
388 return false;
389 }
390 }
391 // It must fall through to `destination` if the switch misses.
392 if destination != targets.otherwise() {
393 return false;
394 }
395 // It must have exactly one branch for value `value` and have no more branches.
396 let mut iter = targets.iter();
397 let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
398 return false;
399 };
400 target_value == value
401}