rustc_mir_transform/
jump_threading.rs

1//! A jump threading optimization.
2//!
3//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps
4//!    X = 0                                      X = 0
5//! ------------\      /--------              ------------
6//!    X = 1     X----X SwitchInt(X)     =>       X = 1
7//! ------------/      \--------              ------------
8//!
9//!
10//! This implementation is heavily inspired by the work outlined in [libfirm].
11//!
12//! The general algorithm proceeds in two phases: (1) walk the CFG backwards to construct a
13//! graph of threading conditions, and (2) propagate fulfilled conditions forward by duplicating
14//! blocks.
15//!
16//! # 1. Condition graph construction
17//!
18//! In this file, we denote as `place ?= value` the existence of a replacement condition
19//! on `place` with given `value`, irrespective of the polarity and target of that
20//! replacement condition.
21//!
22//! Inside a block, we associate with each condition `c` a set of targets:
23//! - `Goto(target)` if fulfilling `c` changes the terminator into a `Goto { target }`;
24//! - `Chain(target, c2)` if fulfilling `c` means that `c2` is fulfilled inside `target`.
25//!
26//! Before walking a block `bb`, we construct the exit set of condition from its successors.
27//! For each condition `c` in a successor `s`, we record that fulfilling `c` in `bb` will fulfill
28//! `c` in `s`, as a `Chain(s, c)` condition.
29//!
30//! When encountering a `switchInt(place) -> [value: bb...]` terminator, we also record a
31//! `place == value` condition for each `value`, and associate a `Goto(target)` condition.
32//!
33//! Then, we walk the statements backwards, transforming the set of conditions along the way,
34//! resulting in a set of conditions at the block entry.
35//!
36//! We try to avoid creating irreducible control-flow by not threading through a loop header.
37//!
38//! Applying the optimisation can create a lot of new MIR, so we bound the instruction
39//! cost by `MAX_COST`.
40//!
41//! # 2. Block duplication
42//!
43//! We now have the set of fulfilled conditions inside each block and their targets.
44//!
45//! For each block `bb` in reverse postorder, we apply in turn the target associated with each
46//! fulfilled condition:
47//! - for `Goto(target)`, change the terminator of `bb` into a `Goto { target }`;
48//! - for `Chain(target, cond)`, duplicate `target` into a new block which fulfills the same
49//! conditions and also fulfills `cond`. This is made efficient by maintaining a map of duplicates,
50//! `duplicate[(target, cond)]` to avoid cloning blocks multiple times.
51//!
52//! [libfirm]: <https://pp.ipd.kit.edu/uploads/publikationen/priesner17masterarbeit.pdf>
53
54use itertools::Itertools as _;
55use rustc_const_eval::const_eval::DummyMachine;
56use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
57use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
58use rustc_index::IndexVec;
59use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
60use rustc_middle::bug;
61use rustc_middle::mir::interpret::Scalar;
62use rustc_middle::mir::visit::Visitor;
63use rustc_middle::mir::*;
64use rustc_middle::ty::{self, ScalarInt, TyCtxt};
65use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, TrackElem, ValueIndex};
66use rustc_span::DUMMY_SP;
67use tracing::{debug, instrument, trace};
68
69use crate::cost_checker::CostChecker;
70
71pub(super) struct JumpThreading;
72
73const MAX_COST: u8 = 100;
74const MAX_PLACES: usize = 100;
75
76impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
77    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
78        sess.mir_opt_level() >= 2
79    }
80
81    #[instrument(skip_all level = "debug")]
82    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
83        let def_id = body.source.def_id();
84        debug!(?def_id);
85
86        // Optimizing coroutines creates query cycles.
87        if tcx.is_coroutine(def_id) {
88            trace!("Skipped for coroutine {:?}", def_id);
89            return;
90        }
91
92        let typing_env = body.typing_env(tcx);
93        let mut finder = TOFinder {
94            tcx,
95            typing_env,
96            ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
97            body,
98            map: Map::new(tcx, body, Some(MAX_PLACES)),
99            maybe_loop_headers: loops::maybe_loop_headers(body),
100            entry_states: IndexVec::from_elem(ConditionSet::default(), &body.basic_blocks),
101        };
102
103        for (bb, bbdata) in traversal::postorder(body) {
104            if bbdata.is_cleanup {
105                continue;
106            }
107
108            let mut state = finder.populate_from_outgoing_edges(bb);
109            trace!("output_states[{bb:?}] = {state:?}");
110
111            finder.process_terminator(bb, &mut state);
112            trace!("pre_terminator_states[{bb:?}] = {state:?}");
113
114            for stmt in bbdata.statements.iter().rev() {
115                if state.is_empty() {
116                    break;
117                }
118
119                finder.process_statement(stmt, &mut state);
120
121                // When a statement mutates a place, assignments to that place that happen
122                // above the mutation cannot fulfill a condition.
123                //   _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
124                //   _1 = 6
125                if let Some((lhs, tail)) = finder.mutated_statement(stmt) {
126                    finder.flood_state(lhs, tail, &mut state);
127                }
128            }
129
130            trace!("entry_states[{bb:?}] = {state:?}");
131            finder.entry_states[bb] = state;
132        }
133
134        let mut entry_states = finder.entry_states;
135        simplify_conditions(body, &mut entry_states);
136        remove_costly_conditions(tcx, typing_env, body, &mut entry_states);
137
138        if let Some(opportunities) = OpportunitySet::new(body, entry_states) {
139            opportunities.apply();
140        }
141    }
142
143    fn is_required(&self) -> bool {
144        false
145    }
146}
147
148struct TOFinder<'a, 'tcx> {
149    tcx: TyCtxt<'tcx>,
150    typing_env: ty::TypingEnv<'tcx>,
151    ecx: InterpCx<'tcx, DummyMachine>,
152    body: &'a Body<'tcx>,
153    map: Map<'tcx>,
154    maybe_loop_headers: DenseBitSet<BasicBlock>,
155    /// This stores the state of each visited block on entry,
156    /// and the current state of the block being visited.
157    // Invariant: for each `bb`, each condition in `entry_states[bb]` has a `chain` that
158    // starts with `bb`.
159    entry_states: IndexVec<BasicBlock, ConditionSet>,
160}
161
162rustc_index::newtype_index! {
163    #[derive(Ord, PartialOrd)]
164    #[debug_format = "_c{}"]
165    struct ConditionIndex {}
166}
167
168/// Represent the following statement. If we can prove that the current local is equal/not-equal
169/// to `value`, jump to `target`.
170#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
171struct Condition {
172    place: ValueIndex,
173    value: ScalarInt,
174    polarity: Polarity,
175}
176
177#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
178enum Polarity {
179    Ne,
180    Eq,
181}
182
183impl Condition {
184    fn matches(&self, place: ValueIndex, value: ScalarInt) -> bool {
185        self.place == place && (self.value == value) == (self.polarity == Polarity::Eq)
186    }
187}
188
189/// Represent the effect of fulfilling a condition.
190#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
191enum EdgeEffect {
192    /// If the condition is fulfilled, replace the current block's terminator by a single goto.
193    Goto { target: BasicBlock },
194    /// If the condition is fulfilled, fulfill the condition `succ_condition` in `succ_block`.
195    Chain { succ_block: BasicBlock, succ_condition: ConditionIndex },
196}
197
198impl EdgeEffect {
199    fn block(self) -> BasicBlock {
200        match self {
201            EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => bb,
202        }
203    }
204
205    fn replace_block(&mut self, target: BasicBlock, new_target: BasicBlock) {
206        match self {
207            EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => {
208                if *bb == target {
209                    *bb = new_target
210                }
211            }
212        }
213    }
214}
215
216#[derive(Clone, Debug, Default)]
217struct ConditionSet {
218    active: Vec<(ConditionIndex, Condition)>,
219    fulfilled: Vec<ConditionIndex>,
220    targets: IndexVec<ConditionIndex, Vec<EdgeEffect>>,
221}
222
223impl ConditionSet {
224    fn is_empty(&self) -> bool {
225        self.active.is_empty()
226    }
227
228    #[tracing::instrument(level = "trace", skip(self))]
229    fn push_condition(&mut self, c: Condition, target: BasicBlock) {
230        let index = self.targets.push(vec![EdgeEffect::Goto { target }]);
231        self.active.push((index, c));
232    }
233
234    /// Register fulfilled condition and remove it from the set.
235    fn fulfill_if(&mut self, f: impl Fn(Condition, &Vec<EdgeEffect>) -> bool) {
236        self.active.retain(|&(index, condition)| {
237            let targets = &self.targets[index];
238            if f(condition, targets) {
239                trace!(?index, ?condition, "fulfill");
240                self.fulfilled.push(index);
241                false
242            } else {
243                true
244            }
245        })
246    }
247
248    /// Register fulfilled condition and remove them from the set.
249    fn fulfill_matches(&mut self, place: ValueIndex, value: ScalarInt) {
250        self.fulfill_if(|c, _| c.matches(place, value))
251    }
252
253    fn retain(&mut self, mut f: impl FnMut(Condition) -> bool) {
254        self.active.retain(|&(_, c)| f(c))
255    }
256
257    fn retain_mut(&mut self, mut f: impl FnMut(Condition) -> Option<Condition>) {
258        self.active.retain_mut(|(_, c)| {
259            if let Some(new) = f(*c) {
260                *c = new;
261                true
262            } else {
263                false
264            }
265        })
266    }
267
268    fn for_each_mut(&mut self, f: impl Fn(&mut Condition)) {
269        for (_, c) in &mut self.active {
270            f(c)
271        }
272    }
273}
274
275impl<'a, 'tcx> TOFinder<'a, 'tcx> {
276    /// Construct the condition set for `bb` from the terminator, without executing its effect.
277    #[instrument(level = "trace", skip(self))]
278    fn populate_from_outgoing_edges(&mut self, bb: BasicBlock) -> ConditionSet {
279        let bbdata = &self.body[bb];
280
281        // This should be the first time we populate `entry_states[bb]`.
282        debug_assert!(self.entry_states[bb].is_empty());
283
284        let state_len =
285            bbdata.terminator().successors().map(|succ| self.entry_states[succ].active.len()).sum();
286        let mut state = ConditionSet {
287            active: Vec::with_capacity(state_len),
288            targets: IndexVec::with_capacity(state_len),
289            fulfilled: Vec::new(),
290        };
291
292        // Use an index-set to deduplicate conditions coming from different successor blocks.
293        let mut known_conditions =
294            FxIndexSet::with_capacity_and_hasher(state_len, Default::default());
295        let mut insert = |condition, succ_block, succ_condition| {
296            let (index, new) = known_conditions.insert_full(condition);
297            let index = ConditionIndex::from_usize(index);
298            if new {
299                state.active.push((index, condition));
300                let _index = state.targets.push(Vec::new());
301                debug_assert_eq!(_index, index);
302            }
303            let target = EdgeEffect::Chain { succ_block, succ_condition };
304            debug_assert!(
305                !state.targets[index].contains(&target),
306                "duplicate targets for index={index:?} as {target:?} targets={:#?}",
307                &state.targets[index],
308            );
309            state.targets[index].push(target);
310        };
311
312        // A given block may have several times the same successor.
313        let mut seen = FxHashSet::default();
314        for succ in bbdata.terminator().successors() {
315            if !seen.insert(succ) {
316                continue;
317            }
318
319            // Do not thread through loop headers.
320            if self.maybe_loop_headers.contains(succ) {
321                continue;
322            }
323
324            for &(succ_index, cond) in self.entry_states[succ].active.iter() {
325                insert(cond, succ, succ_index);
326            }
327        }
328
329        let num_conditions = known_conditions.len();
330        debug_assert_eq!(num_conditions, state.active.len());
331        debug_assert_eq!(num_conditions, state.targets.len());
332        state.fulfilled.reserve(num_conditions);
333
334        state
335    }
336
337    /// Remove all conditions in the state that alias given place.
338    fn flood_state(
339        &self,
340        place: Place<'tcx>,
341        extra_elem: Option<TrackElem>,
342        state: &mut ConditionSet,
343    ) {
344        if state.is_empty() {
345            return;
346        }
347        let mut places_to_exclude = FxHashSet::default();
348        self.map.for_each_aliasing_place(place.as_ref(), extra_elem, &mut |vi| {
349            places_to_exclude.insert(vi);
350        });
351        trace!(?places_to_exclude, "flood_state");
352        if places_to_exclude.is_empty() {
353            return;
354        }
355        state.retain(|c| !places_to_exclude.contains(&c.place));
356    }
357
358    /// Extract the mutated place from a statement.
359    ///
360    /// This method returns the `Place` so we can flood the state in case of a partial assignment.
361    ///     (_1 as Ok).0 = _5;
362    ///     (_1 as Err).0 = _6;
363    /// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as
364    /// the value may have been mangled by the second assignment.
365    ///
366    /// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can
367    /// stop at flooding the discriminant, and preserve the variant fields.
368    ///     (_1 as Some).0 = _6;
369    ///     SetDiscriminant(_1, 1);
370    ///     switchInt((_1 as Some).0)
371    #[instrument(level = "trace", skip(self), ret)]
372    fn mutated_statement(
373        &self,
374        stmt: &Statement<'tcx>,
375    ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
376        match stmt.kind {
377            StatementKind::Assign(box (place, _)) => Some((place, None)),
378            StatementKind::SetDiscriminant { box place, variant_index: _ } => {
379                Some((place, Some(TrackElem::Discriminant)))
380            }
381            StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
382                Some((Place::from(local), None))
383            }
384            StatementKind::Retag(..)
385            | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
386            // copy_nonoverlapping takes pointers and mutated the pointed-to value.
387            | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
388            | StatementKind::AscribeUserType(..)
389            | StatementKind::Coverage(..)
390            | StatementKind::FakeRead(..)
391            | StatementKind::ConstEvalCounter
392            | StatementKind::PlaceMention(..)
393            | StatementKind::BackwardIncompatibleDropHint { .. }
394            | StatementKind::Nop => None,
395        }
396    }
397
398    #[instrument(level = "trace", skip(self, state))]
399    fn process_immediate(&mut self, lhs: PlaceIndex, rhs: ImmTy<'tcx>, state: &mut ConditionSet) {
400        if let Some(lhs) = self.map.value(lhs)
401            && let Immediate::Scalar(Scalar::Int(int)) = *rhs
402        {
403            state.fulfill_matches(lhs, int)
404        }
405    }
406
407    /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
408    #[instrument(level = "trace", skip(self, state))]
409    fn process_constant(
410        &mut self,
411        lhs: PlaceIndex,
412        constant: OpTy<'tcx>,
413        state: &mut ConditionSet,
414    ) {
415        let values_inside = self.map.values_inside(lhs);
416        if !state.active.iter().any(|&(_, cond)| values_inside.contains(&cond.place)) {
417            return;
418        }
419        self.map.for_each_projection_value(
420            lhs,
421            constant,
422            &mut |elem, op| match elem {
423                TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(),
424                TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
425                TrackElem::Discriminant => {
426                    let variant = self.ecx.read_discriminant(op).discard_err()?;
427                    let discr_value =
428                        self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
429                    Some(discr_value.into())
430                }
431                TrackElem::DerefLen => {
432                    let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
433                    let len_usize = op.len(&self.ecx).discard_err()?;
434                    let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
435                    Some(ImmTy::from_uint(len_usize, layout).into())
436                }
437            },
438            &mut |place, op| {
439                if let Some(place) = self.map.value(place)
440                    && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
441                    && let Some(imm) = imm.right()
442                    && let Immediate::Scalar(Scalar::Int(int)) = *imm
443                {
444                    state.fulfill_matches(place, int)
445                }
446            },
447        );
448    }
449
450    #[instrument(level = "trace", skip(self, state))]
451    fn process_copy(&mut self, lhs: PlaceIndex, rhs: PlaceIndex, state: &mut ConditionSet) {
452        let mut renames = FxHashMap::default();
453        self.map.for_each_value_pair(rhs, lhs, &mut |rhs, lhs| {
454            renames.insert(lhs, rhs);
455        });
456        state.for_each_mut(|c| {
457            if let Some(rhs) = renames.get(&c.place) {
458                c.place = *rhs
459            }
460        });
461    }
462
463    #[instrument(level = "trace", skip(self, state))]
464    fn process_operand(&mut self, lhs: PlaceIndex, rhs: &Operand<'tcx>, state: &mut ConditionSet) {
465        match rhs {
466            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
467            Operand::Constant(constant) => {
468                let Some(constant) =
469                    self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
470                else {
471                    return;
472                };
473                self.process_constant(lhs, constant, state);
474            }
475            // Transfer the conditions on the copied rhs.
476            Operand::Move(rhs) | Operand::Copy(rhs) => {
477                let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
478                self.process_copy(lhs, rhs, state)
479            }
480            Operand::RuntimeChecks(_) => {}
481        }
482    }
483
484    #[instrument(level = "trace", skip(self, state))]
485    fn process_assign(
486        &mut self,
487        lhs_place: &Place<'tcx>,
488        rvalue: &Rvalue<'tcx>,
489        state: &mut ConditionSet,
490    ) {
491        let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
492        match rvalue {
493            Rvalue::Use(operand) => self.process_operand(lhs, operand, state),
494            // Transfer the conditions on the copy rhs.
495            Rvalue::Discriminant(rhs) => {
496                let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
497                self.process_copy(lhs, rhs, state)
498            }
499            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
500            Rvalue::Aggregate(box kind, operands) => {
501                let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
502                let lhs = match kind {
503                    // Do not support unions.
504                    AggregateKind::Adt(.., Some(_)) => return,
505                    AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
506                        if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
507                            && let Some(discr_value) = self
508                                .ecx
509                                .discriminant_for_variant(agg_ty, *variant_index)
510                                .discard_err()
511                        {
512                            self.process_immediate(discr_target, discr_value, state);
513                        }
514                        if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
515                            idx
516                        } else {
517                            return;
518                        }
519                    }
520                    _ => lhs,
521                };
522                for (field_index, operand) in operands.iter_enumerated() {
523                    if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
524                        self.process_operand(field, operand, state);
525                    }
526                }
527            }
528            // Transfer the conditions on the copy rhs, after inverting the value of the condition.
529            Rvalue::UnaryOp(UnOp::Not, Operand::Move(operand) | Operand::Copy(operand)) => {
530                let layout = self.ecx.layout_of(operand.ty(self.body, self.tcx).ty).unwrap();
531                let Some(lhs) = self.map.value(lhs) else { return };
532                let Some(operand) = self.map.find_value(operand.as_ref()) else { return };
533                state.retain_mut(|mut c| {
534                    if c.place == lhs {
535                        let value = self
536                            .ecx
537                            .unary_op(UnOp::Not, &ImmTy::from_scalar_int(c.value, layout))
538                            .discard_err()?
539                            .to_scalar_int()
540                            .discard_err()?;
541                        c.place = operand;
542                        c.value = value;
543                    }
544                    Some(c)
545                });
546            }
547            // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
548            // Create a condition on `rhs ?= B`.
549            Rvalue::BinaryOp(
550                op,
551                box (Operand::Move(operand) | Operand::Copy(operand), Operand::Constant(value))
552                | box (Operand::Constant(value), Operand::Move(operand) | Operand::Copy(operand)),
553            ) => {
554                let equals = match op {
555                    BinOp::Eq => ScalarInt::TRUE,
556                    BinOp::Ne => ScalarInt::FALSE,
557                    _ => return,
558                };
559                if value.const_.ty().is_floating_point() {
560                    // Floating point equality does not follow bit-patterns.
561                    // -0.0 and NaN both have special rules for equality,
562                    // and therefore we cannot use integer comparisons for them.
563                    // Avoid handling them, though this could be extended in the future.
564                    return;
565                }
566                let Some(lhs) = self.map.value(lhs) else { return };
567                let Some(operand) = self.map.find_value(operand.as_ref()) else { return };
568                let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
569                else {
570                    return;
571                };
572                state.for_each_mut(|c| {
573                    if c.place == lhs {
574                        let polarity =
575                            if c.matches(lhs, equals) { Polarity::Eq } else { Polarity::Ne };
576                        c.place = operand;
577                        c.value = value;
578                        c.polarity = polarity;
579                    }
580                });
581            }
582
583            _ => {}
584        }
585    }
586
587    #[instrument(level = "trace", skip(self, state))]
588    fn process_statement(&mut self, stmt: &Statement<'tcx>, state: &mut ConditionSet) {
589        // Below, `lhs` is the return value of `mutated_statement`,
590        // the place to which `conditions` apply.
591
592        match &stmt.kind {
593            // If we expect `discriminant(place) ?= A`,
594            // we have an opportunity if `variant_index ?= A`.
595            StatementKind::SetDiscriminant { box place, variant_index } => {
596                let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
597                let enum_ty = place.ty(self.body, self.tcx).ty;
598                // `SetDiscriminant` guarantees that the discriminant is now `variant_index`.
599                // Even if the discriminant write does nothing due to niches, it is UB to set the
600                // discriminant when the data does not encode the desired discriminant.
601                let Some(discr) =
602                    self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
603                else {
604                    return;
605                };
606                self.process_immediate(discr_target, discr, state)
607            }
608            // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
609            StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
610                Operand::Copy(place) | Operand::Move(place),
611            )) => {
612                let Some(place) = self.map.find_value(place.as_ref()) else { return };
613                state.fulfill_matches(place, ScalarInt::TRUE);
614            }
615            StatementKind::Assign(box (lhs_place, rhs)) => {
616                self.process_assign(lhs_place, rhs, state)
617            }
618            _ => {}
619        }
620    }
621
622    /// Execute the terminator for block `bb` into state `entry_states[bb]`.
623    #[instrument(level = "trace", skip(self, state))]
624    fn process_terminator(&mut self, bb: BasicBlock, state: &mut ConditionSet) {
625        let term = self.body.basic_blocks[bb].terminator();
626        let place_to_flood = match term.kind {
627            // Disallowed during optimizations.
628            TerminatorKind::FalseEdge { .. }
629            | TerminatorKind::FalseUnwind { .. }
630            | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
631            // Cannot reason about inline asm.
632            TerminatorKind::InlineAsm { .. } => {
633                state.active.clear();
634                return;
635            }
636            // `SwitchInt` is handled specially.
637            TerminatorKind::SwitchInt { ref discr, ref targets } => {
638                return self.process_switch_int(discr, targets, state);
639            }
640            // These do not modify memory.
641            TerminatorKind::UnwindResume
642            | TerminatorKind::UnwindTerminate(_)
643            | TerminatorKind::Return
644            | TerminatorKind::Unreachable
645            | TerminatorKind::CoroutineDrop
646            // Assertions can be no-op at codegen time, so treat them as such.
647            | TerminatorKind::Assert { .. }
648            | TerminatorKind::Goto { .. } => None,
649            // Flood the overwritten place, and progress through.
650            TerminatorKind::Drop { place: destination, .. }
651            | TerminatorKind::Call { destination, .. } => Some(destination),
652            TerminatorKind::TailCall { .. } => Some(RETURN_PLACE.into()),
653        };
654
655        // This terminator modifies `place_to_flood`, cleanup the associated conditions.
656        if let Some(place_to_flood) = place_to_flood {
657            self.flood_state(place_to_flood, None, state);
658        }
659    }
660
661    #[instrument(level = "trace", skip(self))]
662    fn process_switch_int(
663        &mut self,
664        discr: &Operand<'tcx>,
665        targets: &SwitchTargets,
666        state: &mut ConditionSet,
667    ) {
668        let Some(discr) = discr.place() else { return };
669        let Some(discr_idx) = self.map.find_value(discr.as_ref()) else { return };
670
671        let discr_ty = discr.ty(self.body, self.tcx).ty;
672        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
673
674        // Attempt to fulfill a condition using an outgoing branch's condition.
675        // Only support the case where there are no duplicated outgoing edges.
676        if targets.is_distinct() {
677            for &(index, c) in state.active.iter() {
678                if c.place != discr_idx {
679                    continue;
680                }
681
682                // Set of blocks `t` such that the edge `bb -> t` fulfills `c`.
683                let mut edges_fulfilling_condition = FxHashSet::default();
684
685                // On edge `bb -> tgt`, we know that `discr_idx == branch`.
686                for (branch, tgt) in targets.iter() {
687                    if let Some(branch) = ScalarInt::try_from_uint(branch, discr_layout.size)
688                        && c.matches(discr_idx, branch)
689                    {
690                        edges_fulfilling_condition.insert(tgt);
691                    }
692                }
693
694                // On edge `bb -> otherwise`, we only know that `discr` is different from all the
695                // constants in the switch. That's much weaker information than the equality we
696                // had in the previous arm. All we can conclude is that the replacement condition
697                // `discr != value` can be threaded, and nothing else.
698                if c.polarity == Polarity::Ne
699                    && let Ok(value) = c.value.try_to_bits(discr_layout.size)
700                    && targets.all_values().contains(&value.into())
701                {
702                    edges_fulfilling_condition.insert(targets.otherwise());
703                }
704
705                // Register that jumping to a `t` fulfills condition `c`.
706                // This does *not* mean that `c` is fulfilled in this block: inserting `index` in
707                // `fulfilled` is wrong if we have targets that jump to other blocks.
708                let condition_targets = &state.targets[index];
709
710                let new_edges: Vec<_> = condition_targets
711                    .iter()
712                    .copied()
713                    .filter(|&target| match target {
714                        EdgeEffect::Goto { .. } => false,
715                        EdgeEffect::Chain { succ_block, .. } => {
716                            edges_fulfilling_condition.contains(&succ_block)
717                        }
718                    })
719                    .collect();
720
721                if new_edges.len() == condition_targets.len() {
722                    // If `new_edges == condition_targets`, do not bother creating a new
723                    // `ConditionIndex`, we can use the existing one.
724                    state.fulfilled.push(index);
725                } else {
726                    // Fulfilling `index` may thread conditions that we do not want,
727                    // so create a brand new index to immediately mark fulfilled.
728                    let index = state.targets.push(new_edges);
729                    state.fulfilled.push(index);
730                }
731            }
732        }
733
734        // Introduce additional conditions of the form `discr ?= value` for each value in targets.
735        let mut mk_condition = |value, polarity, target| {
736            let c = Condition { place: discr_idx, value, polarity };
737            state.push_condition(c, target);
738        };
739        if let Some((value, then_, else_)) = targets.as_static_if() {
740            // We have an `if`, generate both `discr == value` and `discr != value`.
741            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
742            mk_condition(value, Polarity::Eq, then_);
743            mk_condition(value, Polarity::Ne, else_);
744        } else {
745            // We have a general switch and we cannot express `discr != value0 && discr != value1`,
746            // so we only generate equality predicates.
747            for (value, target) in targets.iter() {
748                if let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) {
749                    mk_condition(value, Polarity::Eq, target);
750                }
751            }
752        }
753    }
754}
755
756/// Propagate fulfilled conditions forward in the CFG to reduce the amount of duplication.
757#[instrument(level = "debug", skip(body, entry_states))]
758fn simplify_conditions(body: &Body<'_>, entry_states: &mut IndexVec<BasicBlock, ConditionSet>) {
759    let basic_blocks = &body.basic_blocks;
760    let reverse_postorder = basic_blocks.reverse_postorder();
761
762    // Start by computing the number of *incoming edges* for each block.
763    // We do not use the cached `basic_blocks.predecessors` as we only want reachable predecessors.
764    let mut predecessors = IndexVec::from_elem(0, &entry_states);
765    predecessors[START_BLOCK] = 1; // Account for the implicit entry edge.
766    for &bb in reverse_postorder {
767        let term = basic_blocks[bb].terminator();
768        for s in term.successors() {
769            predecessors[s] += 1;
770        }
771    }
772
773    // Compute the number of edges into each block that carry each condition.
774    let mut fulfill_in_pred_count = IndexVec::from_fn_n(
775        |bb: BasicBlock| IndexVec::from_elem_n(0, entry_states[bb].targets.len()),
776        entry_states.len(),
777    );
778
779    // By traversing in RPO, we increase the likelihood to visit predecessors before successors.
780    for &bb in reverse_postorder {
781        let preds = predecessors[bb];
782        trace!(?bb, ?preds);
783
784        // We have removed all the input edges towards this block. Just skip visiting it.
785        if preds == 0 {
786            continue;
787        }
788
789        let state = &mut entry_states[bb];
790        trace!(?state);
791
792        // Conditions that are fulfilled in all the predecessors, are fulfilled in `bb`.
793        trace!(fulfilled_count = ?fulfill_in_pred_count[bb]);
794        for (condition, &cond_preds) in fulfill_in_pred_count[bb].iter_enumerated() {
795            if cond_preds == preds {
796                trace!(?condition);
797                state.fulfilled.push(condition);
798            }
799        }
800
801        // We want to count how many times each condition is fulfilled,
802        // so ensure we are not counting the same edge twice.
803        let mut targets: Vec<_> = state
804            .fulfilled
805            .iter()
806            .flat_map(|&index| state.targets[index].iter().copied())
807            .collect();
808        targets.sort();
809        targets.dedup();
810        trace!(?targets);
811
812        // We may modify the set of successors by applying edges, so track them here.
813        let mut successors = basic_blocks[bb].terminator().successors().collect::<Vec<_>>();
814
815        targets.reverse();
816        while let Some(target) = targets.pop() {
817            match target {
818                EdgeEffect::Goto { target } => {
819                    // We update the count of predecessors. If target or any successor has not been
820                    // processed yet, this increases the likelihood we find something relevant.
821                    predecessors[target] += 1;
822                    for &s in successors.iter() {
823                        predecessors[s] -= 1;
824                    }
825                    // Only process edges that still exist.
826                    targets.retain(|t| t.block() == target);
827                    successors.clear();
828                    successors.push(target);
829                }
830                EdgeEffect::Chain { succ_block, succ_condition } => {
831                    // `predecessors` is the number of incoming *edges* in each block.
832                    // Count the number of edges that apply `succ_condition` into `succ_block`.
833                    let count = successors.iter().filter(|&&s| s == succ_block).count();
834                    fulfill_in_pred_count[succ_block][succ_condition] += count;
835                }
836            }
837        }
838    }
839}
840
841#[instrument(level = "debug", skip(tcx, typing_env, body, entry_states))]
842fn remove_costly_conditions<'tcx>(
843    tcx: TyCtxt<'tcx>,
844    typing_env: ty::TypingEnv<'tcx>,
845    body: &Body<'tcx>,
846    entry_states: &mut IndexVec<BasicBlock, ConditionSet>,
847) {
848    let basic_blocks = &body.basic_blocks;
849
850    let mut costs = IndexVec::from_elem(None, basic_blocks);
851    let mut cost = |bb: BasicBlock| -> u8 {
852        let c = *costs[bb].get_or_insert_with(|| {
853            let bbdata = &basic_blocks[bb];
854            let mut cost = CostChecker::new(tcx, typing_env, None, body);
855            cost.visit_basic_block_data(bb, bbdata);
856            cost.cost().try_into().unwrap_or(MAX_COST)
857        });
858        trace!("cost[{bb:?}] = {c}");
859        c
860    };
861
862    // Initialize costs with `MAX_COST`: if we have a cycle, the cyclic `bb` has infinite costs.
863    let mut condition_cost = IndexVec::from_fn_n(
864        |bb: BasicBlock| IndexVec::from_elem_n(MAX_COST, entry_states[bb].targets.len()),
865        entry_states.len(),
866    );
867
868    let reverse_postorder = basic_blocks.reverse_postorder();
869
870    for &bb in reverse_postorder.iter().rev() {
871        let state = &entry_states[bb];
872        trace!(?bb, ?state);
873
874        let mut current_costs = IndexVec::from_elem(0u8, &state.targets);
875
876        for (condition, targets) in state.targets.iter_enumerated() {
877            for &target in targets {
878                match target {
879                    // A `Goto` has cost 0.
880                    EdgeEffect::Goto { .. } => {}
881                    // Chaining into an already-fulfilled condition is nop.
882                    EdgeEffect::Chain { succ_block, succ_condition }
883                        if entry_states[succ_block].fulfilled.contains(&succ_condition) => {}
884                    // When chaining, use `cost[succ_block][succ_condition] + cost(succ_block)`.
885                    EdgeEffect::Chain { succ_block, succ_condition } => {
886                        // Cost associated with duplicating `succ_block`.
887                        let duplication_cost = cost(succ_block);
888                        // Cost associated with the rest of the chain.
889                        let target_cost =
890                            *condition_cost[succ_block].get(succ_condition).unwrap_or(&MAX_COST);
891                        let cost = current_costs[condition]
892                            .saturating_add(duplication_cost)
893                            .saturating_add(target_cost);
894                        trace!(?condition, ?succ_block, ?duplication_cost, ?target_cost);
895                        current_costs[condition] = cost;
896                    }
897                }
898            }
899        }
900
901        trace!("condition_cost[{bb:?}] = {:?}", current_costs);
902        condition_cost[bb] = current_costs;
903    }
904
905    trace!(?condition_cost);
906
907    for &bb in reverse_postorder {
908        for (index, targets) in entry_states[bb].targets.iter_enumerated_mut() {
909            if condition_cost[bb][index] >= MAX_COST {
910                trace!(?bb, ?index, ?targets, c = ?condition_cost[bb][index], "remove");
911                targets.clear()
912            }
913        }
914    }
915}
916
917struct OpportunitySet<'a, 'tcx> {
918    basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
919    entry_states: IndexVec<BasicBlock, ConditionSet>,
920    /// Cache duplicated block. When cloning a basic block `bb` to fulfill a condition `c`,
921    /// record the target of this `bb with c` edge.
922    duplicates: FxHashMap<(BasicBlock, ConditionIndex), BasicBlock>,
923}
924
925impl<'a, 'tcx> OpportunitySet<'a, 'tcx> {
926    fn new(
927        body: &'a mut Body<'tcx>,
928        mut entry_states: IndexVec<BasicBlock, ConditionSet>,
929    ) -> Option<OpportunitySet<'a, 'tcx>> {
930        trace!(def_id = ?body.source.def_id(), "apply");
931
932        if entry_states.iter().all(|state| state.fulfilled.is_empty()) {
933            return None;
934        }
935
936        // Free some memory, because we will need to clone condition sets.
937        for state in entry_states.iter_mut() {
938            state.active = Default::default();
939        }
940        let duplicates = Default::default();
941        let basic_blocks = body.basic_blocks.as_mut();
942        Some(OpportunitySet { basic_blocks, entry_states, duplicates })
943    }
944
945    /// Apply the opportunities on the graph.
946    #[instrument(level = "debug", skip(self))]
947    fn apply(mut self) {
948        let mut worklist = Vec::with_capacity(self.basic_blocks.len());
949        worklist.push(START_BLOCK);
950
951        // Use a `GrowableBitSet` and not a `DenseBitSet` as we are adding blocks.
952        let mut visited = GrowableBitSet::with_capacity(self.basic_blocks.len());
953
954        while let Some(bb) = worklist.pop() {
955            if !visited.insert(bb) {
956                continue;
957            }
958
959            self.apply_once(bb);
960
961            // `apply_once` may have modified the terminator of `bb`.
962            // Only visit actual successors.
963            worklist.extend(self.basic_blocks[bb].terminator().successors());
964        }
965    }
966
967    /// Apply the opportunities on `bb`.
968    #[instrument(level = "debug", skip(self))]
969    fn apply_once(&mut self, bb: BasicBlock) {
970        let state = &mut self.entry_states[bb];
971        trace!(?state);
972
973        // We are modifying the `bb` in-place. Once a `EdgeEffect` has been applied,
974        // it does not need to be applied again.
975        let mut targets: Vec<_> = state
976            .fulfilled
977            .iter()
978            .flat_map(|&index| std::mem::take(&mut state.targets[index]))
979            .collect();
980        targets.sort();
981        targets.dedup();
982        trace!(?targets);
983
984        // Use a while-pop to allow modifying `targets` from inside the loop.
985        targets.reverse();
986        while let Some(target) = targets.pop() {
987            debug!(?target);
988            trace!(term = ?self.basic_blocks[bb].terminator().kind);
989
990            // By construction, `target.block()` is a successor of `bb`.
991            // When applying targets, we may change the set of successors.
992            // The match below updates the set of targets for consistency.
993            debug_assert!(
994                self.basic_blocks[bb].terminator().successors().contains(&target.block()),
995                "missing {target:?} in successors for {bb:?}, term={:?}",
996                self.basic_blocks[bb].terminator(),
997            );
998
999            match target {
1000                EdgeEffect::Goto { target } => {
1001                    self.apply_goto(bb, target);
1002
1003                    // We now have `target` as single successor. Drop all other target blocks.
1004                    targets.retain(|t| t.block() == target);
1005                    // Also do this on targets that may be applied by a duplicate of `bb`.
1006                    for ts in self.entry_states[bb].targets.iter_mut() {
1007                        ts.retain(|t| t.block() == target);
1008                    }
1009                }
1010                EdgeEffect::Chain { succ_block, succ_condition } => {
1011                    let new_succ_block = self.apply_chain(bb, succ_block, succ_condition);
1012
1013                    // We have a new name for `target`, ensure it is correctly applied.
1014                    if let Some(new_succ_block) = new_succ_block {
1015                        for t in targets.iter_mut() {
1016                            t.replace_block(succ_block, new_succ_block)
1017                        }
1018                        // Also do this on targets that may be applied by a duplicate of `bb`.
1019                        for t in
1020                            self.entry_states[bb].targets.iter_mut().flat_map(|ts| ts.iter_mut())
1021                        {
1022                            t.replace_block(succ_block, new_succ_block)
1023                        }
1024                    }
1025                }
1026            }
1027
1028            trace!(post_term = ?self.basic_blocks[bb].terminator().kind);
1029        }
1030    }
1031
1032    #[instrument(level = "debug", skip(self))]
1033    fn apply_goto(&mut self, bb: BasicBlock, target: BasicBlock) {
1034        self.basic_blocks[bb].terminator_mut().kind = TerminatorKind::Goto { target };
1035    }
1036
1037    #[instrument(level = "debug", skip(self), ret)]
1038    fn apply_chain(
1039        &mut self,
1040        bb: BasicBlock,
1041        target: BasicBlock,
1042        condition: ConditionIndex,
1043    ) -> Option<BasicBlock> {
1044        if self.entry_states[target].fulfilled.contains(&condition) {
1045            // `target` already fulfills `condition`, so we do not need to thread anything.
1046            trace!("fulfilled");
1047            return None;
1048        }
1049
1050        // We may be tempted to modify `target` in-place to avoid a clone. This is wrong.
1051        // We may still have edges from other blocks to `target` that have not been created yet.
1052        // For instance because we may be threading an edge coming from `bb`,
1053        // or `target` may be a block duplicate for which we may still create predecessors.
1054
1055        let new_target = *self.duplicates.entry((target, condition)).or_insert_with(|| {
1056            // If we already have a duplicate of `target` which fulfills `condition`, reuse it.
1057            // Otherwise, we clone a new bb to such ends.
1058            let new_target = self.basic_blocks.push(self.basic_blocks[target].clone());
1059            trace!(?target, ?new_target, ?condition, "clone");
1060
1061            // By definition, `new_target` fulfills the same condition as `target`, with
1062            // `condition` added.
1063            let mut condition_set = self.entry_states[target].clone();
1064            condition_set.fulfilled.push(condition);
1065            let _new_target = self.entry_states.push(condition_set);
1066            debug_assert_eq!(new_target, _new_target);
1067
1068            new_target
1069        });
1070        trace!(?target, ?new_target, ?condition, "reuse");
1071
1072        // Replace `target` by `new_target` where it appears.
1073        // This changes exactly `direct_count` edges.
1074        self.basic_blocks[bb].terminator_mut().successors_mut(|s| {
1075            if *s == target {
1076                *s = new_target;
1077            }
1078        });
1079
1080        Some(new_target)
1081    }
1082}