rustc_mir_transform/
jump_threading.rs

1//! A jump threading optimization.
2//!
3//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps
4//!    X = 0                                      X = 0
5//! ------------\      /--------              ------------
6//!    X = 1     X----X SwitchInt(X)     =>       X = 1
7//! ------------/      \--------              ------------
8//!
9//!
10//! This implementation is heavily inspired by the work outlined in [libfirm].
11//!
12//! The general algorithm proceeds in two phases: (1) walk the CFG backwards to construct a
13//! graph of threading conditions, and (2) propagate fulfilled conditions forward by duplicating
14//! blocks.
15//!
16//! # 1. Condition graph construction
17//!
18//! In this file, we denote as `place ?= value` the existence of a replacement condition
19//! on `place` with given `value`, irrespective of the polarity and target of that
20//! replacement condition.
21//!
22//! Inside a block, we associate with each condition `c` a set of targets:
23//! - `Goto(target)` if fulfilling `c` changes the terminator into a `Goto { target }`;
24//! - `Chain(target, c2)` if fulfilling `c` means that `c2` is fulfilled inside `target`.
25//!
26//! Before walking a block `bb`, we construct the exit set of condition from its successors.
27//! For each condition `c` in a successor `s`, we record that fulfilling `c` in `bb` will fulfill
28//! `c` in `s`, as a `Chain(s, c)` condition.
29//!
30//! When encountering a `switchInt(place) -> [value: bb...]` terminator, we also record a
31//! `place == value` condition for each `value`, and associate a `Goto(target)` condition.
32//!
33//! Then, we walk the statements backwards, transforming the set of conditions along the way,
34//! resulting in a set of conditions at the block entry.
35//!
36//! We try to avoid creating irreducible control-flow by not threading through a loop header.
37//!
38//! Applying the optimisation can create a lot of new MIR, so we bound the instruction
39//! cost by `MAX_COST`.
40//!
41//! # 2. Block duplication
42//!
43//! We now have the set of fulfilled conditions inside each block and their targets.
44//!
45//! For each block `bb` in reverse postorder, we apply in turn the target associated with each
46//! fulfilled condition:
47//! - for `Goto(target)`, change the terminator of `bb` into a `Goto { target }`;
48//! - for `Chain(target, cond)`, duplicate `target` into a new block which fulfills the same
49//! conditions and also fulfills `cond`. This is made efficient by maintaining a map of duplicates,
50//! `duplicate[(target, cond)]` to avoid cloning blocks multiple times.
51//!
52//! [libfirm]: <https://pp.ipd.kit.edu/uploads/publikationen/priesner17masterarbeit.pdf>
53
54use itertools::Itertools as _;
55use rustc_const_eval::const_eval::DummyMachine;
56use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
57use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
58use rustc_index::IndexVec;
59use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
60use rustc_middle::bug;
61use rustc_middle::mir::interpret::Scalar;
62use rustc_middle::mir::visit::Visitor;
63use rustc_middle::mir::*;
64use rustc_middle::ty::{self, ScalarInt, TyCtxt};
65use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, TrackElem, ValueIndex};
66use rustc_span::DUMMY_SP;
67use tracing::{debug, instrument, trace};
68
69use crate::cost_checker::CostChecker;
70
71pub(super) struct JumpThreading;
72
73const MAX_COST: u8 = 100;
74const MAX_PLACES: usize = 100;
75
76impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
77    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
78        sess.mir_opt_level() >= 2
79    }
80
81    #[instrument(skip_all level = "debug")]
82    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
83        let def_id = body.source.def_id();
84        debug!(?def_id);
85
86        // Optimizing coroutines creates query cycles.
87        if tcx.is_coroutine(def_id) {
88            trace!("Skipped for coroutine {:?}", def_id);
89            return;
90        }
91
92        let typing_env = body.typing_env(tcx);
93        let mut finder = TOFinder {
94            tcx,
95            typing_env,
96            ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
97            body,
98            map: Map::new(tcx, body, Some(MAX_PLACES)),
99            maybe_loop_headers: loops::maybe_loop_headers(body),
100            entry_states: IndexVec::from_elem(ConditionSet::default(), &body.basic_blocks),
101        };
102
103        for (bb, bbdata) in traversal::postorder(body) {
104            if bbdata.is_cleanup {
105                continue;
106            }
107
108            let mut state = finder.populate_from_outgoing_edges(bb);
109            trace!("output_states[{bb:?}] = {state:?}");
110
111            finder.process_terminator(bb, &mut state);
112            trace!("pre_terminator_states[{bb:?}] = {state:?}");
113
114            for stmt in bbdata.statements.iter().rev() {
115                if state.is_empty() {
116                    break;
117                }
118
119                finder.process_statement(stmt, &mut state);
120
121                // When a statement mutates a place, assignments to that place that happen
122                // above the mutation cannot fulfill a condition.
123                //   _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
124                //   _1 = 6
125                if let Some((lhs, tail)) = finder.mutated_statement(stmt) {
126                    finder.flood_state(lhs, tail, &mut state);
127                }
128            }
129
130            trace!("entry_states[{bb:?}] = {state:?}");
131            finder.entry_states[bb] = state;
132        }
133
134        let mut entry_states = finder.entry_states;
135        simplify_conditions(body, &mut entry_states);
136        remove_costly_conditions(tcx, typing_env, body, &mut entry_states);
137
138        if let Some(opportunities) = OpportunitySet::new(body, entry_states) {
139            opportunities.apply();
140        }
141    }
142
143    fn is_required(&self) -> bool {
144        false
145    }
146}
147
148struct TOFinder<'a, 'tcx> {
149    tcx: TyCtxt<'tcx>,
150    typing_env: ty::TypingEnv<'tcx>,
151    ecx: InterpCx<'tcx, DummyMachine>,
152    body: &'a Body<'tcx>,
153    map: Map<'tcx>,
154    maybe_loop_headers: DenseBitSet<BasicBlock>,
155    /// This stores the state of each visited block on entry,
156    /// and the current state of the block being visited.
157    // Invariant: for each `bb`, each condition in `entry_states[bb]` has a `chain` that
158    // starts with `bb`.
159    entry_states: IndexVec<BasicBlock, ConditionSet>,
160}
161
162rustc_index::newtype_index! {
163    #[derive(Ord, PartialOrd)]
164    #[debug_format = "_c{}"]
165    struct ConditionIndex {}
166}
167
168/// Represent the following statement. If we can prove that the current local is equal/not-equal
169/// to `value`, jump to `target`.
170#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
171struct Condition {
172    place: ValueIndex,
173    value: ScalarInt,
174    polarity: Polarity,
175}
176
177#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
178enum Polarity {
179    Ne,
180    Eq,
181}
182
183impl Condition {
184    fn matches(&self, place: ValueIndex, value: ScalarInt) -> bool {
185        self.place == place && (self.value == value) == (self.polarity == Polarity::Eq)
186    }
187}
188
189/// Represent the effect of fulfilling a condition.
190#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
191enum EdgeEffect {
192    /// If the condition is fulfilled, replace the current block's terminator by a single goto.
193    Goto { target: BasicBlock },
194    /// If the condition is fulfilled, fulfill the condition `succ_condition` in `succ_block`.
195    Chain { succ_block: BasicBlock, succ_condition: ConditionIndex },
196}
197
198impl EdgeEffect {
199    fn block(self) -> BasicBlock {
200        match self {
201            EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => bb,
202        }
203    }
204
205    fn replace_block(&mut self, target: BasicBlock, new_target: BasicBlock) {
206        match self {
207            EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => {
208                if *bb == target {
209                    *bb = new_target
210                }
211            }
212        }
213    }
214}
215
216#[derive(Clone, Debug, Default)]
217struct ConditionSet {
218    active: Vec<(ConditionIndex, Condition)>,
219    fulfilled: Vec<ConditionIndex>,
220    targets: IndexVec<ConditionIndex, Vec<EdgeEffect>>,
221}
222
223impl ConditionSet {
224    fn is_empty(&self) -> bool {
225        self.active.is_empty()
226    }
227
228    #[tracing::instrument(level = "trace", skip(self))]
229    fn push_condition(&mut self, c: Condition, target: BasicBlock) {
230        let index = self.targets.push(vec![EdgeEffect::Goto { target }]);
231        self.active.push((index, c));
232    }
233
234    /// Register fulfilled condition and remove it from the set.
235    fn fulfill_if(&mut self, f: impl Fn(Condition, &Vec<EdgeEffect>) -> bool) {
236        self.active.retain(|&(index, condition)| {
237            let targets = &self.targets[index];
238            if f(condition, targets) {
239                trace!(?index, ?condition, "fulfill");
240                self.fulfilled.push(index);
241                false
242            } else {
243                true
244            }
245        })
246    }
247
248    /// Register fulfilled condition and remove them from the set.
249    fn fulfill_matches(&mut self, place: ValueIndex, value: ScalarInt) {
250        self.fulfill_if(|c, _| c.matches(place, value))
251    }
252
253    fn retain(&mut self, mut f: impl FnMut(Condition) -> bool) {
254        self.active.retain(|&(_, c)| f(c))
255    }
256
257    fn retain_mut(&mut self, mut f: impl FnMut(Condition) -> Option<Condition>) {
258        self.active.retain_mut(|(_, c)| {
259            if let Some(new) = f(*c) {
260                *c = new;
261                true
262            } else {
263                false
264            }
265        })
266    }
267
268    fn for_each_mut(&mut self, f: impl Fn(&mut Condition)) {
269        for (_, c) in &mut self.active {
270            f(c)
271        }
272    }
273}
274
275impl<'a, 'tcx> TOFinder<'a, 'tcx> {
276    /// Construct the condition set for `bb` from the terminator, without executing its effect.
277    #[instrument(level = "trace", skip(self))]
278    fn populate_from_outgoing_edges(&mut self, bb: BasicBlock) -> ConditionSet {
279        let bbdata = &self.body[bb];
280
281        // This should be the first time we populate `entry_states[bb]`.
282        debug_assert!(self.entry_states[bb].is_empty());
283
284        let state_len =
285            bbdata.terminator().successors().map(|succ| self.entry_states[succ].active.len()).sum();
286        let mut state = ConditionSet {
287            active: Vec::with_capacity(state_len),
288            targets: IndexVec::with_capacity(state_len),
289            fulfilled: Vec::new(),
290        };
291
292        // Use an index-set to deduplicate conditions coming from different successor blocks.
293        let mut known_conditions =
294            FxIndexSet::with_capacity_and_hasher(state_len, Default::default());
295        let mut insert = |condition, succ_block, succ_condition| {
296            let (index, new) = known_conditions.insert_full(condition);
297            let index = ConditionIndex::from_usize(index);
298            if new {
299                state.active.push((index, condition));
300                let _index = state.targets.push(Vec::new());
301                debug_assert_eq!(_index, index);
302            }
303            let target = EdgeEffect::Chain { succ_block, succ_condition };
304            debug_assert!(
305                !state.targets[index].contains(&target),
306                "duplicate targets for index={index:?} as {target:?} targets={:#?}",
307                &state.targets[index],
308            );
309            state.targets[index].push(target);
310        };
311
312        // A given block may have several times the same successor.
313        let mut seen = FxHashSet::default();
314        for succ in bbdata.terminator().successors() {
315            if !seen.insert(succ) {
316                continue;
317            }
318
319            // Do not thread through loop headers.
320            if self.maybe_loop_headers.contains(succ) {
321                continue;
322            }
323
324            for &(succ_index, cond) in self.entry_states[succ].active.iter() {
325                insert(cond, succ, succ_index);
326            }
327        }
328
329        let num_conditions = known_conditions.len();
330        debug_assert_eq!(num_conditions, state.active.len());
331        debug_assert_eq!(num_conditions, state.targets.len());
332        state.fulfilled.reserve(num_conditions);
333
334        state
335    }
336
337    /// Remove all conditions in the state that alias given place.
338    fn flood_state(
339        &self,
340        place: Place<'tcx>,
341        extra_elem: Option<TrackElem>,
342        state: &mut ConditionSet,
343    ) {
344        if state.is_empty() {
345            return;
346        }
347        let mut places_to_exclude = FxHashSet::default();
348        self.map.for_each_aliasing_place(place.as_ref(), extra_elem, &mut |vi| {
349            places_to_exclude.insert(vi);
350        });
351        trace!(?places_to_exclude, "flood_state");
352        if places_to_exclude.is_empty() {
353            return;
354        }
355        state.retain(|c| !places_to_exclude.contains(&c.place));
356    }
357
358    /// Extract the mutated place from a statement.
359    ///
360    /// This method returns the `Place` so we can flood the state in case of a partial assignment.
361    ///     (_1 as Ok).0 = _5;
362    ///     (_1 as Err).0 = _6;
363    /// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as
364    /// the value may have been mangled by the second assignment.
365    ///
366    /// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can
367    /// stop at flooding the discriminant, and preserve the variant fields.
368    ///     (_1 as Some).0 = _6;
369    ///     SetDiscriminant(_1, 1);
370    ///     switchInt((_1 as Some).0)
371    #[instrument(level = "trace", skip(self), ret)]
372    fn mutated_statement(
373        &self,
374        stmt: &Statement<'tcx>,
375    ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
376        match stmt.kind {
377            StatementKind::Assign(box (place, _)) => Some((place, None)),
378            StatementKind::SetDiscriminant { box place, variant_index: _ } => {
379                Some((place, Some(TrackElem::Discriminant)))
380            }
381            StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
382                Some((Place::from(local), None))
383            }
384            StatementKind::Retag(..)
385            | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
386            // copy_nonoverlapping takes pointers and mutated the pointed-to value.
387            | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
388            | StatementKind::AscribeUserType(..)
389            | StatementKind::Coverage(..)
390            | StatementKind::FakeRead(..)
391            | StatementKind::ConstEvalCounter
392            | StatementKind::PlaceMention(..)
393            | StatementKind::BackwardIncompatibleDropHint { .. }
394            | StatementKind::Nop => None,
395        }
396    }
397
398    #[instrument(level = "trace", skip(self, state))]
399    fn process_immediate(&mut self, lhs: PlaceIndex, rhs: ImmTy<'tcx>, state: &mut ConditionSet) {
400        if let Some(lhs) = self.map.value(lhs)
401            && let Immediate::Scalar(Scalar::Int(int)) = *rhs
402        {
403            state.fulfill_matches(lhs, int)
404        }
405    }
406
407    /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
408    #[instrument(level = "trace", skip(self, state))]
409    fn process_constant(
410        &mut self,
411        lhs: PlaceIndex,
412        constant: OpTy<'tcx>,
413        state: &mut ConditionSet,
414    ) {
415        let values_inside = self.map.values_inside(lhs);
416        if !state.active.iter().any(|&(_, cond)| values_inside.contains(&cond.place)) {
417            return;
418        }
419        self.map.for_each_projection_value(
420            lhs,
421            constant,
422            &mut |elem, op| match elem {
423                TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(),
424                TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
425                TrackElem::Discriminant => {
426                    let variant = self.ecx.read_discriminant(op).discard_err()?;
427                    let discr_value =
428                        self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
429                    Some(discr_value.into())
430                }
431                TrackElem::DerefLen => {
432                    let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
433                    let len_usize = op.len(&self.ecx).discard_err()?;
434                    let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
435                    Some(ImmTy::from_uint(len_usize, layout).into())
436                }
437            },
438            &mut |place, op| {
439                if let Some(place) = self.map.value(place)
440                    && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
441                    && let Some(imm) = imm.right()
442                    && let Immediate::Scalar(Scalar::Int(int)) = *imm
443                {
444                    state.fulfill_matches(place, int)
445                }
446            },
447        );
448    }
449
450    #[instrument(level = "trace", skip(self, state))]
451    fn process_copy(&mut self, lhs: PlaceIndex, rhs: PlaceIndex, state: &mut ConditionSet) {
452        let mut renames = FxHashMap::default();
453        self.map.for_each_value_pair(rhs, lhs, &mut |rhs, lhs| {
454            renames.insert(lhs, rhs);
455        });
456        state.for_each_mut(|c| {
457            if let Some(rhs) = renames.get(&c.place) {
458                c.place = *rhs
459            }
460        });
461    }
462
463    #[instrument(level = "trace", skip(self, state))]
464    fn process_operand(&mut self, lhs: PlaceIndex, rhs: &Operand<'tcx>, state: &mut ConditionSet) {
465        match rhs {
466            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
467            Operand::Constant(constant) => {
468                let Some(constant) =
469                    self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
470                else {
471                    return;
472                };
473                self.process_constant(lhs, constant, state);
474            }
475            // Transfer the conditions on the copied rhs.
476            Operand::Move(rhs) | Operand::Copy(rhs) => {
477                let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
478                self.process_copy(lhs, rhs, state)
479            }
480        }
481    }
482
483    #[instrument(level = "trace", skip(self, state))]
484    fn process_assign(
485        &mut self,
486        lhs_place: &Place<'tcx>,
487        rvalue: &Rvalue<'tcx>,
488        state: &mut ConditionSet,
489    ) {
490        let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
491        match rvalue {
492            Rvalue::Use(operand) => self.process_operand(lhs, operand, state),
493            // Transfer the conditions on the copy rhs.
494            Rvalue::Discriminant(rhs) => {
495                let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
496                self.process_copy(lhs, rhs, state)
497            }
498            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
499            Rvalue::Aggregate(box kind, operands) => {
500                let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
501                let lhs = match kind {
502                    // Do not support unions.
503                    AggregateKind::Adt(.., Some(_)) => return,
504                    AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
505                        if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
506                            && let Some(discr_value) = self
507                                .ecx
508                                .discriminant_for_variant(agg_ty, *variant_index)
509                                .discard_err()
510                        {
511                            self.process_immediate(discr_target, discr_value, state);
512                        }
513                        if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
514                            idx
515                        } else {
516                            return;
517                        }
518                    }
519                    _ => lhs,
520                };
521                for (field_index, operand) in operands.iter_enumerated() {
522                    if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
523                        self.process_operand(field, operand, state);
524                    }
525                }
526            }
527            // Transfer the conditions on the copy rhs, after inverting the value of the condition.
528            Rvalue::UnaryOp(UnOp::Not, Operand::Move(operand) | Operand::Copy(operand)) => {
529                let layout = self.ecx.layout_of(operand.ty(self.body, self.tcx).ty).unwrap();
530                let Some(lhs) = self.map.value(lhs) else { return };
531                let Some(operand) = self.map.find_value(operand.as_ref()) else { return };
532                state.retain_mut(|mut c| {
533                    if c.place == lhs {
534                        let value = self
535                            .ecx
536                            .unary_op(UnOp::Not, &ImmTy::from_scalar_int(c.value, layout))
537                            .discard_err()?
538                            .to_scalar_int()
539                            .discard_err()?;
540                        c.place = operand;
541                        c.value = value;
542                    }
543                    Some(c)
544                });
545            }
546            // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
547            // Create a condition on `rhs ?= B`.
548            Rvalue::BinaryOp(
549                op,
550                box (Operand::Move(operand) | Operand::Copy(operand), Operand::Constant(value))
551                | box (Operand::Constant(value), Operand::Move(operand) | Operand::Copy(operand)),
552            ) => {
553                let equals = match op {
554                    BinOp::Eq => ScalarInt::TRUE,
555                    BinOp::Ne => ScalarInt::FALSE,
556                    _ => return,
557                };
558                if value.const_.ty().is_floating_point() {
559                    // Floating point equality does not follow bit-patterns.
560                    // -0.0 and NaN both have special rules for equality,
561                    // and therefore we cannot use integer comparisons for them.
562                    // Avoid handling them, though this could be extended in the future.
563                    return;
564                }
565                let Some(lhs) = self.map.value(lhs) else { return };
566                let Some(operand) = self.map.find_value(operand.as_ref()) else { return };
567                let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
568                else {
569                    return;
570                };
571                state.for_each_mut(|c| {
572                    if c.place == lhs {
573                        let polarity =
574                            if c.matches(lhs, equals) { Polarity::Eq } else { Polarity::Ne };
575                        c.place = operand;
576                        c.value = value;
577                        c.polarity = polarity;
578                    }
579                });
580            }
581
582            _ => {}
583        }
584    }
585
586    #[instrument(level = "trace", skip(self, state))]
587    fn process_statement(&mut self, stmt: &Statement<'tcx>, state: &mut ConditionSet) {
588        // Below, `lhs` is the return value of `mutated_statement`,
589        // the place to which `conditions` apply.
590
591        match &stmt.kind {
592            // If we expect `discriminant(place) ?= A`,
593            // we have an opportunity if `variant_index ?= A`.
594            StatementKind::SetDiscriminant { box place, variant_index } => {
595                let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
596                let enum_ty = place.ty(self.body, self.tcx).ty;
597                // `SetDiscriminant` guarantees that the discriminant is now `variant_index`.
598                // Even if the discriminant write does nothing due to niches, it is UB to set the
599                // discriminant when the data does not encode the desired discriminant.
600                let Some(discr) =
601                    self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
602                else {
603                    return;
604                };
605                self.process_immediate(discr_target, discr, state)
606            }
607            // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
608            StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
609                Operand::Copy(place) | Operand::Move(place),
610            )) => {
611                let Some(place) = self.map.find_value(place.as_ref()) else { return };
612                state.fulfill_matches(place, ScalarInt::TRUE);
613            }
614            StatementKind::Assign(box (lhs_place, rhs)) => {
615                self.process_assign(lhs_place, rhs, state)
616            }
617            _ => {}
618        }
619    }
620
621    /// Execute the terminator for block `bb` into state `entry_states[bb]`.
622    #[instrument(level = "trace", skip(self, state))]
623    fn process_terminator(&mut self, bb: BasicBlock, state: &mut ConditionSet) {
624        let term = self.body.basic_blocks[bb].terminator();
625        let place_to_flood = match term.kind {
626            // Disallowed during optimizations.
627            TerminatorKind::FalseEdge { .. }
628            | TerminatorKind::FalseUnwind { .. }
629            | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
630            // Cannot reason about inline asm.
631            TerminatorKind::InlineAsm { .. } => {
632                state.active.clear();
633                return;
634            }
635            // `SwitchInt` is handled specially.
636            TerminatorKind::SwitchInt { ref discr, ref targets } => {
637                return self.process_switch_int(discr, targets, state);
638            }
639            // These do not modify memory.
640            TerminatorKind::UnwindResume
641            | TerminatorKind::UnwindTerminate(_)
642            | TerminatorKind::Return
643            | TerminatorKind::Unreachable
644            | TerminatorKind::CoroutineDrop
645            // Assertions can be no-op at codegen time, so treat them as such.
646            | TerminatorKind::Assert { .. }
647            | TerminatorKind::Goto { .. } => None,
648            // Flood the overwritten place, and progress through.
649            TerminatorKind::Drop { place: destination, .. }
650            | TerminatorKind::Call { destination, .. } => Some(destination),
651            TerminatorKind::TailCall { .. } => Some(RETURN_PLACE.into()),
652        };
653
654        // This terminator modifies `place_to_flood`, cleanup the associated conditions.
655        if let Some(place_to_flood) = place_to_flood {
656            self.flood_state(place_to_flood, None, state);
657        }
658    }
659
660    #[instrument(level = "trace", skip(self))]
661    fn process_switch_int(
662        &mut self,
663        discr: &Operand<'tcx>,
664        targets: &SwitchTargets,
665        state: &mut ConditionSet,
666    ) {
667        let Some(discr) = discr.place() else { return };
668        let Some(discr_idx) = self.map.find_value(discr.as_ref()) else { return };
669
670        let discr_ty = discr.ty(self.body, self.tcx).ty;
671        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
672
673        // Attempt to fulfill a condition using an outgoing branch's condition.
674        // Only support the case where there are no duplicated outgoing edges.
675        if targets.is_distinct() {
676            for &(index, c) in state.active.iter() {
677                if c.place != discr_idx {
678                    continue;
679                }
680
681                // Set of blocks `t` such that the edge `bb -> t` fulfills `c`.
682                let mut edges_fulfilling_condition = FxHashSet::default();
683
684                // On edge `bb -> tgt`, we know that `discr_idx == branch`.
685                for (branch, tgt) in targets.iter() {
686                    if let Some(branch) = ScalarInt::try_from_uint(branch, discr_layout.size)
687                        && c.matches(discr_idx, branch)
688                    {
689                        edges_fulfilling_condition.insert(tgt);
690                    }
691                }
692
693                // On edge `bb -> otherwise`, we only know that `discr` is different from all the
694                // constants in the switch. That's much weaker information than the equality we
695                // had in the previous arm. All we can conclude is that the replacement condition
696                // `discr != value` can be threaded, and nothing else.
697                if c.polarity == Polarity::Ne
698                    && let Ok(value) = c.value.try_to_bits(discr_layout.size)
699                    && targets.all_values().contains(&value.into())
700                {
701                    edges_fulfilling_condition.insert(targets.otherwise());
702                }
703
704                // Register that jumping to a `t` fulfills condition `c`.
705                // This does *not* mean that `c` is fulfilled in this block: inserting `index` in
706                // `fulfilled` is wrong if we have targets that jump to other blocks.
707                let condition_targets = &state.targets[index];
708
709                let new_edges: Vec<_> = condition_targets
710                    .iter()
711                    .copied()
712                    .filter(|&target| match target {
713                        EdgeEffect::Goto { .. } => false,
714                        EdgeEffect::Chain { succ_block, .. } => {
715                            edges_fulfilling_condition.contains(&succ_block)
716                        }
717                    })
718                    .collect();
719
720                if new_edges.len() == condition_targets.len() {
721                    // If `new_edges == condition_targets`, do not bother creating a new
722                    // `ConditionIndex`, we can use the existing one.
723                    state.fulfilled.push(index);
724                } else {
725                    // Fulfilling `index` may thread conditions that we do not want,
726                    // so create a brand new index to immediately mark fulfilled.
727                    let index = state.targets.push(new_edges);
728                    state.fulfilled.push(index);
729                }
730            }
731        }
732
733        // Introduce additional conditions of the form `discr ?= value` for each value in targets.
734        let mut mk_condition = |value, polarity, target| {
735            let c = Condition { place: discr_idx, value, polarity };
736            state.push_condition(c, target);
737        };
738        if let Some((value, then_, else_)) = targets.as_static_if() {
739            // We have an `if`, generate both `discr == value` and `discr != value`.
740            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
741            mk_condition(value, Polarity::Eq, then_);
742            mk_condition(value, Polarity::Ne, else_);
743        } else {
744            // We have a general switch and we cannot express `discr != value0 && discr != value1`,
745            // so we only generate equality predicates.
746            for (value, target) in targets.iter() {
747                if let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) {
748                    mk_condition(value, Polarity::Eq, target);
749                }
750            }
751        }
752    }
753}
754
755/// Propagate fulfilled conditions forward in the CFG to reduce the amount of duplication.
756#[instrument(level = "debug", skip(body, entry_states))]
757fn simplify_conditions(body: &Body<'_>, entry_states: &mut IndexVec<BasicBlock, ConditionSet>) {
758    let basic_blocks = &body.basic_blocks;
759    let reverse_postorder = basic_blocks.reverse_postorder();
760
761    // Start by computing the number of *incoming edges* for each block.
762    // We do not use the cached `basic_blocks.predecessors` as we only want reachable predecessors.
763    let mut predecessors = IndexVec::from_elem(0, &entry_states);
764    predecessors[START_BLOCK] = 1; // Account for the implicit entry edge.
765    for &bb in reverse_postorder {
766        let term = basic_blocks[bb].terminator();
767        for s in term.successors() {
768            predecessors[s] += 1;
769        }
770    }
771
772    // Compute the number of edges into each block that carry each condition.
773    let mut fulfill_in_pred_count = IndexVec::from_fn_n(
774        |bb: BasicBlock| IndexVec::from_elem_n(0, entry_states[bb].targets.len()),
775        entry_states.len(),
776    );
777
778    // By traversing in RPO, we increase the likelihood to visit predecessors before successors.
779    for &bb in reverse_postorder {
780        let preds = predecessors[bb];
781        trace!(?bb, ?preds);
782
783        // We have removed all the input edges towards this block. Just skip visiting it.
784        if preds == 0 {
785            continue;
786        }
787
788        let state = &mut entry_states[bb];
789        trace!(?state);
790
791        // Conditions that are fulfilled in all the predecessors, are fulfilled in `bb`.
792        trace!(fulfilled_count = ?fulfill_in_pred_count[bb]);
793        for (condition, &cond_preds) in fulfill_in_pred_count[bb].iter_enumerated() {
794            if cond_preds == preds {
795                trace!(?condition);
796                state.fulfilled.push(condition);
797            }
798        }
799
800        // We want to count how many times each condition is fulfilled,
801        // so ensure we are not counting the same edge twice.
802        let mut targets: Vec<_> = state
803            .fulfilled
804            .iter()
805            .flat_map(|&index| state.targets[index].iter().copied())
806            .collect();
807        targets.sort();
808        targets.dedup();
809        trace!(?targets);
810
811        // We may modify the set of successors by applying edges, so track them here.
812        let mut successors = basic_blocks[bb].terminator().successors().collect::<Vec<_>>();
813
814        targets.reverse();
815        while let Some(target) = targets.pop() {
816            match target {
817                EdgeEffect::Goto { target } => {
818                    // We update the count of predecessors. If target or any successor has not been
819                    // processed yet, this increases the likelihood we find something relevant.
820                    predecessors[target] += 1;
821                    for &s in successors.iter() {
822                        predecessors[s] -= 1;
823                    }
824                    // Only process edges that still exist.
825                    targets.retain(|t| t.block() == target);
826                    successors.clear();
827                    successors.push(target);
828                }
829                EdgeEffect::Chain { succ_block, succ_condition } => {
830                    // `predecessors` is the number of incoming *edges* in each block.
831                    // Count the number of edges that apply `succ_condition` into `succ_block`.
832                    let count = successors.iter().filter(|&&s| s == succ_block).count();
833                    fulfill_in_pred_count[succ_block][succ_condition] += count;
834                }
835            }
836        }
837    }
838}
839
840#[instrument(level = "debug", skip(tcx, typing_env, body, entry_states))]
841fn remove_costly_conditions<'tcx>(
842    tcx: TyCtxt<'tcx>,
843    typing_env: ty::TypingEnv<'tcx>,
844    body: &Body<'tcx>,
845    entry_states: &mut IndexVec<BasicBlock, ConditionSet>,
846) {
847    let basic_blocks = &body.basic_blocks;
848
849    let mut costs = IndexVec::from_elem(None, basic_blocks);
850    let mut cost = |bb: BasicBlock| -> u8 {
851        let c = *costs[bb].get_or_insert_with(|| {
852            let bbdata = &basic_blocks[bb];
853            let mut cost = CostChecker::new(tcx, typing_env, None, body);
854            cost.visit_basic_block_data(bb, bbdata);
855            cost.cost().try_into().unwrap_or(MAX_COST)
856        });
857        trace!("cost[{bb:?}] = {c}");
858        c
859    };
860
861    // Initialize costs with `MAX_COST`: if we have a cycle, the cyclic `bb` has infinite costs.
862    let mut condition_cost = IndexVec::from_fn_n(
863        |bb: BasicBlock| IndexVec::from_elem_n(MAX_COST, entry_states[bb].targets.len()),
864        entry_states.len(),
865    );
866
867    let reverse_postorder = basic_blocks.reverse_postorder();
868
869    for &bb in reverse_postorder.iter().rev() {
870        let state = &entry_states[bb];
871        trace!(?bb, ?state);
872
873        let mut current_costs = IndexVec::from_elem(0u8, &state.targets);
874
875        for (condition, targets) in state.targets.iter_enumerated() {
876            for &target in targets {
877                match target {
878                    // A `Goto` has cost 0.
879                    EdgeEffect::Goto { .. } => {}
880                    // Chaining into an already-fulfilled condition is nop.
881                    EdgeEffect::Chain { succ_block, succ_condition }
882                        if entry_states[succ_block].fulfilled.contains(&succ_condition) => {}
883                    // When chaining, use `cost[succ_block][succ_condition] + cost(succ_block)`.
884                    EdgeEffect::Chain { succ_block, succ_condition } => {
885                        // Cost associated with duplicating `succ_block`.
886                        let duplication_cost = cost(succ_block);
887                        // Cost associated with the rest of the chain.
888                        let target_cost =
889                            *condition_cost[succ_block].get(succ_condition).unwrap_or(&MAX_COST);
890                        let cost = current_costs[condition]
891                            .saturating_add(duplication_cost)
892                            .saturating_add(target_cost);
893                        trace!(?condition, ?succ_block, ?duplication_cost, ?target_cost);
894                        current_costs[condition] = cost;
895                    }
896                }
897            }
898        }
899
900        trace!("condition_cost[{bb:?}] = {:?}", current_costs);
901        condition_cost[bb] = current_costs;
902    }
903
904    trace!(?condition_cost);
905
906    for &bb in reverse_postorder {
907        for (index, targets) in entry_states[bb].targets.iter_enumerated_mut() {
908            if condition_cost[bb][index] >= MAX_COST {
909                trace!(?bb, ?index, ?targets, c = ?condition_cost[bb][index], "remove");
910                targets.clear()
911            }
912        }
913    }
914}
915
916struct OpportunitySet<'a, 'tcx> {
917    basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
918    entry_states: IndexVec<BasicBlock, ConditionSet>,
919    /// Cache duplicated block. When cloning a basic block `bb` to fulfill a condition `c`,
920    /// record the target of this `bb with c` edge.
921    duplicates: FxHashMap<(BasicBlock, ConditionIndex), BasicBlock>,
922}
923
924impl<'a, 'tcx> OpportunitySet<'a, 'tcx> {
925    fn new(
926        body: &'a mut Body<'tcx>,
927        mut entry_states: IndexVec<BasicBlock, ConditionSet>,
928    ) -> Option<OpportunitySet<'a, 'tcx>> {
929        trace!(def_id = ?body.source.def_id(), "apply");
930
931        if entry_states.iter().all(|state| state.fulfilled.is_empty()) {
932            return None;
933        }
934
935        // Free some memory, because we will need to clone condition sets.
936        for state in entry_states.iter_mut() {
937            state.active = Default::default();
938        }
939        let duplicates = Default::default();
940        let basic_blocks = body.basic_blocks.as_mut();
941        Some(OpportunitySet { basic_blocks, entry_states, duplicates })
942    }
943
944    /// Apply the opportunities on the graph.
945    #[instrument(level = "debug", skip(self))]
946    fn apply(mut self) {
947        let mut worklist = Vec::with_capacity(self.basic_blocks.len());
948        worklist.push(START_BLOCK);
949
950        // Use a `GrowableBitSet` and not a `DenseBitSet` as we are adding blocks.
951        let mut visited = GrowableBitSet::with_capacity(self.basic_blocks.len());
952
953        while let Some(bb) = worklist.pop() {
954            if !visited.insert(bb) {
955                continue;
956            }
957
958            self.apply_once(bb);
959
960            // `apply_once` may have modified the terminator of `bb`.
961            // Only visit actual successors.
962            worklist.extend(self.basic_blocks[bb].terminator().successors());
963        }
964    }
965
966    /// Apply the opportunities on `bb`.
967    #[instrument(level = "debug", skip(self))]
968    fn apply_once(&mut self, bb: BasicBlock) {
969        let state = &mut self.entry_states[bb];
970        trace!(?state);
971
972        // We are modifying the `bb` in-place. Once a `EdgeEffect` has been applied,
973        // it does not need to be applied again.
974        let mut targets: Vec<_> = state
975            .fulfilled
976            .iter()
977            .flat_map(|&index| std::mem::take(&mut state.targets[index]))
978            .collect();
979        targets.sort();
980        targets.dedup();
981        trace!(?targets);
982
983        // Use a while-pop to allow modifying `targets` from inside the loop.
984        targets.reverse();
985        while let Some(target) = targets.pop() {
986            debug!(?target);
987            trace!(term = ?self.basic_blocks[bb].terminator().kind);
988
989            // By construction, `target.block()` is a successor of `bb`.
990            // When applying targets, we may change the set of successors.
991            // The match below updates the set of targets for consistency.
992            debug_assert!(
993                self.basic_blocks[bb].terminator().successors().contains(&target.block()),
994                "missing {target:?} in successors for {bb:?}, term={:?}",
995                self.basic_blocks[bb].terminator(),
996            );
997
998            match target {
999                EdgeEffect::Goto { target } => {
1000                    self.apply_goto(bb, target);
1001
1002                    // We now have `target` as single successor. Drop all other target blocks.
1003                    targets.retain(|t| t.block() == target);
1004                    // Also do this on targets that may be applied by a duplicate of `bb`.
1005                    for ts in self.entry_states[bb].targets.iter_mut() {
1006                        ts.retain(|t| t.block() == target);
1007                    }
1008                }
1009                EdgeEffect::Chain { succ_block, succ_condition } => {
1010                    let new_succ_block = self.apply_chain(bb, succ_block, succ_condition);
1011
1012                    // We have a new name for `target`, ensure it is correctly applied.
1013                    if let Some(new_succ_block) = new_succ_block {
1014                        for t in targets.iter_mut() {
1015                            t.replace_block(succ_block, new_succ_block)
1016                        }
1017                        // Also do this on targets that may be applied by a duplicate of `bb`.
1018                        for t in
1019                            self.entry_states[bb].targets.iter_mut().flat_map(|ts| ts.iter_mut())
1020                        {
1021                            t.replace_block(succ_block, new_succ_block)
1022                        }
1023                    }
1024                }
1025            }
1026
1027            trace!(post_term = ?self.basic_blocks[bb].terminator().kind);
1028        }
1029    }
1030
1031    #[instrument(level = "debug", skip(self))]
1032    fn apply_goto(&mut self, bb: BasicBlock, target: BasicBlock) {
1033        self.basic_blocks[bb].terminator_mut().kind = TerminatorKind::Goto { target };
1034    }
1035
1036    #[instrument(level = "debug", skip(self), ret)]
1037    fn apply_chain(
1038        &mut self,
1039        bb: BasicBlock,
1040        target: BasicBlock,
1041        condition: ConditionIndex,
1042    ) -> Option<BasicBlock> {
1043        if self.entry_states[target].fulfilled.contains(&condition) {
1044            // `target` already fulfills `condition`, so we do not need to thread anything.
1045            trace!("fulfilled");
1046            return None;
1047        }
1048
1049        // We may be tempted to modify `target` in-place to avoid a clone. This is wrong.
1050        // We may still have edges from other blocks to `target` that have not been created yet.
1051        // For instance because we may be threading an edge coming from `bb`,
1052        // or `target` may be a block duplicate for which we may still create predecessors.
1053
1054        let new_target = *self.duplicates.entry((target, condition)).or_insert_with(|| {
1055            // If we already have a duplicate of `target` which fulfills `condition`, reuse it.
1056            // Otherwise, we clone a new bb to such ends.
1057            let new_target = self.basic_blocks.push(self.basic_blocks[target].clone());
1058            trace!(?target, ?new_target, ?condition, "clone");
1059
1060            // By definition, `new_target` fulfills the same condition as `target`, with
1061            // `condition` added.
1062            let mut condition_set = self.entry_states[target].clone();
1063            condition_set.fulfilled.push(condition);
1064            let _new_target = self.entry_states.push(condition_set);
1065            debug_assert_eq!(new_target, _new_target);
1066
1067            new_target
1068        });
1069        trace!(?target, ?new_target, ?condition, "reuse");
1070
1071        // Replace `target` by `new_target` where it appears.
1072        // This changes exactly `direct_count` edges.
1073        self.basic_blocks[bb].terminator_mut().successors_mut(|s| {
1074            if *s == target {
1075                *s = new_target;
1076            }
1077        });
1078
1079        Some(new_target)
1080    }
1081}