1use itertools::Itertools as _;
55use rustc_const_eval::const_eval::DummyMachine;
56use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
57use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
58use rustc_index::IndexVec;
59use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
60use rustc_middle::bug;
61use rustc_middle::mir::interpret::Scalar;
62use rustc_middle::mir::visit::Visitor;
63use rustc_middle::mir::*;
64use rustc_middle::ty::{self, ScalarInt, TyCtxt};
65use rustc_mir_dataflow::value_analysis::{
66 Map, PlaceCollectionMode, PlaceIndex, TrackElem, ValueIndex,
67};
68use rustc_span::DUMMY_SP;
69use tracing::{debug, instrument, trace};
70
71use crate::cost_checker::CostChecker;
72
73pub(super) struct JumpThreading;
74
75const MAX_COST: u8 = 100;
76
77impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
78 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
79 if sess.target.is_like_gpu {
80 return false;
86 }
87 sess.mir_opt_level() >= 2
88 }
89
90 #[instrument(skip_all level = "debug")]
91 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
92 let def_id = body.source.def_id();
93 debug!(?def_id);
94
95 if tcx.is_coroutine(def_id) {
97 trace!("Skipped for coroutine {:?}", def_id);
98 return;
99 }
100
101 let typing_env = body.typing_env(tcx);
102 let mut finder = TOFinder {
103 tcx,
104 typing_env,
105 ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
106 body,
107 map: Map::new(tcx, body, PlaceCollectionMode::OnDemand),
108 maybe_loop_headers: maybe_loop_headers(body),
109 entry_states: IndexVec::from_elem(ConditionSet::default(), &body.basic_blocks),
110 };
111
112 for (bb, bbdata) in traversal::postorder(body) {
113 if bbdata.is_cleanup {
114 continue;
115 }
116
117 let mut state = finder.populate_from_outgoing_edges(bb);
118 trace!("output_states[{bb:?}] = {state:?}");
119
120 finder.process_terminator(bb, &mut state);
121 trace!("pre_terminator_states[{bb:?}] = {state:?}");
122
123 for stmt in bbdata.statements.iter().rev() {
124 if state.is_empty() {
125 break;
126 }
127
128 finder.process_statement(stmt, &mut state);
129
130 if let Some((lhs, tail)) = finder.mutated_statement(stmt) {
135 finder.flood_state(lhs, tail, &mut state);
136 }
137 }
138
139 trace!("entry_states[{bb:?}] = {state:?}");
140 finder.entry_states[bb] = state;
141 }
142
143 let mut entry_states = finder.entry_states;
144 simplify_conditions(body, &mut entry_states);
145 remove_costly_conditions(tcx, typing_env, body, &mut entry_states);
146
147 if let Some(opportunities) = OpportunitySet::new(body, entry_states) {
148 opportunities.apply();
149 }
150 }
151
152 fn is_required(&self) -> bool {
153 false
154 }
155}
156
157struct TOFinder<'a, 'tcx> {
158 tcx: TyCtxt<'tcx>,
159 typing_env: ty::TypingEnv<'tcx>,
160 ecx: InterpCx<'tcx, DummyMachine>,
161 body: &'a Body<'tcx>,
162 map: Map<'tcx>,
163 maybe_loop_headers: DenseBitSet<BasicBlock>,
164 entry_states: IndexVec<BasicBlock, ConditionSet>,
169}
170
171rustc_index::newtype_index! {
172 #[orderable]
173 #[debug_format = "_c{}"]
174 struct ConditionIndex {}
175}
176
177#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
180struct Condition {
181 place: ValueIndex,
182 value: ScalarInt,
183 polarity: Polarity,
184}
185
186#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
187enum Polarity {
188 Ne,
189 Eq,
190}
191
192impl Condition {
193 fn matches(&self, place: ValueIndex, value: ScalarInt) -> bool {
194 self.place == place && (self.value == value) == (self.polarity == Polarity::Eq)
195 }
196}
197
198#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
200enum EdgeEffect {
201 Goto { target: BasicBlock },
203 Chain { succ_block: BasicBlock, succ_condition: ConditionIndex },
205}
206
207impl EdgeEffect {
208 fn block(self) -> BasicBlock {
209 match self {
210 EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => bb,
211 }
212 }
213
214 fn replace_block(&mut self, target: BasicBlock, new_target: BasicBlock) {
215 match self {
216 EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => {
217 if *bb == target {
218 *bb = new_target
219 }
220 }
221 }
222 }
223}
224
225#[derive(Clone, Debug, Default)]
226struct ConditionSet {
227 active: Vec<(ConditionIndex, Condition)>,
228 fulfilled: Vec<ConditionIndex>,
229 targets: IndexVec<ConditionIndex, Vec<EdgeEffect>>,
230}
231
232impl ConditionSet {
233 fn is_empty(&self) -> bool {
234 self.active.is_empty()
235 }
236
237 #[tracing::instrument(level = "trace", skip(self))]
238 fn push_condition(&mut self, c: Condition, target: BasicBlock) {
239 let index = self.targets.push(vec![EdgeEffect::Goto { target }]);
240 self.active.push((index, c));
241 }
242
243 fn fulfill_if(&mut self, f: impl Fn(Condition, &Vec<EdgeEffect>) -> bool) {
245 self.active.retain(|&(index, condition)| {
246 let targets = &self.targets[index];
247 if f(condition, targets) {
248 trace!(?index, ?condition, "fulfill");
249 self.fulfilled.push(index);
250 false
251 } else {
252 true
253 }
254 })
255 }
256
257 fn fulfill_matches(&mut self, place: ValueIndex, value: ScalarInt) {
259 self.fulfill_if(|c, _| c.matches(place, value))
260 }
261
262 fn retain(&mut self, mut f: impl FnMut(Condition) -> bool) {
263 self.active.retain(|&(_, c)| f(c))
264 }
265
266 fn retain_mut(&mut self, mut f: impl FnMut(Condition) -> Option<Condition>) {
267 self.active.retain_mut(|(_, c)| {
268 if let Some(new) = f(*c) {
269 *c = new;
270 true
271 } else {
272 false
273 }
274 })
275 }
276
277 fn for_each_mut(&mut self, f: impl Fn(&mut Condition)) {
278 for (_, c) in &mut self.active {
279 f(c)
280 }
281 }
282}
283
284impl<'a, 'tcx> TOFinder<'a, 'tcx> {
285 fn place(&mut self, place: Place<'tcx>, tail: Option<TrackElem>) -> Option<PlaceIndex> {
286 self.map.register_place(self.tcx, self.body, place, tail)
287 }
288
289 fn value(&mut self, place: PlaceIndex) -> Option<ValueIndex> {
290 self.map.register_value(self.tcx, self.typing_env, place)
291 }
292
293 fn place_value(&mut self, place: Place<'tcx>, tail: Option<TrackElem>) -> Option<ValueIndex> {
294 let place = self.place(place, tail)?;
295 self.value(place)
296 }
297
298 #[instrument(level = "trace", skip(self))]
300 fn populate_from_outgoing_edges(&mut self, bb: BasicBlock) -> ConditionSet {
301 let bbdata = &self.body[bb];
302
303 debug_assert!(self.entry_states[bb].is_empty());
305
306 let state_len =
307 bbdata.terminator().successors().map(|succ| self.entry_states[succ].active.len()).sum();
308 let mut state = ConditionSet {
309 active: Vec::with_capacity(state_len),
310 targets: IndexVec::with_capacity(state_len),
311 fulfilled: Vec::new(),
312 };
313
314 let mut known_conditions =
316 FxIndexSet::with_capacity_and_hasher(state_len, Default::default());
317 let mut insert = |condition, succ_block, succ_condition| {
318 let (index, new) = known_conditions.insert_full(condition);
319 let index = ConditionIndex::from_usize(index);
320 if new {
321 state.active.push((index, condition));
322 let _index = state.targets.push(Vec::new());
323 debug_assert_eq!(_index, index);
324 }
325 let target = EdgeEffect::Chain { succ_block, succ_condition };
326 debug_assert!(
327 !state.targets[index].contains(&target),
328 "duplicate targets for index={index:?} as {target:?} targets={:#?}",
329 &state.targets[index],
330 );
331 state.targets[index].push(target);
332 };
333
334 let mut seen = FxHashSet::default();
336 for succ in bbdata.terminator().successors() {
337 if !seen.insert(succ) {
338 continue;
339 }
340
341 if self.maybe_loop_headers.contains(succ) {
343 continue;
344 }
345
346 for &(succ_index, cond) in self.entry_states[succ].active.iter() {
347 insert(cond, succ, succ_index);
348 }
349 }
350
351 let num_conditions = known_conditions.len();
352 debug_assert_eq!(num_conditions, state.active.len());
353 debug_assert_eq!(num_conditions, state.targets.len());
354 state.fulfilled.reserve(num_conditions);
355
356 state
357 }
358
359 fn flood_state(
361 &self,
362 place: Place<'tcx>,
363 extra_elem: Option<TrackElem>,
364 state: &mut ConditionSet,
365 ) {
366 if state.is_empty() {
367 return;
368 }
369 let mut places_to_exclude = FxHashSet::default();
370 self.map.for_each_aliasing_place(place.as_ref(), extra_elem, &mut |vi| {
371 places_to_exclude.insert(vi);
372 });
373 trace!(?places_to_exclude, "flood_state");
374 if places_to_exclude.is_empty() {
375 return;
376 }
377 state.retain(|c| !places_to_exclude.contains(&c.place));
378 }
379
380 #[instrument(level = "trace", skip(self), ret)]
394 fn mutated_statement(
395 &self,
396 stmt: &Statement<'tcx>,
397 ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
398 match stmt.kind {
399 StatementKind::Assign((place, _)) => Some((place, None)),
400 StatementKind::SetDiscriminant { ref place, variant_index: _ } => {
401 Some((**place, Some(TrackElem::Discriminant)))
402 }
403 StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
404 Some((Place::from(local), None))
405 }
406 | StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(..))
407 | StatementKind::Intrinsic(NonDivergingIntrinsic::CopyNonOverlapping(..))
409 | StatementKind::AscribeUserType(..)
410 | StatementKind::Coverage(..)
411 | StatementKind::FakeRead(..)
412 | StatementKind::ConstEvalCounter
413 | StatementKind::PlaceMention(..)
414 | StatementKind::BackwardIncompatibleDropHint { .. }
415 | StatementKind::Nop => None,
416 }
417 }
418
419 #[instrument(level = "trace", skip(self, state))]
420 fn process_immediate(&mut self, lhs: PlaceIndex, rhs: ImmTy<'tcx>, state: &mut ConditionSet) {
421 if let Some(lhs) = self.value(lhs)
422 && let Immediate::Scalar(Scalar::Int(int)) = *rhs
423 {
424 state.fulfill_matches(lhs, int)
425 }
426 }
427
428 #[instrument(level = "trace", skip(self, state))]
430 fn process_constant(
431 &mut self,
432 lhs: PlaceIndex,
433 constant: OpTy<'tcx>,
434 state: &mut ConditionSet,
435 ) {
436 self.map.for_each_projection_value(
437 lhs,
438 constant,
439 &mut |elem, op| match elem {
440 TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(),
441 TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
442 TrackElem::Discriminant => {
443 let variant = self.ecx.read_discriminant(op).discard_err()?;
444 let discr_value =
445 self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
446 Some(discr_value.into())
447 }
448 TrackElem::DerefLen => {
449 let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
450 let len_usize = op.len(&self.ecx).discard_err()?;
451 let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
452 Some(ImmTy::from_uint(len_usize, layout).into())
453 }
454 },
455 &mut |place, op| {
456 if let Some(place) = self.map.value(place)
457 && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
458 && let Some(imm) = imm.right()
459 && let Immediate::Scalar(Scalar::Int(int)) = *imm
460 {
461 state.fulfill_matches(place, int)
462 }
463 },
464 );
465 }
466
467 #[instrument(level = "trace", skip(self, state))]
468 fn process_copy(&mut self, lhs: PlaceIndex, rhs: PlaceIndex, state: &mut ConditionSet) {
469 let mut renames = FxHashMap::default();
470 self.map.register_copy_tree(
471 lhs, rhs, &mut |lhs, rhs| {
474 renames.insert(lhs, rhs);
475 },
476 );
477 state.for_each_mut(|c| {
478 if let Some(rhs) = renames.get(&c.place) {
479 c.place = *rhs
480 }
481 });
482 }
483
484 #[instrument(level = "trace", skip(self, state))]
485 fn process_operand(&mut self, lhs: PlaceIndex, rhs: &Operand<'tcx>, state: &mut ConditionSet) {
486 match rhs {
487 Operand::Constant(constant) => {
489 let Some(constant) =
490 self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
491 else {
492 return;
493 };
494 self.process_constant(lhs, constant, state);
495 }
496 Operand::Move(rhs) | Operand::Copy(rhs) => {
498 let Some(rhs) = self.place(*rhs, None) else { return };
499 self.process_copy(lhs, rhs, state)
500 }
501 Operand::RuntimeChecks(_) => {}
502 }
503 }
504
505 #[instrument(level = "trace", skip(self, state))]
506 fn process_assign(
507 &mut self,
508 lhs_place: &Place<'tcx>,
509 rvalue: &Rvalue<'tcx>,
510 state: &mut ConditionSet,
511 ) {
512 let Some(lhs) = self.place(*lhs_place, None) else { return };
513 match rvalue {
514 Rvalue::Use(operand, _) => self.process_operand(lhs, operand, state),
515 Rvalue::Discriminant(rhs) => {
517 let Some(rhs) = self.place(*rhs, Some(TrackElem::Discriminant)) else { return };
518 self.process_copy(lhs, rhs, state)
519 }
520 Rvalue::Aggregate(kind, operands) => {
522 let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
523 let lhs = match kind {
524 AggregateKind::Adt(.., Some(_)) => return,
526 AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
527 let discr_ty = agg_ty.discriminant_ty(self.tcx);
528 let discr_target =
529 self.map.register_place_index(discr_ty, lhs, TrackElem::Discriminant);
530 if let Some(discr_value) =
531 self.ecx.discriminant_for_variant(agg_ty, *variant_index).discard_err()
532 {
533 self.process_immediate(discr_target, discr_value, state);
534 }
535 self.map.register_place_index(
536 agg_ty,
537 lhs,
538 TrackElem::Variant(*variant_index),
539 )
540 }
541 _ => lhs,
542 };
543 for (field_index, operand) in operands.iter_enumerated() {
544 let operand_ty = operand.ty(self.body, self.tcx);
545 let field = self.map.register_place_index(
546 operand_ty,
547 lhs,
548 TrackElem::Field(field_index),
549 );
550 self.process_operand(field, operand, state);
551 }
552 }
553 Rvalue::UnaryOp(UnOp::Not, Operand::Move(operand) | Operand::Copy(operand)) => {
555 let layout = self.ecx.layout_of(operand.ty(self.body, self.tcx).ty).unwrap();
556 let Some(lhs) = self.value(lhs) else { return };
557 let Some(operand) = self.place_value(*operand, None) else { return };
558 state.retain_mut(|mut c| {
559 if c.place == lhs {
560 let value = self
561 .ecx
562 .unary_op(UnOp::Not, &ImmTy::from_scalar_int(c.value, layout))
563 .discard_err()?
564 .to_scalar_int()
565 .discard_err()?;
566 c.place = operand;
567 c.value = value;
568 }
569 Some(c)
570 });
571 }
572 Rvalue::BinaryOp(
575 op,
576 (Operand::Move(operand) | Operand::Copy(operand), Operand::Constant(value))
577 | (Operand::Constant(value), Operand::Move(operand) | Operand::Copy(operand)),
578 ) => {
579 let equals = match op {
580 BinOp::Eq => ScalarInt::TRUE,
581 BinOp::Ne => ScalarInt::FALSE,
582 _ => return,
583 };
584 if value.const_.ty().is_floating_point() {
585 return;
590 }
591 let Some(lhs) = self.value(lhs) else { return };
592 let Some(operand) = self.place_value(*operand, None) else { return };
593 let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
594 else {
595 return;
596 };
597 state.for_each_mut(|c| {
598 if c.place == lhs {
599 let polarity =
600 if c.matches(lhs, equals) { Polarity::Eq } else { Polarity::Ne };
601 c.place = operand;
602 c.value = value;
603 c.polarity = polarity;
604 }
605 });
606 }
607
608 _ => {}
609 }
610 }
611
612 #[instrument(level = "trace", skip(self, state))]
613 fn process_statement(&mut self, stmt: &Statement<'tcx>, state: &mut ConditionSet) {
614 match &stmt.kind {
618 StatementKind::SetDiscriminant { place, variant_index } => {
621 let Some(discr_target) = self.place(**place, Some(TrackElem::Discriminant)) else {
622 return;
623 };
624 let enum_ty = place.ty(self.body, self.tcx).ty;
625 let Some(discr) =
629 self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
630 else {
631 return;
632 };
633 self.process_immediate(discr_target, discr, state)
634 }
635 StatementKind::Intrinsic(NonDivergingIntrinsic::Assume(
637 Operand::Copy(place) | Operand::Move(place),
638 )) => {
639 let Some(place) = self.place_value(*place, None) else { return };
640 state.fulfill_matches(place, ScalarInt::TRUE);
641 }
642 StatementKind::Assign((lhs_place, rhs)) => self.process_assign(lhs_place, rhs, state),
643 _ => {}
644 }
645 }
646
647 #[instrument(level = "trace", skip(self, state))]
649 fn process_terminator(&mut self, bb: BasicBlock, state: &mut ConditionSet) {
650 let term = self.body.basic_blocks[bb].terminator();
651 let place_to_flood = match term.kind {
652 TerminatorKind::FalseEdge { .. }
654 | TerminatorKind::FalseUnwind { .. }
655 | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
656 TerminatorKind::InlineAsm { .. } => {
658 state.active.clear();
659 return;
660 }
661 TerminatorKind::SwitchInt { ref discr, ref targets } => {
663 return self.process_switch_int(discr, targets, state);
664 }
665 TerminatorKind::UnwindResume
667 | TerminatorKind::UnwindTerminate(_)
668 | TerminatorKind::Return
669 | TerminatorKind::Unreachable
670 | TerminatorKind::CoroutineDrop
671 | TerminatorKind::Assert { .. }
673 | TerminatorKind::Goto { .. } => None,
674 TerminatorKind::Drop { place: destination, .. }
676 | TerminatorKind::Call { destination, .. } => Some(destination),
677 TerminatorKind::TailCall { .. } => Some(RETURN_PLACE.into()),
678 };
679
680 if let Some(place_to_flood) = place_to_flood {
682 self.flood_state(place_to_flood, None, state);
683 }
684 }
685
686 #[instrument(level = "trace", skip(self))]
687 fn process_switch_int(
688 &mut self,
689 discr: &Operand<'tcx>,
690 targets: &SwitchTargets,
691 state: &mut ConditionSet,
692 ) {
693 let Some(discr) = discr.place() else { return };
694 let Some(discr_idx) = self.place_value(discr, None) else { return };
695
696 let discr_ty = discr.ty(self.body, self.tcx).ty;
697 let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
698
699 if targets.is_distinct() {
702 for &(index, c) in state.active.iter() {
703 if c.place != discr_idx {
704 continue;
705 }
706
707 let mut edges_fulfilling_condition = FxHashSet::default();
709
710 for (branch, tgt) in targets.iter() {
712 if let Some(branch) = ScalarInt::try_from_uint(branch, discr_layout.size)
713 && c.matches(discr_idx, branch)
714 {
715 edges_fulfilling_condition.insert(tgt);
716 }
717 }
718
719 if c.polarity == Polarity::Ne
724 && let Ok(value) = c.value.try_to_bits(discr_layout.size)
725 && targets.all_values().contains(&value.into())
726 {
727 edges_fulfilling_condition.insert(targets.otherwise());
728 }
729
730 let condition_targets = &state.targets[index];
734
735 let new_edges: Vec<_> = condition_targets
736 .iter()
737 .copied()
738 .filter(|&target| match target {
739 EdgeEffect::Goto { .. } => false,
740 EdgeEffect::Chain { succ_block, .. } => {
741 edges_fulfilling_condition.contains(&succ_block)
742 }
743 })
744 .collect();
745
746 if new_edges.len() == condition_targets.len() {
747 state.fulfilled.push(index);
750 } else {
751 let index = state.targets.push(new_edges);
754 state.fulfilled.push(index);
755 }
756 }
757 }
758
759 let mut mk_condition = |value, polarity, target| {
761 let c = Condition { place: discr_idx, value, polarity };
762 state.push_condition(c, target);
763 };
764 if let Some((value, then_, else_)) = targets.as_static_if() {
765 let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
767 mk_condition(value, Polarity::Eq, then_);
768 mk_condition(value, Polarity::Ne, else_);
769 } else {
770 for (value, target) in targets.iter() {
773 if let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) {
774 mk_condition(value, Polarity::Eq, target);
775 }
776 }
777 }
778 }
779}
780
781#[instrument(level = "debug", skip(body, entry_states))]
783fn simplify_conditions(body: &Body<'_>, entry_states: &mut IndexVec<BasicBlock, ConditionSet>) {
784 let basic_blocks = &body.basic_blocks;
785 let reverse_postorder = basic_blocks.reverse_postorder();
786
787 let mut predecessors = IndexVec::from_elem(0, &entry_states);
790 predecessors[START_BLOCK] = 1; for &bb in reverse_postorder {
792 let term = basic_blocks[bb].terminator();
793 for s in term.successors() {
794 predecessors[s] += 1;
795 }
796 }
797
798 let mut fulfill_in_pred_count = IndexVec::from_fn_n(
800 |bb: BasicBlock| IndexVec::from_elem_n(0, entry_states[bb].targets.len()),
801 entry_states.len(),
802 );
803
804 for &bb in reverse_postorder {
806 let preds = predecessors[bb];
807 trace!(?bb, ?preds);
808
809 if preds == 0 {
811 continue;
812 }
813
814 let state = &mut entry_states[bb];
815 trace!(?state);
816
817 trace!(fulfilled_count = ?fulfill_in_pred_count[bb]);
819 for (condition, &cond_preds) in fulfill_in_pred_count[bb].iter_enumerated() {
820 if cond_preds == preds {
821 trace!(?condition);
822 state.fulfilled.push(condition);
823 }
824 }
825
826 let mut targets: Vec<_> = state
829 .fulfilled
830 .iter()
831 .flat_map(|&index| state.targets[index].iter().copied())
832 .collect();
833 targets.sort();
834 targets.dedup();
835 trace!(?targets);
836
837 let mut successors = basic_blocks[bb].terminator().successors().collect::<Vec<_>>();
839
840 targets.reverse();
841 while let Some(target) = targets.pop() {
842 match target {
843 EdgeEffect::Goto { target } => {
844 predecessors[target] += 1;
847 for &s in successors.iter() {
848 predecessors[s] -= 1;
849 }
850 targets.retain(|t| t.block() == target);
852 successors.clear();
853 successors.push(target);
854 }
855 EdgeEffect::Chain { succ_block, succ_condition } => {
856 let count = successors.iter().filter(|&&s| s == succ_block).count();
859 fulfill_in_pred_count[succ_block][succ_condition] += count;
860 }
861 }
862 }
863 }
864}
865
866#[instrument(level = "debug", skip(tcx, typing_env, body, entry_states))]
867fn remove_costly_conditions<'tcx>(
868 tcx: TyCtxt<'tcx>,
869 typing_env: ty::TypingEnv<'tcx>,
870 body: &Body<'tcx>,
871 entry_states: &mut IndexVec<BasicBlock, ConditionSet>,
872) {
873 let basic_blocks = &body.basic_blocks;
874
875 let mut costs = IndexVec::from_elem(None, basic_blocks);
876 let mut cost = |bb: BasicBlock| -> u8 {
877 let c = *costs[bb].get_or_insert_with(|| {
878 let bbdata = &basic_blocks[bb];
879 let mut cost = CostChecker::new(tcx, typing_env, None, body);
880 cost.visit_basic_block_data(bb, bbdata);
881 cost.cost().try_into().unwrap_or(MAX_COST)
882 });
883 trace!("cost[{bb:?}] = {c}");
884 c
885 };
886
887 let mut condition_cost = IndexVec::from_fn_n(
889 |bb: BasicBlock| IndexVec::from_elem_n(MAX_COST, entry_states[bb].targets.len()),
890 entry_states.len(),
891 );
892
893 let reverse_postorder = basic_blocks.reverse_postorder();
894
895 for &bb in reverse_postorder.iter().rev() {
896 let state = &entry_states[bb];
897 trace!(?bb, ?state);
898
899 let mut current_costs = IndexVec::from_elem(0u8, &state.targets);
900
901 for (condition, targets) in state.targets.iter_enumerated() {
902 for &target in targets {
903 match target {
904 EdgeEffect::Goto { .. } => {}
906 EdgeEffect::Chain { succ_block, succ_condition }
908 if entry_states[succ_block].fulfilled.contains(&succ_condition) => {}
909 EdgeEffect::Chain { succ_block, succ_condition } => {
911 let duplication_cost = cost(succ_block);
913 let target_cost =
915 *condition_cost[succ_block].get(succ_condition).unwrap_or(&MAX_COST);
916 let cost = current_costs[condition]
917 .saturating_add(duplication_cost)
918 .saturating_add(target_cost);
919 trace!(?condition, ?succ_block, ?duplication_cost, ?target_cost);
920 current_costs[condition] = cost;
921 }
922 }
923 }
924 }
925
926 trace!("condition_cost[{bb:?}] = {:?}", current_costs);
927 condition_cost[bb] = current_costs;
928 }
929
930 trace!(?condition_cost);
931
932 for &bb in reverse_postorder {
933 for (index, targets) in entry_states[bb].targets.iter_enumerated_mut() {
934 if condition_cost[bb][index] >= MAX_COST {
935 trace!(?bb, ?index, ?targets, c = ?condition_cost[bb][index], "remove");
936 targets.clear()
937 }
938 }
939 }
940}
941
942struct OpportunitySet<'a, 'tcx> {
943 basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
944 entry_states: IndexVec<BasicBlock, ConditionSet>,
945 duplicates: FxHashMap<(BasicBlock, ConditionIndex), BasicBlock>,
948}
949
950impl<'a, 'tcx> OpportunitySet<'a, 'tcx> {
951 fn new(
952 body: &'a mut Body<'tcx>,
953 mut entry_states: IndexVec<BasicBlock, ConditionSet>,
954 ) -> Option<OpportunitySet<'a, 'tcx>> {
955 trace!(def_id = ?body.source.def_id(), "apply");
956
957 if entry_states.iter().all(|state| state.fulfilled.is_empty()) {
958 return None;
959 }
960
961 for state in entry_states.iter_mut() {
963 state.active = Default::default();
964 }
965 let duplicates = Default::default();
966 let basic_blocks = body.basic_blocks.as_mut();
967 Some(OpportunitySet { basic_blocks, entry_states, duplicates })
968 }
969
970 #[instrument(level = "debug", skip(self))]
972 fn apply(mut self) {
973 let mut worklist = Vec::with_capacity(self.basic_blocks.len());
974 worklist.push(START_BLOCK);
975
976 let mut visited = GrowableBitSet::with_capacity(self.basic_blocks.len());
978
979 while let Some(bb) = worklist.pop() {
980 if !visited.insert(bb) {
981 continue;
982 }
983
984 self.apply_once(bb);
985
986 worklist.extend(self.basic_blocks[bb].terminator().successors());
989 }
990 }
991
992 #[instrument(level = "debug", skip(self))]
994 fn apply_once(&mut self, bb: BasicBlock) {
995 let state = &mut self.entry_states[bb];
996 trace!(?state);
997
998 let mut targets: Vec<_> = state
1001 .fulfilled
1002 .iter()
1003 .flat_map(|&index| std::mem::take(&mut state.targets[index]))
1004 .collect();
1005 targets.sort();
1006 targets.dedup();
1007 trace!(?targets);
1008
1009 targets.reverse();
1011 while let Some(target) = targets.pop() {
1012 debug!(?target);
1013 trace!(term = ?self.basic_blocks[bb].terminator().kind);
1014
1015 debug_assert!(
1019 self.basic_blocks[bb].terminator().successors().contains(&target.block()),
1020 "missing {target:?} in successors for {bb:?}, term={:?}",
1021 self.basic_blocks[bb].terminator(),
1022 );
1023
1024 match target {
1025 EdgeEffect::Goto { target } => {
1026 self.apply_goto(bb, target);
1027
1028 targets.retain(|t| t.block() == target);
1030 for ts in self.entry_states[bb].targets.iter_mut() {
1032 ts.retain(|t| t.block() == target);
1033 }
1034 }
1035 EdgeEffect::Chain { succ_block, succ_condition } => {
1036 let new_succ_block = self.apply_chain(bb, succ_block, succ_condition);
1037
1038 if let Some(new_succ_block) = new_succ_block {
1040 for t in targets.iter_mut() {
1041 t.replace_block(succ_block, new_succ_block)
1042 }
1043 for t in
1045 self.entry_states[bb].targets.iter_mut().flat_map(|ts| ts.iter_mut())
1046 {
1047 t.replace_block(succ_block, new_succ_block)
1048 }
1049 }
1050 }
1051 }
1052
1053 trace!(post_term = ?self.basic_blocks[bb].terminator().kind);
1054 }
1055 }
1056
1057 #[instrument(level = "debug", skip(self))]
1058 fn apply_goto(&mut self, bb: BasicBlock, target: BasicBlock) {
1059 self.basic_blocks[bb].terminator_mut().kind = TerminatorKind::Goto { target };
1060 }
1061
1062 #[instrument(level = "debug", skip(self), ret)]
1063 fn apply_chain(
1064 &mut self,
1065 bb: BasicBlock,
1066 target: BasicBlock,
1067 condition: ConditionIndex,
1068 ) -> Option<BasicBlock> {
1069 if self.entry_states[target].fulfilled.contains(&condition) {
1070 trace!("fulfilled");
1072 return None;
1073 }
1074
1075 let new_target = *self.duplicates.entry((target, condition)).or_insert_with(|| {
1081 let new_target = self.basic_blocks.push(self.basic_blocks[target].clone());
1084 trace!(?target, ?new_target, ?condition, "clone");
1085
1086 let mut condition_set = self.entry_states[target].clone();
1089 condition_set.fulfilled.push(condition);
1090 let _new_target = self.entry_states.push(condition_set);
1091 debug_assert_eq!(new_target, _new_target);
1092
1093 new_target
1094 });
1095 trace!(?target, ?new_target, ?condition, "reuse");
1096
1097 self.basic_blocks[bb].terminator_mut().successors_mut(|s| {
1100 if *s == target {
1101 *s = new_target;
1102 }
1103 });
1104
1105 Some(new_target)
1106 }
1107}
1108
1109fn maybe_loop_headers(body: &Body<'_>) -> DenseBitSet<BasicBlock> {
1115 let mut maybe_loop_headers = DenseBitSet::new_empty(body.basic_blocks.len());
1116 let mut visited = DenseBitSet::new_empty(body.basic_blocks.len());
1117 for (bb, bbdata) in traversal::postorder(body) {
1118 for succ in bbdata.terminator().successors() {
1121 if !visited.contains(succ) {
1122 maybe_loop_headers.insert(succ);
1123 }
1124 }
1125
1126 let _new = visited.insert(bb);
1129 debug_assert!(_new);
1130 }
1131
1132 maybe_loop_headers
1133}