rustc_mir_transform/
cost_checker.rs

1use rustc_middle::bug;
2use rustc_middle::mir::visit::*;
3use rustc_middle::mir::*;
4use rustc_middle::ty::{self, Ty, TyCtxt};
5
6const INSTR_COST: usize = 5;
7const CALL_PENALTY: usize = 25;
8const LANDINGPAD_PENALTY: usize = 50;
9const RESUME_PENALTY: usize = 45;
10const LARGE_SWITCH_PENALTY: usize = 20;
11const CONST_SWITCH_BONUS: usize = 10;
12
13/// Verify that the callee body is compatible with the caller.
14#[derive(Clone)]
15pub(super) struct CostChecker<'b, 'tcx> {
16    tcx: TyCtxt<'tcx>,
17    typing_env: ty::TypingEnv<'tcx>,
18    penalty: usize,
19    bonus: usize,
20    callee_body: &'b Body<'tcx>,
21    instance: Option<ty::Instance<'tcx>>,
22}
23
24impl<'b, 'tcx> CostChecker<'b, 'tcx> {
25    pub(super) fn new(
26        tcx: TyCtxt<'tcx>,
27        typing_env: ty::TypingEnv<'tcx>,
28        instance: Option<ty::Instance<'tcx>>,
29        callee_body: &'b Body<'tcx>,
30    ) -> CostChecker<'b, 'tcx> {
31        CostChecker { tcx, typing_env, callee_body, instance, penalty: 0, bonus: 0 }
32    }
33
34    /// Add function-level costs not well-represented by the block-level costs.
35    ///
36    /// Needed because the `CostChecker` is used sometimes for just blocks,
37    /// and even the full `Inline` doesn't call `visit_body`, so there's nowhere
38    /// to put this logic in the visitor.
39    pub(super) fn add_function_level_costs(&mut self) {
40        // If the only has one Call (or similar), inlining isn't increasing the total
41        // number of calls, so give extra encouragement to inlining that.
42        if self.callee_body.basic_blocks.iter().filter(|bbd| is_call_like(bbd.terminator())).count()
43            == 1
44        {
45            self.bonus += CALL_PENALTY;
46        }
47    }
48
49    pub(super) fn cost(&self) -> usize {
50        usize::saturating_sub(self.penalty, self.bonus)
51    }
52
53    fn instantiate_ty(&self, v: Ty<'tcx>) -> Ty<'tcx> {
54        if let Some(instance) = self.instance {
55            instance.instantiate_mir(self.tcx, ty::EarlyBinder::bind(&v))
56        } else {
57            v
58        }
59    }
60}
61
62impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
63    fn visit_operand(&mut self, operand: &Operand<'tcx>, _: Location) {
64        match operand {
65            Operand::RuntimeChecks(RuntimeChecks::UbChecks) => {
66                if !self
67                    .tcx
68                    .sess
69                    .opts
70                    .unstable_opts
71                    .inline_mir_preserve_debug
72                    .unwrap_or(self.tcx.sess.ub_checks())
73                {
74                    // If this is in optimized MIR it's because it's used later, so if we don't need UB
75                    // checks this session, give a bonus here to offset the cost of the call later.
76                    self.bonus += CALL_PENALTY;
77                }
78            }
79            _ => {}
80        }
81    }
82
83    fn visit_statement(&mut self, statement: &Statement<'tcx>, loc: Location) {
84        // Most costs are in rvalues and terminators, not in statements.
85        match statement.kind {
86            StatementKind::Intrinsic(ref ndi) => {
87                self.penalty += match **ndi {
88                    NonDivergingIntrinsic::Assume(..) => INSTR_COST,
89                    NonDivergingIntrinsic::CopyNonOverlapping(..) => CALL_PENALTY,
90                };
91            }
92            StatementKind::Assign(..) => self.penalty += INSTR_COST,
93            _ => {}
94        }
95        self.super_statement(statement, loc)
96    }
97
98    fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, loc: Location) {
99        match &terminator.kind {
100            TerminatorKind::Drop { place, unwind, .. } => {
101                // If the place doesn't actually need dropping, treat it like a regular goto.
102                let ty = self.instantiate_ty(place.ty(self.callee_body, self.tcx).ty);
103                if ty.needs_drop(self.tcx, self.typing_env) {
104                    self.penalty += CALL_PENALTY;
105                    if let UnwindAction::Cleanup(_) = unwind {
106                        self.penalty += LANDINGPAD_PENALTY;
107                    }
108                }
109            }
110            TerminatorKind::Call { func, unwind, .. } => {
111                self.penalty += if let Some((def_id, ..)) = func.const_fn_def()
112                    && self.tcx.intrinsic(def_id).is_some()
113                {
114                    // Don't give intrinsics the extra penalty for calls
115                    INSTR_COST
116                } else {
117                    CALL_PENALTY
118                };
119                if let UnwindAction::Cleanup(_) = unwind {
120                    self.penalty += LANDINGPAD_PENALTY;
121                }
122            }
123            TerminatorKind::TailCall { .. } => {
124                self.penalty += CALL_PENALTY;
125            }
126            TerminatorKind::SwitchInt { discr, targets } => {
127                if matches!(discr, Operand::Constant(_) | Operand::RuntimeChecks(_)) {
128                    // Not only will this become a `Goto`, but likely other
129                    // things will be removable as unreachable.
130                    self.bonus += CONST_SWITCH_BONUS;
131                } else if targets.all_targets().len() > 3 {
132                    // More than false/true/unreachable gets extra cost.
133                    self.penalty += LARGE_SWITCH_PENALTY;
134                } else {
135                    self.penalty += INSTR_COST;
136                }
137            }
138            TerminatorKind::Assert { unwind, msg, .. } => {
139                self.penalty += if msg.is_optional_overflow_check()
140                    && !self
141                        .tcx
142                        .sess
143                        .opts
144                        .unstable_opts
145                        .inline_mir_preserve_debug
146                        .unwrap_or(self.tcx.sess.overflow_checks())
147                {
148                    INSTR_COST
149                } else {
150                    CALL_PENALTY
151                };
152                if let UnwindAction::Cleanup(_) = unwind {
153                    self.penalty += LANDINGPAD_PENALTY;
154                }
155            }
156            TerminatorKind::UnwindResume => self.penalty += RESUME_PENALTY,
157            TerminatorKind::InlineAsm { unwind, .. } => {
158                self.penalty += INSTR_COST;
159                if let UnwindAction::Cleanup(_) = unwind {
160                    self.penalty += LANDINGPAD_PENALTY;
161                }
162            }
163            TerminatorKind::Unreachable => {
164                self.bonus += INSTR_COST;
165            }
166            TerminatorKind::Goto { .. } | TerminatorKind::Return => {}
167            TerminatorKind::UnwindTerminate(..) => {}
168            kind @ (TerminatorKind::FalseUnwind { .. }
169            | TerminatorKind::FalseEdge { .. }
170            | TerminatorKind::Yield { .. }
171            | TerminatorKind::CoroutineDrop) => {
172                bug!("{kind:?} should not be in runtime MIR");
173            }
174        }
175        self.super_terminator(terminator, loc)
176    }
177}
178
179/// A terminator that's more call-like (might do a bunch of work, might panic, etc)
180/// than it is goto-/return-like (no side effects, etc).
181///
182/// Used to treat multi-call functions (which could inline exponentially)
183/// different from those that only do one or none of these "complex" things.
184pub(super) fn is_call_like(terminator: &Terminator<'_>) -> bool {
185    use TerminatorKind::*;
186    match terminator.kind {
187        Call { .. } | TailCall { .. } | Drop { .. } | Assert { .. } | InlineAsm { .. } => true,
188
189        Goto { .. }
190        | SwitchInt { .. }
191        | UnwindResume
192        | UnwindTerminate(_)
193        | Return
194        | Unreachable => false,
195
196        Yield { .. } | CoroutineDrop | FalseEdge { .. } | FalseUnwind { .. } => {
197            unreachable!()
198        }
199    }
200}