rustc_mir_transform/
unreachable_enum_branching.rs

1//! A pass that eliminates branches on uninhabited or unreachable enum variants.
2
3use rustc_abi::Variants;
4use rustc_data_structures::fx::FxHashSet;
5use rustc_middle::bug;
6use rustc_middle::mir::{
7    BasicBlock, BasicBlockData, BasicBlocks, Body, Local, Operand, Rvalue, StatementKind,
8    TerminatorKind,
9};
10use rustc_middle::ty::layout::TyAndLayout;
11use rustc_middle::ty::{Ty, TyCtxt};
12use tracing::trace;
13
14use crate::patch::MirPatch;
15
16pub(super) struct UnreachableEnumBranching;
17
18fn get_discriminant_local(terminator: &TerminatorKind<'_>) -> Option<Local> {
19    if let TerminatorKind::SwitchInt { discr: Operand::Move(p), .. } = terminator {
20        p.as_local()
21    } else {
22        None
23    }
24}
25
26/// If the basic block terminates by switching on a discriminant, this returns the `Ty` the
27/// discriminant is read from. Otherwise, returns None.
28fn get_switched_on_type<'tcx>(
29    block_data: &BasicBlockData<'tcx>,
30    tcx: TyCtxt<'tcx>,
31    body: &Body<'tcx>,
32) -> Option<Ty<'tcx>> {
33    let terminator = block_data.terminator();
34
35    // Only bother checking blocks which terminate by switching on a local.
36    let local = get_discriminant_local(&terminator.kind)?;
37
38    let stmt_before_term = block_data.statements.last()?;
39
40    if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
41        && l.as_local() == Some(local)
42    {
43        let ty = place.ty(body, tcx).ty;
44        if ty.is_enum() {
45            return Some(ty);
46        }
47    }
48
49    None
50}
51
52fn variant_discriminants<'tcx>(
53    layout: &TyAndLayout<'tcx>,
54    ty: Ty<'tcx>,
55    tcx: TyCtxt<'tcx>,
56) -> FxHashSet<u128> {
57    match &layout.variants {
58        Variants::Empty => {
59            // Uninhabited, no valid discriminant.
60            FxHashSet::default()
61        }
62        Variants::Single { index } => {
63            let mut res = FxHashSet::default();
64            res.insert(
65                ty.discriminant_for_variant(tcx, *index)
66                    .map_or(index.as_u32() as u128, |discr| discr.val),
67            );
68            res
69        }
70        Variants::Multiple { variants, .. } => variants
71            .iter_enumerated()
72            .filter_map(|(idx, layout)| {
73                (!layout.is_uninhabited())
74                    .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val)
75            })
76            .collect(),
77    }
78}
79
80impl<'tcx> crate::MirPass<'tcx> for UnreachableEnumBranching {
81    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
82        sess.mir_opt_level() > 0
83    }
84
85    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
86        trace!("UnreachableEnumBranching starting for {:?}", body.source);
87
88        let mut unreachable_targets = Vec::new();
89        let mut patch = MirPatch::new(body);
90
91        for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
92            trace!("processing block {:?}", bb);
93
94            if bb_data.is_cleanup {
95                continue;
96            }
97
98            let Some(discriminant_ty) = get_switched_on_type(bb_data, tcx, body) else { continue };
99
100            let layout = tcx.layout_of(body.typing_env(tcx).as_query_input(discriminant_ty));
101
102            let mut allowed_variants = if let Ok(layout) = layout {
103                // Find allowed variants based on uninhabited.
104                variant_discriminants(&layout, discriminant_ty, tcx)
105            } else if let Some(variant_range) = discriminant_ty.variant_range(tcx) {
106                // If there are some generics, we can still get the allowed variants.
107                variant_range
108                    .map(|variant| {
109                        discriminant_ty.discriminant_for_variant(tcx, variant).unwrap().val
110                    })
111                    .collect()
112            } else {
113                continue;
114            };
115
116            trace!("allowed_variants = {:?}", allowed_variants);
117
118            unreachable_targets.clear();
119            let TerminatorKind::SwitchInt { targets, discr } = &bb_data.terminator().kind else {
120                bug!()
121            };
122
123            for (index, (val, _)) in targets.iter().enumerate() {
124                if !allowed_variants.remove(&val) {
125                    unreachable_targets.push(index);
126                }
127            }
128            let otherwise_is_empty_unreachable =
129                body.basic_blocks[targets.otherwise()].is_empty_unreachable();
130            fn check_successors(basic_blocks: &BasicBlocks<'_>, bb: BasicBlock) -> bool {
131                // After resolving https://github.com/llvm/llvm-project/issues/78578,
132                // We can remove this check.
133                // The main issue here is that `early-tailduplication` causes compile time overhead
134                // and potential performance problems.
135                // Simply put, when encounter a switch (indirect branch) statement,
136                // `early-tailduplication` tries to duplicate the switch branch statement with BB
137                // into (each) predecessors. This makes CFG very complex.
138                // We can understand it as it transforms the following code
139                // ```rust
140                // match a { ... many cases };
141                // match b { ... many cases };
142                // ```
143                // into
144                // ```rust
145                // match a { ... many match b { goto BB cases } }
146                // ... BB cases
147                // ```
148                // Abandon this transformation when it is possible (the best effort)
149                // to encounter the problem.
150                let mut successors = basic_blocks[bb].terminator().successors();
151                let Some(first_successor) = successors.next() else { return true };
152                if successors.next().is_some() {
153                    return true;
154                }
155                if let TerminatorKind::SwitchInt { .. } =
156                    &basic_blocks[first_successor].terminator().kind
157                {
158                    return false;
159                };
160                true
161            }
162            // If and only if there is a variant that does not have a branch set, change the
163            // current of otherwise as the variant branch and set otherwise to unreachable. It
164            // transforms following code
165            // ```rust
166            // match c {
167            //     Ordering::Less => 1,
168            //     Ordering::Equal => 2,
169            //     _ => 3,
170            // }
171            // ```
172            // to
173            // ```rust
174            // match c {
175            //     Ordering::Less => 1,
176            //     Ordering::Equal => 2,
177            //     Ordering::Greater => 3,
178            // }
179            // ```
180            let otherwise_is_last_variant = !otherwise_is_empty_unreachable
181                && allowed_variants.len() == 1
182                // Despite the LLVM issue, we hope that small enum can still be transformed.
183                // This is valuable for both `a <= b` and `if let Some/Ok(v)`.
184                && (targets.all_targets().len() <= 3
185                    || check_successors(&body.basic_blocks, targets.otherwise()));
186            let replace_otherwise_to_unreachable = otherwise_is_last_variant
187                || (!otherwise_is_empty_unreachable && allowed_variants.is_empty());
188
189            if unreachable_targets.is_empty() && !replace_otherwise_to_unreachable {
190                continue;
191            }
192
193            let unreachable_block = patch.unreachable_no_cleanup_block();
194            let mut targets = targets.clone();
195            if replace_otherwise_to_unreachable {
196                if otherwise_is_last_variant {
197                    // We have checked that `allowed_variants` has only one element.
198                    #[allow(rustc::potential_query_instability)]
199                    let last_variant = *allowed_variants.iter().next().unwrap();
200                    targets.add_target(last_variant, targets.otherwise());
201                }
202                unreachable_targets.push(targets.iter().count());
203            }
204            for index in unreachable_targets.iter() {
205                targets.all_targets_mut()[*index] = unreachable_block;
206            }
207            patch.patch_terminator(bb, TerminatorKind::SwitchInt { targets, discr: discr.clone() });
208        }
209
210        patch.apply(body);
211    }
212
213    fn is_required(&self) -> bool {
214        false
215    }
216}