1use itertools::Itertools as _;
55use rustc_const_eval::const_eval::DummyMachine;
56use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
57use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
58use rustc_index::IndexVec;
59use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
60use rustc_middle::bug;
61use rustc_middle::mir::interpret::Scalar;
62use rustc_middle::mir::visit::Visitor;
63use rustc_middle::mir::*;
64use rustc_middle::ty::{self, ScalarInt, TyCtxt};
65use rustc_mir_dataflow::value_analysis::{
66 Map, PlaceCollectionMode, PlaceIndex, TrackElem, ValueIndex,
67};
68use rustc_span::DUMMY_SP;
69use tracing::{debug, instrument, trace};
70
71use crate::cost_checker::CostChecker;
72
73pub(super) struct JumpThreading;
74
75const MAX_COST: u8 = 100;
76
77impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
78 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
79 sess.mir_opt_level() >= 2
80 }
81
82 #[instrument(skip_all level = "debug")]
83 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
84 let def_id = body.source.def_id();
85 debug!(?def_id);
86
87 if tcx.is_coroutine(def_id) {
89 trace!("Skipped for coroutine {:?}", def_id);
90 return;
91 }
92
93 let typing_env = body.typing_env(tcx);
94 let mut finder = TOFinder {
95 tcx,
96 typing_env,
97 ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
98 body,
99 map: Map::new(tcx, body, PlaceCollectionMode::OnDemand),
100 maybe_loop_headers: loops::maybe_loop_headers(body),
101 entry_states: IndexVec::from_elem(ConditionSet::default(), &body.basic_blocks),
102 };
103
104 for (bb, bbdata) in traversal::postorder(body) {
105 if bbdata.is_cleanup {
106 continue;
107 }
108
109 let mut state = finder.populate_from_outgoing_edges(bb);
110 trace!("output_states[{bb:?}] = {state:?}");
111
112 finder.process_terminator(bb, &mut state);
113 trace!("pre_terminator_states[{bb:?}] = {state:?}");
114
115 for stmt in bbdata.statements.iter().rev() {
116 if state.is_empty() {
117 break;
118 }
119
120 finder.process_statement(stmt, &mut state);
121
122 if let Some((lhs, tail)) = finder.mutated_statement(stmt) {
127 finder.flood_state(lhs, tail, &mut state);
128 }
129 }
130
131 trace!("entry_states[{bb:?}] = {state:?}");
132 finder.entry_states[bb] = state;
133 }
134
135 let mut entry_states = finder.entry_states;
136 simplify_conditions(body, &mut entry_states);
137 remove_costly_conditions(tcx, typing_env, body, &mut entry_states);
138
139 if let Some(opportunities) = OpportunitySet::new(body, entry_states) {
140 opportunities.apply();
141 }
142 }
143
144 fn is_required(&self) -> bool {
145 false
146 }
147}
148
149struct TOFinder<'a, 'tcx> {
150 tcx: TyCtxt<'tcx>,
151 typing_env: ty::TypingEnv<'tcx>,
152 ecx: InterpCx<'tcx, DummyMachine>,
153 body: &'a Body<'tcx>,
154 map: Map<'tcx>,
155 maybe_loop_headers: DenseBitSet<BasicBlock>,
156 entry_states: IndexVec<BasicBlock, ConditionSet>,
161}
162
163rustc_index::newtype_index! {
164 #[derive(Ord, PartialOrd)]
165 #[debug_format = "_c{}"]
166 struct ConditionIndex {}
167}
168
169#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
172struct Condition {
173 place: ValueIndex,
174 value: ScalarInt,
175 polarity: Polarity,
176}
177
178#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
179enum Polarity {
180 Ne,
181 Eq,
182}
183
184impl Condition {
185 fn matches(&self, place: ValueIndex, value: ScalarInt) -> bool {
186 self.place == place && (self.value == value) == (self.polarity == Polarity::Eq)
187 }
188}
189
190#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
192enum EdgeEffect {
193 Goto { target: BasicBlock },
195 Chain { succ_block: BasicBlock, succ_condition: ConditionIndex },
197}
198
199impl EdgeEffect {
200 fn block(self) -> BasicBlock {
201 match self {
202 EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => bb,
203 }
204 }
205
206 fn replace_block(&mut self, target: BasicBlock, new_target: BasicBlock) {
207 match self {
208 EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => {
209 if *bb == target {
210 *bb = new_target
211 }
212 }
213 }
214 }
215}
216
217#[derive(Clone, Debug, Default)]
218struct ConditionSet {
219 active: Vec<(ConditionIndex, Condition)>,
220 fulfilled: Vec<ConditionIndex>,
221 targets: IndexVec<ConditionIndex, Vec<EdgeEffect>>,
222}
223
224impl ConditionSet {
225 fn is_empty(&self) -> bool {
226 self.active.is_empty()
227 }
228
229 #[tracing::instrument(level = "trace", skip(self))]
230 fn push_condition(&mut self, c: Condition, target: BasicBlock) {
231 let index = self.targets.push(vec![EdgeEffect::Goto { target }]);
232 self.active.push((index, c));
233 }
234
235 fn fulfill_if(&mut self, f: impl Fn(Condition, &Vec<EdgeEffect>) -> bool) {
237 self.active.retain(|&(index, condition)| {
238 let targets = &self.targets[index];
239 if f(condition, targets) {
240 trace!(?index, ?condition, "fulfill");
241 self.fulfilled.push(index);
242 false
243 } else {
244 true
245 }
246 })
247 }
248
249 fn fulfill_matches(&mut self, place: ValueIndex, value: ScalarInt) {
251 self.fulfill_if(|c, _| c.matches(place, value))
252 }
253
254 fn retain(&mut self, mut f: impl FnMut(Condition) -> bool) {
255 self.active.retain(|&(_, c)| f(c))
256 }
257
258 fn retain_mut(&mut self, mut f: impl FnMut(Condition) -> Option<Condition>) {
259 self.active.retain_mut(|(_, c)| {
260 if let Some(new) = f(*c) {
261 *c = new;
262 true
263 } else {
264 false
265 }
266 })
267 }
268
269 fn for_each_mut(&mut self, f: impl Fn(&mut Condition)) {
270 for (_, c) in &mut self.active {
271 f(c)
272 }
273 }
274}
275
276impl<'a, 'tcx> TOFinder<'a, 'tcx> {
277 fn place(&mut self, place: Place<'tcx>, tail: Option<TrackElem>) -> Option<PlaceIndex> {
278 self.map.register_place(self.tcx, self.body, place, tail)
279 }
280
281 fn value(&mut self, place: PlaceIndex) -> Option<ValueIndex> {
282 self.map.register_value(self.tcx, self.typing_env, place)
283 }
284
285 fn place_value(&mut self, place: Place<'tcx>, tail: Option<TrackElem>) -> Option<ValueIndex> {
286 let place = self.place(place, tail)?;
287 self.value(place)
288 }
289
290 #[instrument(level = "trace", skip(self))]
292 fn populate_from_outgoing_edges(&mut self, bb: BasicBlock) -> ConditionSet {
293 let bbdata = &self.body[bb];
294
295 debug_assert!(self.entry_states[bb].is_empty());
297
298 let state_len =
299 bbdata.terminator().successors().map(|succ| self.entry_states[succ].active.len()).sum();
300 let mut state = ConditionSet {
301 active: Vec::with_capacity(state_len),
302 targets: IndexVec::with_capacity(state_len),
303 fulfilled: Vec::new(),
304 };
305
306 let mut known_conditions =
308 FxIndexSet::with_capacity_and_hasher(state_len, Default::default());
309 let mut insert = |condition, succ_block, succ_condition| {
310 let (index, new) = known_conditions.insert_full(condition);
311 let index = ConditionIndex::from_usize(index);
312 if new {
313 state.active.push((index, condition));
314 let _index = state.targets.push(Vec::new());
315 debug_assert_eq!(_index, index);
316 }
317 let target = EdgeEffect::Chain { succ_block, succ_condition };
318 debug_assert!(
319 !state.targets[index].contains(&target),
320 "duplicate targets for index={index:?} as {target:?} targets={:#?}",
321 &state.targets[index],
322 );
323 state.targets[index].push(target);
324 };
325
326 let mut seen = FxHashSet::default();
328 for succ in bbdata.terminator().successors() {
329 if !seen.insert(succ) {
330 continue;
331 }
332
333 if self.maybe_loop_headers.contains(succ) {
335 continue;
336 }
337
338 for &(succ_index, cond) in self.entry_states[succ].active.iter() {
339 insert(cond, succ, succ_index);
340 }
341 }
342
343 let num_conditions = known_conditions.len();
344 debug_assert_eq!(num_conditions, state.active.len());
345 debug_assert_eq!(num_conditions, state.targets.len());
346 state.fulfilled.reserve(num_conditions);
347
348 state
349 }
350
351 fn flood_state(
353 &self,
354 place: Place<'tcx>,
355 extra_elem: Option<TrackElem>,
356 state: &mut ConditionSet,
357 ) {
358 if state.is_empty() {
359 return;
360 }
361 let mut places_to_exclude = FxHashSet::default();
362 self.map.for_each_aliasing_place(place.as_ref(), extra_elem, &mut |vi| {
363 places_to_exclude.insert(vi);
364 });
365 trace!(?places_to_exclude, "flood_state");
366 if places_to_exclude.is_empty() {
367 return;
368 }
369 state.retain(|c| !places_to_exclude.contains(&c.place));
370 }
371
372 #[instrument(level = "trace", skip(self), ret)]
386 fn mutated_statement(
387 &self,
388 stmt: &Statement<'tcx>,
389 ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
390 match stmt.kind {
391 StatementKind::Assign(box (place, _)) => Some((place, None)),
392 StatementKind::SetDiscriminant { box place, variant_index: _ } => {
393 Some((place, Some(TrackElem::Discriminant)))
394 }
395 StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
396 Some((Place::from(local), None))
397 }
398 StatementKind::Retag(..)
399 | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
400 | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
402 | StatementKind::AscribeUserType(..)
403 | StatementKind::Coverage(..)
404 | StatementKind::FakeRead(..)
405 | StatementKind::ConstEvalCounter
406 | StatementKind::PlaceMention(..)
407 | StatementKind::BackwardIncompatibleDropHint { .. }
408 | StatementKind::Nop => None,
409 }
410 }
411
412 #[instrument(level = "trace", skip(self, state))]
413 fn process_immediate(&mut self, lhs: PlaceIndex, rhs: ImmTy<'tcx>, state: &mut ConditionSet) {
414 if let Some(lhs) = self.value(lhs)
415 && let Immediate::Scalar(Scalar::Int(int)) = *rhs
416 {
417 state.fulfill_matches(lhs, int)
418 }
419 }
420
421 #[instrument(level = "trace", skip(self, state))]
423 fn process_constant(
424 &mut self,
425 lhs: PlaceIndex,
426 constant: OpTy<'tcx>,
427 state: &mut ConditionSet,
428 ) {
429 self.map.for_each_projection_value(
430 lhs,
431 constant,
432 &mut |elem, op| match elem {
433 TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(),
434 TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
435 TrackElem::Discriminant => {
436 let variant = self.ecx.read_discriminant(op).discard_err()?;
437 let discr_value =
438 self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
439 Some(discr_value.into())
440 }
441 TrackElem::DerefLen => {
442 let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
443 let len_usize = op.len(&self.ecx).discard_err()?;
444 let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
445 Some(ImmTy::from_uint(len_usize, layout).into())
446 }
447 },
448 &mut |place, op| {
449 if let Some(place) = self.map.value(place)
450 && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
451 && let Some(imm) = imm.right()
452 && let Immediate::Scalar(Scalar::Int(int)) = *imm
453 {
454 state.fulfill_matches(place, int)
455 }
456 },
457 );
458 }
459
460 #[instrument(level = "trace", skip(self, state))]
461 fn process_copy(&mut self, lhs: PlaceIndex, rhs: PlaceIndex, state: &mut ConditionSet) {
462 let mut renames = FxHashMap::default();
463 self.map.register_copy_tree(
464 lhs, rhs, &mut |lhs, rhs| {
467 renames.insert(lhs, rhs);
468 },
469 );
470 state.for_each_mut(|c| {
471 if let Some(rhs) = renames.get(&c.place) {
472 c.place = *rhs
473 }
474 });
475 }
476
477 #[instrument(level = "trace", skip(self, state))]
478 fn process_operand(&mut self, lhs: PlaceIndex, rhs: &Operand<'tcx>, state: &mut ConditionSet) {
479 match rhs {
480 Operand::Constant(constant) => {
482 let Some(constant) =
483 self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
484 else {
485 return;
486 };
487 self.process_constant(lhs, constant, state);
488 }
489 Operand::Move(rhs) | Operand::Copy(rhs) => {
491 let Some(rhs) = self.place(*rhs, None) else { return };
492 self.process_copy(lhs, rhs, state)
493 }
494 Operand::RuntimeChecks(_) => {}
495 }
496 }
497
498 #[instrument(level = "trace", skip(self, state))]
499 fn process_assign(
500 &mut self,
501 lhs_place: &Place<'tcx>,
502 rvalue: &Rvalue<'tcx>,
503 state: &mut ConditionSet,
504 ) {
505 let Some(lhs) = self.place(*lhs_place, None) else { return };
506 match rvalue {
507 Rvalue::Use(operand) => self.process_operand(lhs, operand, state),
508 Rvalue::Discriminant(rhs) => {
510 let Some(rhs) = self.place(*rhs, Some(TrackElem::Discriminant)) else { return };
511 self.process_copy(lhs, rhs, state)
512 }
513 Rvalue::Aggregate(box kind, operands) => {
515 let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
516 let lhs = match kind {
517 AggregateKind::Adt(.., Some(_)) => return,
519 AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
520 let discr_ty = agg_ty.discriminant_ty(self.tcx);
521 let discr_target =
522 self.map.register_place_index(discr_ty, lhs, TrackElem::Discriminant);
523 if let Some(discr_value) =
524 self.ecx.discriminant_for_variant(agg_ty, *variant_index).discard_err()
525 {
526 self.process_immediate(discr_target, discr_value, state);
527 }
528 self.map.register_place_index(
529 agg_ty,
530 lhs,
531 TrackElem::Variant(*variant_index),
532 )
533 }
534 _ => lhs,
535 };
536 for (field_index, operand) in operands.iter_enumerated() {
537 let operand_ty = operand.ty(self.body, self.tcx);
538 let field = self.map.register_place_index(
539 operand_ty,
540 lhs,
541 TrackElem::Field(field_index),
542 );
543 self.process_operand(field, operand, state);
544 }
545 }
546 Rvalue::UnaryOp(UnOp::Not, Operand::Move(operand) | Operand::Copy(operand)) => {
548 let layout = self.ecx.layout_of(operand.ty(self.body, self.tcx).ty).unwrap();
549 let Some(lhs) = self.value(lhs) else { return };
550 let Some(operand) = self.place_value(*operand, None) else { return };
551 state.retain_mut(|mut c| {
552 if c.place == lhs {
553 let value = self
554 .ecx
555 .unary_op(UnOp::Not, &ImmTy::from_scalar_int(c.value, layout))
556 .discard_err()?
557 .to_scalar_int()
558 .discard_err()?;
559 c.place = operand;
560 c.value = value;
561 }
562 Some(c)
563 });
564 }
565 Rvalue::BinaryOp(
568 op,
569 box (Operand::Move(operand) | Operand::Copy(operand), Operand::Constant(value))
570 | box (Operand::Constant(value), Operand::Move(operand) | Operand::Copy(operand)),
571 ) => {
572 let equals = match op {
573 BinOp::Eq => ScalarInt::TRUE,
574 BinOp::Ne => ScalarInt::FALSE,
575 _ => return,
576 };
577 if value.const_.ty().is_floating_point() {
578 return;
583 }
584 let Some(lhs) = self.value(lhs) else { return };
585 let Some(operand) = self.place_value(*operand, None) else { return };
586 let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
587 else {
588 return;
589 };
590 state.for_each_mut(|c| {
591 if c.place == lhs {
592 let polarity =
593 if c.matches(lhs, equals) { Polarity::Eq } else { Polarity::Ne };
594 c.place = operand;
595 c.value = value;
596 c.polarity = polarity;
597 }
598 });
599 }
600
601 _ => {}
602 }
603 }
604
605 #[instrument(level = "trace", skip(self, state))]
606 fn process_statement(&mut self, stmt: &Statement<'tcx>, state: &mut ConditionSet) {
607 match &stmt.kind {
611 StatementKind::SetDiscriminant { box place, variant_index } => {
614 let Some(discr_target) = self.place(*place, Some(TrackElem::Discriminant)) else {
615 return;
616 };
617 let enum_ty = place.ty(self.body, self.tcx).ty;
618 let Some(discr) =
622 self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
623 else {
624 return;
625 };
626 self.process_immediate(discr_target, discr, state)
627 }
628 StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
630 Operand::Copy(place) | Operand::Move(place),
631 )) => {
632 let Some(place) = self.place_value(*place, None) else { return };
633 state.fulfill_matches(place, ScalarInt::TRUE);
634 }
635 StatementKind::Assign(box (lhs_place, rhs)) => {
636 self.process_assign(lhs_place, rhs, state)
637 }
638 _ => {}
639 }
640 }
641
642 #[instrument(level = "trace", skip(self, state))]
644 fn process_terminator(&mut self, bb: BasicBlock, state: &mut ConditionSet) {
645 let term = self.body.basic_blocks[bb].terminator();
646 let place_to_flood = match term.kind {
647 TerminatorKind::FalseEdge { .. }
649 | TerminatorKind::FalseUnwind { .. }
650 | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
651 TerminatorKind::InlineAsm { .. } => {
653 state.active.clear();
654 return;
655 }
656 TerminatorKind::SwitchInt { ref discr, ref targets } => {
658 return self.process_switch_int(discr, targets, state);
659 }
660 TerminatorKind::UnwindResume
662 | TerminatorKind::UnwindTerminate(_)
663 | TerminatorKind::Return
664 | TerminatorKind::Unreachable
665 | TerminatorKind::CoroutineDrop
666 | TerminatorKind::Assert { .. }
668 | TerminatorKind::Goto { .. } => None,
669 TerminatorKind::Drop { place: destination, .. }
671 | TerminatorKind::Call { destination, .. } => Some(destination),
672 TerminatorKind::TailCall { .. } => Some(RETURN_PLACE.into()),
673 };
674
675 if let Some(place_to_flood) = place_to_flood {
677 self.flood_state(place_to_flood, None, state);
678 }
679 }
680
681 #[instrument(level = "trace", skip(self))]
682 fn process_switch_int(
683 &mut self,
684 discr: &Operand<'tcx>,
685 targets: &SwitchTargets,
686 state: &mut ConditionSet,
687 ) {
688 let Some(discr) = discr.place() else { return };
689 let Some(discr_idx) = self.place_value(discr, None) else { return };
690
691 let discr_ty = discr.ty(self.body, self.tcx).ty;
692 let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
693
694 if targets.is_distinct() {
697 for &(index, c) in state.active.iter() {
698 if c.place != discr_idx {
699 continue;
700 }
701
702 let mut edges_fulfilling_condition = FxHashSet::default();
704
705 for (branch, tgt) in targets.iter() {
707 if let Some(branch) = ScalarInt::try_from_uint(branch, discr_layout.size)
708 && c.matches(discr_idx, branch)
709 {
710 edges_fulfilling_condition.insert(tgt);
711 }
712 }
713
714 if c.polarity == Polarity::Ne
719 && let Ok(value) = c.value.try_to_bits(discr_layout.size)
720 && targets.all_values().contains(&value.into())
721 {
722 edges_fulfilling_condition.insert(targets.otherwise());
723 }
724
725 let condition_targets = &state.targets[index];
729
730 let new_edges: Vec<_> = condition_targets
731 .iter()
732 .copied()
733 .filter(|&target| match target {
734 EdgeEffect::Goto { .. } => false,
735 EdgeEffect::Chain { succ_block, .. } => {
736 edges_fulfilling_condition.contains(&succ_block)
737 }
738 })
739 .collect();
740
741 if new_edges.len() == condition_targets.len() {
742 state.fulfilled.push(index);
745 } else {
746 let index = state.targets.push(new_edges);
749 state.fulfilled.push(index);
750 }
751 }
752 }
753
754 let mut mk_condition = |value, polarity, target| {
756 let c = Condition { place: discr_idx, value, polarity };
757 state.push_condition(c, target);
758 };
759 if let Some((value, then_, else_)) = targets.as_static_if() {
760 let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
762 mk_condition(value, Polarity::Eq, then_);
763 mk_condition(value, Polarity::Ne, else_);
764 } else {
765 for (value, target) in targets.iter() {
768 if let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) {
769 mk_condition(value, Polarity::Eq, target);
770 }
771 }
772 }
773 }
774}
775
776#[instrument(level = "debug", skip(body, entry_states))]
778fn simplify_conditions(body: &Body<'_>, entry_states: &mut IndexVec<BasicBlock, ConditionSet>) {
779 let basic_blocks = &body.basic_blocks;
780 let reverse_postorder = basic_blocks.reverse_postorder();
781
782 let mut predecessors = IndexVec::from_elem(0, &entry_states);
785 predecessors[START_BLOCK] = 1; for &bb in reverse_postorder {
787 let term = basic_blocks[bb].terminator();
788 for s in term.successors() {
789 predecessors[s] += 1;
790 }
791 }
792
793 let mut fulfill_in_pred_count = IndexVec::from_fn_n(
795 |bb: BasicBlock| IndexVec::from_elem_n(0, entry_states[bb].targets.len()),
796 entry_states.len(),
797 );
798
799 for &bb in reverse_postorder {
801 let preds = predecessors[bb];
802 trace!(?bb, ?preds);
803
804 if preds == 0 {
806 continue;
807 }
808
809 let state = &mut entry_states[bb];
810 trace!(?state);
811
812 trace!(fulfilled_count = ?fulfill_in_pred_count[bb]);
814 for (condition, &cond_preds) in fulfill_in_pred_count[bb].iter_enumerated() {
815 if cond_preds == preds {
816 trace!(?condition);
817 state.fulfilled.push(condition);
818 }
819 }
820
821 let mut targets: Vec<_> = state
824 .fulfilled
825 .iter()
826 .flat_map(|&index| state.targets[index].iter().copied())
827 .collect();
828 targets.sort();
829 targets.dedup();
830 trace!(?targets);
831
832 let mut successors = basic_blocks[bb].terminator().successors().collect::<Vec<_>>();
834
835 targets.reverse();
836 while let Some(target) = targets.pop() {
837 match target {
838 EdgeEffect::Goto { target } => {
839 predecessors[target] += 1;
842 for &s in successors.iter() {
843 predecessors[s] -= 1;
844 }
845 targets.retain(|t| t.block() == target);
847 successors.clear();
848 successors.push(target);
849 }
850 EdgeEffect::Chain { succ_block, succ_condition } => {
851 let count = successors.iter().filter(|&&s| s == succ_block).count();
854 fulfill_in_pred_count[succ_block][succ_condition] += count;
855 }
856 }
857 }
858 }
859}
860
861#[instrument(level = "debug", skip(tcx, typing_env, body, entry_states))]
862fn remove_costly_conditions<'tcx>(
863 tcx: TyCtxt<'tcx>,
864 typing_env: ty::TypingEnv<'tcx>,
865 body: &Body<'tcx>,
866 entry_states: &mut IndexVec<BasicBlock, ConditionSet>,
867) {
868 let basic_blocks = &body.basic_blocks;
869
870 let mut costs = IndexVec::from_elem(None, basic_blocks);
871 let mut cost = |bb: BasicBlock| -> u8 {
872 let c = *costs[bb].get_or_insert_with(|| {
873 let bbdata = &basic_blocks[bb];
874 let mut cost = CostChecker::new(tcx, typing_env, None, body);
875 cost.visit_basic_block_data(bb, bbdata);
876 cost.cost().try_into().unwrap_or(MAX_COST)
877 });
878 trace!("cost[{bb:?}] = {c}");
879 c
880 };
881
882 let mut condition_cost = IndexVec::from_fn_n(
884 |bb: BasicBlock| IndexVec::from_elem_n(MAX_COST, entry_states[bb].targets.len()),
885 entry_states.len(),
886 );
887
888 let reverse_postorder = basic_blocks.reverse_postorder();
889
890 for &bb in reverse_postorder.iter().rev() {
891 let state = &entry_states[bb];
892 trace!(?bb, ?state);
893
894 let mut current_costs = IndexVec::from_elem(0u8, &state.targets);
895
896 for (condition, targets) in state.targets.iter_enumerated() {
897 for &target in targets {
898 match target {
899 EdgeEffect::Goto { .. } => {}
901 EdgeEffect::Chain { succ_block, succ_condition }
903 if entry_states[succ_block].fulfilled.contains(&succ_condition) => {}
904 EdgeEffect::Chain { succ_block, succ_condition } => {
906 let duplication_cost = cost(succ_block);
908 let target_cost =
910 *condition_cost[succ_block].get(succ_condition).unwrap_or(&MAX_COST);
911 let cost = current_costs[condition]
912 .saturating_add(duplication_cost)
913 .saturating_add(target_cost);
914 trace!(?condition, ?succ_block, ?duplication_cost, ?target_cost);
915 current_costs[condition] = cost;
916 }
917 }
918 }
919 }
920
921 trace!("condition_cost[{bb:?}] = {:?}", current_costs);
922 condition_cost[bb] = current_costs;
923 }
924
925 trace!(?condition_cost);
926
927 for &bb in reverse_postorder {
928 for (index, targets) in entry_states[bb].targets.iter_enumerated_mut() {
929 if condition_cost[bb][index] >= MAX_COST {
930 trace!(?bb, ?index, ?targets, c = ?condition_cost[bb][index], "remove");
931 targets.clear()
932 }
933 }
934 }
935}
936
937struct OpportunitySet<'a, 'tcx> {
938 basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
939 entry_states: IndexVec<BasicBlock, ConditionSet>,
940 duplicates: FxHashMap<(BasicBlock, ConditionIndex), BasicBlock>,
943}
944
945impl<'a, 'tcx> OpportunitySet<'a, 'tcx> {
946 fn new(
947 body: &'a mut Body<'tcx>,
948 mut entry_states: IndexVec<BasicBlock, ConditionSet>,
949 ) -> Option<OpportunitySet<'a, 'tcx>> {
950 trace!(def_id = ?body.source.def_id(), "apply");
951
952 if entry_states.iter().all(|state| state.fulfilled.is_empty()) {
953 return None;
954 }
955
956 for state in entry_states.iter_mut() {
958 state.active = Default::default();
959 }
960 let duplicates = Default::default();
961 let basic_blocks = body.basic_blocks.as_mut();
962 Some(OpportunitySet { basic_blocks, entry_states, duplicates })
963 }
964
965 #[instrument(level = "debug", skip(self))]
967 fn apply(mut self) {
968 let mut worklist = Vec::with_capacity(self.basic_blocks.len());
969 worklist.push(START_BLOCK);
970
971 let mut visited = GrowableBitSet::with_capacity(self.basic_blocks.len());
973
974 while let Some(bb) = worklist.pop() {
975 if !visited.insert(bb) {
976 continue;
977 }
978
979 self.apply_once(bb);
980
981 worklist.extend(self.basic_blocks[bb].terminator().successors());
984 }
985 }
986
987 #[instrument(level = "debug", skip(self))]
989 fn apply_once(&mut self, bb: BasicBlock) {
990 let state = &mut self.entry_states[bb];
991 trace!(?state);
992
993 let mut targets: Vec<_> = state
996 .fulfilled
997 .iter()
998 .flat_map(|&index| std::mem::take(&mut state.targets[index]))
999 .collect();
1000 targets.sort();
1001 targets.dedup();
1002 trace!(?targets);
1003
1004 targets.reverse();
1006 while let Some(target) = targets.pop() {
1007 debug!(?target);
1008 trace!(term = ?self.basic_blocks[bb].terminator().kind);
1009
1010 debug_assert!(
1014 self.basic_blocks[bb].terminator().successors().contains(&target.block()),
1015 "missing {target:?} in successors for {bb:?}, term={:?}",
1016 self.basic_blocks[bb].terminator(),
1017 );
1018
1019 match target {
1020 EdgeEffect::Goto { target } => {
1021 self.apply_goto(bb, target);
1022
1023 targets.retain(|t| t.block() == target);
1025 for ts in self.entry_states[bb].targets.iter_mut() {
1027 ts.retain(|t| t.block() == target);
1028 }
1029 }
1030 EdgeEffect::Chain { succ_block, succ_condition } => {
1031 let new_succ_block = self.apply_chain(bb, succ_block, succ_condition);
1032
1033 if let Some(new_succ_block) = new_succ_block {
1035 for t in targets.iter_mut() {
1036 t.replace_block(succ_block, new_succ_block)
1037 }
1038 for t in
1040 self.entry_states[bb].targets.iter_mut().flat_map(|ts| ts.iter_mut())
1041 {
1042 t.replace_block(succ_block, new_succ_block)
1043 }
1044 }
1045 }
1046 }
1047
1048 trace!(post_term = ?self.basic_blocks[bb].terminator().kind);
1049 }
1050 }
1051
1052 #[instrument(level = "debug", skip(self))]
1053 fn apply_goto(&mut self, bb: BasicBlock, target: BasicBlock) {
1054 self.basic_blocks[bb].terminator_mut().kind = TerminatorKind::Goto { target };
1055 }
1056
1057 #[instrument(level = "debug", skip(self), ret)]
1058 fn apply_chain(
1059 &mut self,
1060 bb: BasicBlock,
1061 target: BasicBlock,
1062 condition: ConditionIndex,
1063 ) -> Option<BasicBlock> {
1064 if self.entry_states[target].fulfilled.contains(&condition) {
1065 trace!("fulfilled");
1067 return None;
1068 }
1069
1070 let new_target = *self.duplicates.entry((target, condition)).or_insert_with(|| {
1076 let new_target = self.basic_blocks.push(self.basic_blocks[target].clone());
1079 trace!(?target, ?new_target, ?condition, "clone");
1080
1081 let mut condition_set = self.entry_states[target].clone();
1084 condition_set.fulfilled.push(condition);
1085 let _new_target = self.entry_states.push(condition_set);
1086 debug_assert_eq!(new_target, _new_target);
1087
1088 new_target
1089 });
1090 trace!(?target, ?new_target, ?condition, "reuse");
1091
1092 self.basic_blocks[bb].terminator_mut().successors_mut(|s| {
1095 if *s == target {
1096 *s = new_target;
1097 }
1098 });
1099
1100 Some(new_target)
1101 }
1102}