rustc_mir_transform/
jump_threading.rs

1//! A jump threading optimization.
2//!
3//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps
4//!    X = 0                                      X = 0
5//! ------------\      /--------              ------------
6//!    X = 1     X----X SwitchInt(X)     =>       X = 1
7//! ------------/      \--------              ------------
8//!
9//!
10//! We proceed by walking the cfg backwards starting from each `SwitchInt` terminator,
11//! looking for assignments that will turn the `SwitchInt` into a simple `Goto`.
12//!
13//! The algorithm maintains a set of replacement conditions:
14//! - `conditions[place]` contains `Condition { value, polarity: Eq, target }`
15//!   if assigning `value` to `place` turns the `SwitchInt` into `Goto { target }`.
16//! - `conditions[place]` contains `Condition { value, polarity: Ne, target }`
17//!   if assigning anything different from `value` to `place` turns the `SwitchInt`
18//!   into `Goto { target }`.
19//!
20//! In this file, we denote as `place ?= value` the existence of a replacement condition
21//! on `place` with given `value`, irrespective of the polarity and target of that
22//! replacement condition.
23//!
24//! We then walk the CFG backwards transforming the set of conditions.
25//! When we find a fulfilling assignment, we record a `ThreadingOpportunity`.
26//! All `ThreadingOpportunity`s are applied to the body, by duplicating blocks if required.
27//!
28//! The optimization search can be very heavy, as it performs a DFS on MIR starting from
29//! each `SwitchInt` terminator. To manage the complexity, we:
30//! - bound the maximum depth by a constant `MAX_BACKTRACK`;
31//! - we only traverse `Goto` terminators.
32//!
33//! We try to avoid creating irreducible control-flow by not threading through a loop header.
34//!
35//! Likewise, applying the optimisation can create a lot of new MIR, so we bound the instruction
36//! cost by `MAX_COST`.
37
38use rustc_arena::DroplessArena;
39use rustc_const_eval::const_eval::DummyMachine;
40use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
41use rustc_data_structures::fx::FxHashSet;
42use rustc_index::IndexVec;
43use rustc_index::bit_set::DenseBitSet;
44use rustc_middle::bug;
45use rustc_middle::mir::interpret::Scalar;
46use rustc_middle::mir::visit::Visitor;
47use rustc_middle::mir::*;
48use rustc_middle::ty::layout::LayoutOf;
49use rustc_middle::ty::{self, ScalarInt, TyCtxt};
50use rustc_mir_dataflow::lattice::HasBottom;
51use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
52use rustc_span::DUMMY_SP;
53use tracing::{debug, instrument, trace};
54
55use crate::cost_checker::CostChecker;
56
57pub(super) struct JumpThreading;
58
59const MAX_BACKTRACK: usize = 5;
60const MAX_COST: usize = 100;
61const MAX_PLACES: usize = 100;
62
63impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
64    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
65        sess.mir_opt_level() >= 2
66    }
67
68    #[instrument(skip_all level = "debug")]
69    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
70        let def_id = body.source.def_id();
71        debug!(?def_id);
72
73        // Optimizing coroutines creates query cycles.
74        if tcx.is_coroutine(def_id) {
75            trace!("Skipped for coroutine {:?}", def_id);
76            return;
77        }
78
79        let typing_env = body.typing_env(tcx);
80        let arena = &DroplessArena::default();
81        let mut finder = TOFinder {
82            tcx,
83            typing_env,
84            ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
85            body,
86            arena,
87            map: Map::new(tcx, body, Some(MAX_PLACES)),
88            loop_headers: loop_headers(body),
89            opportunities: Vec::new(),
90        };
91
92        for bb in body.basic_blocks.indices() {
93            finder.start_from_switch(bb);
94        }
95
96        let opportunities = finder.opportunities;
97        debug!(?opportunities);
98        if opportunities.is_empty() {
99            return;
100        }
101
102        // Verify that we do not thread through a loop header.
103        for to in opportunities.iter() {
104            assert!(to.chain.iter().all(|&block| !finder.loop_headers.contains(block)));
105        }
106        OpportunitySet::new(body, opportunities).apply(body);
107    }
108
109    fn is_required(&self) -> bool {
110        false
111    }
112}
113
114#[derive(Debug)]
115struct ThreadingOpportunity {
116    /// The list of `BasicBlock`s from the one that found the opportunity to the `SwitchInt`.
117    chain: Vec<BasicBlock>,
118    /// The `SwitchInt` will be replaced by `Goto { target }`.
119    target: BasicBlock,
120}
121
122struct TOFinder<'a, 'tcx> {
123    tcx: TyCtxt<'tcx>,
124    typing_env: ty::TypingEnv<'tcx>,
125    ecx: InterpCx<'tcx, DummyMachine>,
126    body: &'a Body<'tcx>,
127    map: Map<'tcx>,
128    loop_headers: DenseBitSet<BasicBlock>,
129    /// We use an arena to avoid cloning the slices when cloning `state`.
130    arena: &'a DroplessArena,
131    opportunities: Vec<ThreadingOpportunity>,
132}
133
134/// Represent the following statement. If we can prove that the current local is equal/not-equal
135/// to `value`, jump to `target`.
136#[derive(Copy, Clone, Debug)]
137struct Condition {
138    value: ScalarInt,
139    polarity: Polarity,
140    target: BasicBlock,
141}
142
143#[derive(Copy, Clone, Debug, Eq, PartialEq)]
144enum Polarity {
145    Ne,
146    Eq,
147}
148
149impl Condition {
150    fn matches(&self, value: ScalarInt) -> bool {
151        (self.value == value) == (self.polarity == Polarity::Eq)
152    }
153
154    fn inv(mut self) -> Self {
155        self.polarity = match self.polarity {
156            Polarity::Eq => Polarity::Ne,
157            Polarity::Ne => Polarity::Eq,
158        };
159        self
160    }
161}
162
163#[derive(Copy, Clone, Debug)]
164struct ConditionSet<'a>(&'a [Condition]);
165
166impl HasBottom for ConditionSet<'_> {
167    const BOTTOM: Self = ConditionSet(&[]);
168
169    fn is_bottom(&self) -> bool {
170        self.0.is_empty()
171    }
172}
173
174impl<'a> ConditionSet<'a> {
175    fn iter(self) -> impl Iterator<Item = Condition> + 'a {
176        self.0.iter().copied()
177    }
178
179    fn iter_matches(self, value: ScalarInt) -> impl Iterator<Item = Condition> + 'a {
180        self.iter().filter(move |c| c.matches(value))
181    }
182
183    fn map(self, arena: &'a DroplessArena, f: impl Fn(Condition) -> Condition) -> ConditionSet<'a> {
184        ConditionSet(arena.alloc_from_iter(self.iter().map(f)))
185    }
186}
187
188impl<'a, 'tcx> TOFinder<'a, 'tcx> {
189    fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool {
190        state.all_bottom()
191    }
192
193    /// Recursion entry point to find threading opportunities.
194    #[instrument(level = "trace", skip(self))]
195    fn start_from_switch(&mut self, bb: BasicBlock) {
196        let bbdata = &self.body[bb];
197        if bbdata.is_cleanup || self.loop_headers.contains(bb) {
198            return;
199        }
200        let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return };
201        let Some(discr) = discr.place() else { return };
202        debug!(?discr, ?bb);
203
204        let discr_ty = discr.ty(self.body, self.tcx).ty;
205        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else {
206            return;
207        };
208
209        let Some(discr) = self.map.find(discr.as_ref()) else { return };
210        debug!(?discr);
211
212        let cost = CostChecker::new(self.tcx, self.typing_env, None, self.body);
213        let mut state = State::new_reachable();
214
215        let conds = if let Some((value, then, else_)) = targets.as_static_if() {
216            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
217            self.arena.alloc_from_iter([
218                Condition { value, polarity: Polarity::Eq, target: then },
219                Condition { value, polarity: Polarity::Ne, target: else_ },
220            ])
221        } else {
222            self.arena.alloc_from_iter(targets.iter().filter_map(|(value, target)| {
223                let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
224                Some(Condition { value, polarity: Polarity::Eq, target })
225            }))
226        };
227        let conds = ConditionSet(conds);
228        state.insert_value_idx(discr, conds, &self.map);
229
230        self.find_opportunity(bb, state, cost, 0);
231    }
232
233    /// Recursively walk statements backwards from this bb's terminator to find threading
234    /// opportunities.
235    #[instrument(level = "trace", skip(self, cost), ret)]
236    fn find_opportunity(
237        &mut self,
238        bb: BasicBlock,
239        mut state: State<ConditionSet<'a>>,
240        mut cost: CostChecker<'_, 'tcx>,
241        depth: usize,
242    ) {
243        // Do not thread through loop headers.
244        if self.loop_headers.contains(bb) {
245            return;
246        }
247
248        debug!(cost = ?cost.cost());
249        for (statement_index, stmt) in
250            self.body.basic_blocks[bb].statements.iter().enumerate().rev()
251        {
252            if self.is_empty(&state) {
253                return;
254            }
255
256            cost.visit_statement(stmt, Location { block: bb, statement_index });
257            if cost.cost() > MAX_COST {
258                return;
259            }
260
261            // Attempt to turn the `current_condition` on `lhs` into a condition on another place.
262            self.process_statement(bb, stmt, &mut state);
263
264            // When a statement mutates a place, assignments to that place that happen
265            // above the mutation cannot fulfill a condition.
266            //   _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
267            //   _1 = 6
268            if let Some((lhs, tail)) = self.mutated_statement(stmt) {
269                state.flood_with_tail_elem(lhs.as_ref(), tail, &self.map, ConditionSet::BOTTOM);
270            }
271        }
272
273        if self.is_empty(&state) || depth >= MAX_BACKTRACK {
274            return;
275        }
276
277        let last_non_rec = self.opportunities.len();
278
279        let predecessors = &self.body.basic_blocks.predecessors()[bb];
280        if let &[pred] = &predecessors[..]
281            && bb != START_BLOCK
282        {
283            let term = self.body.basic_blocks[pred].terminator();
284            match term.kind {
285                TerminatorKind::SwitchInt { ref discr, ref targets } => {
286                    self.process_switch_int(discr, targets, bb, &mut state);
287                    self.find_opportunity(pred, state, cost, depth + 1);
288                }
289                _ => self.recurse_through_terminator(pred, || state, &cost, depth),
290            }
291        } else if let &[ref predecessors @ .., last_pred] = &predecessors[..] {
292            for &pred in predecessors {
293                self.recurse_through_terminator(pred, || state.clone(), &cost, depth);
294            }
295            self.recurse_through_terminator(last_pred, || state, &cost, depth);
296        }
297
298        let new_tos = &mut self.opportunities[last_non_rec..];
299        debug!(?new_tos);
300
301        // Try to deduplicate threading opportunities.
302        if new_tos.len() > 1
303            && new_tos.len() == predecessors.len()
304            && predecessors
305                .iter()
306                .zip(new_tos.iter())
307                .all(|(&pred, to)| to.chain == &[pred] && to.target == new_tos[0].target)
308        {
309            // All predecessors have a threading opportunity, and they all point to the same block.
310            debug!(?new_tos, "dedup");
311            let first = &mut new_tos[0];
312            *first = ThreadingOpportunity { chain: vec![bb], target: first.target };
313            self.opportunities.truncate(last_non_rec + 1);
314            return;
315        }
316
317        for op in self.opportunities[last_non_rec..].iter_mut() {
318            op.chain.push(bb);
319        }
320    }
321
322    /// Extract the mutated place from a statement.
323    ///
324    /// This method returns the `Place` so we can flood the state in case of a partial assignment.
325    ///     (_1 as Ok).0 = _5;
326    ///     (_1 as Err).0 = _6;
327    /// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as
328    /// the value may have been mangled by the second assignment.
329    ///
330    /// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can
331    /// stop at flooding the discriminant, and preserve the variant fields.
332    ///     (_1 as Some).0 = _6;
333    ///     SetDiscriminant(_1, 1);
334    ///     switchInt((_1 as Some).0)
335    #[instrument(level = "trace", skip(self), ret)]
336    fn mutated_statement(
337        &self,
338        stmt: &Statement<'tcx>,
339    ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
340        match stmt.kind {
341            StatementKind::Assign(box (place, _))
342            | StatementKind::Deinit(box place) => Some((place, None)),
343            StatementKind::SetDiscriminant { box place, variant_index: _ } => {
344                Some((place, Some(TrackElem::Discriminant)))
345            }
346            StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
347                Some((Place::from(local), None))
348            }
349            StatementKind::Retag(..)
350            | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
351            // copy_nonoverlapping takes pointers and mutated the pointed-to value.
352            | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
353            | StatementKind::AscribeUserType(..)
354            | StatementKind::Coverage(..)
355            | StatementKind::FakeRead(..)
356            | StatementKind::ConstEvalCounter
357            | StatementKind::PlaceMention(..)
358            | StatementKind::BackwardIncompatibleDropHint { .. }
359            | StatementKind::Nop => None,
360        }
361    }
362
363    #[instrument(level = "trace", skip(self))]
364    fn process_immediate(
365        &mut self,
366        bb: BasicBlock,
367        lhs: PlaceIndex,
368        rhs: ImmTy<'tcx>,
369        state: &mut State<ConditionSet<'a>>,
370    ) {
371        let register_opportunity = |c: Condition| {
372            debug!(?bb, ?c.target, "register");
373            self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
374        };
375
376        if let Some(conditions) = state.try_get_idx(lhs, &self.map)
377            && let Immediate::Scalar(Scalar::Int(int)) = *rhs
378        {
379            conditions.iter_matches(int).for_each(register_opportunity);
380        }
381    }
382
383    /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
384    #[instrument(level = "trace", skip(self))]
385    fn process_constant(
386        &mut self,
387        bb: BasicBlock,
388        lhs: PlaceIndex,
389        constant: OpTy<'tcx>,
390        state: &mut State<ConditionSet<'a>>,
391    ) {
392        self.map.for_each_projection_value(
393            lhs,
394            constant,
395            &mut |elem, op| match elem {
396                TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).discard_err(),
397                TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
398                TrackElem::Discriminant => {
399                    let variant = self.ecx.read_discriminant(op).discard_err()?;
400                    let discr_value =
401                        self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
402                    Some(discr_value.into())
403                }
404                TrackElem::DerefLen => {
405                    let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
406                    let len_usize = op.len(&self.ecx).discard_err()?;
407                    let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
408                    Some(ImmTy::from_uint(len_usize, layout).into())
409                }
410            },
411            &mut |place, op| {
412                if let Some(conditions) = state.try_get_idx(place, &self.map)
413                    && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
414                    && let Some(imm) = imm.right()
415                    && let Immediate::Scalar(Scalar::Int(int)) = *imm
416                {
417                    conditions.iter_matches(int).for_each(|c: Condition| {
418                        self.opportunities
419                            .push(ThreadingOpportunity { chain: vec![bb], target: c.target })
420                    })
421                }
422            },
423        );
424    }
425
426    #[instrument(level = "trace", skip(self))]
427    fn process_operand(
428        &mut self,
429        bb: BasicBlock,
430        lhs: PlaceIndex,
431        rhs: &Operand<'tcx>,
432        state: &mut State<ConditionSet<'a>>,
433    ) {
434        match rhs {
435            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
436            Operand::Constant(constant) => {
437                let Some(constant) =
438                    self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
439                else {
440                    return;
441                };
442                self.process_constant(bb, lhs, constant, state);
443            }
444            // Transfer the conditions on the copied rhs.
445            Operand::Move(rhs) | Operand::Copy(rhs) => {
446                let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
447                state.insert_place_idx(rhs, lhs, &self.map);
448            }
449        }
450    }
451
452    #[instrument(level = "trace", skip(self))]
453    fn process_assign(
454        &mut self,
455        bb: BasicBlock,
456        lhs_place: &Place<'tcx>,
457        rhs: &Rvalue<'tcx>,
458        state: &mut State<ConditionSet<'a>>,
459    ) {
460        let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
461        match rhs {
462            Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state),
463            // Transfer the conditions on the copy rhs.
464            Rvalue::CopyForDeref(rhs) => self.process_operand(bb, lhs, &Operand::Copy(*rhs), state),
465            Rvalue::Discriminant(rhs) => {
466                let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
467                state.insert_place_idx(rhs, lhs, &self.map);
468            }
469            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
470            Rvalue::Aggregate(box ref kind, ref operands) => {
471                let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
472                let lhs = match kind {
473                    // Do not support unions.
474                    AggregateKind::Adt(.., Some(_)) => return,
475                    AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
476                        if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
477                            && let Some(discr_value) = self
478                                .ecx
479                                .discriminant_for_variant(agg_ty, *variant_index)
480                                .discard_err()
481                        {
482                            self.process_immediate(bb, discr_target, discr_value, state);
483                        }
484                        if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
485                            idx
486                        } else {
487                            return;
488                        }
489                    }
490                    _ => lhs,
491                };
492                for (field_index, operand) in operands.iter_enumerated() {
493                    if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
494                        self.process_operand(bb, field, operand, state);
495                    }
496                }
497            }
498            // Transfer the conditions on the copy rhs, after inversing polarity.
499            Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
500                if !place.ty(self.body, self.tcx).ty.is_bool() {
501                    // Constructing the conditions by inverting the polarity
502                    // of equality is only correct for bools. That is to say,
503                    // `!a == b` is not `a != b` for integers greater than 1 bit.
504                    return;
505                }
506                let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
507                let Some(place) = self.map.find(place.as_ref()) else { return };
508                // FIXME: I think This could be generalized to not bool if we
509                // actually perform a logical not on the condition's value.
510                let conds = conditions.map(self.arena, Condition::inv);
511                state.insert_value_idx(place, conds, &self.map);
512            }
513            // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
514            // Create a condition on `rhs ?= B`.
515            Rvalue::BinaryOp(
516                op,
517                box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
518                | box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
519            ) => {
520                let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
521                let Some(place) = self.map.find(place.as_ref()) else { return };
522                let equals = match op {
523                    BinOp::Eq => ScalarInt::TRUE,
524                    BinOp::Ne => ScalarInt::FALSE,
525                    _ => return,
526                };
527                if value.const_.ty().is_floating_point() {
528                    // Floating point equality does not follow bit-patterns.
529                    // -0.0 and NaN both have special rules for equality,
530                    // and therefore we cannot use integer comparisons for them.
531                    // Avoid handling them, though this could be extended in the future.
532                    return;
533                }
534                let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
535                else {
536                    return;
537                };
538                let conds = conditions.map(self.arena, |c| Condition {
539                    value,
540                    polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
541                    ..c
542                });
543                state.insert_value_idx(place, conds, &self.map);
544            }
545
546            _ => {}
547        }
548    }
549
550    #[instrument(level = "trace", skip(self))]
551    fn process_statement(
552        &mut self,
553        bb: BasicBlock,
554        stmt: &Statement<'tcx>,
555        state: &mut State<ConditionSet<'a>>,
556    ) {
557        let register_opportunity = |c: Condition| {
558            debug!(?bb, ?c.target, "register");
559            self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
560        };
561
562        // Below, `lhs` is the return value of `mutated_statement`,
563        // the place to which `conditions` apply.
564
565        match &stmt.kind {
566            // If we expect `discriminant(place) ?= A`,
567            // we have an opportunity if `variant_index ?= A`.
568            StatementKind::SetDiscriminant { box place, variant_index } => {
569                let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
570                let enum_ty = place.ty(self.body, self.tcx).ty;
571                // `SetDiscriminant` guarantees that the discriminant is now `variant_index`.
572                // Even if the discriminant write does nothing due to niches, it is UB to set the
573                // discriminant when the data does not encode the desired discriminant.
574                let Some(discr) =
575                    self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
576                else {
577                    return;
578                };
579                self.process_immediate(bb, discr_target, discr, state);
580            }
581            // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
582            StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
583                Operand::Copy(place) | Operand::Move(place),
584            )) => {
585                let Some(conditions) = state.try_get(place.as_ref(), &self.map) else { return };
586                conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
587            }
588            StatementKind::Assign(box (lhs_place, rhs)) => {
589                self.process_assign(bb, lhs_place, rhs, state);
590            }
591            _ => {}
592        }
593    }
594
595    #[instrument(level = "trace", skip(self, state, cost))]
596    fn recurse_through_terminator(
597        &mut self,
598        bb: BasicBlock,
599        // Pass a closure that may clone the state, as we don't want to do it each time.
600        state: impl FnOnce() -> State<ConditionSet<'a>>,
601        cost: &CostChecker<'_, 'tcx>,
602        depth: usize,
603    ) {
604        let term = self.body.basic_blocks[bb].terminator();
605        let place_to_flood = match term.kind {
606            // We come from a target, so those are not possible.
607            TerminatorKind::UnwindResume
608            | TerminatorKind::UnwindTerminate(_)
609            | TerminatorKind::Return
610            | TerminatorKind::TailCall { .. }
611            | TerminatorKind::Unreachable
612            | TerminatorKind::CoroutineDrop => bug!("{term:?} has no terminators"),
613            // Disallowed during optimizations.
614            TerminatorKind::FalseEdge { .. }
615            | TerminatorKind::FalseUnwind { .. }
616            | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
617            // Cannot reason about inline asm.
618            TerminatorKind::InlineAsm { .. } => return,
619            // `SwitchInt` is handled specially.
620            TerminatorKind::SwitchInt { .. } => return,
621            // We can recurse, no thing particular to do.
622            TerminatorKind::Goto { .. } => None,
623            // Flood the overwritten place, and progress through.
624            TerminatorKind::Drop { place: destination, .. }
625            | TerminatorKind::Call { destination, .. } => Some(destination),
626            // Ignore, as this can be a no-op at codegen time.
627            TerminatorKind::Assert { .. } => None,
628        };
629
630        // We can recurse through this terminator.
631        let mut state = state();
632        if let Some(place_to_flood) = place_to_flood {
633            state.flood_with(place_to_flood.as_ref(), &self.map, ConditionSet::BOTTOM);
634        }
635        self.find_opportunity(bb, state, cost.clone(), depth + 1);
636    }
637
638    #[instrument(level = "trace", skip(self))]
639    fn process_switch_int(
640        &mut self,
641        discr: &Operand<'tcx>,
642        targets: &SwitchTargets,
643        target_bb: BasicBlock,
644        state: &mut State<ConditionSet<'a>>,
645    ) {
646        debug_assert_ne!(target_bb, START_BLOCK);
647        debug_assert_eq!(self.body.basic_blocks.predecessors()[target_bb].len(), 1);
648
649        let Some(discr) = discr.place() else { return };
650        let discr_ty = discr.ty(self.body, self.tcx).ty;
651        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else {
652            return;
653        };
654        let Some(conditions) = state.try_get(discr.as_ref(), &self.map) else { return };
655
656        if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
657            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
658            debug_assert_eq!(targets.iter().filter(|&(_, target)| target == target_bb).count(), 1);
659
660            // We are inside `target_bb`. Since we have a single predecessor, we know we passed
661            // through the `SwitchInt` before arriving here. Therefore, we know that
662            // `discr == value`. If one condition can be fulfilled by `discr == value`,
663            // that's an opportunity.
664            for c in conditions.iter_matches(value) {
665                debug!(?target_bb, ?c.target, "register");
666                self.opportunities.push(ThreadingOpportunity { chain: vec![], target: c.target });
667            }
668        } else if let Some((value, _, else_bb)) = targets.as_static_if()
669            && target_bb == else_bb
670        {
671            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
672
673            // We only know that `discr != value`. That's much weaker information than
674            // the equality we had in the previous arm. All we can conclude is that
675            // the replacement condition `discr != value` can be threaded, and nothing else.
676            for c in conditions.iter() {
677                if c.value == value && c.polarity == Polarity::Ne {
678                    debug!(?target_bb, ?c.target, "register");
679                    self.opportunities
680                        .push(ThreadingOpportunity { chain: vec![], target: c.target });
681                }
682            }
683        }
684    }
685}
686
687struct OpportunitySet {
688    opportunities: Vec<ThreadingOpportunity>,
689    /// For each bb, give the TOs in which it appears. The pair corresponds to the index
690    /// in `opportunities` and the index in `ThreadingOpportunity::chain`.
691    involving_tos: IndexVec<BasicBlock, Vec<(usize, usize)>>,
692    /// Cache the number of predecessors for each block, as we clear the basic block cache..
693    predecessors: IndexVec<BasicBlock, usize>,
694}
695
696impl OpportunitySet {
697    fn new(body: &Body<'_>, opportunities: Vec<ThreadingOpportunity>) -> OpportunitySet {
698        let mut involving_tos = IndexVec::from_elem(Vec::new(), &body.basic_blocks);
699        for (index, to) in opportunities.iter().enumerate() {
700            for (ibb, &bb) in to.chain.iter().enumerate() {
701                involving_tos[bb].push((index, ibb));
702            }
703            involving_tos[to.target].push((index, to.chain.len()));
704        }
705        let predecessors = predecessor_count(body);
706        OpportunitySet { opportunities, involving_tos, predecessors }
707    }
708
709    /// Apply the opportunities on the graph.
710    fn apply(&mut self, body: &mut Body<'_>) {
711        for i in 0..self.opportunities.len() {
712            self.apply_once(i, body);
713        }
714    }
715
716    #[instrument(level = "trace", skip(self, body))]
717    fn apply_once(&mut self, index: usize, body: &mut Body<'_>) {
718        debug!(?self.predecessors);
719        debug!(?self.involving_tos);
720
721        // Check that `predecessors` satisfies its invariant.
722        debug_assert_eq!(self.predecessors, predecessor_count(body));
723
724        // Remove the TO from the vector to allow modifying the other ones later.
725        let op = &mut self.opportunities[index];
726        debug!(?op);
727        let op_chain = std::mem::take(&mut op.chain);
728        let op_target = op.target;
729        debug_assert_eq!(op_chain.len(), op_chain.iter().collect::<FxHashSet<_>>().len());
730
731        let Some((current, chain)) = op_chain.split_first() else { return };
732        let basic_blocks = body.basic_blocks.as_mut();
733
734        // Invariant: the control-flow is well-formed at the end of each iteration.
735        let mut current = *current;
736        for &succ in chain {
737            debug!(?current, ?succ);
738
739            // `succ` must be a successor of `current`. If it is not, this means this TO is not
740            // satisfiable and a previous TO erased this edge, so we bail out.
741            if !basic_blocks[current].terminator().successors().any(|s| s == succ) {
742                debug!("impossible");
743                return;
744            }
745
746            // Fast path: `succ` is only used once, so we can reuse it directly.
747            if self.predecessors[succ] == 1 {
748                debug!("single");
749                current = succ;
750                continue;
751            }
752
753            let new_succ = basic_blocks.push(basic_blocks[succ].clone());
754            debug!(?new_succ);
755
756            // Replace `succ` by `new_succ` where it appears.
757            let mut num_edges = 0;
758            for s in basic_blocks[current].terminator_mut().successors_mut() {
759                if *s == succ {
760                    *s = new_succ;
761                    num_edges += 1;
762                }
763            }
764
765            // Update predecessors with the new block.
766            let _new_succ = self.predecessors.push(num_edges);
767            debug_assert_eq!(new_succ, _new_succ);
768            self.predecessors[succ] -= num_edges;
769            self.update_predecessor_count(basic_blocks[new_succ].terminator(), Update::Incr);
770
771            // Replace the `current -> succ` edge by `current -> new_succ` in all the following
772            // TOs. This is necessary to avoid trying to thread through a non-existing edge. We
773            // use `involving_tos` here to avoid traversing the full set of TOs on each iteration.
774            let mut new_involved = Vec::new();
775            for &(to_index, in_to_index) in &self.involving_tos[current] {
776                // That TO has already been applied, do nothing.
777                if to_index <= index {
778                    continue;
779                }
780
781                let other_to = &mut self.opportunities[to_index];
782                if other_to.chain.get(in_to_index) != Some(&current) {
783                    continue;
784                }
785                let s = other_to.chain.get_mut(in_to_index + 1).unwrap_or(&mut other_to.target);
786                if *s == succ {
787                    // `other_to` references the `current -> succ` edge, so replace `succ`.
788                    *s = new_succ;
789                    new_involved.push((to_index, in_to_index + 1));
790                }
791            }
792
793            // The TOs that we just updated now reference `new_succ`. Update `involving_tos`
794            // in case we need to duplicate an edge starting at `new_succ` later.
795            let _new_succ = self.involving_tos.push(new_involved);
796            debug_assert_eq!(new_succ, _new_succ);
797
798            current = new_succ;
799        }
800
801        let current = &mut basic_blocks[current];
802        self.update_predecessor_count(current.terminator(), Update::Decr);
803        current.terminator_mut().kind = TerminatorKind::Goto { target: op_target };
804        self.predecessors[op_target] += 1;
805    }
806
807    fn update_predecessor_count(&mut self, terminator: &Terminator<'_>, incr: Update) {
808        match incr {
809            Update::Incr => {
810                for s in terminator.successors() {
811                    self.predecessors[s] += 1;
812                }
813            }
814            Update::Decr => {
815                for s in terminator.successors() {
816                    self.predecessors[s] -= 1;
817                }
818            }
819        }
820    }
821}
822
823fn predecessor_count(body: &Body<'_>) -> IndexVec<BasicBlock, usize> {
824    let mut predecessors: IndexVec<_, _> =
825        body.basic_blocks.predecessors().iter().map(|ps| ps.len()).collect();
826    predecessors[START_BLOCK] += 1; // Account for the implicit entry edge.
827    predecessors
828}
829
830enum Update {
831    Incr,
832    Decr,
833}
834
835/// Compute the set of loop headers in the given body. We define a loop header as a block which has
836/// at least a predecessor which it dominates. This definition is only correct for reducible CFGs.
837/// But if the CFG is already irreducible, there is no point in trying much harder.
838/// is already irreducible.
839fn loop_headers(body: &Body<'_>) -> DenseBitSet<BasicBlock> {
840    let mut loop_headers = DenseBitSet::new_empty(body.basic_blocks.len());
841    let dominators = body.basic_blocks.dominators();
842    // Only visit reachable blocks.
843    for (bb, bbdata) in traversal::preorder(body) {
844        for succ in bbdata.terminator().successors() {
845            if dominators.dominates(succ, bb) {
846                loop_headers.insert(succ);
847            }
848        }
849    }
850    loop_headers
851}