rustc_mir_transform/
jump_threading.rs

1//! A jump threading optimization.
2//!
3//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps
4//!    X = 0                                      X = 0
5//! ------------\      /--------              ------------
6//!    X = 1     X----X SwitchInt(X)     =>       X = 1
7//! ------------/      \--------              ------------
8//!
9//!
10//! This implementation is heavily inspired by the work outlined in [libfirm].
11//!
12//! The general algorithm proceeds in two phases: (1) walk the CFG backwards to construct a
13//! graph of threading conditions, and (2) propagate fulfilled conditions forward by duplicating
14//! blocks.
15//!
16//! # 1. Condition graph construction
17//!
18//! In this file, we denote as `place ?= value` the existence of a replacement condition
19//! on `place` with given `value`, irrespective of the polarity and target of that
20//! replacement condition.
21//!
22//! Inside a block, we associate with each condition `c` a set of targets:
23//! - `Goto(target)` if fulfilling `c` changes the terminator into a `Goto { target }`;
24//! - `Chain(target, c2)` if fulfilling `c` means that `c2` is fulfilled inside `target`.
25//!
26//! Before walking a block `bb`, we construct the exit set of condition from its successors.
27//! For each condition `c` in a successor `s`, we record that fulfilling `c` in `bb` will fulfill
28//! `c` in `s`, as a `Chain(s, c)` condition.
29//!
30//! When encountering a `switchInt(place) -> [value: bb...]` terminator, we also record a
31//! `place == value` condition for each `value`, and associate a `Goto(target)` condition.
32//!
33//! Then, we walk the statements backwards, transforming the set of conditions along the way,
34//! resulting in a set of conditions at the block entry.
35//!
36//! We try to avoid creating irreducible control-flow by not threading through a loop header.
37//!
38//! Applying the optimisation can create a lot of new MIR, so we bound the instruction
39//! cost by `MAX_COST`.
40//!
41//! # 2. Block duplication
42//!
43//! We now have the set of fulfilled conditions inside each block and their targets.
44//!
45//! For each block `bb` in reverse postorder, we apply in turn the target associated with each
46//! fulfilled condition:
47//! - for `Goto(target)`, change the terminator of `bb` into a `Goto { target }`;
48//! - for `Chain(target, cond)`, duplicate `target` into a new block which fulfills the same
49//! conditions and also fulfills `cond`. This is made efficient by maintaining a map of duplicates,
50//! `duplicate[(target, cond)]` to avoid cloning blocks multiple times.
51//!
52//! [libfirm]: <https://pp.ipd.kit.edu/uploads/publikationen/priesner17masterarbeit.pdf>
53
54use itertools::Itertools as _;
55use rustc_const_eval::const_eval::DummyMachine;
56use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
57use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
58use rustc_index::IndexVec;
59use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
60use rustc_middle::bug;
61use rustc_middle::mir::interpret::Scalar;
62use rustc_middle::mir::visit::Visitor;
63use rustc_middle::mir::*;
64use rustc_middle::ty::{self, ScalarInt, TyCtxt};
65use rustc_mir_dataflow::value_analysis::{
66    Map, PlaceCollectionMode, PlaceIndex, TrackElem, ValueIndex,
67};
68use rustc_span::DUMMY_SP;
69use tracing::{debug, instrument, trace};
70
71use crate::cost_checker::CostChecker;
72
73pub(super) struct JumpThreading;
74
75const MAX_COST: u8 = 100;
76
77impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
78    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
79        sess.mir_opt_level() >= 2
80    }
81
82    #[instrument(skip_all level = "debug")]
83    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
84        let def_id = body.source.def_id();
85        debug!(?def_id);
86
87        // Optimizing coroutines creates query cycles.
88        if tcx.is_coroutine(def_id) {
89            trace!("Skipped for coroutine {:?}", def_id);
90            return;
91        }
92
93        let typing_env = body.typing_env(tcx);
94        let mut finder = TOFinder {
95            tcx,
96            typing_env,
97            ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
98            body,
99            map: Map::new(tcx, body, PlaceCollectionMode::OnDemand),
100            maybe_loop_headers: loops::maybe_loop_headers(body),
101            entry_states: IndexVec::from_elem(ConditionSet::default(), &body.basic_blocks),
102        };
103
104        for (bb, bbdata) in traversal::postorder(body) {
105            if bbdata.is_cleanup {
106                continue;
107            }
108
109            let mut state = finder.populate_from_outgoing_edges(bb);
110            trace!("output_states[{bb:?}] = {state:?}");
111
112            finder.process_terminator(bb, &mut state);
113            trace!("pre_terminator_states[{bb:?}] = {state:?}");
114
115            for stmt in bbdata.statements.iter().rev() {
116                if state.is_empty() {
117                    break;
118                }
119
120                finder.process_statement(stmt, &mut state);
121
122                // When a statement mutates a place, assignments to that place that happen
123                // above the mutation cannot fulfill a condition.
124                //   _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
125                //   _1 = 6
126                if let Some((lhs, tail)) = finder.mutated_statement(stmt) {
127                    finder.flood_state(lhs, tail, &mut state);
128                }
129            }
130
131            trace!("entry_states[{bb:?}] = {state:?}");
132            finder.entry_states[bb] = state;
133        }
134
135        let mut entry_states = finder.entry_states;
136        simplify_conditions(body, &mut entry_states);
137        remove_costly_conditions(tcx, typing_env, body, &mut entry_states);
138
139        if let Some(opportunities) = OpportunitySet::new(body, entry_states) {
140            opportunities.apply();
141        }
142    }
143
144    fn is_required(&self) -> bool {
145        false
146    }
147}
148
149struct TOFinder<'a, 'tcx> {
150    tcx: TyCtxt<'tcx>,
151    typing_env: ty::TypingEnv<'tcx>,
152    ecx: InterpCx<'tcx, DummyMachine>,
153    body: &'a Body<'tcx>,
154    map: Map<'tcx>,
155    maybe_loop_headers: DenseBitSet<BasicBlock>,
156    /// This stores the state of each visited block on entry,
157    /// and the current state of the block being visited.
158    // Invariant: for each `bb`, each condition in `entry_states[bb]` has a `chain` that
159    // starts with `bb`.
160    entry_states: IndexVec<BasicBlock, ConditionSet>,
161}
162
163rustc_index::newtype_index! {
164    #[derive(Ord, PartialOrd)]
165    #[debug_format = "_c{}"]
166    struct ConditionIndex {}
167}
168
169/// Represent the following statement. If we can prove that the current local is equal/not-equal
170/// to `value`, jump to `target`.
171#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
172struct Condition {
173    place: ValueIndex,
174    value: ScalarInt,
175    polarity: Polarity,
176}
177
178#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
179enum Polarity {
180    Ne,
181    Eq,
182}
183
184impl Condition {
185    fn matches(&self, place: ValueIndex, value: ScalarInt) -> bool {
186        self.place == place && (self.value == value) == (self.polarity == Polarity::Eq)
187    }
188}
189
190/// Represent the effect of fulfilling a condition.
191#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
192enum EdgeEffect {
193    /// If the condition is fulfilled, replace the current block's terminator by a single goto.
194    Goto { target: BasicBlock },
195    /// If the condition is fulfilled, fulfill the condition `succ_condition` in `succ_block`.
196    Chain { succ_block: BasicBlock, succ_condition: ConditionIndex },
197}
198
199impl EdgeEffect {
200    fn block(self) -> BasicBlock {
201        match self {
202            EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => bb,
203        }
204    }
205
206    fn replace_block(&mut self, target: BasicBlock, new_target: BasicBlock) {
207        match self {
208            EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => {
209                if *bb == target {
210                    *bb = new_target
211                }
212            }
213        }
214    }
215}
216
217#[derive(Clone, Debug, Default)]
218struct ConditionSet {
219    active: Vec<(ConditionIndex, Condition)>,
220    fulfilled: Vec<ConditionIndex>,
221    targets: IndexVec<ConditionIndex, Vec<EdgeEffect>>,
222}
223
224impl ConditionSet {
225    fn is_empty(&self) -> bool {
226        self.active.is_empty()
227    }
228
229    #[tracing::instrument(level = "trace", skip(self))]
230    fn push_condition(&mut self, c: Condition, target: BasicBlock) {
231        let index = self.targets.push(vec![EdgeEffect::Goto { target }]);
232        self.active.push((index, c));
233    }
234
235    /// Register fulfilled condition and remove it from the set.
236    fn fulfill_if(&mut self, f: impl Fn(Condition, &Vec<EdgeEffect>) -> bool) {
237        self.active.retain(|&(index, condition)| {
238            let targets = &self.targets[index];
239            if f(condition, targets) {
240                trace!(?index, ?condition, "fulfill");
241                self.fulfilled.push(index);
242                false
243            } else {
244                true
245            }
246        })
247    }
248
249    /// Register fulfilled condition and remove them from the set.
250    fn fulfill_matches(&mut self, place: ValueIndex, value: ScalarInt) {
251        self.fulfill_if(|c, _| c.matches(place, value))
252    }
253
254    fn retain(&mut self, mut f: impl FnMut(Condition) -> bool) {
255        self.active.retain(|&(_, c)| f(c))
256    }
257
258    fn retain_mut(&mut self, mut f: impl FnMut(Condition) -> Option<Condition>) {
259        self.active.retain_mut(|(_, c)| {
260            if let Some(new) = f(*c) {
261                *c = new;
262                true
263            } else {
264                false
265            }
266        })
267    }
268
269    fn for_each_mut(&mut self, f: impl Fn(&mut Condition)) {
270        for (_, c) in &mut self.active {
271            f(c)
272        }
273    }
274}
275
276impl<'a, 'tcx> TOFinder<'a, 'tcx> {
277    fn place(&mut self, place: Place<'tcx>, tail: Option<TrackElem>) -> Option<PlaceIndex> {
278        self.map.register_place(self.tcx, self.body, place, tail)
279    }
280
281    fn value(&mut self, place: PlaceIndex) -> Option<ValueIndex> {
282        self.map.register_value(self.tcx, self.typing_env, place)
283    }
284
285    fn place_value(&mut self, place: Place<'tcx>, tail: Option<TrackElem>) -> Option<ValueIndex> {
286        let place = self.place(place, tail)?;
287        self.value(place)
288    }
289
290    /// Construct the condition set for `bb` from the terminator, without executing its effect.
291    #[instrument(level = "trace", skip(self))]
292    fn populate_from_outgoing_edges(&mut self, bb: BasicBlock) -> ConditionSet {
293        let bbdata = &self.body[bb];
294
295        // This should be the first time we populate `entry_states[bb]`.
296        debug_assert!(self.entry_states[bb].is_empty());
297
298        let state_len =
299            bbdata.terminator().successors().map(|succ| self.entry_states[succ].active.len()).sum();
300        let mut state = ConditionSet {
301            active: Vec::with_capacity(state_len),
302            targets: IndexVec::with_capacity(state_len),
303            fulfilled: Vec::new(),
304        };
305
306        // Use an index-set to deduplicate conditions coming from different successor blocks.
307        let mut known_conditions =
308            FxIndexSet::with_capacity_and_hasher(state_len, Default::default());
309        let mut insert = |condition, succ_block, succ_condition| {
310            let (index, new) = known_conditions.insert_full(condition);
311            let index = ConditionIndex::from_usize(index);
312            if new {
313                state.active.push((index, condition));
314                let _index = state.targets.push(Vec::new());
315                debug_assert_eq!(_index, index);
316            }
317            let target = EdgeEffect::Chain { succ_block, succ_condition };
318            debug_assert!(
319                !state.targets[index].contains(&target),
320                "duplicate targets for index={index:?} as {target:?} targets={:#?}",
321                &state.targets[index],
322            );
323            state.targets[index].push(target);
324        };
325
326        // A given block may have several times the same successor.
327        let mut seen = FxHashSet::default();
328        for succ in bbdata.terminator().successors() {
329            if !seen.insert(succ) {
330                continue;
331            }
332
333            // Do not thread through loop headers.
334            if self.maybe_loop_headers.contains(succ) {
335                continue;
336            }
337
338            for &(succ_index, cond) in self.entry_states[succ].active.iter() {
339                insert(cond, succ, succ_index);
340            }
341        }
342
343        let num_conditions = known_conditions.len();
344        debug_assert_eq!(num_conditions, state.active.len());
345        debug_assert_eq!(num_conditions, state.targets.len());
346        state.fulfilled.reserve(num_conditions);
347
348        state
349    }
350
351    /// Remove all conditions in the state that alias given place.
352    fn flood_state(
353        &self,
354        place: Place<'tcx>,
355        extra_elem: Option<TrackElem>,
356        state: &mut ConditionSet,
357    ) {
358        if state.is_empty() {
359            return;
360        }
361        let mut places_to_exclude = FxHashSet::default();
362        self.map.for_each_aliasing_place(place.as_ref(), extra_elem, &mut |vi| {
363            places_to_exclude.insert(vi);
364        });
365        trace!(?places_to_exclude, "flood_state");
366        if places_to_exclude.is_empty() {
367            return;
368        }
369        state.retain(|c| !places_to_exclude.contains(&c.place));
370    }
371
372    /// Extract the mutated place from a statement.
373    ///
374    /// This method returns the `Place` so we can flood the state in case of a partial assignment.
375    ///     (_1 as Ok).0 = _5;
376    ///     (_1 as Err).0 = _6;
377    /// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as
378    /// the value may have been mangled by the second assignment.
379    ///
380    /// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can
381    /// stop at flooding the discriminant, and preserve the variant fields.
382    ///     (_1 as Some).0 = _6;
383    ///     SetDiscriminant(_1, 1);
384    ///     switchInt((_1 as Some).0)
385    #[instrument(level = "trace", skip(self), ret)]
386    fn mutated_statement(
387        &self,
388        stmt: &Statement<'tcx>,
389    ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
390        match stmt.kind {
391            StatementKind::Assign(box (place, _)) => Some((place, None)),
392            StatementKind::SetDiscriminant { box place, variant_index: _ } => {
393                Some((place, Some(TrackElem::Discriminant)))
394            }
395            StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
396                Some((Place::from(local), None))
397            }
398            StatementKind::Retag(..)
399            | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
400            // copy_nonoverlapping takes pointers and mutated the pointed-to value.
401            | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
402            | StatementKind::AscribeUserType(..)
403            | StatementKind::Coverage(..)
404            | StatementKind::FakeRead(..)
405            | StatementKind::ConstEvalCounter
406            | StatementKind::PlaceMention(..)
407            | StatementKind::BackwardIncompatibleDropHint { .. }
408            | StatementKind::Nop => None,
409        }
410    }
411
412    #[instrument(level = "trace", skip(self, state))]
413    fn process_immediate(&mut self, lhs: PlaceIndex, rhs: ImmTy<'tcx>, state: &mut ConditionSet) {
414        if let Some(lhs) = self.value(lhs)
415            && let Immediate::Scalar(Scalar::Int(int)) = *rhs
416        {
417            state.fulfill_matches(lhs, int)
418        }
419    }
420
421    /// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
422    #[instrument(level = "trace", skip(self, state))]
423    fn process_constant(
424        &mut self,
425        lhs: PlaceIndex,
426        constant: OpTy<'tcx>,
427        state: &mut ConditionSet,
428    ) {
429        self.map.for_each_projection_value(
430            lhs,
431            constant,
432            &mut |elem, op| match elem {
433                TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(),
434                TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
435                TrackElem::Discriminant => {
436                    let variant = self.ecx.read_discriminant(op).discard_err()?;
437                    let discr_value =
438                        self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
439                    Some(discr_value.into())
440                }
441                TrackElem::DerefLen => {
442                    let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
443                    let len_usize = op.len(&self.ecx).discard_err()?;
444                    let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
445                    Some(ImmTy::from_uint(len_usize, layout).into())
446                }
447            },
448            &mut |place, op| {
449                if let Some(place) = self.map.value(place)
450                    && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
451                    && let Some(imm) = imm.right()
452                    && let Immediate::Scalar(Scalar::Int(int)) = *imm
453                {
454                    state.fulfill_matches(place, int)
455                }
456            },
457        );
458    }
459
460    #[instrument(level = "trace", skip(self, state))]
461    fn process_copy(&mut self, lhs: PlaceIndex, rhs: PlaceIndex, state: &mut ConditionSet) {
462        let mut renames = FxHashMap::default();
463        self.map.register_copy_tree(
464            lhs, // tree to copy
465            rhs, // tree to build
466            &mut |lhs, rhs| {
467                renames.insert(lhs, rhs);
468            },
469        );
470        state.for_each_mut(|c| {
471            if let Some(rhs) = renames.get(&c.place) {
472                c.place = *rhs
473            }
474        });
475    }
476
477    #[instrument(level = "trace", skip(self, state))]
478    fn process_operand(&mut self, lhs: PlaceIndex, rhs: &Operand<'tcx>, state: &mut ConditionSet) {
479        match rhs {
480            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
481            Operand::Constant(constant) => {
482                let Some(constant) =
483                    self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
484                else {
485                    return;
486                };
487                self.process_constant(lhs, constant, state);
488            }
489            // Transfer the conditions on the copied rhs.
490            Operand::Move(rhs) | Operand::Copy(rhs) => {
491                let Some(rhs) = self.place(*rhs, None) else { return };
492                self.process_copy(lhs, rhs, state)
493            }
494            Operand::RuntimeChecks(_) => {}
495        }
496    }
497
498    #[instrument(level = "trace", skip(self, state))]
499    fn process_assign(
500        &mut self,
501        lhs_place: &Place<'tcx>,
502        rvalue: &Rvalue<'tcx>,
503        state: &mut ConditionSet,
504    ) {
505        let Some(lhs) = self.place(*lhs_place, None) else { return };
506        match rvalue {
507            Rvalue::Use(operand) => self.process_operand(lhs, operand, state),
508            // Transfer the conditions on the copy rhs.
509            Rvalue::Discriminant(rhs) => {
510                let Some(rhs) = self.place(*rhs, Some(TrackElem::Discriminant)) else { return };
511                self.process_copy(lhs, rhs, state)
512            }
513            // If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
514            Rvalue::Aggregate(box kind, operands) => {
515                let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
516                let lhs = match kind {
517                    // Do not support unions.
518                    AggregateKind::Adt(.., Some(_)) => return,
519                    AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
520                        let discr_ty = agg_ty.discriminant_ty(self.tcx);
521                        let discr_target =
522                            self.map.register_place_index(discr_ty, lhs, TrackElem::Discriminant);
523                        if let Some(discr_value) =
524                            self.ecx.discriminant_for_variant(agg_ty, *variant_index).discard_err()
525                        {
526                            self.process_immediate(discr_target, discr_value, state);
527                        }
528                        self.map.register_place_index(
529                            agg_ty,
530                            lhs,
531                            TrackElem::Variant(*variant_index),
532                        )
533                    }
534                    _ => lhs,
535                };
536                for (field_index, operand) in operands.iter_enumerated() {
537                    let operand_ty = operand.ty(self.body, self.tcx);
538                    let field = self.map.register_place_index(
539                        operand_ty,
540                        lhs,
541                        TrackElem::Field(field_index),
542                    );
543                    self.process_operand(field, operand, state);
544                }
545            }
546            // Transfer the conditions on the copy rhs, after inverting the value of the condition.
547            Rvalue::UnaryOp(UnOp::Not, Operand::Move(operand) | Operand::Copy(operand)) => {
548                let layout = self.ecx.layout_of(operand.ty(self.body, self.tcx).ty).unwrap();
549                let Some(lhs) = self.value(lhs) else { return };
550                let Some(operand) = self.place_value(*operand, None) else { return };
551                state.retain_mut(|mut c| {
552                    if c.place == lhs {
553                        let value = self
554                            .ecx
555                            .unary_op(UnOp::Not, &ImmTy::from_scalar_int(c.value, layout))
556                            .discard_err()?
557                            .to_scalar_int()
558                            .discard_err()?;
559                        c.place = operand;
560                        c.value = value;
561                    }
562                    Some(c)
563                });
564            }
565            // We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
566            // Create a condition on `rhs ?= B`.
567            Rvalue::BinaryOp(
568                op,
569                box (Operand::Move(operand) | Operand::Copy(operand), Operand::Constant(value))
570                | box (Operand::Constant(value), Operand::Move(operand) | Operand::Copy(operand)),
571            ) => {
572                let equals = match op {
573                    BinOp::Eq => ScalarInt::TRUE,
574                    BinOp::Ne => ScalarInt::FALSE,
575                    _ => return,
576                };
577                if value.const_.ty().is_floating_point() {
578                    // Floating point equality does not follow bit-patterns.
579                    // -0.0 and NaN both have special rules for equality,
580                    // and therefore we cannot use integer comparisons for them.
581                    // Avoid handling them, though this could be extended in the future.
582                    return;
583                }
584                let Some(lhs) = self.value(lhs) else { return };
585                let Some(operand) = self.place_value(*operand, None) else { return };
586                let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
587                else {
588                    return;
589                };
590                state.for_each_mut(|c| {
591                    if c.place == lhs {
592                        let polarity =
593                            if c.matches(lhs, equals) { Polarity::Eq } else { Polarity::Ne };
594                        c.place = operand;
595                        c.value = value;
596                        c.polarity = polarity;
597                    }
598                });
599            }
600
601            _ => {}
602        }
603    }
604
605    #[instrument(level = "trace", skip(self, state))]
606    fn process_statement(&mut self, stmt: &Statement<'tcx>, state: &mut ConditionSet) {
607        // Below, `lhs` is the return value of `mutated_statement`,
608        // the place to which `conditions` apply.
609
610        match &stmt.kind {
611            // If we expect `discriminant(place) ?= A`,
612            // we have an opportunity if `variant_index ?= A`.
613            StatementKind::SetDiscriminant { box place, variant_index } => {
614                let Some(discr_target) = self.place(*place, Some(TrackElem::Discriminant)) else {
615                    return;
616                };
617                let enum_ty = place.ty(self.body, self.tcx).ty;
618                // `SetDiscriminant` guarantees that the discriminant is now `variant_index`.
619                // Even if the discriminant write does nothing due to niches, it is UB to set the
620                // discriminant when the data does not encode the desired discriminant.
621                let Some(discr) =
622                    self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
623                else {
624                    return;
625                };
626                self.process_immediate(discr_target, discr, state)
627            }
628            // If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
629            StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
630                Operand::Copy(place) | Operand::Move(place),
631            )) => {
632                let Some(place) = self.place_value(*place, None) else { return };
633                state.fulfill_matches(place, ScalarInt::TRUE);
634            }
635            StatementKind::Assign(box (lhs_place, rhs)) => {
636                self.process_assign(lhs_place, rhs, state)
637            }
638            _ => {}
639        }
640    }
641
642    /// Execute the terminator for block `bb` into state `entry_states[bb]`.
643    #[instrument(level = "trace", skip(self, state))]
644    fn process_terminator(&mut self, bb: BasicBlock, state: &mut ConditionSet) {
645        let term = self.body.basic_blocks[bb].terminator();
646        let place_to_flood = match term.kind {
647            // Disallowed during optimizations.
648            TerminatorKind::FalseEdge { .. }
649            | TerminatorKind::FalseUnwind { .. }
650            | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
651            // Cannot reason about inline asm.
652            TerminatorKind::InlineAsm { .. } => {
653                state.active.clear();
654                return;
655            }
656            // `SwitchInt` is handled specially.
657            TerminatorKind::SwitchInt { ref discr, ref targets } => {
658                return self.process_switch_int(discr, targets, state);
659            }
660            // These do not modify memory.
661            TerminatorKind::UnwindResume
662            | TerminatorKind::UnwindTerminate(_)
663            | TerminatorKind::Return
664            | TerminatorKind::Unreachable
665            | TerminatorKind::CoroutineDrop
666            // Assertions can be no-op at codegen time, so treat them as such.
667            | TerminatorKind::Assert { .. }
668            | TerminatorKind::Goto { .. } => None,
669            // Flood the overwritten place, and progress through.
670            TerminatorKind::Drop { place: destination, .. }
671            | TerminatorKind::Call { destination, .. } => Some(destination),
672            TerminatorKind::TailCall { .. } => Some(RETURN_PLACE.into()),
673        };
674
675        // This terminator modifies `place_to_flood`, cleanup the associated conditions.
676        if let Some(place_to_flood) = place_to_flood {
677            self.flood_state(place_to_flood, None, state);
678        }
679    }
680
681    #[instrument(level = "trace", skip(self))]
682    fn process_switch_int(
683        &mut self,
684        discr: &Operand<'tcx>,
685        targets: &SwitchTargets,
686        state: &mut ConditionSet,
687    ) {
688        let Some(discr) = discr.place() else { return };
689        let Some(discr_idx) = self.place_value(discr, None) else { return };
690
691        let discr_ty = discr.ty(self.body, self.tcx).ty;
692        let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
693
694        // Attempt to fulfill a condition using an outgoing branch's condition.
695        // Only support the case where there are no duplicated outgoing edges.
696        if targets.is_distinct() {
697            for &(index, c) in state.active.iter() {
698                if c.place != discr_idx {
699                    continue;
700                }
701
702                // Set of blocks `t` such that the edge `bb -> t` fulfills `c`.
703                let mut edges_fulfilling_condition = FxHashSet::default();
704
705                // On edge `bb -> tgt`, we know that `discr_idx == branch`.
706                for (branch, tgt) in targets.iter() {
707                    if let Some(branch) = ScalarInt::try_from_uint(branch, discr_layout.size)
708                        && c.matches(discr_idx, branch)
709                    {
710                        edges_fulfilling_condition.insert(tgt);
711                    }
712                }
713
714                // On edge `bb -> otherwise`, we only know that `discr` is different from all the
715                // constants in the switch. That's much weaker information than the equality we
716                // had in the previous arm. All we can conclude is that the replacement condition
717                // `discr != value` can be threaded, and nothing else.
718                if c.polarity == Polarity::Ne
719                    && let Ok(value) = c.value.try_to_bits(discr_layout.size)
720                    && targets.all_values().contains(&value.into())
721                {
722                    edges_fulfilling_condition.insert(targets.otherwise());
723                }
724
725                // Register that jumping to a `t` fulfills condition `c`.
726                // This does *not* mean that `c` is fulfilled in this block: inserting `index` in
727                // `fulfilled` is wrong if we have targets that jump to other blocks.
728                let condition_targets = &state.targets[index];
729
730                let new_edges: Vec<_> = condition_targets
731                    .iter()
732                    .copied()
733                    .filter(|&target| match target {
734                        EdgeEffect::Goto { .. } => false,
735                        EdgeEffect::Chain { succ_block, .. } => {
736                            edges_fulfilling_condition.contains(&succ_block)
737                        }
738                    })
739                    .collect();
740
741                if new_edges.len() == condition_targets.len() {
742                    // If `new_edges == condition_targets`, do not bother creating a new
743                    // `ConditionIndex`, we can use the existing one.
744                    state.fulfilled.push(index);
745                } else {
746                    // Fulfilling `index` may thread conditions that we do not want,
747                    // so create a brand new index to immediately mark fulfilled.
748                    let index = state.targets.push(new_edges);
749                    state.fulfilled.push(index);
750                }
751            }
752        }
753
754        // Introduce additional conditions of the form `discr ?= value` for each value in targets.
755        let mut mk_condition = |value, polarity, target| {
756            let c = Condition { place: discr_idx, value, polarity };
757            state.push_condition(c, target);
758        };
759        if let Some((value, then_, else_)) = targets.as_static_if() {
760            // We have an `if`, generate both `discr == value` and `discr != value`.
761            let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
762            mk_condition(value, Polarity::Eq, then_);
763            mk_condition(value, Polarity::Ne, else_);
764        } else {
765            // We have a general switch and we cannot express `discr != value0 && discr != value1`,
766            // so we only generate equality predicates.
767            for (value, target) in targets.iter() {
768                if let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) {
769                    mk_condition(value, Polarity::Eq, target);
770                }
771            }
772        }
773    }
774}
775
776/// Propagate fulfilled conditions forward in the CFG to reduce the amount of duplication.
777#[instrument(level = "debug", skip(body, entry_states))]
778fn simplify_conditions(body: &Body<'_>, entry_states: &mut IndexVec<BasicBlock, ConditionSet>) {
779    let basic_blocks = &body.basic_blocks;
780    let reverse_postorder = basic_blocks.reverse_postorder();
781
782    // Start by computing the number of *incoming edges* for each block.
783    // We do not use the cached `basic_blocks.predecessors` as we only want reachable predecessors.
784    let mut predecessors = IndexVec::from_elem(0, &entry_states);
785    predecessors[START_BLOCK] = 1; // Account for the implicit entry edge.
786    for &bb in reverse_postorder {
787        let term = basic_blocks[bb].terminator();
788        for s in term.successors() {
789            predecessors[s] += 1;
790        }
791    }
792
793    // Compute the number of edges into each block that carry each condition.
794    let mut fulfill_in_pred_count = IndexVec::from_fn_n(
795        |bb: BasicBlock| IndexVec::from_elem_n(0, entry_states[bb].targets.len()),
796        entry_states.len(),
797    );
798
799    // By traversing in RPO, we increase the likelihood to visit predecessors before successors.
800    for &bb in reverse_postorder {
801        let preds = predecessors[bb];
802        trace!(?bb, ?preds);
803
804        // We have removed all the input edges towards this block. Just skip visiting it.
805        if preds == 0 {
806            continue;
807        }
808
809        let state = &mut entry_states[bb];
810        trace!(?state);
811
812        // Conditions that are fulfilled in all the predecessors, are fulfilled in `bb`.
813        trace!(fulfilled_count = ?fulfill_in_pred_count[bb]);
814        for (condition, &cond_preds) in fulfill_in_pred_count[bb].iter_enumerated() {
815            if cond_preds == preds {
816                trace!(?condition);
817                state.fulfilled.push(condition);
818            }
819        }
820
821        // We want to count how many times each condition is fulfilled,
822        // so ensure we are not counting the same edge twice.
823        let mut targets: Vec<_> = state
824            .fulfilled
825            .iter()
826            .flat_map(|&index| state.targets[index].iter().copied())
827            .collect();
828        targets.sort();
829        targets.dedup();
830        trace!(?targets);
831
832        // We may modify the set of successors by applying edges, so track them here.
833        let mut successors = basic_blocks[bb].terminator().successors().collect::<Vec<_>>();
834
835        targets.reverse();
836        while let Some(target) = targets.pop() {
837            match target {
838                EdgeEffect::Goto { target } => {
839                    // We update the count of predecessors. If target or any successor has not been
840                    // processed yet, this increases the likelihood we find something relevant.
841                    predecessors[target] += 1;
842                    for &s in successors.iter() {
843                        predecessors[s] -= 1;
844                    }
845                    // Only process edges that still exist.
846                    targets.retain(|t| t.block() == target);
847                    successors.clear();
848                    successors.push(target);
849                }
850                EdgeEffect::Chain { succ_block, succ_condition } => {
851                    // `predecessors` is the number of incoming *edges* in each block.
852                    // Count the number of edges that apply `succ_condition` into `succ_block`.
853                    let count = successors.iter().filter(|&&s| s == succ_block).count();
854                    fulfill_in_pred_count[succ_block][succ_condition] += count;
855                }
856            }
857        }
858    }
859}
860
861#[instrument(level = "debug", skip(tcx, typing_env, body, entry_states))]
862fn remove_costly_conditions<'tcx>(
863    tcx: TyCtxt<'tcx>,
864    typing_env: ty::TypingEnv<'tcx>,
865    body: &Body<'tcx>,
866    entry_states: &mut IndexVec<BasicBlock, ConditionSet>,
867) {
868    let basic_blocks = &body.basic_blocks;
869
870    let mut costs = IndexVec::from_elem(None, basic_blocks);
871    let mut cost = |bb: BasicBlock| -> u8 {
872        let c = *costs[bb].get_or_insert_with(|| {
873            let bbdata = &basic_blocks[bb];
874            let mut cost = CostChecker::new(tcx, typing_env, None, body);
875            cost.visit_basic_block_data(bb, bbdata);
876            cost.cost().try_into().unwrap_or(MAX_COST)
877        });
878        trace!("cost[{bb:?}] = {c}");
879        c
880    };
881
882    // Initialize costs with `MAX_COST`: if we have a cycle, the cyclic `bb` has infinite costs.
883    let mut condition_cost = IndexVec::from_fn_n(
884        |bb: BasicBlock| IndexVec::from_elem_n(MAX_COST, entry_states[bb].targets.len()),
885        entry_states.len(),
886    );
887
888    let reverse_postorder = basic_blocks.reverse_postorder();
889
890    for &bb in reverse_postorder.iter().rev() {
891        let state = &entry_states[bb];
892        trace!(?bb, ?state);
893
894        let mut current_costs = IndexVec::from_elem(0u8, &state.targets);
895
896        for (condition, targets) in state.targets.iter_enumerated() {
897            for &target in targets {
898                match target {
899                    // A `Goto` has cost 0.
900                    EdgeEffect::Goto { .. } => {}
901                    // Chaining into an already-fulfilled condition is nop.
902                    EdgeEffect::Chain { succ_block, succ_condition }
903                        if entry_states[succ_block].fulfilled.contains(&succ_condition) => {}
904                    // When chaining, use `cost[succ_block][succ_condition] + cost(succ_block)`.
905                    EdgeEffect::Chain { succ_block, succ_condition } => {
906                        // Cost associated with duplicating `succ_block`.
907                        let duplication_cost = cost(succ_block);
908                        // Cost associated with the rest of the chain.
909                        let target_cost =
910                            *condition_cost[succ_block].get(succ_condition).unwrap_or(&MAX_COST);
911                        let cost = current_costs[condition]
912                            .saturating_add(duplication_cost)
913                            .saturating_add(target_cost);
914                        trace!(?condition, ?succ_block, ?duplication_cost, ?target_cost);
915                        current_costs[condition] = cost;
916                    }
917                }
918            }
919        }
920
921        trace!("condition_cost[{bb:?}] = {:?}", current_costs);
922        condition_cost[bb] = current_costs;
923    }
924
925    trace!(?condition_cost);
926
927    for &bb in reverse_postorder {
928        for (index, targets) in entry_states[bb].targets.iter_enumerated_mut() {
929            if condition_cost[bb][index] >= MAX_COST {
930                trace!(?bb, ?index, ?targets, c = ?condition_cost[bb][index], "remove");
931                targets.clear()
932            }
933        }
934    }
935}
936
937struct OpportunitySet<'a, 'tcx> {
938    basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
939    entry_states: IndexVec<BasicBlock, ConditionSet>,
940    /// Cache duplicated block. When cloning a basic block `bb` to fulfill a condition `c`,
941    /// record the target of this `bb with c` edge.
942    duplicates: FxHashMap<(BasicBlock, ConditionIndex), BasicBlock>,
943}
944
945impl<'a, 'tcx> OpportunitySet<'a, 'tcx> {
946    fn new(
947        body: &'a mut Body<'tcx>,
948        mut entry_states: IndexVec<BasicBlock, ConditionSet>,
949    ) -> Option<OpportunitySet<'a, 'tcx>> {
950        trace!(def_id = ?body.source.def_id(), "apply");
951
952        if entry_states.iter().all(|state| state.fulfilled.is_empty()) {
953            return None;
954        }
955
956        // Free some memory, because we will need to clone condition sets.
957        for state in entry_states.iter_mut() {
958            state.active = Default::default();
959        }
960        let duplicates = Default::default();
961        let basic_blocks = body.basic_blocks.as_mut();
962        Some(OpportunitySet { basic_blocks, entry_states, duplicates })
963    }
964
965    /// Apply the opportunities on the graph.
966    #[instrument(level = "debug", skip(self))]
967    fn apply(mut self) {
968        let mut worklist = Vec::with_capacity(self.basic_blocks.len());
969        worklist.push(START_BLOCK);
970
971        // Use a `GrowableBitSet` and not a `DenseBitSet` as we are adding blocks.
972        let mut visited = GrowableBitSet::with_capacity(self.basic_blocks.len());
973
974        while let Some(bb) = worklist.pop() {
975            if !visited.insert(bb) {
976                continue;
977            }
978
979            self.apply_once(bb);
980
981            // `apply_once` may have modified the terminator of `bb`.
982            // Only visit actual successors.
983            worklist.extend(self.basic_blocks[bb].terminator().successors());
984        }
985    }
986
987    /// Apply the opportunities on `bb`.
988    #[instrument(level = "debug", skip(self))]
989    fn apply_once(&mut self, bb: BasicBlock) {
990        let state = &mut self.entry_states[bb];
991        trace!(?state);
992
993        // We are modifying the `bb` in-place. Once a `EdgeEffect` has been applied,
994        // it does not need to be applied again.
995        let mut targets: Vec<_> = state
996            .fulfilled
997            .iter()
998            .flat_map(|&index| std::mem::take(&mut state.targets[index]))
999            .collect();
1000        targets.sort();
1001        targets.dedup();
1002        trace!(?targets);
1003
1004        // Use a while-pop to allow modifying `targets` from inside the loop.
1005        targets.reverse();
1006        while let Some(target) = targets.pop() {
1007            debug!(?target);
1008            trace!(term = ?self.basic_blocks[bb].terminator().kind);
1009
1010            // By construction, `target.block()` is a successor of `bb`.
1011            // When applying targets, we may change the set of successors.
1012            // The match below updates the set of targets for consistency.
1013            debug_assert!(
1014                self.basic_blocks[bb].terminator().successors().contains(&target.block()),
1015                "missing {target:?} in successors for {bb:?}, term={:?}",
1016                self.basic_blocks[bb].terminator(),
1017            );
1018
1019            match target {
1020                EdgeEffect::Goto { target } => {
1021                    self.apply_goto(bb, target);
1022
1023                    // We now have `target` as single successor. Drop all other target blocks.
1024                    targets.retain(|t| t.block() == target);
1025                    // Also do this on targets that may be applied by a duplicate of `bb`.
1026                    for ts in self.entry_states[bb].targets.iter_mut() {
1027                        ts.retain(|t| t.block() == target);
1028                    }
1029                }
1030                EdgeEffect::Chain { succ_block, succ_condition } => {
1031                    let new_succ_block = self.apply_chain(bb, succ_block, succ_condition);
1032
1033                    // We have a new name for `target`, ensure it is correctly applied.
1034                    if let Some(new_succ_block) = new_succ_block {
1035                        for t in targets.iter_mut() {
1036                            t.replace_block(succ_block, new_succ_block)
1037                        }
1038                        // Also do this on targets that may be applied by a duplicate of `bb`.
1039                        for t in
1040                            self.entry_states[bb].targets.iter_mut().flat_map(|ts| ts.iter_mut())
1041                        {
1042                            t.replace_block(succ_block, new_succ_block)
1043                        }
1044                    }
1045                }
1046            }
1047
1048            trace!(post_term = ?self.basic_blocks[bb].terminator().kind);
1049        }
1050    }
1051
1052    #[instrument(level = "debug", skip(self))]
1053    fn apply_goto(&mut self, bb: BasicBlock, target: BasicBlock) {
1054        self.basic_blocks[bb].terminator_mut().kind = TerminatorKind::Goto { target };
1055    }
1056
1057    #[instrument(level = "debug", skip(self), ret)]
1058    fn apply_chain(
1059        &mut self,
1060        bb: BasicBlock,
1061        target: BasicBlock,
1062        condition: ConditionIndex,
1063    ) -> Option<BasicBlock> {
1064        if self.entry_states[target].fulfilled.contains(&condition) {
1065            // `target` already fulfills `condition`, so we do not need to thread anything.
1066            trace!("fulfilled");
1067            return None;
1068        }
1069
1070        // We may be tempted to modify `target` in-place to avoid a clone. This is wrong.
1071        // We may still have edges from other blocks to `target` that have not been created yet.
1072        // For instance because we may be threading an edge coming from `bb`,
1073        // or `target` may be a block duplicate for which we may still create predecessors.
1074
1075        let new_target = *self.duplicates.entry((target, condition)).or_insert_with(|| {
1076            // If we already have a duplicate of `target` which fulfills `condition`, reuse it.
1077            // Otherwise, we clone a new bb to such ends.
1078            let new_target = self.basic_blocks.push(self.basic_blocks[target].clone());
1079            trace!(?target, ?new_target, ?condition, "clone");
1080
1081            // By definition, `new_target` fulfills the same condition as `target`, with
1082            // `condition` added.
1083            let mut condition_set = self.entry_states[target].clone();
1084            condition_set.fulfilled.push(condition);
1085            let _new_target = self.entry_states.push(condition_set);
1086            debug_assert_eq!(new_target, _new_target);
1087
1088            new_target
1089        });
1090        trace!(?target, ?new_target, ?condition, "reuse");
1091
1092        // Replace `target` by `new_target` where it appears.
1093        // This changes exactly `direct_count` edges.
1094        self.basic_blocks[bb].terminator_mut().successors_mut(|s| {
1095            if *s == target {
1096                *s = new_target;
1097            }
1098        });
1099
1100        Some(new_target)
1101    }
1102}