rustc_mir_transform/
jump_threading.rs

1//! A jump threading optimization.
2//!
3//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps
4//!    X = 0                                      X = 0
5//! ------------\      /--------              ------------
6//!    X = 1     X----X SwitchInt(X)     =>       X = 1
7//! ------------/      \--------              ------------
8//!
9//!
10//! We proceed by walking the cfg backwards starting from each `SwitchInt` terminator,
11//! looking for assignments that will turn the `SwitchInt` into a simple `Goto`.
12//!
13//! The algorithm maintains a set of replacement conditions:
14//! - `conditions[place]` contains `Condition { value, polarity: Eq, target }`
15//!   if assigning `value` to `place` turns the `SwitchInt` into `Goto { target }`.
16//! - `conditions[place]` contains `Condition { value, polarity: Ne, target }`
17//!   if assigning anything different from `value` to `place` turns the `SwitchInt`
18//!   into `Goto { target }`.
19//!
20//! In this file, we denote as `place ?= value` the existence of a replacement condition
21//! on `place` with given `value`, irrespective of the polarity and target of that
22//! replacement condition.
23//!
24//! We then walk the CFG backwards transforming the set of conditions.
25//! When we find a fulfilling assignment, we record a `ThreadingOpportunity`.
26//! All `ThreadingOpportunity`s are applied to the body, by duplicating blocks if required.
27//!
28//! The optimization search can be very heavy, as it performs a DFS on MIR starting from
29//! each `SwitchInt` terminator. To manage the complexity, we:
30//! - bound the maximum depth by a constant `MAX_BACKTRACK`;
31//! - we only traverse `Goto` terminators.
32//!
33//! We try to avoid creating irreducible control-flow by not threading through a loop header.
34//!
35//! Likewise, applying the optimisation can create a lot of new MIR, so we bound the instruction
36//! cost by `MAX_COST`.
37
38use rustc_arena::DroplessArena;
39use rustc_const_eval::const_eval::DummyMachine;
40use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
41use rustc_data_structures::fx::FxHashSet;
42use rustc_index::IndexVec;
43use rustc_index::bit_set::DenseBitSet;
44use rustc_middle::bug;
45use rustc_middle::mir::interpret::Scalar;
46use rustc_middle::mir::visit::Visitor;
47use rustc_middle::mir::*;
48use rustc_middle::ty::layout::LayoutOf;
49use rustc_middle::ty::{self, ScalarInt, TyCtxt};
50use rustc_mir_dataflow::lattice::HasBottom;
51use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
52use rustc_span::DUMMY_SP;
53use tracing::{debug, instrument, trace};
54
55use crate::cost_checker::CostChecker;
56
57pub(super) struct JumpThreading;
58
59const MAX_BACKTRACK: usize = 5;
60const MAX_COST: usize = 100;
61const MAX_PLACES: usize = 100;
62
63impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
64    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
65        sess.mir_opt_level() >= 2
66    }
67
68    #[instrument(skip_all level = "debug")]
69    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
70        let def_id = body.source.def_id();
71        debug!(?def_id);
72
73        // Optimizing coroutines creates query cycles.
74        if tcx.is_coroutine(def_id) {
75            trace!("Skipped for coroutine {:?}", def_id);
76            return;
77        }
78
79        let typing_env = body.typing_env(tcx);
80        let arena = &DroplessArena::default();
81        let mut finder = TOFinder {
82            tcx,
83            typing_env,
84            ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
85            body,
86            arena,
87            map: Map::new(tcx, body, Some(MAX_PLACES)),
88            loop_headers: loop_headers(body),
89            opportunities: Vec::new(),
90        };
91
92        for bb in body.basic_blocks.indices() {
93            finder.start_from_switch(bb);
94        }
95
96        let opportunities = finder.opportunities;
97        debug!(?opportunities);
98        if opportunities.is_empty() {
99            return;
100        }
101
102        // Verify that we do not thread through a loop header.
103        for to in opportunities.iter() {
104            assert!(to.chain.iter().all(|&block| !finder.loop_headers.contains(block)));
105        }
106        OpportunitySet::new(body, opportunities).apply(body);
107    }
108
109    fn is_required(&self) -> bool {
110        false
111    }
112}
113
114#[derive(Debug)]
115struct ThreadingOpportunity {
116    /// The list of `BasicBlock`s from the one that found the opportunity to the `SwitchInt`.
117    chain: Vec<BasicBlock>,
118    /// The `SwitchInt` will be replaced by `Goto { target }`.
119    target: BasicBlock,
120}
121
122struct TOFinder<'a, 'tcx> {
123    tcx: TyCtxt<'tcx>,
124    typing_env: ty::TypingEnv<'tcx>,
125    ecx: InterpCx<'tcx, DummyMachine>,
126    body: &'a Body<'tcx>,
127    map: Map<'tcx>,
128    loop_headers: DenseBitSet<BasicBlock>,
129    /// We use an arena to avoid cloning the slices when cloning `state`.
130    arena: &'a DroplessArena,
131    opportunities: Vec<ThreadingOpportunity>,
132}
133
134/// Represent the following statement. If we can prove that the current local is equal/not-equal
135/// to `value`, jump to `target`.
136#[derive(Copy, Clone, Debug)]
137struct Condition {
138    value: ScalarInt,
139    polarity: Polarity,
140    target: BasicBlock,
141}
142
143#[derive(Copy, Clone, Debug, Eq, PartialEq)]
144enum Polarity {
145    Ne,
146    Eq,
147}
148
149impl Condition {
150    fn matches(&self, value: ScalarInt) -> bool {
151        (self.value == value) == (self.polarity == Polarity::Eq)
152    }
153}
154
155#[derive(Copy, Clone, Debug)]
156struct ConditionSet<'a>(&'a [Condition]);
157
158impl HasBottom for ConditionSet<'_> {
159    const BOTTOM: Self = ConditionSet(&[]);
160
161    fn is_bottom(&self) -> bool {
162        self.0.is_empty()
163    }
164}
165
166impl<'a> ConditionSet<'a> {
167    fn iter(self) -> impl Iterator<Item = Condition> {
168        self.0.iter().copied()
169    }
170
171    fn iter_matches(self, value: ScalarInt) -> impl Iterator<Item = Condition> {
172        self.iter().filter(move |c| c.matches(value))
173    }
174
175    fn map(
176        self,
177        arena: &'a DroplessArena,
178        f: impl Fn(Condition) -> Option<Condition>,
179    ) -> Option<ConditionSet<'a>> {
180        let set = arena.try_alloc_from_iter(self.iter().map(|c| f(c).ok_or(()))).ok()?;
181        Some(ConditionSet(set))
182    }
183}
184
185impl<'a, 'tcx> TOFinder<'a, 'tcx> {
186    fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool {
187        state.all_bottom()
188    }
189
190    /// Recursion entry point to find threading opportunities.
191    #[instrument(level = "trace", skip(self))]
192    fn start_from_switch(&mut self, bb: BasicBlock) {
193        let bbdata = &self.body[bb];
194        if bbdata.is_cleanup || self.loop_headers.contains(bb) {
195            return;
196        }
197        let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return };
198        let Some(discr) = discr.place() else { return };
199        debug!(?discr, ?bb);
200
201        let discr_ty = discr.ty(self.body, self.tcx).ty;
202        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
203
204        let Some(discr) = self.map.find(discr.as_ref()) else { return };
205        debug!(?discr);
206
207        let cost = CostChecker::new(self.tcx, self.typing_env, None, self.body);
208        let mut state = State::new_reachable();
209
210        let conds = if let Some((value, then, else_)) = targets.as_static_if() {
211            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
212            self.arena.alloc_from_iter([
213                Condition { value, polarity: Polarity::Eq, target: then },
214                Condition { value, polarity: Polarity::Ne, target: else_ },
215            ])
216        } else {
217            self.arena.alloc_from_iter(targets.iter().filter_map(|(value, target)| {
218                let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
219                Some(Condition { value, polarity: Polarity::Eq, target })
220            }))
221        };
222        let conds = ConditionSet(conds);
223        state.insert_value_idx(discr, conds, &self.map);
224
225        self.find_opportunity(bb, state, cost, 0)
226    }
227
228    /// Recursively walk statements backwards from this bb's terminator to find threading
229    /// opportunities.
230    #[instrument(level = "trace", skip(self, cost), ret)]
231    fn find_opportunity(
232        &mut self,
233        bb: BasicBlock,
234        mut state: State<ConditionSet<'a>>,
235        mut cost: CostChecker<'_, 'tcx>,
236        depth: usize,
237    ) {
238        // Do not thread through loop headers.
239        if self.loop_headers.contains(bb) {
240            return;
241        }
242
243        debug!(cost = ?cost.cost());
244        for (statement_index, stmt) in
245            self.body.basic_blocks[bb].statements.iter().enumerate().rev()
246        {
247            if self.is_empty(&state) {
248                return;
249            }
250
251            cost.visit_statement(stmt, Location { block: bb, statement_index });
252            if cost.cost() > MAX_COST {
253                return;
254            }
255
256            // Attempt to turn the `current_condition` on `lhs` into a condition on another place.
257            self.process_statement(bb, stmt, &mut state);
258
259            // When a statement mutates a place, assignments to that place that happen
260            // above the mutation cannot fulfill a condition.
261            //   _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
262            //   _1 = 6
263            if let Some((lhs, tail)) = self.mutated_statement(stmt) {
264                state.flood_with_tail_elem(lhs.as_ref(), tail, &self.map, ConditionSet::BOTTOM);
265            }
266        }
267
268        if self.is_empty(&state) || depth >= MAX_BACKTRACK {
269            return;
270        }
271
272        let last_non_rec = self.opportunities.len();
273
274        let predecessors = &self.body.basic_blocks.predecessors()[bb];
275        if let &[pred] = &predecessors[..]
276            && bb != START_BLOCK
277        {
278            let term = self.body.basic_blocks[pred].terminator();
279            match term.kind {
280                TerminatorKind::SwitchInt { ref discr, ref targets } => {
281                    self.process_switch_int(discr, targets, bb, &mut state);
282                    self.find_opportunity(pred, state, cost, depth + 1);
283                }
284                _ => self.recurse_through_terminator(pred, || state, &cost, depth),
285            }
286        } else if let &[ref predecessors @ .., last_pred] = &predecessors[..] {
287            for &pred in predecessors {
288                self.recurse_through_terminator(pred, || state.clone(), &cost, depth);
289            }
290            self.recurse_through_terminator(last_pred, || state, &cost, depth);
291        }
292
293        let new_tos = &mut self.opportunities[last_non_rec..];
294        debug!(?new_tos);
295
296        // Try to deduplicate threading opportunities.
297        if new_tos.len() > 1
298            && new_tos.len() == predecessors.len()
299            && predecessors
300                .iter()
301                .zip(new_tos.iter())
302                .all(|(&pred, to)| to.chain == &[pred] && to.target == new_tos[0].target)
303        {
304            // All predecessors have a threading opportunity, and they all point to the same block.
305            debug!(?new_tos, "dedup");
306            let first = &mut new_tos[0];
307            *first = ThreadingOpportunity { chain: vec![bb], target: first.target };
308            self.opportunities.truncate(last_non_rec + 1);
309            return;
310        }
311
312        for op in self.opportunities[last_non_rec..].iter_mut() {
313            op.chain.push(bb);
314        }
315    }
316
317    /// Extract the mutated place from a statement.
318    ///
319    /// This method returns the `Place` so we can flood the state in case of a partial assignment.
320    ///     (_1 as Ok).0 = _5;
321    ///     (_1 as Err).0 = _6;
322    /// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as
323    /// the value may have been mangled by the second assignment.
324    ///
325    /// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can
326    /// stop at flooding the discriminant, and preserve the variant fields.
327    ///     (_1 as Some).0 = _6;
328    ///     SetDiscriminant(_1, 1);
329    ///     switchInt((_1 as Some).0)
330    #[instrument(level = "trace", skip(self), ret)]
331    fn mutated_statement(
332        &self,
333        stmt: &Statement<'tcx>,
334    ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
335        match stmt.kind {
336            StatementKind::Assign(box (place, _))
337            | StatementKind::Deinit(box place) => Some((place, None)),
338            StatementKind::SetDiscriminant { box place, variant_index: _ } => {
339                Some((place, Some(TrackElem::Discriminant)))
340            }
341            StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
342                Some((Place::from(local), None))
343            }
344            StatementKind::Retag(..)
345            | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
346            // copy_nonoverlapping takes pointers and mutated the pointed-to value.
347            | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
348            | StatementKind::AscribeUserType(..)
349            | StatementKind::Coverage(..)
350            | StatementKind::FakeRead(..)
351            | StatementKind::ConstEvalCounter
352            | StatementKind::PlaceMention(..)
353            | StatementKind::BackwardIncompatibleDropHint { .. }
354            | StatementKind::Nop => None,
355        }
356    }
357
358    #[instrument(level = "trace", skip(self))]
359    fn process_immediate(
360        &mut self,
361        bb: BasicBlock,
362        lhs: PlaceIndex,
363        rhs: ImmTy<'tcx>,
364        state: &mut State<ConditionSet<'a>>,
365    ) {
366        let register_opportunity = |c: Condition| {
367            debug!(?bb, ?c.target, "register");
368            self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
369        };
370
371        if let Some(conditions) = state.try_get_idx(lhs, &self.map)
372            && let Immediate::Scalar(Scalar::Int(int)) = *rhs
373        {
374            conditions.iter_matches(int).for_each(register_opportunity);
375        }
376    }
377
378    /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
379    #[instrument(level = "trace", skip(self))]
380    fn process_constant(
381        &mut self,
382        bb: BasicBlock,
383        lhs: PlaceIndex,
384        constant: OpTy<'tcx>,
385        state: &mut State<ConditionSet<'a>>,
386    ) {
387        self.map.for_each_projection_value(
388            lhs,
389            constant,
390            &mut |elem, op| match elem {
391                TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(),
392                TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
393                TrackElem::Discriminant => {
394                    let variant = self.ecx.read_discriminant(op).discard_err()?;
395                    let discr_value =
396                        self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
397                    Some(discr_value.into())
398                }
399                TrackElem::DerefLen => {
400                    let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
401                    let len_usize = op.len(&self.ecx).discard_err()?;
402                    let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
403                    Some(ImmTy::from_uint(len_usize, layout).into())
404                }
405            },
406            &mut |place, op| {
407                if let Some(conditions) = state.try_get_idx(place, &self.map)
408                    && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
409                    && let Some(imm) = imm.right()
410                    && let Immediate::Scalar(Scalar::Int(int)) = *imm
411                {
412                    conditions.iter_matches(int).for_each(|c: Condition| {
413                        self.opportunities
414                            .push(ThreadingOpportunity { chain: vec![bb], target: c.target })
415                    })
416                }
417            },
418        );
419    }
420
421    #[instrument(level = "trace", skip(self))]
422    fn process_operand(
423        &mut self,
424        bb: BasicBlock,
425        lhs: PlaceIndex,
426        rhs: &Operand<'tcx>,
427        state: &mut State<ConditionSet<'a>>,
428    ) {
429        match rhs {
430            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
431            Operand::Constant(constant) => {
432                let Some(constant) =
433                    self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
434                else {
435                    return;
436                };
437                self.process_constant(bb, lhs, constant, state);
438            }
439            // Transfer the conditions on the copied rhs.
440            Operand::Move(rhs) | Operand::Copy(rhs) => {
441                let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
442                state.insert_place_idx(rhs, lhs, &self.map);
443            }
444        }
445    }
446
447    #[instrument(level = "trace", skip(self))]
448    fn process_assign(
449        &mut self,
450        bb: BasicBlock,
451        lhs_place: &Place<'tcx>,
452        rhs: &Rvalue<'tcx>,
453        state: &mut State<ConditionSet<'a>>,
454    ) {
455        let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
456        match rhs {
457            Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state),
458            // Transfer the conditions on the copy rhs.
459            Rvalue::CopyForDeref(rhs) => self.process_operand(bb, lhs, &Operand::Copy(*rhs), state),
460            Rvalue::Discriminant(rhs) => {
461                let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
462                state.insert_place_idx(rhs, lhs, &self.map);
463            }
464            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
465            Rvalue::Aggregate(box kind, operands) => {
466                let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
467                let lhs = match kind {
468                    // Do not support unions.
469                    AggregateKind::Adt(.., Some(_)) => return,
470                    AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
471                        if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
472                            && let Some(discr_value) = self
473                                .ecx
474                                .discriminant_for_variant(agg_ty, *variant_index)
475                                .discard_err()
476                        {
477                            self.process_immediate(bb, discr_target, discr_value, state);
478                        }
479                        if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
480                            idx
481                        } else {
482                            return;
483                        }
484                    }
485                    _ => lhs,
486                };
487                for (field_index, operand) in operands.iter_enumerated() {
488                    if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
489                        self.process_operand(bb, field, operand, state);
490                    }
491                }
492            }
493            // Transfer the conditions on the copy rhs, after inverting the value of the condition.
494            Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
495                let layout = self.ecx.layout_of(place.ty(self.body, self.tcx).ty).unwrap();
496                let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
497                let Some(place) = self.map.find(place.as_ref()) else { return };
498                let Some(conds) = conditions.map(self.arena, |mut cond| {
499                    cond.value = self
500                        .ecx
501                        .unary_op(UnOp::Not, &ImmTy::from_scalar_int(cond.value, layout))
502                        .discard_err()?
503                        .to_scalar_int()
504                        .discard_err()?;
505                    Some(cond)
506                }) else {
507                    return;
508                };
509                state.insert_value_idx(place, conds, &self.map);
510            }
511            // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
512            // Create a condition on `rhs ?= B`.
513            Rvalue::BinaryOp(
514                op,
515                box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
516                | box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
517            ) => {
518                let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
519                let Some(place) = self.map.find(place.as_ref()) else { return };
520                let equals = match op {
521                    BinOp::Eq => ScalarInt::TRUE,
522                    BinOp::Ne => ScalarInt::FALSE,
523                    _ => return,
524                };
525                if value.const_.ty().is_floating_point() {
526                    // Floating point equality does not follow bit-patterns.
527                    // -0.0 and NaN both have special rules for equality,
528                    // and therefore we cannot use integer comparisons for them.
529                    // Avoid handling them, though this could be extended in the future.
530                    return;
531                }
532                let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
533                else {
534                    return;
535                };
536                let Some(conds) = conditions.map(self.arena, |c| {
537                    Some(Condition {
538                        value,
539                        polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
540                        ..c
541                    })
542                }) else {
543                    return;
544                };
545                state.insert_value_idx(place, conds, &self.map);
546            }
547
548            _ => {}
549        }
550    }
551
552    #[instrument(level = "trace", skip(self))]
553    fn process_statement(
554        &mut self,
555        bb: BasicBlock,
556        stmt: &Statement<'tcx>,
557        state: &mut State<ConditionSet<'a>>,
558    ) {
559        let register_opportunity = |c: Condition| {
560            debug!(?bb, ?c.target, "register");
561            self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
562        };
563
564        // Below, `lhs` is the return value of `mutated_statement`,
565        // the place to which `conditions` apply.
566
567        match &stmt.kind {
568            // If we expect `discriminant(place) ?= A`,
569            // we have an opportunity if `variant_index ?= A`.
570            StatementKind::SetDiscriminant { box place, variant_index } => {
571                let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
572                let enum_ty = place.ty(self.body, self.tcx).ty;
573                // `SetDiscriminant` guarantees that the discriminant is now `variant_index`.
574                // Even if the discriminant write does nothing due to niches, it is UB to set the
575                // discriminant when the data does not encode the desired discriminant.
576                let Some(discr) =
577                    self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
578                else {
579                    return;
580                };
581                self.process_immediate(bb, discr_target, discr, state)
582            }
583            // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
584            StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
585                Operand::Copy(place) | Operand::Move(place),
586            )) => {
587                let Some(conditions) = state.try_get(place.as_ref(), &self.map) else { return };
588                conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity)
589            }
590            StatementKind::Assign(box (lhs_place, rhs)) => {
591                self.process_assign(bb, lhs_place, rhs, state)
592            }
593            _ => {}
594        }
595    }
596
597    #[instrument(level = "trace", skip(self, state, cost))]
598    fn recurse_through_terminator(
599        &mut self,
600        bb: BasicBlock,
601        // Pass a closure that may clone the state, as we don't want to do it each time.
602        state: impl FnOnce() -> State<ConditionSet<'a>>,
603        cost: &CostChecker<'_, 'tcx>,
604        depth: usize,
605    ) {
606        let term = self.body.basic_blocks[bb].terminator();
607        let place_to_flood = match term.kind {
608            // We come from a target, so those are not possible.
609            TerminatorKind::UnwindResume
610            | TerminatorKind::UnwindTerminate(_)
611            | TerminatorKind::Return
612            | TerminatorKind::TailCall { .. }
613            | TerminatorKind::Unreachable
614            | TerminatorKind::CoroutineDrop => bug!("{term:?} has no terminators"),
615            // Disallowed during optimizations.
616            TerminatorKind::FalseEdge { .. }
617            | TerminatorKind::FalseUnwind { .. }
618            | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
619            // Cannot reason about inline asm.
620            TerminatorKind::InlineAsm { .. } => return,
621            // `SwitchInt` is handled specially.
622            TerminatorKind::SwitchInt { .. } => return,
623            // We can recurse, no thing particular to do.
624            TerminatorKind::Goto { .. } => None,
625            // Flood the overwritten place, and progress through.
626            TerminatorKind::Drop { place: destination, .. }
627            | TerminatorKind::Call { destination, .. } => Some(destination),
628            // Ignore, as this can be a no-op at codegen time.
629            TerminatorKind::Assert { .. } => None,
630        };
631
632        // We can recurse through this terminator.
633        let mut state = state();
634        if let Some(place_to_flood) = place_to_flood {
635            state.flood_with(place_to_flood.as_ref(), &self.map, ConditionSet::BOTTOM);
636        }
637        self.find_opportunity(bb, state, cost.clone(), depth + 1)
638    }
639
640    #[instrument(level = "trace", skip(self))]
641    fn process_switch_int(
642        &mut self,
643        discr: &Operand<'tcx>,
644        targets: &SwitchTargets,
645        target_bb: BasicBlock,
646        state: &mut State<ConditionSet<'a>>,
647    ) {
648        debug_assert_ne!(target_bb, START_BLOCK);
649        debug_assert_eq!(self.body.basic_blocks.predecessors()[target_bb].len(), 1);
650
651        let Some(discr) = discr.place() else { return };
652        let discr_ty = discr.ty(self.body, self.tcx).ty;
653        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else {
654            return;
655        };
656        let Some(conditions) = state.try_get(discr.as_ref(), &self.map) else { return };
657
658        if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
659            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
660            debug_assert_eq!(targets.iter().filter(|&(_, target)| target == target_bb).count(), 1);
661
662            // We are inside `target_bb`. Since we have a single predecessor, we know we passed
663            // through the `SwitchInt` before arriving here. Therefore, we know that
664            // `discr == value`. If one condition can be fulfilled by `discr == value`,
665            // that's an opportunity.
666            for c in conditions.iter_matches(value) {
667                debug!(?target_bb, ?c.target, "register");
668                self.opportunities.push(ThreadingOpportunity { chain: vec![], target: c.target });
669            }
670        } else if let Some((value, _, else_bb)) = targets.as_static_if()
671            && target_bb == else_bb
672        {
673            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
674
675            // We only know that `discr != value`. That's much weaker information than
676            // the equality we had in the previous arm. All we can conclude is that
677            // the replacement condition `discr != value` can be threaded, and nothing else.
678            for c in conditions.iter() {
679                if c.value == value && c.polarity == Polarity::Ne {
680                    debug!(?target_bb, ?c.target, "register");
681                    self.opportunities
682                        .push(ThreadingOpportunity { chain: vec![], target: c.target });
683                }
684            }
685        }
686    }
687}
688
689struct OpportunitySet {
690    opportunities: Vec<ThreadingOpportunity>,
691    /// For each bb, give the TOs in which it appears. The pair corresponds to the index
692    /// in `opportunities` and the index in `ThreadingOpportunity::chain`.
693    involving_tos: IndexVec<BasicBlock, Vec<(usize, usize)>>,
694    /// Cache the number of predecessors for each block, as we clear the basic block cache..
695    predecessors: IndexVec<BasicBlock, usize>,
696}
697
698impl OpportunitySet {
699    fn new(body: &Body<'_>, opportunities: Vec<ThreadingOpportunity>) -> OpportunitySet {
700        let mut involving_tos = IndexVec::from_elem(Vec::new(), &body.basic_blocks);
701        for (index, to) in opportunities.iter().enumerate() {
702            for (ibb, &bb) in to.chain.iter().enumerate() {
703                involving_tos[bb].push((index, ibb));
704            }
705            involving_tos[to.target].push((index, to.chain.len()));
706        }
707        let predecessors = predecessor_count(body);
708        OpportunitySet { opportunities, involving_tos, predecessors }
709    }
710
711    /// Apply the opportunities on the graph.
712    fn apply(&mut self, body: &mut Body<'_>) {
713        for i in 0..self.opportunities.len() {
714            self.apply_once(i, body);
715        }
716    }
717
718    #[instrument(level = "trace", skip(self, body))]
719    fn apply_once(&mut self, index: usize, body: &mut Body<'_>) {
720        debug!(?self.predecessors);
721        debug!(?self.involving_tos);
722
723        // Check that `predecessors` satisfies its invariant.
724        debug_assert_eq!(self.predecessors, predecessor_count(body));
725
726        // Remove the TO from the vector to allow modifying the other ones later.
727        let op = &mut self.opportunities[index];
728        debug!(?op);
729        let op_chain = std::mem::take(&mut op.chain);
730        let op_target = op.target;
731        debug_assert_eq!(op_chain.len(), op_chain.iter().collect::<FxHashSet<_>>().len());
732
733        let Some((current, chain)) = op_chain.split_first() else { return };
734        let basic_blocks = body.basic_blocks.as_mut();
735
736        // Invariant: the control-flow is well-formed at the end of each iteration.
737        let mut current = *current;
738        for &succ in chain {
739            debug!(?current, ?succ);
740
741            // `succ` must be a successor of `current`. If it is not, this means this TO is not
742            // satisfiable and a previous TO erased this edge, so we bail out.
743            if !basic_blocks[current].terminator().successors().any(|s| s == succ) {
744                debug!("impossible");
745                return;
746            }
747
748            // Fast path: `succ` is only used once, so we can reuse it directly.
749            if self.predecessors[succ] == 1 {
750                debug!("single");
751                current = succ;
752                continue;
753            }
754
755            let new_succ = basic_blocks.push(basic_blocks[succ].clone());
756            debug!(?new_succ);
757
758            // Replace `succ` by `new_succ` where it appears.
759            let mut num_edges = 0;
760            basic_blocks[current].terminator_mut().successors_mut(|s| {
761                if *s == succ {
762                    *s = new_succ;
763                    num_edges += 1;
764                }
765            });
766
767            // Update predecessors with the new block.
768            let _new_succ = self.predecessors.push(num_edges);
769            debug_assert_eq!(new_succ, _new_succ);
770            self.predecessors[succ] -= num_edges;
771            self.update_predecessor_count(basic_blocks[new_succ].terminator(), Update::Incr);
772
773            // Replace the `current -> succ` edge by `current -> new_succ` in all the following
774            // TOs. This is necessary to avoid trying to thread through a non-existing edge. We
775            // use `involving_tos` here to avoid traversing the full set of TOs on each iteration.
776            let mut new_involved = Vec::new();
777            for &(to_index, in_to_index) in &self.involving_tos[current] {
778                // That TO has already been applied, do nothing.
779                if to_index <= index {
780                    continue;
781                }
782
783                let other_to = &mut self.opportunities[to_index];
784                if other_to.chain.get(in_to_index) != Some(&current) {
785                    continue;
786                }
787                let s = other_to.chain.get_mut(in_to_index + 1).unwrap_or(&mut other_to.target);
788                if *s == succ {
789                    // `other_to` references the `current -> succ` edge, so replace `succ`.
790                    *s = new_succ;
791                    new_involved.push((to_index, in_to_index + 1));
792                }
793            }
794
795            // The TOs that we just updated now reference `new_succ`. Update `involving_tos`
796            // in case we need to duplicate an edge starting at `new_succ` later.
797            let _new_succ = self.involving_tos.push(new_involved);
798            debug_assert_eq!(new_succ, _new_succ);
799
800            current = new_succ;
801        }
802
803        let current = &mut basic_blocks[current];
804        self.update_predecessor_count(current.terminator(), Update::Decr);
805        current.terminator_mut().kind = TerminatorKind::Goto { target: op_target };
806        self.predecessors[op_target] += 1;
807    }
808
809    fn update_predecessor_count(&mut self, terminator: &Terminator<'_>, incr: Update) {
810        match incr {
811            Update::Incr => {
812                for s in terminator.successors() {
813                    self.predecessors[s] += 1;
814                }
815            }
816            Update::Decr => {
817                for s in terminator.successors() {
818                    self.predecessors[s] -= 1;
819                }
820            }
821        }
822    }
823}
824
825fn predecessor_count(body: &Body<'_>) -> IndexVec<BasicBlock, usize> {
826    let mut predecessors: IndexVec<_, _> =
827        body.basic_blocks.predecessors().iter().map(|ps| ps.len()).collect();
828    predecessors[START_BLOCK] += 1; // Account for the implicit entry edge.
829    predecessors
830}
831
832enum Update {
833    Incr,
834    Decr,
835}
836
837/// Compute the set of loop headers in the given body. We define a loop header as a block which has
838/// at least a predecessor which it dominates. This definition is only correct for reducible CFGs.
839/// But if the CFG is already irreducible, there is no point in trying much harder.
840/// is already irreducible.
841fn loop_headers(body: &Body<'_>) -> DenseBitSet<BasicBlock> {
842    let mut loop_headers = DenseBitSet::new_empty(body.basic_blocks.len());
843    let dominators = body.basic_blocks.dominators();
844    // Only visit reachable blocks.
845    for (bb, bbdata) in traversal::preorder(body) {
846        for succ in bbdata.terminator().successors() {
847            if dominators.dominates(succ, bb) {
848                loop_headers.insert(succ);
849            }
850        }
851    }
852    loop_headers
853}