rustc_mir_transform/
cost_checker.rs

1use rustc_middle::bug;
2use rustc_middle::mir::visit::*;
3use rustc_middle::mir::*;
4use rustc_middle::ty::{self, Ty, TyCtxt};
5
6const INSTR_COST: usize = 5;
7const CALL_PENALTY: usize = 25;
8const LANDINGPAD_PENALTY: usize = 50;
9const RESUME_PENALTY: usize = 45;
10const LARGE_SWITCH_PENALTY: usize = 20;
11const CONST_SWITCH_BONUS: usize = 10;
12
13/// Verify that the callee body is compatible with the caller.
14#[derive(Clone)]
15pub(super) struct CostChecker<'b, 'tcx> {
16    tcx: TyCtxt<'tcx>,
17    typing_env: ty::TypingEnv<'tcx>,
18    penalty: usize,
19    bonus: usize,
20    callee_body: &'b Body<'tcx>,
21    instance: Option<ty::Instance<'tcx>>,
22}
23
24impl<'b, 'tcx> CostChecker<'b, 'tcx> {
25    pub(super) fn new(
26        tcx: TyCtxt<'tcx>,
27        typing_env: ty::TypingEnv<'tcx>,
28        instance: Option<ty::Instance<'tcx>>,
29        callee_body: &'b Body<'tcx>,
30    ) -> CostChecker<'b, 'tcx> {
31        CostChecker { tcx, typing_env, callee_body, instance, penalty: 0, bonus: 0 }
32    }
33
34    /// Add function-level costs not well-represented by the block-level costs.
35    ///
36    /// Needed because the `CostChecker` is used sometimes for just blocks,
37    /// and even the full `Inline` doesn't call `visit_body`, so there's nowhere
38    /// to put this logic in the visitor.
39    pub(super) fn add_function_level_costs(&mut self) {
40        fn is_call_like(bbd: &BasicBlockData<'_>) -> bool {
41            use TerminatorKind::*;
42            match bbd.terminator().kind {
43                Call { .. } | TailCall { .. } | Drop { .. } | Assert { .. } | InlineAsm { .. } => {
44                    true
45                }
46
47                Goto { .. }
48                | SwitchInt { .. }
49                | UnwindResume
50                | UnwindTerminate(_)
51                | Return
52                | Unreachable => false,
53
54                Yield { .. } | CoroutineDrop | FalseEdge { .. } | FalseUnwind { .. } => {
55                    unreachable!()
56                }
57            }
58        }
59
60        // If the only has one Call (or similar), inlining isn't increasing the total
61        // number of calls, so give extra encouragement to inlining that.
62        if self.callee_body.basic_blocks.iter().filter(|bbd| is_call_like(bbd)).count() == 1 {
63            self.bonus += CALL_PENALTY;
64        }
65    }
66
67    pub(super) fn cost(&self) -> usize {
68        usize::saturating_sub(self.penalty, self.bonus)
69    }
70
71    fn instantiate_ty(&self, v: Ty<'tcx>) -> Ty<'tcx> {
72        if let Some(instance) = self.instance {
73            instance.instantiate_mir(self.tcx, ty::EarlyBinder::bind(&v))
74        } else {
75            v
76        }
77    }
78}
79
80impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
81    fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
82        // Most costs are in rvalues and terminators, not in statements.
83        match statement.kind {
84            StatementKind::Intrinsic(ref ndi) => {
85                self.penalty += match **ndi {
86                    NonDivergingIntrinsic::Assume(..) => INSTR_COST,
87                    NonDivergingIntrinsic::CopyNonOverlapping(..) => CALL_PENALTY,
88                };
89            }
90            _ => self.super_statement(statement, location),
91        }
92    }
93
94    fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) {
95        match rvalue {
96            Rvalue::NullaryOp(NullOp::UbChecks, ..)
97                if !self
98                    .tcx
99                    .sess
100                    .opts
101                    .unstable_opts
102                    .inline_mir_preserve_debug
103                    .unwrap_or(self.tcx.sess.ub_checks()) =>
104            {
105                // If this is in optimized MIR it's because it's used later,
106                // so if we don't need UB checks this session, give a bonus
107                // here to offset the cost of the call later.
108                self.bonus += CALL_PENALTY;
109            }
110            // These are essentially constants that didn't end up in an Operand,
111            // so treat them as also being free.
112            Rvalue::NullaryOp(..) => {}
113            _ => self.penalty += INSTR_COST,
114        }
115    }
116
117    fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) {
118        match &terminator.kind {
119            TerminatorKind::Drop { place, unwind, .. } => {
120                // If the place doesn't actually need dropping, treat it like a regular goto.
121                let ty = self.instantiate_ty(place.ty(self.callee_body, self.tcx).ty);
122                if ty.needs_drop(self.tcx, self.typing_env) {
123                    self.penalty += CALL_PENALTY;
124                    if let UnwindAction::Cleanup(_) = unwind {
125                        self.penalty += LANDINGPAD_PENALTY;
126                    }
127                }
128            }
129            TerminatorKind::Call { func, unwind, .. } => {
130                self.penalty += if let Some((def_id, ..)) = func.const_fn_def()
131                    && self.tcx.intrinsic(def_id).is_some()
132                {
133                    // Don't give intrinsics the extra penalty for calls
134                    INSTR_COST
135                } else {
136                    CALL_PENALTY
137                };
138                if let UnwindAction::Cleanup(_) = unwind {
139                    self.penalty += LANDINGPAD_PENALTY;
140                }
141            }
142            TerminatorKind::TailCall { .. } => {
143                self.penalty += CALL_PENALTY;
144            }
145            TerminatorKind::SwitchInt { discr, targets } => {
146                if discr.constant().is_some() {
147                    // Not only will this become a `Goto`, but likely other
148                    // things will be removable as unreachable.
149                    self.bonus += CONST_SWITCH_BONUS;
150                } else if targets.all_targets().len() > 3 {
151                    // More than false/true/unreachable gets extra cost.
152                    self.penalty += LARGE_SWITCH_PENALTY;
153                } else {
154                    self.penalty += INSTR_COST;
155                }
156            }
157            TerminatorKind::Assert { unwind, msg, .. } => {
158                self.penalty += if msg.is_optional_overflow_check()
159                    && !self
160                        .tcx
161                        .sess
162                        .opts
163                        .unstable_opts
164                        .inline_mir_preserve_debug
165                        .unwrap_or(self.tcx.sess.overflow_checks())
166                {
167                    INSTR_COST
168                } else {
169                    CALL_PENALTY
170                };
171                if let UnwindAction::Cleanup(_) = unwind {
172                    self.penalty += LANDINGPAD_PENALTY;
173                }
174            }
175            TerminatorKind::UnwindResume => self.penalty += RESUME_PENALTY,
176            TerminatorKind::InlineAsm { unwind, .. } => {
177                self.penalty += INSTR_COST;
178                if let UnwindAction::Cleanup(_) = unwind {
179                    self.penalty += LANDINGPAD_PENALTY;
180                }
181            }
182            TerminatorKind::Unreachable => {
183                self.bonus += INSTR_COST;
184            }
185            TerminatorKind::Goto { .. } | TerminatorKind::Return => {}
186            TerminatorKind::UnwindTerminate(..) => {}
187            kind @ (TerminatorKind::FalseUnwind { .. }
188            | TerminatorKind::FalseEdge { .. }
189            | TerminatorKind::Yield { .. }
190            | TerminatorKind::CoroutineDrop) => {
191                bug!("{kind:?} should not be in runtime MIR");
192            }
193        }
194    }
195}