1use itertools::Itertools as _;
55use rustc_const_eval::const_eval::DummyMachine;
56use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
57use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexSet};
58use rustc_index::IndexVec;
59use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
60use rustc_middle::bug;
61use rustc_middle::mir::interpret::Scalar;
62use rustc_middle::mir::visit::Visitor;
63use rustc_middle::mir::*;
64use rustc_middle::ty::{self, ScalarInt, TyCtxt};
65use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, TrackElem, ValueIndex};
66use rustc_span::DUMMY_SP;
67use tracing::{debug, instrument, trace};
68
69use crate::cost_checker::CostChecker;
70
71pub(super) struct JumpThreading;
72
73const MAX_COST: u8 = 100;
74const MAX_PLACES: usize = 100;
75
76impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
77 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
78 sess.mir_opt_level() >= 2
79 }
80
81 #[instrument(skip_all level = "debug")]
82 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
83 let def_id = body.source.def_id();
84 debug!(?def_id);
85
86 if tcx.is_coroutine(def_id) {
88 trace!("Skipped for coroutine {:?}", def_id);
89 return;
90 }
91
92 let typing_env = body.typing_env(tcx);
93 let mut finder = TOFinder {
94 tcx,
95 typing_env,
96 ecx: InterpCx::new(tcx, DUMMY_SP, typing_env, DummyMachine),
97 body,
98 map: Map::new(tcx, body, Some(MAX_PLACES)),
99 maybe_loop_headers: loops::maybe_loop_headers(body),
100 entry_states: IndexVec::from_elem(ConditionSet::default(), &body.basic_blocks),
101 };
102
103 for (bb, bbdata) in traversal::postorder(body) {
104 if bbdata.is_cleanup {
105 continue;
106 }
107
108 let mut state = finder.populate_from_outgoing_edges(bb);
109 trace!("output_states[{bb:?}] = {state:?}");
110
111 finder.process_terminator(bb, &mut state);
112 trace!("pre_terminator_states[{bb:?}] = {state:?}");
113
114 for stmt in bbdata.statements.iter().rev() {
115 if state.is_empty() {
116 break;
117 }
118
119 finder.process_statement(stmt, &mut state);
120
121 if let Some((lhs, tail)) = finder.mutated_statement(stmt) {
126 finder.flood_state(lhs, tail, &mut state);
127 }
128 }
129
130 trace!("entry_states[{bb:?}] = {state:?}");
131 finder.entry_states[bb] = state;
132 }
133
134 let mut entry_states = finder.entry_states;
135 simplify_conditions(body, &mut entry_states);
136 remove_costly_conditions(tcx, typing_env, body, &mut entry_states);
137
138 if let Some(opportunities) = OpportunitySet::new(body, entry_states) {
139 opportunities.apply();
140 }
141 }
142
143 fn is_required(&self) -> bool {
144 false
145 }
146}
147
148struct TOFinder<'a, 'tcx> {
149 tcx: TyCtxt<'tcx>,
150 typing_env: ty::TypingEnv<'tcx>,
151 ecx: InterpCx<'tcx, DummyMachine>,
152 body: &'a Body<'tcx>,
153 map: Map<'tcx>,
154 maybe_loop_headers: DenseBitSet<BasicBlock>,
155 entry_states: IndexVec<BasicBlock, ConditionSet>,
160}
161
162rustc_index::newtype_index! {
163 #[derive(Ord, PartialOrd)]
164 #[debug_format = "_c{}"]
165 struct ConditionIndex {}
166}
167
168#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
171struct Condition {
172 place: ValueIndex,
173 value: ScalarInt,
174 polarity: Polarity,
175}
176
177#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
178enum Polarity {
179 Ne,
180 Eq,
181}
182
183impl Condition {
184 fn matches(&self, place: ValueIndex, value: ScalarInt) -> bool {
185 self.place == place && (self.value == value) == (self.polarity == Polarity::Eq)
186 }
187}
188
189#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
191enum EdgeEffect {
192 Goto { target: BasicBlock },
194 Chain { succ_block: BasicBlock, succ_condition: ConditionIndex },
196}
197
198impl EdgeEffect {
199 fn block(self) -> BasicBlock {
200 match self {
201 EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => bb,
202 }
203 }
204
205 fn replace_block(&mut self, target: BasicBlock, new_target: BasicBlock) {
206 match self {
207 EdgeEffect::Goto { target: bb } | EdgeEffect::Chain { succ_block: bb, .. } => {
208 if *bb == target {
209 *bb = new_target
210 }
211 }
212 }
213 }
214}
215
216#[derive(Clone, Debug, Default)]
217struct ConditionSet {
218 active: Vec<(ConditionIndex, Condition)>,
219 fulfilled: Vec<ConditionIndex>,
220 targets: IndexVec<ConditionIndex, Vec<EdgeEffect>>,
221}
222
223impl ConditionSet {
224 fn is_empty(&self) -> bool {
225 self.active.is_empty()
226 }
227
228 #[tracing::instrument(level = "trace", skip(self))]
229 fn push_condition(&mut self, c: Condition, target: BasicBlock) {
230 let index = self.targets.push(vec![EdgeEffect::Goto { target }]);
231 self.active.push((index, c));
232 }
233
234 fn fulfill_if(&mut self, f: impl Fn(Condition, &Vec<EdgeEffect>) -> bool) {
236 self.active.retain(|&(index, condition)| {
237 let targets = &self.targets[index];
238 if f(condition, targets) {
239 trace!(?index, ?condition, "fulfill");
240 self.fulfilled.push(index);
241 false
242 } else {
243 true
244 }
245 })
246 }
247
248 fn fulfill_matches(&mut self, place: ValueIndex, value: ScalarInt) {
250 self.fulfill_if(|c, _| c.matches(place, value))
251 }
252
253 fn retain(&mut self, mut f: impl FnMut(Condition) -> bool) {
254 self.active.retain(|&(_, c)| f(c))
255 }
256
257 fn retain_mut(&mut self, mut f: impl FnMut(Condition) -> Option<Condition>) {
258 self.active.retain_mut(|(_, c)| {
259 if let Some(new) = f(*c) {
260 *c = new;
261 true
262 } else {
263 false
264 }
265 })
266 }
267
268 fn for_each_mut(&mut self, f: impl Fn(&mut Condition)) {
269 for (_, c) in &mut self.active {
270 f(c)
271 }
272 }
273}
274
275impl<'a, 'tcx> TOFinder<'a, 'tcx> {
276 #[instrument(level = "trace", skip(self))]
278 fn populate_from_outgoing_edges(&mut self, bb: BasicBlock) -> ConditionSet {
279 let bbdata = &self.body[bb];
280
281 debug_assert!(self.entry_states[bb].is_empty());
283
284 let state_len =
285 bbdata.terminator().successors().map(|succ| self.entry_states[succ].active.len()).sum();
286 let mut state = ConditionSet {
287 active: Vec::with_capacity(state_len),
288 targets: IndexVec::with_capacity(state_len),
289 fulfilled: Vec::new(),
290 };
291
292 let mut known_conditions =
294 FxIndexSet::with_capacity_and_hasher(state_len, Default::default());
295 let mut insert = |condition, succ_block, succ_condition| {
296 let (index, new) = known_conditions.insert_full(condition);
297 let index = ConditionIndex::from_usize(index);
298 if new {
299 state.active.push((index, condition));
300 let _index = state.targets.push(Vec::new());
301 debug_assert_eq!(_index, index);
302 }
303 let target = EdgeEffect::Chain { succ_block, succ_condition };
304 debug_assert!(
305 !state.targets[index].contains(&target),
306 "duplicate targets for index={index:?} as {target:?} targets={:#?}",
307 &state.targets[index],
308 );
309 state.targets[index].push(target);
310 };
311
312 let mut seen = FxHashSet::default();
314 for succ in bbdata.terminator().successors() {
315 if !seen.insert(succ) {
316 continue;
317 }
318
319 if self.maybe_loop_headers.contains(succ) {
321 continue;
322 }
323
324 for &(succ_index, cond) in self.entry_states[succ].active.iter() {
325 insert(cond, succ, succ_index);
326 }
327 }
328
329 let num_conditions = known_conditions.len();
330 debug_assert_eq!(num_conditions, state.active.len());
331 debug_assert_eq!(num_conditions, state.targets.len());
332 state.fulfilled.reserve(num_conditions);
333
334 state
335 }
336
337 fn flood_state(
339 &self,
340 place: Place<'tcx>,
341 extra_elem: Option<TrackElem>,
342 state: &mut ConditionSet,
343 ) {
344 if state.is_empty() {
345 return;
346 }
347 let mut places_to_exclude = FxHashSet::default();
348 self.map.for_each_aliasing_place(place.as_ref(), extra_elem, &mut |vi| {
349 places_to_exclude.insert(vi);
350 });
351 trace!(?places_to_exclude, "flood_state");
352 if places_to_exclude.is_empty() {
353 return;
354 }
355 state.retain(|c| !places_to_exclude.contains(&c.place));
356 }
357
358 #[instrument(level = "trace", skip(self), ret)]
372 fn mutated_statement(
373 &self,
374 stmt: &Statement<'tcx>,
375 ) -> Option<(Place<'tcx>, Option<TrackElem>)> {
376 match stmt.kind {
377 StatementKind::Assign(box (place, _)) => Some((place, None)),
378 StatementKind::SetDiscriminant { box place, variant_index: _ } => {
379 Some((place, Some(TrackElem::Discriminant)))
380 }
381 StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
382 Some((Place::from(local), None))
383 }
384 StatementKind::Retag(..)
385 | StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
386 | StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
388 | StatementKind::AscribeUserType(..)
389 | StatementKind::Coverage(..)
390 | StatementKind::FakeRead(..)
391 | StatementKind::ConstEvalCounter
392 | StatementKind::PlaceMention(..)
393 | StatementKind::BackwardIncompatibleDropHint { .. }
394 | StatementKind::Nop => None,
395 }
396 }
397
398 #[instrument(level = "trace", skip(self, state))]
399 fn process_immediate(&mut self, lhs: PlaceIndex, rhs: ImmTy<'tcx>, state: &mut ConditionSet) {
400 if let Some(lhs) = self.map.value(lhs)
401 && let Immediate::Scalar(Scalar::Int(int)) = *rhs
402 {
403 state.fulfill_matches(lhs, int)
404 }
405 }
406
407 #[instrument(level = "trace", skip(self, state))]
409 fn process_constant(
410 &mut self,
411 lhs: PlaceIndex,
412 constant: OpTy<'tcx>,
413 state: &mut ConditionSet,
414 ) {
415 let values_inside = self.map.values_inside(lhs);
416 if !state.active.iter().any(|&(_, cond)| values_inside.contains(&cond.place)) {
417 return;
418 }
419 self.map.for_each_projection_value(
420 lhs,
421 constant,
422 &mut |elem, op| match elem {
423 TrackElem::Field(idx) => self.ecx.project_field(op, idx).discard_err(),
424 TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).discard_err(),
425 TrackElem::Discriminant => {
426 let variant = self.ecx.read_discriminant(op).discard_err()?;
427 let discr_value =
428 self.ecx.discriminant_for_variant(op.layout.ty, variant).discard_err()?;
429 Some(discr_value.into())
430 }
431 TrackElem::DerefLen => {
432 let op: OpTy<'_> = self.ecx.deref_pointer(op).discard_err()?.into();
433 let len_usize = op.len(&self.ecx).discard_err()?;
434 let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
435 Some(ImmTy::from_uint(len_usize, layout).into())
436 }
437 },
438 &mut |place, op| {
439 if let Some(place) = self.map.value(place)
440 && let Some(imm) = self.ecx.read_immediate_raw(op).discard_err()
441 && let Some(imm) = imm.right()
442 && let Immediate::Scalar(Scalar::Int(int)) = *imm
443 {
444 state.fulfill_matches(place, int)
445 }
446 },
447 );
448 }
449
450 #[instrument(level = "trace", skip(self, state))]
451 fn process_copy(&mut self, lhs: PlaceIndex, rhs: PlaceIndex, state: &mut ConditionSet) {
452 let mut renames = FxHashMap::default();
453 self.map.for_each_value_pair(rhs, lhs, &mut |rhs, lhs| {
454 renames.insert(lhs, rhs);
455 });
456 state.for_each_mut(|c| {
457 if let Some(rhs) = renames.get(&c.place) {
458 c.place = *rhs
459 }
460 });
461 }
462
463 #[instrument(level = "trace", skip(self, state))]
464 fn process_operand(&mut self, lhs: PlaceIndex, rhs: &Operand<'tcx>, state: &mut ConditionSet) {
465 match rhs {
466 Operand::Constant(constant) => {
468 let Some(constant) =
469 self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
470 else {
471 return;
472 };
473 self.process_constant(lhs, constant, state);
474 }
475 Operand::Move(rhs) | Operand::Copy(rhs) => {
477 let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
478 self.process_copy(lhs, rhs, state)
479 }
480 }
481 }
482
483 #[instrument(level = "trace", skip(self, state))]
484 fn process_assign(
485 &mut self,
486 lhs_place: &Place<'tcx>,
487 rvalue: &Rvalue<'tcx>,
488 state: &mut ConditionSet,
489 ) {
490 let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
491 match rvalue {
492 Rvalue::Use(operand) => self.process_operand(lhs, operand, state),
493 Rvalue::Discriminant(rhs) => {
495 let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
496 self.process_copy(lhs, rhs, state)
497 }
498 Rvalue::Aggregate(box kind, operands) => {
500 let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
501 let lhs = match kind {
502 AggregateKind::Adt(.., Some(_)) => return,
504 AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
505 if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
506 && let Some(discr_value) = self
507 .ecx
508 .discriminant_for_variant(agg_ty, *variant_index)
509 .discard_err()
510 {
511 self.process_immediate(discr_target, discr_value, state);
512 }
513 if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
514 idx
515 } else {
516 return;
517 }
518 }
519 _ => lhs,
520 };
521 for (field_index, operand) in operands.iter_enumerated() {
522 if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
523 self.process_operand(field, operand, state);
524 }
525 }
526 }
527 Rvalue::UnaryOp(UnOp::Not, Operand::Move(operand) | Operand::Copy(operand)) => {
529 let layout = self.ecx.layout_of(operand.ty(self.body, self.tcx).ty).unwrap();
530 let Some(lhs) = self.map.value(lhs) else { return };
531 let Some(operand) = self.map.find_value(operand.as_ref()) else { return };
532 state.retain_mut(|mut c| {
533 if c.place == lhs {
534 let value = self
535 .ecx
536 .unary_op(UnOp::Not, &ImmTy::from_scalar_int(c.value, layout))
537 .discard_err()?
538 .to_scalar_int()
539 .discard_err()?;
540 c.place = operand;
541 c.value = value;
542 }
543 Some(c)
544 });
545 }
546 Rvalue::BinaryOp(
549 op,
550 box (Operand::Move(operand) | Operand::Copy(operand), Operand::Constant(value))
551 | box (Operand::Constant(value), Operand::Move(operand) | Operand::Copy(operand)),
552 ) => {
553 let equals = match op {
554 BinOp::Eq => ScalarInt::TRUE,
555 BinOp::Ne => ScalarInt::FALSE,
556 _ => return,
557 };
558 if value.const_.ty().is_floating_point() {
559 return;
564 }
565 let Some(lhs) = self.map.value(lhs) else { return };
566 let Some(operand) = self.map.find_value(operand.as_ref()) else { return };
567 let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
568 else {
569 return;
570 };
571 state.for_each_mut(|c| {
572 if c.place == lhs {
573 let polarity =
574 if c.matches(lhs, equals) { Polarity::Eq } else { Polarity::Ne };
575 c.place = operand;
576 c.value = value;
577 c.polarity = polarity;
578 }
579 });
580 }
581
582 _ => {}
583 }
584 }
585
586 #[instrument(level = "trace", skip(self, state))]
587 fn process_statement(&mut self, stmt: &Statement<'tcx>, state: &mut ConditionSet) {
588 match &stmt.kind {
592 StatementKind::SetDiscriminant { box place, variant_index } => {
595 let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
596 let enum_ty = place.ty(self.body, self.tcx).ty;
597 let Some(discr) =
601 self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
602 else {
603 return;
604 };
605 self.process_immediate(discr_target, discr, state)
606 }
607 StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
609 Operand::Copy(place) | Operand::Move(place),
610 )) => {
611 let Some(place) = self.map.find_value(place.as_ref()) else { return };
612 state.fulfill_matches(place, ScalarInt::TRUE);
613 }
614 StatementKind::Assign(box (lhs_place, rhs)) => {
615 self.process_assign(lhs_place, rhs, state)
616 }
617 _ => {}
618 }
619 }
620
621 #[instrument(level = "trace", skip(self, state))]
623 fn process_terminator(&mut self, bb: BasicBlock, state: &mut ConditionSet) {
624 let term = self.body.basic_blocks[bb].terminator();
625 let place_to_flood = match term.kind {
626 TerminatorKind::FalseEdge { .. }
628 | TerminatorKind::FalseUnwind { .. }
629 | TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
630 TerminatorKind::InlineAsm { .. } => {
632 state.active.clear();
633 return;
634 }
635 TerminatorKind::SwitchInt { ref discr, ref targets } => {
637 return self.process_switch_int(discr, targets, state);
638 }
639 TerminatorKind::UnwindResume
641 | TerminatorKind::UnwindTerminate(_)
642 | TerminatorKind::Return
643 | TerminatorKind::Unreachable
644 | TerminatorKind::CoroutineDrop
645 | TerminatorKind::Assert { .. }
647 | TerminatorKind::Goto { .. } => None,
648 TerminatorKind::Drop { place: destination, .. }
650 | TerminatorKind::Call { destination, .. } => Some(destination),
651 TerminatorKind::TailCall { .. } => Some(RETURN_PLACE.into()),
652 };
653
654 if let Some(place_to_flood) = place_to_flood {
656 self.flood_state(place_to_flood, None, state);
657 }
658 }
659
660 #[instrument(level = "trace", skip(self))]
661 fn process_switch_int(
662 &mut self,
663 discr: &Operand<'tcx>,
664 targets: &SwitchTargets,
665 state: &mut ConditionSet,
666 ) {
667 let Some(discr) = discr.place() else { return };
668 let Some(discr_idx) = self.map.find_value(discr.as_ref()) else { return };
669
670 let discr_ty = discr.ty(self.body, self.tcx).ty;
671 let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
672
673 if targets.is_distinct() {
676 for &(index, c) in state.active.iter() {
677 if c.place != discr_idx {
678 continue;
679 }
680
681 let mut edges_fulfilling_condition = FxHashSet::default();
683
684 for (branch, tgt) in targets.iter() {
686 if let Some(branch) = ScalarInt::try_from_uint(branch, discr_layout.size)
687 && c.matches(discr_idx, branch)
688 {
689 edges_fulfilling_condition.insert(tgt);
690 }
691 }
692
693 if c.polarity == Polarity::Ne
698 && let Ok(value) = c.value.try_to_bits(discr_layout.size)
699 && targets.all_values().contains(&value.into())
700 {
701 edges_fulfilling_condition.insert(targets.otherwise());
702 }
703
704 let condition_targets = &state.targets[index];
708
709 let new_edges: Vec<_> = condition_targets
710 .iter()
711 .copied()
712 .filter(|&target| match target {
713 EdgeEffect::Goto { .. } => false,
714 EdgeEffect::Chain { succ_block, .. } => {
715 edges_fulfilling_condition.contains(&succ_block)
716 }
717 })
718 .collect();
719
720 if new_edges.len() == condition_targets.len() {
721 state.fulfilled.push(index);
724 } else {
725 let index = state.targets.push(new_edges);
728 state.fulfilled.push(index);
729 }
730 }
731 }
732
733 let mut mk_condition = |value, polarity, target| {
735 let c = Condition { place: discr_idx, value, polarity };
736 state.push_condition(c, target);
737 };
738 if let Some((value, then_, else_)) = targets.as_static_if() {
739 let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
741 mk_condition(value, Polarity::Eq, then_);
742 mk_condition(value, Polarity::Ne, else_);
743 } else {
744 for (value, target) in targets.iter() {
747 if let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) {
748 mk_condition(value, Polarity::Eq, target);
749 }
750 }
751 }
752 }
753}
754
755#[instrument(level = "debug", skip(body, entry_states))]
757fn simplify_conditions(body: &Body<'_>, entry_states: &mut IndexVec<BasicBlock, ConditionSet>) {
758 let basic_blocks = &body.basic_blocks;
759 let reverse_postorder = basic_blocks.reverse_postorder();
760
761 let mut predecessors = IndexVec::from_elem(0, &entry_states);
764 predecessors[START_BLOCK] = 1; for &bb in reverse_postorder {
766 let term = basic_blocks[bb].terminator();
767 for s in term.successors() {
768 predecessors[s] += 1;
769 }
770 }
771
772 let mut fulfill_in_pred_count = IndexVec::from_fn_n(
774 |bb: BasicBlock| IndexVec::from_elem_n(0, entry_states[bb].targets.len()),
775 entry_states.len(),
776 );
777
778 for &bb in reverse_postorder {
780 let preds = predecessors[bb];
781 trace!(?bb, ?preds);
782
783 if preds == 0 {
785 continue;
786 }
787
788 let state = &mut entry_states[bb];
789 trace!(?state);
790
791 trace!(fulfilled_count = ?fulfill_in_pred_count[bb]);
793 for (condition, &cond_preds) in fulfill_in_pred_count[bb].iter_enumerated() {
794 if cond_preds == preds {
795 trace!(?condition);
796 state.fulfilled.push(condition);
797 }
798 }
799
800 let mut targets: Vec<_> = state
803 .fulfilled
804 .iter()
805 .flat_map(|&index| state.targets[index].iter().copied())
806 .collect();
807 targets.sort();
808 targets.dedup();
809 trace!(?targets);
810
811 let mut successors = basic_blocks[bb].terminator().successors().collect::<Vec<_>>();
813
814 targets.reverse();
815 while let Some(target) = targets.pop() {
816 match target {
817 EdgeEffect::Goto { target } => {
818 predecessors[target] += 1;
821 for &s in successors.iter() {
822 predecessors[s] -= 1;
823 }
824 targets.retain(|t| t.block() == target);
826 successors.clear();
827 successors.push(target);
828 }
829 EdgeEffect::Chain { succ_block, succ_condition } => {
830 let count = successors.iter().filter(|&&s| s == succ_block).count();
833 fulfill_in_pred_count[succ_block][succ_condition] += count;
834 }
835 }
836 }
837 }
838}
839
840#[instrument(level = "debug", skip(tcx, typing_env, body, entry_states))]
841fn remove_costly_conditions<'tcx>(
842 tcx: TyCtxt<'tcx>,
843 typing_env: ty::TypingEnv<'tcx>,
844 body: &Body<'tcx>,
845 entry_states: &mut IndexVec<BasicBlock, ConditionSet>,
846) {
847 let basic_blocks = &body.basic_blocks;
848
849 let mut costs = IndexVec::from_elem(None, basic_blocks);
850 let mut cost = |bb: BasicBlock| -> u8 {
851 let c = *costs[bb].get_or_insert_with(|| {
852 let bbdata = &basic_blocks[bb];
853 let mut cost = CostChecker::new(tcx, typing_env, None, body);
854 cost.visit_basic_block_data(bb, bbdata);
855 cost.cost().try_into().unwrap_or(MAX_COST)
856 });
857 trace!("cost[{bb:?}] = {c}");
858 c
859 };
860
861 let mut condition_cost = IndexVec::from_fn_n(
863 |bb: BasicBlock| IndexVec::from_elem_n(MAX_COST, entry_states[bb].targets.len()),
864 entry_states.len(),
865 );
866
867 let reverse_postorder = basic_blocks.reverse_postorder();
868
869 for &bb in reverse_postorder.iter().rev() {
870 let state = &entry_states[bb];
871 trace!(?bb, ?state);
872
873 let mut current_costs = IndexVec::from_elem(0u8, &state.targets);
874
875 for (condition, targets) in state.targets.iter_enumerated() {
876 for &target in targets {
877 match target {
878 EdgeEffect::Goto { .. } => {}
880 EdgeEffect::Chain { succ_block, succ_condition }
882 if entry_states[succ_block].fulfilled.contains(&succ_condition) => {}
883 EdgeEffect::Chain { succ_block, succ_condition } => {
885 let duplication_cost = cost(succ_block);
887 let target_cost =
889 *condition_cost[succ_block].get(succ_condition).unwrap_or(&MAX_COST);
890 let cost = current_costs[condition]
891 .saturating_add(duplication_cost)
892 .saturating_add(target_cost);
893 trace!(?condition, ?succ_block, ?duplication_cost, ?target_cost);
894 current_costs[condition] = cost;
895 }
896 }
897 }
898 }
899
900 trace!("condition_cost[{bb:?}] = {:?}", current_costs);
901 condition_cost[bb] = current_costs;
902 }
903
904 trace!(?condition_cost);
905
906 for &bb in reverse_postorder {
907 for (index, targets) in entry_states[bb].targets.iter_enumerated_mut() {
908 if condition_cost[bb][index] >= MAX_COST {
909 trace!(?bb, ?index, ?targets, c = ?condition_cost[bb][index], "remove");
910 targets.clear()
911 }
912 }
913 }
914}
915
916struct OpportunitySet<'a, 'tcx> {
917 basic_blocks: &'a mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
918 entry_states: IndexVec<BasicBlock, ConditionSet>,
919 duplicates: FxHashMap<(BasicBlock, ConditionIndex), BasicBlock>,
922}
923
924impl<'a, 'tcx> OpportunitySet<'a, 'tcx> {
925 fn new(
926 body: &'a mut Body<'tcx>,
927 mut entry_states: IndexVec<BasicBlock, ConditionSet>,
928 ) -> Option<OpportunitySet<'a, 'tcx>> {
929 trace!(def_id = ?body.source.def_id(), "apply");
930
931 if entry_states.iter().all(|state| state.fulfilled.is_empty()) {
932 return None;
933 }
934
935 for state in entry_states.iter_mut() {
937 state.active = Default::default();
938 }
939 let duplicates = Default::default();
940 let basic_blocks = body.basic_blocks.as_mut();
941 Some(OpportunitySet { basic_blocks, entry_states, duplicates })
942 }
943
944 #[instrument(level = "debug", skip(self))]
946 fn apply(mut self) {
947 let mut worklist = Vec::with_capacity(self.basic_blocks.len());
948 worklist.push(START_BLOCK);
949
950 let mut visited = GrowableBitSet::with_capacity(self.basic_blocks.len());
952
953 while let Some(bb) = worklist.pop() {
954 if !visited.insert(bb) {
955 continue;
956 }
957
958 self.apply_once(bb);
959
960 worklist.extend(self.basic_blocks[bb].terminator().successors());
963 }
964 }
965
966 #[instrument(level = "debug", skip(self))]
968 fn apply_once(&mut self, bb: BasicBlock) {
969 let state = &mut self.entry_states[bb];
970 trace!(?state);
971
972 let mut targets: Vec<_> = state
975 .fulfilled
976 .iter()
977 .flat_map(|&index| std::mem::take(&mut state.targets[index]))
978 .collect();
979 targets.sort();
980 targets.dedup();
981 trace!(?targets);
982
983 targets.reverse();
985 while let Some(target) = targets.pop() {
986 debug!(?target);
987 trace!(term = ?self.basic_blocks[bb].terminator().kind);
988
989 debug_assert!(
993 self.basic_blocks[bb].terminator().successors().contains(&target.block()),
994 "missing {target:?} in successors for {bb:?}, term={:?}",
995 self.basic_blocks[bb].terminator(),
996 );
997
998 match target {
999 EdgeEffect::Goto { target } => {
1000 self.apply_goto(bb, target);
1001
1002 targets.retain(|t| t.block() == target);
1004 for ts in self.entry_states[bb].targets.iter_mut() {
1006 ts.retain(|t| t.block() == target);
1007 }
1008 }
1009 EdgeEffect::Chain { succ_block, succ_condition } => {
1010 let new_succ_block = self.apply_chain(bb, succ_block, succ_condition);
1011
1012 if let Some(new_succ_block) = new_succ_block {
1014 for t in targets.iter_mut() {
1015 t.replace_block(succ_block, new_succ_block)
1016 }
1017 for t in
1019 self.entry_states[bb].targets.iter_mut().flat_map(|ts| ts.iter_mut())
1020 {
1021 t.replace_block(succ_block, new_succ_block)
1022 }
1023 }
1024 }
1025 }
1026
1027 trace!(post_term = ?self.basic_blocks[bb].terminator().kind);
1028 }
1029 }
1030
1031 #[instrument(level = "debug", skip(self))]
1032 fn apply_goto(&mut self, bb: BasicBlock, target: BasicBlock) {
1033 self.basic_blocks[bb].terminator_mut().kind = TerminatorKind::Goto { target };
1034 }
1035
1036 #[instrument(level = "debug", skip(self), ret)]
1037 fn apply_chain(
1038 &mut self,
1039 bb: BasicBlock,
1040 target: BasicBlock,
1041 condition: ConditionIndex,
1042 ) -> Option<BasicBlock> {
1043 if self.entry_states[target].fulfilled.contains(&condition) {
1044 trace!("fulfilled");
1046 return None;
1047 }
1048
1049 let new_target = *self.duplicates.entry((target, condition)).or_insert_with(|| {
1055 let new_target = self.basic_blocks.push(self.basic_blocks[target].clone());
1058 trace!(?target, ?new_target, ?condition, "clone");
1059
1060 let mut condition_set = self.entry_states[target].clone();
1063 condition_set.fulfilled.push(condition);
1064 let _new_target = self.entry_states.push(condition_set);
1065 debug_assert_eq!(new_target, _new_target);
1066
1067 new_target
1068 });
1069 trace!(?target, ?new_target, ?condition, "reuse");
1070
1071 self.basic_blocks[bb].terminator_mut().successors_mut(|s| {
1074 if *s == target {
1075 *s = new_target;
1076 }
1077 });
1078
1079 Some(new_target)
1080 }
1081}