rustc_mir_transform/early_otherwise_branch.rs
1use std::fmt::Debug;
2
3use rustc_data_structures::thin_vec::ThinVec;
4use rustc_middle::mir::*;
5use rustc_middle::ty::{Ty, TyCtxt};
6use tracing::trace;
7
8use super::simplify::simplify_cfg;
9use crate::patch::MirPatch;
10
11/// This pass optimizes something like
12/// ```ignore (syntax-highlighting-only)
13/// let x: Option<()>;
14/// let y: Option<()>;
15/// match (x,y) {
16/// (Some(_), Some(_)) => {0},
17/// (None, None) => {2},
18/// _ => {1}
19/// }
20/// ```
21/// into something like
22/// ```ignore (syntax-highlighting-only)
23/// let x: Option<()>;
24/// let y: Option<()>;
25/// let discriminant_x = std::mem::discriminant(x);
26/// let discriminant_y = std::mem::discriminant(y);
27/// if discriminant_x == discriminant_y {
28/// match x {
29/// Some(_) => 0,
30/// None => 2,
31/// }
32/// } else {
33/// 1
34/// }
35/// ```
36///
37/// Specifically, it looks for instances of control flow like this:
38/// ```text
39///
40/// =================
41/// | BB1 |
42/// |---------------| ============================
43/// | ... | /------> | BBC |
44/// |---------------| | |--------------------------|
45/// | switchInt(Q) | | | _cl = discriminant(P) |
46/// | c | --------/ |--------------------------|
47/// | d | -------\ | switchInt(_cl) |
48/// | ... | | | c | ---> BBC.2
49/// | otherwise | --\ | /--- | otherwise |
50/// ================= | | | ============================
51/// | | |
52/// ================= | | |
53/// | BBU | <-| | | ============================
54/// |---------------| \-------> | BBD |
55/// |---------------| | |--------------------------|
56/// | unreachable | | | _dl = discriminant(P) |
57/// ================= | |--------------------------|
58/// | | switchInt(_dl) |
59/// ================= | | d | ---> BBD.2
60/// | BB9 | <--------------- | otherwise |
61/// |---------------| ============================
62/// | ... |
63/// =================
64/// ```
65/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
66/// code:
67/// - `BB1` is `parent` and `BBC, BBD` are children
68/// - `P` is `child_place`
69/// - `child_ty` is the type of `_cl`.
70/// - `Q` is `parent_op`.
71/// - `parent_ty` is the type of `Q`.
72/// - `BB9` is `destination`
73/// All this is then transformed into:
74/// ```text
75///
76/// =======================
77/// | BB1 |
78/// |---------------------| ============================
79/// | ... | /------> | BBEq |
80/// | _s = discriminant(P)| | |--------------------------|
81/// | _t = Ne(Q, _s) | | |--------------------------|
82/// |---------------------| | | switchInt(Q) |
83/// | switchInt(_t) | | | c | ---> BBC.2
84/// | false | --------/ | d | ---> BBD.2
85/// | otherwise | /--------- | otherwise |
86/// ======================= | ============================
87/// |
88/// ================= |
89/// | BB9 | <-----------/
90/// |---------------|
91/// | ... |
92/// =================
93/// ```
94pub(super) struct EarlyOtherwiseBranch;
95
96impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
97 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
98 sess.mir_opt_level() >= 2
99 }
100
101 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
102 trace!("running EarlyOtherwiseBranch on {:?}", body.source);
103
104 let mut should_cleanup = false;
105
106 // Also consider newly generated bbs in the same pass
107 for parent in body.basic_blocks.indices() {
108 let bbs = &*body.basic_blocks;
109 let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
110
111 trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
112
113 should_cleanup = true;
114
115 let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
116 &bbs[parent].terminator().kind
117 else {
118 unreachable!()
119 };
120 // Always correct since we can only switch on `Copy` types
121 let parent_op = parent_op.to_copy();
122 let parent_ty = parent_op.ty(body.local_decls(), tcx);
123 let statements_before = bbs[parent].statements.len();
124 let parent_end = Location { block: parent, statement_index: statements_before };
125
126 let mut patch = MirPatch::new(body);
127
128 let second_operand = if opt_data.need_hoist_discriminant {
129 // create temp to store second discriminant in, `_s` in example above
130 let second_discriminant_temp =
131 patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
132
133 // create assignment of discriminant
134 patch.add_assign(
135 parent_end,
136 Place::from(second_discriminant_temp),
137 Rvalue::Discriminant(opt_data.child_place),
138 );
139 Operand::Move(Place::from(second_discriminant_temp))
140 } else {
141 Operand::Copy(opt_data.child_place)
142 };
143
144 // create temp to store inequality comparison between the two discriminants, `_t` in
145 // example above
146 let nequal = BinOp::Ne;
147 let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
148 let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
149
150 // create inequality comparison
151 let comp_rvalue =
152 Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
153 patch.add_statement(
154 parent_end,
155 StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
156 );
157
158 let eq_new_targets = parent_targets.iter().map(|(value, child)| {
159 let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
160 else {
161 unreachable!()
162 };
163 (value, targets.target_for_value(value))
164 });
165 // The otherwise either is the same target branch or an unreachable.
166 let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
167
168 // Create `bbEq` in example above
169 let eq_switch = BasicBlockData::new(
170 Some(Terminator {
171 source_info: bbs[parent].terminator().source_info,
172 kind: TerminatorKind::SwitchInt {
173 // switch on the first discriminant, so we can mark the second one as dead
174 discr: parent_op,
175 targets: eq_targets,
176 },
177 attributes: ThinVec::new(),
178 }),
179 bbs[parent].is_cleanup,
180 );
181
182 let eq_bb = patch.new_block(eq_switch);
183
184 // Jump to it on the basis of the inequality comparison
185 let true_case = opt_data.destination;
186 let false_case = eq_bb;
187 patch.patch_terminator(
188 parent,
189 TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
190 );
191
192 patch.apply(body);
193 }
194
195 // Since this optimization adds new basic blocks and invalidates others,
196 // clean up the cfg to make it nicer for other passes
197 if should_cleanup {
198 simplify_cfg(tcx, body);
199 }
200 }
201
202 fn is_required(&self) -> bool {
203 false
204 }
205}
206
207#[derive(Debug)]
208struct OptimizationData<'tcx> {
209 destination: BasicBlock,
210 child_place: Place<'tcx>,
211 child_ty: Ty<'tcx>,
212 child_source: SourceInfo,
213 need_hoist_discriminant: bool,
214}
215
216fn evaluate_candidate<'tcx>(
217 tcx: TyCtxt<'tcx>,
218 body: &Body<'tcx>,
219 parent: BasicBlock,
220) -> Option<OptimizationData<'tcx>> {
221 let bbs = &body.basic_blocks;
222 // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled.
223 if bbs[parent].is_cleanup {
224 return None;
225 }
226 let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind
227 else {
228 return None;
229 };
230 let parent_ty = parent_discr.ty(body.local_decls(), tcx);
231 let (_, child) = targets.iter().next()?;
232
233 let Terminator {
234 kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
235 source_info,
236 attributes: _,
237 } = bbs[child].terminator()
238 else {
239 return None;
240 };
241 let child_ty = child_discr.ty(body.local_decls(), tcx);
242 if child_ty != parent_ty {
243 return None;
244 }
245
246 // We only handle:
247 // ```
248 // bb4: {
249 // _8 = discriminant((_3.1: Enum1));
250 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
251 // }
252 // ```
253 // and
254 // ```
255 // bb2: {
256 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
257 // }
258 // ```
259 if bbs[child].statements.len() > 1 {
260 return None;
261 }
262
263 // When thie BB has exactly one statement, this statement should be discriminant.
264 let need_hoist_discriminant = bbs[child].statements.len() == 1;
265 let child_place = if need_hoist_discriminant {
266 if !bbs[targets.otherwise()].is_empty_unreachable() {
267 // Someone could write code like this:
268 // ```rust
269 // let Q = val;
270 // if discriminant(P) == otherwise {
271 // let ptr = &mut Q as *mut _ as *mut u8;
272 // // It may be difficult for us to effectively determine whether values are valid.
273 // // Invalid values can come from all sorts of corners.
274 // unsafe { *ptr = 10; }
275 // }
276 //
277 // match P {
278 // A => match Q {
279 // A => {
280 // // code
281 // }
282 // _ => {
283 // // don't use Q
284 // }
285 // }
286 // _ => {
287 // // don't use Q
288 // }
289 // };
290 // ```
291 //
292 // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
293 // invalid value, which is UB.
294 // In order to fix this, **we would either need to show that the discriminant computation of
295 // `place` is computed in all branches**.
296 // FIXME(#95162) For the moment, we adopt a conservative approach and
297 // consider only the `otherwise` branch has no statements and an unreachable terminator.
298 return None;
299 }
300 // Handle:
301 // ```
302 // bb4: {
303 // _8 = discriminant((_3.1: Enum1));
304 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
305 // }
306 // ```
307 let [
308 Statement {
309 kind: StatementKind::Assign((_, Rvalue::Discriminant(child_place))), ..
310 },
311 ] = bbs[child].statements.as_slice()
312 else {
313 return None;
314 };
315 *child_place
316 } else {
317 // Handle:
318 // ```
319 // bb2: {
320 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
321 // }
322 // ```
323 let Operand::Copy(child_place) = child_discr else {
324 return None;
325 };
326 *child_place
327 };
328 let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
329 {
330 child_targets.otherwise()
331 } else {
332 targets.otherwise()
333 };
334
335 // Verify that the optimization is legal for each branch
336 for (value, child) in targets.iter() {
337 if !verify_candidate_branch(
338 &bbs[child],
339 value,
340 child_place,
341 destination,
342 need_hoist_discriminant,
343 ) {
344 return None;
345 }
346 }
347 Some(OptimizationData {
348 destination,
349 child_place,
350 child_ty,
351 child_source: *source_info,
352 need_hoist_discriminant,
353 })
354}
355
356fn verify_candidate_branch<'tcx>(
357 branch: &BasicBlockData<'tcx>,
358 value: u128,
359 place: Place<'tcx>,
360 destination: BasicBlock,
361 need_hoist_discriminant: bool,
362) -> bool {
363 // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
364 let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
365 return false;
366 };
367 if need_hoist_discriminant {
368 // If we need hoist discriminant, the branch must have exactly one statement.
369 let [statement] = branch.statements.as_slice() else {
370 return false;
371 };
372 // The statement must assign the discriminant of `place`.
373 let StatementKind::Assign((discr_place, Rvalue::Discriminant(from_place))) = statement.kind
374 else {
375 return false;
376 };
377 if from_place != place {
378 return false;
379 }
380 // The assignment must invalidate a local that terminate on a `SwitchInt`.
381 if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
382 return false;
383 }
384 } else {
385 // If we don't need hoist discriminant, the branch must not have any statements.
386 if !branch.statements.is_empty() {
387 return false;
388 }
389 // The place on `SwitchInt` must be the same.
390 if *switch_op != Operand::Copy(place) {
391 return false;
392 }
393 }
394 // It must fall through to `destination` if the switch misses.
395 if destination != targets.otherwise() {
396 return false;
397 }
398 // It must have exactly one branch for value `value` and have no more branches.
399 let mut iter = targets.iter();
400 let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
401 return false;
402 };
403 target_value == value
404}