rustc_mir_transform/
jump_threading.rs

1//! A jump threading optimization.
2//!
3//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps
4//!    X = 0                                      X = 0
5//! ------------\      /--------              ------------
6//!    X = 1     X----X SwitchInt(X)     =>       X = 1
7//! ------------/      \--------              ------------
8//!
9//!
10//! We proceed by walking the cfg backwards starting from each `SwitchInt` terminator,
11//! looking for assignments that will turn the `SwitchInt` into a simple `Goto`.
12//!
13//! The algorithm maintains a set of replacement conditions:
14//! - `conditions[place]` contains `Condition { value, polarity: Eq, target }`
15//!   if assigning `value` to `place` turns the `SwitchInt` into `Goto { target }`.
16//! - `conditions[place]` contains `Condition { value, polarity: Ne, target }`
17//!   if assigning anything different from `value` to `place` turns the `SwitchInt`
18//!   into `Goto { target }`.
19//!
20//! In this file, we denote as `place ?= value` the existence of a replacement condition
21//! on `place` with given `value`, irrespective of the polarity and target of that
22//! replacement condition.
23//!
24//! We then walk the CFG backwards transforming the set of conditions.
25//! When we find a fulfilling assignment, we record a `ThreadingOpportunity`.
26//! All `ThreadingOpportunity`s are applied to the body, by duplicating blocks if required.
27//!
28//! The optimization search can be very heavy, as it performs a DFS on MIR starting from
29//! each `SwitchInt` terminator. To manage the complexity, we:
30//! - bound the maximum depth by a constant `MAX_BACKTRACK`;
31//! - we only traverse `Goto` terminators.
32//!
33//! We try to avoid creating irreducible control-flow by not threading through a loop header.
34//!
35//! Likewise, applying the optimisation can create a lot of new MIR, so we bound the instruction
36//! cost by `MAX_COST`.
37
38use rustc_arena::DroplessArena;
39use rustc_const_eval::const_eval::DummyMachine;
40use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
41use rustc_data_structures::fx::FxHashSet;
42use rustc_index::IndexVec;
43use rustc_index::bit_set::DenseBitSet;
44use rustc_middle::bug;
45use rustc_middle::mir::interpret::Scalar;
46use rustc_middle::mir::visit::Visitor;
47use rustc_middle::mir::*;
48use rustc_middle::ty::{self, ScalarInt, TyCtxt};
49use rustc_mir_dataflow::lattice::HasBottom;
50use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
51use rustc_span::DUMMY_SP;
52use tracing::{debug, instrument, trace};
53
54use crate::cost_checker::CostChecker;
55
56pub(super) struct JumpThreading;
57
58const MAX_BACKTRACK: usize = 5;
59const MAX_COST: usize = 100;
60const MAX_PLACES: usize = 100;
61
62impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
63    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
64        sess.mir_opt_level() >= 2
65    }
66
67    #[instrument(skip_all level = "debug")]
68    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
69        let def_id = body.source.def_id();
70        debug!(?def_id);
71
72        // Optimizing coroutines creates query cycles.
73        if tcx.is_coroutine(def_id) {
74            trace!("Skipped for coroutine {:?}", def_id);
75            return;
76        }
77
78        let typing_env = body.typing_env(tcx);
79        let arena = &DroplessArena::default();
80        let mut finder = TOFinder {
81            tcx,
82            typing_env,
83            ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
84            body,
85            arena,
86            map: Map::new(tcx, body, Some(MAX_PLACES)),
87            maybe_loop_headers: loops::maybe_loop_headers(body),
88            opportunities: Vec::new(),
89        };
90
91        for (bb, _) in traversal::preorder(body) {
92            finder.start_from_switch(bb);
93        }
94
95        let opportunities = finder.opportunities;
96        debug!(?opportunities);
97        if opportunities.is_empty() {
98            return;
99        }
100
101        // Verify that we do not thread through a loop header.
102        for to in opportunities.iter() {
103            assert!(to.chain.iter().all(|&block| !finder.maybe_loop_headers.contains(block)));
104        }
105        OpportunitySet::new(body, opportunities).apply(body);
106    }
107
108    fn is_required(&self) -> bool {
109        false
110    }
111}
112
113#[derive(Debug)]
114struct ThreadingOpportunity {
115    /// The list of `BasicBlock`s from the one that found the opportunity to the `SwitchInt`.
116    chain: Vec<BasicBlock>,
117    /// The `SwitchInt` will be replaced by `Goto { target }`.
118    target: BasicBlock,
119}
120
121struct TOFinder<'a, 'tcx> {
122    tcx: TyCtxt<'tcx>,
123    typing_env: ty::TypingEnv<'tcx>,
124    ecx: InterpCx<'tcx, DummyMachine>,
125    body: &'a Body<'tcx>,
126    map: Map<'tcx>,
127    maybe_loop_headers: DenseBitSet<BasicBlock>,
128    /// We use an arena to avoid cloning the slices when cloning `state`.
129    arena: &'a DroplessArena,
130    opportunities: Vec<ThreadingOpportunity>,
131}
132
133/// Represent the following statement. If we can prove that the current local is equal/not-equal
134/// to `value`, jump to `target`.
135#[derive(Copy, Clone, Debug)]
136struct Condition {
137    value: ScalarInt,
138    polarity: Polarity,
139    target: BasicBlock,
140}
141
142#[derive(Copy, Clone, Debug, Eq, PartialEq)]
143enum Polarity {
144    Ne,
145    Eq,
146}
147
148impl Condition {
149    fn matches(&self, value: ScalarInt) -> bool {
150        (self.value == value) == (self.polarity == Polarity::Eq)
151    }
152}
153
154#[derive(Copy, Clone, Debug)]
155struct ConditionSet<'a>(&'a [Condition]);
156
157impl HasBottom for ConditionSet<'_> {
158    const BOTTOM: Self = ConditionSet(&[]);
159
160    fn is_bottom(&self) -> bool {
161        self.0.is_empty()
162    }
163}
164
165impl<'a> ConditionSet<'a> {
166    fn iter(self) -> impl Iterator<Item = Condition> {
167        self.0.iter().copied()
168    }
169
170    fn iter_matches(self, value: ScalarInt) -> impl Iterator<Item = Condition> {
171        self.iter().filter(move |c| c.matches(value))
172    }
173
174    fn map(
175        self,
176        arena: &'a DroplessArena,
177        f: impl Fn(Condition) -> Option<Condition>,
178    ) -> Option<ConditionSet<'a>> {
179        let set = arena.try_alloc_from_iter(self.iter().map(|c| f(c).ok_or(()))).ok()?;
180        Some(ConditionSet(set))
181    }
182}
183
184impl<'a, 'tcx> TOFinder<'a, 'tcx> {
185    fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool {
186        state.all_bottom()
187    }
188
189    /// Recursion entry point to find threading opportunities.
190    #[instrument(level = "trace", skip(self))]
191    fn start_from_switch(&mut self, bb: BasicBlock) {
192        let bbdata = &self.body[bb];
193        if bbdata.is_cleanup || self.maybe_loop_headers.contains(bb) {
194            return;
195        }
196        let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return };
197        let Some(discr) = discr.place() else { return };
198        debug!(?discr, ?bb);
199
200        let discr_ty = discr.ty(self.body, self.tcx).ty;
201        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
202
203        let Some(discr) = self.map.find(discr.as_ref()) else { return };
204        debug!(?discr);
205
206        let cost = CostChecker::new(self.tcx, self.typing_env, None, self.body);
207        let mut state = State::new_reachable();
208
209        let conds = if let Some((value, then, else_)) = targets.as_static_if() {
210            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
211            self.arena.alloc_from_iter([
212                Condition { value, polarity: Polarity::Eq, target: then },
213                Condition { value, polarity: Polarity::Ne, target: else_ },
214            ])
215        } else {
216            self.arena.alloc_from_iter(targets.iter().filter_map(|(value, target)| {
217                let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
218                Some(Condition { value, polarity: Polarity::Eq, target })
219            }))
220        };
221        let conds = ConditionSet(conds);
222        state.insert_value_idx(discr, conds, &self.map);
223
224        self.find_opportunity(bb, state, cost, 0)
225    }
226
227    /// Recursively walk statements backwards from this bb's terminator to find threading
228    /// opportunities.
229    #[instrument(level = "trace", skip(self, cost), ret)]
230    fn find_opportunity(
231        &mut self,
232        bb: BasicBlock,
233        mut state: State<ConditionSet<'a>>,
234        mut cost: CostChecker<'_, 'tcx>,
235        depth: usize,
236    ) {
237        // Do not thread through loop headers.
238        if self.maybe_loop_headers.contains(bb) {
239            return;
240        }
241
242        debug!(cost = ?cost.cost());
243        for (statement_index, stmt) in
244            self.body.basic_blocks[bb].statements.iter().enumerate().rev()
245        {
246            if self.is_empty(&state) {
247                return;
248            }
249
250            cost.visit_statement(stmt, Location { block: bb, statement_index });
251            if cost.cost() > MAX_COST {
252                return;
253            }
254
255            // Attempt to turn the `current_condition` on `lhs` into a condition on another place.
256            self.process_statement(bb, stmt, &mut state);
257
258            // When a statement mutates a place, assignments to that place that happen
259            // above the mutation cannot fulfill a condition.
260            //   _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
261            //   _1 = 6
262            if let Some((lhs, tail)) = self.mutated_statement(stmt) {
263                state.flood_with_tail_elem(lhs.as_ref(), tail, &self.map, ConditionSet::BOTTOM);
264            }
265        }
266
267        if self.is_empty(&state) || depth >= MAX_BACKTRACK {
268            return;
269        }
270
271        let last_non_rec = self.opportunities.len();
272
273        let predecessors = &self.body.basic_blocks.predecessors()[bb];
274        if let &[pred] = &predecessors[..]
275            && bb != START_BLOCK
276        {
277            let term = self.body.basic_blocks[pred].terminator();
278            match term.kind {
279                TerminatorKind::SwitchInt { ref discr, ref targets } => {
280                    self.process_switch_int(discr, targets, bb, &mut state);
281                    self.find_opportunity(pred, state, cost, depth + 1);
282                }
283                _ => self.recurse_through_terminator(pred, || state, &cost, depth),
284            }
285        } else if let &[ref predecessors @ .., last_pred] = &predecessors[..] {
286            for &pred in predecessors {
287                self.recurse_through_terminator(pred, || state.clone(), &cost, depth);
288            }
289            self.recurse_through_terminator(last_pred, || state, &cost, depth);
290        }
291
292        let new_tos = &mut self.opportunities[last_non_rec..];
293        debug!(?new_tos);
294
295        // Try to deduplicate threading opportunities.
296        if new_tos.len() > 1
297            && new_tos.len() == predecessors.len()
298            && predecessors
299                .iter()
300                .zip(new_tos.iter())
301                .all(|(&pred, to)| to.chain == &[pred] && to.target == new_tos[0].target)
302        {
303            // All predecessors have a threading opportunity, and they all point to the same block.
304            debug!(?new_tos, "dedup");
305            let first = &mut new_tos[0];
306            *first = ThreadingOpportunity { chain: vec![bb], target: first.target };
307            self.opportunities.truncate(last_non_rec + 1);
308            return;
309        }
310
311        for op in self.opportunities[last_non_rec..].iter_mut() {
312            op.chain.push(bb);
313        }
314    }
315
316    /// Extract the mutated place from a statement.
317    ///
318    /// This method returns the `Place` so we can flood the state in case of a partial assignment.
319    ///     (_1 as Ok).0 = _5;
320    ///     (_1 as Err).0 = _6;
321    /// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as
322    /// the value may have been mangled by the second assignment.
323    ///
324    /// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can
325    /// stop at flooding the discriminant, and preserve the variant fields.
326    ///     (_1 as Some).0 = _6;
327    ///     SetDiscriminant(_1, 1);
328    ///     switchInt((_1 as Some).0)
329    #[instrument(level = "trace", skip(self), ret)]
330    fn mutated_statement(
331        &self,
332        stmt: &Statement<'tcx>,
333    ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
334        match stmt.kind {
335            StatementKind::Assign(box (place, _)) => Some((place, None)),
336            StatementKind::SetDiscriminant { box place, variant_index: _ } => {
337                Some((place, Some(TrackElem::Discriminant)))
338            }
339            StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
340                Some((Place::from(local), None))
341            }
342            StatementKind::Retag(..)
343            | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
344            // copy_nonoverlapping takes pointers and mutated the pointed-to value.
345            | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
346            | StatementKind::AscribeUserType(..)
347            | StatementKind::Coverage(..)
348            | StatementKind::FakeRead(..)
349            | StatementKind::ConstEvalCounter
350            | StatementKind::PlaceMention(..)
351            | StatementKind::BackwardIncompatibleDropHint { .. }
352            | StatementKind::Nop => None,
353        }
354    }
355
356    #[instrument(level = "trace", skip(self))]
357    fn process_immediate(
358        &mut self,
359        bb: BasicBlock,
360        lhs: PlaceIndex,
361        rhs: ImmTy<'tcx>,
362        state: &mut State<ConditionSet<'a>>,
363    ) {
364        let register_opportunity = |c: Condition| {
365            debug!(?bb, ?c.target, "register");
366            self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
367        };
368
369        if let Some(conditions) = state.try_get_idx(lhs, &self.map)
370            && let Immediate::Scalar(Scalar::Int(int)) = *rhs
371        {
372            conditions.iter_matches(int).for_each(register_opportunity);
373        }
374    }
375
376    /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
377    #[instrument(level = "trace", skip(self))]
378    fn process_constant(
379        &mut self,
380        bb: BasicBlock,
381        lhs: PlaceIndex,
382        constant: OpTy<'tcx>,
383        state: &mut State<ConditionSet<'a>>,
384    ) {
385        self.map.for_each_projection_value(
386            lhs,
387            constant,
388            &mut |elem, op| match elem {
389                TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(),
390                TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
391                TrackElem::Discriminant => {
392                    let variant = self.ecx.read_discriminant(op).discard_err()?;
393                    let discr_value =
394                        self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
395                    Some(discr_value.into())
396                }
397                TrackElem::DerefLen => {
398                    let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
399                    let len_usize = op.len(&self.ecx).discard_err()?;
400                    let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
401                    Some(ImmTy::from_uint(len_usize, layout).into())
402                }
403            },
404            &mut |place, op| {
405                if let Some(conditions) = state.try_get_idx(place, &self.map)
406                    && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
407                    && let Some(imm) = imm.right()
408                    && let Immediate::Scalar(Scalar::Int(int)) = *imm
409                {
410                    conditions.iter_matches(int).for_each(|c: Condition| {
411                        self.opportunities
412                            .push(ThreadingOpportunity { chain: vec![bb], target: c.target })
413                    })
414                }
415            },
416        );
417    }
418
419    #[instrument(level = "trace", skip(self))]
420    fn process_operand(
421        &mut self,
422        bb: BasicBlock,
423        lhs: PlaceIndex,
424        rhs: &Operand<'tcx>,
425        state: &mut State<ConditionSet<'a>>,
426    ) {
427        match rhs {
428            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
429            Operand::Constant(constant) => {
430                let Some(constant) =
431                    self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
432                else {
433                    return;
434                };
435                self.process_constant(bb, lhs, constant, state);
436            }
437            // Transfer the conditions on the copied rhs.
438            Operand::Move(rhs) | Operand::Copy(rhs) => {
439                let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
440                state.insert_place_idx(rhs, lhs, &self.map);
441            }
442        }
443    }
444
445    #[instrument(level = "trace", skip(self))]
446    fn process_assign(
447        &mut self,
448        bb: BasicBlock,
449        lhs_place: &Place<'tcx>,
450        rhs: &Rvalue<'tcx>,
451        state: &mut State<ConditionSet<'a>>,
452    ) {
453        let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
454        match rhs {
455            Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state),
456            // Transfer the conditions on the copy rhs.
457            Rvalue::Discriminant(rhs) => {
458                let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
459                state.insert_place_idx(rhs, lhs, &self.map);
460            }
461            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
462            Rvalue::Aggregate(box kind, operands) => {
463                let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
464                let lhs = match kind {
465                    // Do not support unions.
466                    AggregateKind::Adt(.., Some(_)) => return,
467                    AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
468                        if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
469                            && let Some(discr_value) = self
470                                .ecx
471                                .discriminant_for_variant(agg_ty, *variant_index)
472                                .discard_err()
473                        {
474                            self.process_immediate(bb, discr_target, discr_value, state);
475                        }
476                        if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
477                            idx
478                        } else {
479                            return;
480                        }
481                    }
482                    _ => lhs,
483                };
484                for (field_index, operand) in operands.iter_enumerated() {
485                    if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
486                        self.process_operand(bb, field, operand, state);
487                    }
488                }
489            }
490            // Transfer the conditions on the copy rhs, after inverting the value of the condition.
491            Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
492                let layout = self.ecx.layout_of(place.ty(self.body, self.tcx).ty).unwrap();
493                let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
494                let Some(place) = self.map.find(place.as_ref()) else { return };
495                let Some(conds) = conditions.map(self.arena, |mut cond| {
496                    cond.value = self
497                        .ecx
498                        .unary_op(UnOp::Not, &ImmTy::from_scalar_int(cond.value, layout))
499                        .discard_err()?
500                        .to_scalar_int()
501                        .discard_err()?;
502                    Some(cond)
503                }) else {
504                    return;
505                };
506                state.insert_value_idx(place, conds, &self.map);
507            }
508            // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
509            // Create a condition on `rhs ?= B`.
510            Rvalue::BinaryOp(
511                op,
512                box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
513                | box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
514            ) => {
515                let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
516                let Some(place) = self.map.find(place.as_ref()) else { return };
517                let equals = match op {
518                    BinOp::Eq => ScalarInt::TRUE,
519                    BinOp::Ne => ScalarInt::FALSE,
520                    _ => return,
521                };
522                if value.const_.ty().is_floating_point() {
523                    // Floating point equality does not follow bit-patterns.
524                    // -0.0 and NaN both have special rules for equality,
525                    // and therefore we cannot use integer comparisons for them.
526                    // Avoid handling them, though this could be extended in the future.
527                    return;
528                }
529                let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
530                else {
531                    return;
532                };
533                let Some(conds) = conditions.map(self.arena, |c| {
534                    Some(Condition {
535                        value,
536                        polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
537                        ..c
538                    })
539                }) else {
540                    return;
541                };
542                state.insert_value_idx(place, conds, &self.map);
543            }
544
545            _ => {}
546        }
547    }
548
549    #[instrument(level = "trace", skip(self))]
550    fn process_statement(
551        &mut self,
552        bb: BasicBlock,
553        stmt: &Statement<'tcx>,
554        state: &mut State<ConditionSet<'a>>,
555    ) {
556        let register_opportunity = |c: Condition| {
557            debug!(?bb, ?c.target, "register");
558            self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
559        };
560
561        // Below, `lhs` is the return value of `mutated_statement`,
562        // the place to which `conditions` apply.
563
564        match &stmt.kind {
565            // If we expect `discriminant(place) ?= A`,
566            // we have an opportunity if `variant_index ?= A`.
567            StatementKind::SetDiscriminant { box place, variant_index } => {
568                let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
569                let enum_ty = place.ty(self.body, self.tcx).ty;
570                // `SetDiscriminant` guarantees that the discriminant is now `variant_index`.
571                // Even if the discriminant write does nothing due to niches, it is UB to set the
572                // discriminant when the data does not encode the desired discriminant.
573                let Some(discr) =
574                    self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
575                else {
576                    return;
577                };
578                self.process_immediate(bb, discr_target, discr, state)
579            }
580            // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
581            StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
582                Operand::Copy(place) | Operand::Move(place),
583            )) => {
584                let Some(conditions) = state.try_get(place.as_ref(), &self.map) else { return };
585                conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity)
586            }
587            StatementKind::Assign(box (lhs_place, rhs)) => {
588                self.process_assign(bb, lhs_place, rhs, state)
589            }
590            _ => {}
591        }
592    }
593
594    #[instrument(level = "trace", skip(self, state, cost))]
595    fn recurse_through_terminator(
596        &mut self,
597        bb: BasicBlock,
598        // Pass a closure that may clone the state, as we don't want to do it each time.
599        state: impl FnOnce() -> State<ConditionSet<'a>>,
600        cost: &CostChecker<'_, 'tcx>,
601        depth: usize,
602    ) {
603        let term = self.body.basic_blocks[bb].terminator();
604        let place_to_flood = match term.kind {
605            // We come from a target, so those are not possible.
606            TerminatorKind::UnwindResume
607            | TerminatorKind::UnwindTerminate(_)
608            | TerminatorKind::Return
609            | TerminatorKind::TailCall { .. }
610            | TerminatorKind::Unreachable
611            | TerminatorKind::CoroutineDrop => bug!("{term:?} has no terminators"),
612            // Disallowed during optimizations.
613            TerminatorKind::FalseEdge { .. }
614            | TerminatorKind::FalseUnwind { .. }
615            | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
616            // Cannot reason about inline asm.
617            TerminatorKind::InlineAsm { .. } => return,
618            // `SwitchInt` is handled specially.
619            TerminatorKind::SwitchInt { .. } => return,
620            // We can recurse, no thing particular to do.
621            TerminatorKind::Goto { .. } => None,
622            // Flood the overwritten place, and progress through.
623            TerminatorKind::Drop { place: destination, .. }
624            | TerminatorKind::Call { destination, .. } => Some(destination),
625            // Ignore, as this can be a no-op at codegen time.
626            TerminatorKind::Assert { .. } => None,
627        };
628
629        // We can recurse through this terminator.
630        let mut state = state();
631        if let Some(place_to_flood) = place_to_flood {
632            state.flood_with(place_to_flood.as_ref(), &self.map, ConditionSet::BOTTOM);
633        }
634        self.find_opportunity(bb, state, cost.clone(), depth + 1)
635    }
636
637    #[instrument(level = "trace", skip(self))]
638    fn process_switch_int(
639        &mut self,
640        discr: &Operand<'tcx>,
641        targets: &SwitchTargets,
642        target_bb: BasicBlock,
643        state: &mut State<ConditionSet<'a>>,
644    ) {
645        debug_assert_ne!(target_bb, START_BLOCK);
646        debug_assert_eq!(self.body.basic_blocks.predecessors()[target_bb].len(), 1);
647
648        let Some(discr) = discr.place() else { return };
649        let discr_ty = discr.ty(self.body, self.tcx).ty;
650        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else {
651            return;
652        };
653        let Some(conditions) = state.try_get(discr.as_ref(), &self.map) else { return };
654
655        if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
656            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
657            debug_assert_eq!(targets.iter().filter(|&(_, target)| target == target_bb).count(), 1);
658
659            // We are inside `target_bb`. Since we have a single predecessor, we know we passed
660            // through the `SwitchInt` before arriving here. Therefore, we know that
661            // `discr == value`. If one condition can be fulfilled by `discr == value`,
662            // that's an opportunity.
663            for c in conditions.iter_matches(value) {
664                debug!(?target_bb, ?c.target, "register");
665                self.opportunities.push(ThreadingOpportunity { chain: vec![], target: c.target });
666            }
667        } else if let Some((value, _, else_bb)) = targets.as_static_if()
668            && target_bb == else_bb
669        {
670            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
671
672            // We only know that `discr != value`. That's much weaker information than
673            // the equality we had in the previous arm. All we can conclude is that
674            // the replacement condition `discr != value` can be threaded, and nothing else.
675            for c in conditions.iter() {
676                if c.value == value && c.polarity == Polarity::Ne {
677                    debug!(?target_bb, ?c.target, "register");
678                    self.opportunities
679                        .push(ThreadingOpportunity { chain: vec![], target: c.target });
680                }
681            }
682        }
683    }
684}
685
686struct OpportunitySet {
687    opportunities: Vec<ThreadingOpportunity>,
688    /// For each bb, give the TOs in which it appears. The pair corresponds to the index
689    /// in `opportunities` and the index in `ThreadingOpportunity::chain`.
690    involving_tos: IndexVec<BasicBlock, Vec<(usize, usize)>>,
691    /// Cache the number of predecessors for each block, as we clear the basic block cache..
692    predecessors: IndexVec<BasicBlock, usize>,
693}
694
695impl OpportunitySet {
696    fn new(body: &Body<'_>, opportunities: Vec<ThreadingOpportunity>) -> OpportunitySet {
697        let mut involving_tos = IndexVec::from_elem(Vec::new(), &body.basic_blocks);
698        for (index, to) in opportunities.iter().enumerate() {
699            for (ibb, &bb) in to.chain.iter().enumerate() {
700                involving_tos[bb].push((index, ibb));
701            }
702            involving_tos[to.target].push((index, to.chain.len()));
703        }
704        let predecessors = predecessor_count(body);
705        OpportunitySet { opportunities, involving_tos, predecessors }
706    }
707
708    /// Apply the opportunities on the graph.
709    fn apply(&mut self, body: &mut Body<'_>) {
710        for i in 0..self.opportunities.len() {
711            self.apply_once(i, body);
712        }
713    }
714
715    #[instrument(level = "trace", skip(self, body))]
716    fn apply_once(&mut self, index: usize, body: &mut Body<'_>) {
717        debug!(?self.predecessors);
718        debug!(?self.involving_tos);
719
720        // Check that `predecessors` satisfies its invariant.
721        debug_assert_eq!(self.predecessors, predecessor_count(body));
722
723        // Remove the TO from the vector to allow modifying the other ones later.
724        let op = &mut self.opportunities[index];
725        debug!(?op);
726        let op_chain = std::mem::take(&mut op.chain);
727        let op_target = op.target;
728        debug_assert_eq!(op_chain.len(), op_chain.iter().collect::<FxHashSet<_>>().len());
729
730        let Some((current, chain)) = op_chain.split_first() else { return };
731        let basic_blocks = body.basic_blocks.as_mut();
732
733        // Invariant: the control-flow is well-formed at the end of each iteration.
734        let mut current = *current;
735        for &succ in chain {
736            debug!(?current, ?succ);
737
738            // `succ` must be a successor of `current`. If it is not, this means this TO is not
739            // satisfiable and a previous TO erased this edge, so we bail out.
740            if !basic_blocks[current].terminator().successors().any(|s| s == succ) {
741                debug!("impossible");
742                return;
743            }
744
745            // Fast path: `succ` is only used once, so we can reuse it directly.
746            if self.predecessors[succ] == 1 {
747                debug!("single");
748                current = succ;
749                continue;
750            }
751
752            let new_succ = basic_blocks.push(basic_blocks[succ].clone());
753            debug!(?new_succ);
754
755            // Replace `succ` by `new_succ` where it appears.
756            let mut num_edges = 0;
757            basic_blocks[current].terminator_mut().successors_mut(|s| {
758                if *s == succ {
759                    *s = new_succ;
760                    num_edges += 1;
761                }
762            });
763
764            // Update predecessors with the new block.
765            let _new_succ = self.predecessors.push(num_edges);
766            debug_assert_eq!(new_succ, _new_succ);
767            self.predecessors[succ] -= num_edges;
768            self.update_predecessor_count(basic_blocks[new_succ].terminator(), Update::Incr);
769
770            // Replace the `current -> succ` edge by `current -> new_succ` in all the following
771            // TOs. This is necessary to avoid trying to thread through a non-existing edge. We
772            // use `involving_tos` here to avoid traversing the full set of TOs on each iteration.
773            let mut new_involved = Vec::new();
774            for &(to_index, in_to_index) in &self.involving_tos[current] {
775                // That TO has already been applied, do nothing.
776                if to_index <= index {
777                    continue;
778                }
779
780                let other_to = &mut self.opportunities[to_index];
781                if other_to.chain.get(in_to_index) != Some(&current) {
782                    continue;
783                }
784                let s = other_to.chain.get_mut(in_to_index + 1).unwrap_or(&mut other_to.target);
785                if *s == succ {
786                    // `other_to` references the `current -> succ` edge, so replace `succ`.
787                    *s = new_succ;
788                    new_involved.push((to_index, in_to_index + 1));
789                }
790            }
791
792            // The TOs that we just updated now reference `new_succ`. Update `involving_tos`
793            // in case we need to duplicate an edge starting at `new_succ` later.
794            let _new_succ = self.involving_tos.push(new_involved);
795            debug_assert_eq!(new_succ, _new_succ);
796
797            current = new_succ;
798        }
799
800        let current = &mut basic_blocks[current];
801        self.update_predecessor_count(current.terminator(), Update::Decr);
802        current.terminator_mut().kind = TerminatorKind::Goto { target: op_target };
803        self.predecessors[op_target] += 1;
804    }
805
806    fn update_predecessor_count(&mut self, terminator: &Terminator<'_>, incr: Update) {
807        match incr {
808            Update::Incr => {
809                for s in terminator.successors() {
810                    self.predecessors[s] += 1;
811                }
812            }
813            Update::Decr => {
814                for s in terminator.successors() {
815                    self.predecessors[s] -= 1;
816                }
817            }
818        }
819    }
820}
821
822fn predecessor_count(body: &Body<'_>) -> IndexVec<BasicBlock, usize> {
823    let mut predecessors: IndexVec<_, _> =
824        body.basic_blocks.predecessors().iter().map(|ps| ps.len()).collect();
825    predecessors[START_BLOCK] += 1; // Account for the implicit entry edge.
826    predecessors
827}
828
829enum Update {
830    Incr,
831    Decr,
832}