1use rustc_abi::Integer;
2use rustc_const_eval::const_eval::mk_eval_cx_for_const_val;
3use rustc_middle::mir::*;
4use rustc_middle::ty::layout::{IntegerExt, TyAndLayout};
5use rustc_middle::ty::util::Discr;
6use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt};
7
8use super::simplify::simplify_cfg;
9use crate::patch::MirPatch;
10use crate::unreachable_prop::remove_successors_from_switch;
11
12pub(super) struct MatchBranchSimplification;
14
15impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
16 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
17 sess.mir_opt_level() >= 2
19 }
20
21 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
22 let typing_env = body.typing_env(tcx);
23 let mut changed = false;
24 for bb in body.basic_blocks.indices() {
25 if !candidate_match(body, bb) {
26 continue;
27 };
28 changed |= simplify_match(tcx, typing_env, body, bb)
29 }
30
31 if changed {
32 simplify_cfg(tcx, body);
33 }
34 }
35
36 fn is_required(&self) -> bool {
37 false
38 }
39}
40
41struct SimplifyMatch<'tcx, 'a> {
42 tcx: TyCtxt<'tcx>,
43 typing_env: ty::TypingEnv<'tcx>,
44 patch: MirPatch<'tcx>,
45 body: &'a Body<'tcx>,
46 switch_bb: BasicBlock,
47 discr: &'a Operand<'tcx>,
48 discr_local: Option<Local>,
49 discr_ty: Ty<'tcx>,
50}
51
52impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> {
53 fn discr_local(&mut self) -> Local {
54 *self.discr_local.get_or_insert_with(|| {
55 let source_info = self.body.basic_blocks[self.switch_bb].terminator().source_info;
57 self.patch.new_temp(self.discr_ty, source_info.span)
58 })
59 }
60
61 fn unify_if_equal_const(
63 &self,
64 dest: Place<'tcx>,
65 consts: &[(u128, &ConstOperand<'tcx>)],
66 otherwise: Option<&ConstOperand<'tcx>>,
67 ) -> Option<StatementKind<'tcx>> {
68 let (_, first_const, mut others) = split_first_case(consts, otherwise);
69 let first_scalar_int = first_const.const_.try_eval_scalar_int(self.tcx, self.typing_env)?;
70 if others.all(|const_| {
71 const_.const_.try_eval_scalar_int(self.tcx, self.typing_env) == Some(first_scalar_int)
72 }) {
73 Some(StatementKind::Assign(Box::new((
74 dest,
75 Rvalue::Use(Operand::Constant(Box::new(first_const.clone()))),
76 ))))
77 } else {
78 None
79 }
80 }
81
82 fn unify_by_eq_op(
114 &mut self,
115 dest: Place<'tcx>,
116 consts: &[(u128, &ConstOperand<'tcx>)],
117 otherwise: Option<&ConstOperand<'tcx>>,
118 ) -> Option<StatementKind<'tcx>> {
119 let (first_case, first_const, mut others) = split_first_case(consts, otherwise);
121 if !first_const.ty().is_bool() {
122 return None;
123 }
124 let first_bool = first_const.const_.try_eval_bool(self.tcx, self.typing_env)?;
125 if others.all(|const_| {
126 const_.const_.try_eval_bool(self.tcx, self.typing_env) == Some(!first_bool)
127 }) {
128 let size =
130 self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap().size;
131 let const_cmp = Operand::const_from_scalar(
132 self.tcx,
133 self.discr_ty,
134 rustc_const_eval::interpret::Scalar::from_uint(first_case, size),
135 rustc_span::DUMMY_SP,
136 );
137 let op = if first_bool { BinOp::Eq } else { BinOp::Ne };
138 let rval = Rvalue::BinaryOp(
139 op,
140 Box::new((Operand::Copy(Place::from(self.discr_local())), const_cmp)),
141 );
142 Some(StatementKind::Assign(Box::new((dest, rval))))
143 } else {
144 None
145 }
146 }
147
148 fn unify_by_int_to_int(
186 &mut self,
187 dest: Place<'tcx>,
188 consts: &[(u128, &ConstOperand<'tcx>)],
189 ) -> Option<StatementKind<'tcx>> {
190 let (_, first_const) = consts[0];
191 if !first_const.ty().is_integral() {
192 return None;
193 }
194 let discr_layout =
195 self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap();
196 if consts.iter().all(|&(case, const_)| {
197 let Some(scalar_int) = const_.const_.try_eval_scalar_int(self.tcx, self.typing_env)
198 else {
199 return false;
200 };
201 can_cast(self.tcx, case, discr_layout, const_.ty(), scalar_int)
202 }) {
203 let operand = Operand::Copy(Place::from(self.discr_local()));
204 let rval = if first_const.ty() == self.discr_ty {
205 Rvalue::Use(operand)
206 } else {
207 Rvalue::Cast(CastKind::IntToInt, operand, first_const.ty())
208 };
209 Some(StatementKind::Assign(Box::new((dest, rval))))
210 } else {
211 None
212 }
213 }
214
215 fn unify_by_copy(
232 &self,
233 dest: Place<'tcx>,
234 rvals: &[(u128, &Rvalue<'tcx>)],
235 ) -> Option<StatementKind<'tcx>> {
236 let bbs = &self.body.basic_blocks;
237 let &Statement {
241 kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(copy_src_place))),
242 ..
243 } = bbs[self.switch_bb].statements.last()?
244 else {
245 return None;
246 };
247 if self.discr.place() != Some(discr_place) {
248 return None;
249 }
250 let src_ty = copy_src_place.ty(self.body.local_decls(), self.tcx);
251 if !src_ty.ty.is_enum() || src_ty.variant_index.is_some() {
252 return None;
253 }
254 let dest_ty = dest.ty(self.body.local_decls(), self.tcx);
255 if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() {
256 return None;
257 }
258 let ty::Adt(def, _) = dest_ty.ty.kind() else {
259 return None;
260 };
261
262 for &(case, rvalue) in rvals.iter() {
263 match rvalue {
264 Rvalue::Use(Operand::Constant(box constant))
266 if let Const::Val(const_, ty) = constant.const_ =>
267 {
268 let (ecx, op) = mk_eval_cx_for_const_val(
269 self.tcx.at(constant.span),
270 self.typing_env,
271 const_,
272 ty,
273 )?;
274 let variant = ecx.read_discriminant(&op).discard_err()?;
275 if !def.variants()[variant].fields.is_empty() {
276 return None;
277 }
278 let Discr { val, .. } = ty.discriminant_for_variant(self.tcx, variant)?;
279 if val != case {
280 return None;
281 }
282 }
283 Rvalue::Use(Operand::Copy(src_place)) if *src_place == copy_src_place => {}
284 Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
286 if fields.is_empty()
287 && let Some(Discr { val, .. }) =
288 src_ty.ty.discriminant_for_variant(self.tcx, *variant_index)
289 && val == case => {}
290 _ => return None,
291 }
292 }
293 Some(StatementKind::Assign(Box::new((dest, Rvalue::Use(Operand::Copy(copy_src_place))))))
294 }
295
296 fn try_unify_stmts(
298 &mut self,
299 index: usize,
300 stmts: &[(u128, &StatementKind<'tcx>)],
301 otherwise: Option<&StatementKind<'tcx>>,
302 ) -> Option<StatementKind<'tcx>> {
303 if let Some(new_stmt) = identical_stmts(stmts, otherwise) {
304 return Some(new_stmt);
305 }
306
307 let (dest, rvals, otherwise) = candidate_assign(stmts, otherwise)?;
308 if let Some((consts, otherwise)) = candidate_const(&rvals, otherwise) {
309 if let Some(new_stmt) = self.unify_if_equal_const(dest, &consts, otherwise) {
310 return Some(new_stmt);
311 }
312 if let Some(new_stmt) = self.unify_by_eq_op(dest, &consts, otherwise) {
313 return Some(new_stmt);
314 }
315 if otherwise.is_none()
317 && let Some(new_stmt) = self.unify_by_int_to_int(dest, &consts)
318 {
319 return Some(new_stmt);
320 }
321 }
322
323 if index == 0
325 && dest.is_stable_offset()
327 && otherwise.is_none()
329 && let Some(new_stmt) = self.unify_by_copy(dest, &rvals)
330 {
331 return Some(new_stmt);
332 }
333 None
334 }
335}
336
337fn candidate_match<'tcx>(body: &Body<'tcx>, switch_bb: BasicBlock) -> bool {
339 use itertools::Itertools;
340 let targets = match &body.basic_blocks[switch_bb].terminator().kind {
341 TerminatorKind::SwitchInt {
342 discr: Operand::Copy(_) | Operand::Move(_), targets, ..
343 } => targets,
344 _ => return false,
346 };
347 if targets.all_targets().contains(&switch_bb) {
349 return false;
350 }
351 if !targets.is_distinct() {
353 return false;
354 }
355 targets
357 .all_targets()
358 .iter()
359 .map(|&bb| &body.basic_blocks[bb])
360 .filter(|bb| !bb.is_empty_unreachable())
361 .map(|bb| (bb.statements.len(), &bb.terminator().kind))
362 .all_equal()
363}
364
365fn simplify_match<'tcx>(
366 tcx: TyCtxt<'tcx>,
367 typing_env: ty::TypingEnv<'tcx>,
368 body: &mut Body<'tcx>,
369 switch_bb: BasicBlock,
370) -> bool {
371 let (discr, targets) = match &body.basic_blocks[switch_bb].terminator().kind {
372 TerminatorKind::SwitchInt { discr, targets, .. } => (discr, targets),
373 _ => unreachable!(),
374 };
375 let mut simplify_match = SimplifyMatch {
376 tcx,
377 typing_env,
378 patch: MirPatch::new(body),
379 body,
380 switch_bb,
381 discr,
382 discr_local: None,
383 discr_ty: discr.ty(body.local_decls(), tcx),
384 };
385 let reachable_cases: Vec<_> =
386 targets.iter().filter(|&(_, bb)| !body.basic_blocks[bb].is_empty_unreachable()).collect();
387 let mut new_stmts = Vec::new();
388 let otherwise = if body.basic_blocks[targets.otherwise()].is_empty_unreachable() {
389 None
390 } else {
391 Some(targets.otherwise())
392 };
393 match (reachable_cases.len(), otherwise.is_none()) {
395 (1, true) | (0, false) => {
396 let mut patch = simplify_match.patch;
397 remove_successors_from_switch(tcx, switch_bb, body, &mut patch, |bb| {
398 body.basic_blocks[bb].is_empty_unreachable()
399 });
400 patch.apply(body);
401 return true;
402 }
403 _ => {}
404 }
405 let Some(&(_, first_case_bb)) = reachable_cases.first() else {
406 return false;
407 };
408 let stmt_len = body.basic_blocks[first_case_bb].statements.len();
409 let mut cases = Vec::with_capacity(stmt_len);
410 for index in 0..stmt_len {
412 cases.clear();
413 let otherwise = otherwise.map(|bb| &body.basic_blocks[bb].statements[index].kind);
414 for &(case, bb) in &reachable_cases {
415 cases.push((case, &body.basic_blocks[bb].statements[index].kind));
416 }
417 let Some(new_stmt) = simplify_match.try_unify_stmts(index, &cases, otherwise) else {
418 return false;
419 };
420 new_stmts.push(new_stmt);
421 }
422 let discr = discr.clone();
424
425 let statement_index = body.basic_blocks[switch_bb].statements.len();
426 let parent_end = Location { block: switch_bb, statement_index };
427 let mut patch = simplify_match.patch;
428 if let Some(discr_local) = simplify_match.discr_local {
429 patch.add_statement(parent_end, StatementKind::StorageLive(discr_local));
430 patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr));
431 }
432 for new_stmt in new_stmts {
433 patch.add_statement(parent_end, new_stmt);
434 }
435 if let Some(discr_local) = simplify_match.discr_local {
436 patch.add_statement(parent_end, StatementKind::StorageDead(discr_local));
437 }
438 patch.patch_terminator(switch_bb, body.basic_blocks[first_case_bb].terminator().kind.clone());
439 patch.apply(body);
440 true
441}
442
443fn can_cast(
445 tcx: TyCtxt<'_>,
446 src_val: impl Into<u128>,
447 src_layout: TyAndLayout<'_>,
448 cast_ty: Ty<'_>,
449 target_scalar: ScalarInt,
450) -> bool {
451 let from_scalar = ScalarInt::try_from_uint(src_val.into(), src_layout.size).unwrap();
452 let v = match src_layout.ty.kind() {
453 ty::Uint(_) => from_scalar.to_uint(src_layout.size),
454 ty::Int(_) => from_scalar.to_int(src_layout.size) as u128,
455 _ => return false,
458 };
459 let size = match *cast_ty.kind() {
460 ty::Int(t) => Integer::from_int_ty(&tcx, t).size(),
461 ty::Uint(t) => Integer::from_uint_ty(&tcx, t).size(),
462 _ => return false,
463 };
464 let v = size.truncate(v);
465 let cast_scalar = ScalarInt::try_from_uint(v, size).unwrap();
466 cast_scalar == target_scalar
467}
468
469fn candidate_assign<'tcx, 'a>(
470 stmts: &'a [(u128, &'a StatementKind<'tcx>)],
471 otherwise: Option<&'a StatementKind<'tcx>>,
472) -> Option<(Place<'tcx>, Vec<(u128, &'a Rvalue<'tcx>)>, Option<&'a Rvalue<'tcx>>)> {
473 let (_, first_stmt) = stmts[0];
474 let (dest, _) = first_stmt.as_assign()?;
475 let otherwise = if let Some(otherwise) = otherwise {
476 let Some((otherwise_dest, rval)) = otherwise.as_assign() else {
477 return None;
478 };
479 if otherwise_dest != dest {
480 return None;
481 }
482 Some(rval)
483 } else {
484 None
485 };
486 let rvals = stmts
487 .into_iter()
488 .map(|&(case, stmt)| {
489 let (other_dest, rval) = stmt.as_assign()?;
490 if other_dest != dest {
491 return None;
492 }
493 Some((case, rval))
494 })
495 .try_collect()?;
496 Some((*dest, rvals, otherwise))
497}
498
499fn candidate_const<'tcx, 'a>(
501 rvals: &'a [(u128, &'a Rvalue<'tcx>)],
502 otherwise: Option<&'a Rvalue<'tcx>>,
503) -> Option<(Vec<(u128, &'a ConstOperand<'tcx>)>, Option<&'a ConstOperand<'tcx>>)> {
504 let otherwise = if let Some(otherwise) = otherwise {
505 let Rvalue::Use(Operand::Constant(box const_)) = otherwise else {
506 return None;
507 };
508 Some(const_)
509 } else {
510 None
511 };
512 let consts = rvals
513 .into_iter()
514 .map(|&(case, rval)| {
515 let Rvalue::Use(Operand::Constant(box const_)) = rval else { return None };
516 Some((case, const_))
517 })
518 .try_collect()?;
519 Some((consts, otherwise))
520}
521
522fn split_first_case<'a, T>(
524 stmts: &'a [(u128, &'a T)],
525 otherwise: Option<&'a T>,
526) -> (u128, &'a T, impl Iterator<Item = &'a T>) {
527 let (first_case, first) = stmts[0];
528 (first_case, first, stmts[1..].into_iter().map(|&(_, val)| val).chain(otherwise))
529}
530
531fn identical_stmts<'tcx>(
533 stmts: &[(u128, &StatementKind<'tcx>)],
534 otherwise: Option<&StatementKind<'tcx>>,
535) -> Option<StatementKind<'tcx>> {
536 use itertools::Itertools;
537 let (_, first_stmt, others) = split_first_case(stmts, otherwise);
538 if std::iter::once(first_stmt).chain(others).all_equal() {
539 return Some(first_stmt.clone());
540 }
541 None
542}