rustc_mir_transform/
unreachable_enum_branching.rs
1use rustc_abi::Variants;
4use rustc_data_structures::fx::FxHashSet;
5use rustc_middle::bug;
6use rustc_middle::mir::{
7 BasicBlock, BasicBlockData, BasicBlocks, Body, Local, Operand, Rvalue, StatementKind,
8 TerminatorKind,
9};
10use rustc_middle::ty::layout::TyAndLayout;
11use rustc_middle::ty::{Ty, TyCtxt};
12use tracing::trace;
13
14use crate::patch::MirPatch;
15
16pub(super) struct UnreachableEnumBranching;
17
18fn get_discriminant_local(terminator: &TerminatorKind<'_>) -> Option<Local> {
19 if let TerminatorKind::SwitchInt { discr: Operand::Move(p), .. } = terminator {
20 p.as_local()
21 } else {
22 None
23 }
24}
25
26fn get_switched_on_type<'tcx>(
29 block_data: &BasicBlockData<'tcx>,
30 tcx: TyCtxt<'tcx>,
31 body: &Body<'tcx>,
32) -> Option<Ty<'tcx>> {
33 let terminator = block_data.terminator();
34
35 let local = get_discriminant_local(&terminator.kind)?;
37
38 let stmt_before_term = block_data.statements.last()?;
39
40 if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
41 && l.as_local() == Some(local)
42 {
43 let ty = place.ty(body, tcx).ty;
44 if ty.is_enum() {
45 return Some(ty);
46 }
47 }
48
49 None
50}
51
52fn variant_discriminants<'tcx>(
53 layout: &TyAndLayout<'tcx>,
54 ty: Ty<'tcx>,
55 tcx: TyCtxt<'tcx>,
56) -> FxHashSet<u128> {
57 match &layout.variants {
58 Variants::Empty => {
59 FxHashSet::default()
61 }
62 Variants::Single { index } => {
63 let mut res = FxHashSet::default();
64 res.insert(
65 ty.discriminant_for_variant(tcx, *index)
66 .map_or(index.as_u32() as u128, |discr| discr.val),
67 );
68 res
69 }
70 Variants::Multiple { variants, .. } => variants
71 .iter_enumerated()
72 .filter_map(|(idx, layout)| {
73 (!layout.is_uninhabited())
74 .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val)
75 })
76 .collect(),
77 }
78}
79
80impl<'tcx> crate::MirPass<'tcx> for UnreachableEnumBranching {
81 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
82 sess.mir_opt_level() > 0
83 }
84
85 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
86 trace!("UnreachableEnumBranching starting for {:?}", body.source);
87
88 let mut unreachable_targets = Vec::new();
89 let mut patch = MirPatch::new(body);
90
91 for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
92 trace!("processing block {:?}", bb);
93
94 if bb_data.is_cleanup {
95 continue;
96 }
97
98 let Some(discriminant_ty) = get_switched_on_type(bb_data, tcx, body) else { continue };
99
100 let layout = tcx.layout_of(body.typing_env(tcx).as_query_input(discriminant_ty));
101
102 let mut allowed_variants = if let Ok(layout) = layout {
103 variant_discriminants(&layout, discriminant_ty, tcx)
105 } else if let Some(variant_range) = discriminant_ty.variant_range(tcx) {
106 variant_range
108 .map(|variant| {
109 discriminant_ty.discriminant_for_variant(tcx, variant).unwrap().val
110 })
111 .collect()
112 } else {
113 continue;
114 };
115
116 trace!("allowed_variants = {:?}", allowed_variants);
117
118 unreachable_targets.clear();
119 let TerminatorKind::SwitchInt { targets, discr } = &bb_data.terminator().kind else {
120 bug!()
121 };
122
123 for (index, (val, _)) in targets.iter().enumerate() {
124 if !allowed_variants.remove(&val) {
125 unreachable_targets.push(index);
126 }
127 }
128 let otherwise_is_empty_unreachable =
129 body.basic_blocks[targets.otherwise()].is_empty_unreachable();
130 fn check_successors(basic_blocks: &BasicBlocks<'_>, bb: BasicBlock) -> bool {
131 let mut successors = basic_blocks[bb].terminator().successors();
151 let Some(first_successor) = successors.next() else { return true };
152 if successors.next().is_some() {
153 return true;
154 }
155 if let TerminatorKind::SwitchInt { .. } =
156 &basic_blocks[first_successor].terminator().kind
157 {
158 return false;
159 };
160 true
161 }
162 let otherwise_is_last_variant = !otherwise_is_empty_unreachable
181 && allowed_variants.len() == 1
182 && (targets.all_targets().len() <= 3
185 || check_successors(&body.basic_blocks, targets.otherwise()));
186 let replace_otherwise_to_unreachable = otherwise_is_last_variant
187 || (!otherwise_is_empty_unreachable && allowed_variants.is_empty());
188
189 if unreachable_targets.is_empty() && !replace_otherwise_to_unreachable {
190 continue;
191 }
192
193 let unreachable_block = patch.unreachable_no_cleanup_block();
194 let mut targets = targets.clone();
195 if replace_otherwise_to_unreachable {
196 if otherwise_is_last_variant {
197 #[allow(rustc::potential_query_instability)]
199 let last_variant = *allowed_variants.iter().next().unwrap();
200 targets.add_target(last_variant, targets.otherwise());
201 }
202 unreachable_targets.push(targets.iter().count());
203 }
204 for index in unreachable_targets.iter() {
205 targets.all_targets_mut()[*index] = unreachable_block;
206 }
207 patch.patch_terminator(bb, TerminatorKind::SwitchInt { targets, discr: discr.clone() });
208 }
209
210 patch.apply(body);
211 }
212
213 fn is_required(&self) -> bool {
214 false
215 }
216}