use rustc_middle::bug;
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
const INSTR_COST: usize = 5;
const CALL_PENALTY: usize = 25;
const LANDINGPAD_PENALTY: usize = 50;
const RESUME_PENALTY: usize = 45;
const LARGE_SWITCH_PENALTY: usize = 20;
const CONST_SWITCH_BONUS: usize = 10;
#[derive(Clone)]
pub(crate) struct CostChecker<'b, 'tcx> {
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
penalty: usize,
bonus: usize,
callee_body: &'b Body<'tcx>,
instance: Option<ty::Instance<'tcx>>,
}
impl<'b, 'tcx> CostChecker<'b, 'tcx> {
pub fn new(
tcx: TyCtxt<'tcx>,
param_env: ParamEnv<'tcx>,
instance: Option<ty::Instance<'tcx>>,
callee_body: &'b Body<'tcx>,
) -> CostChecker<'b, 'tcx> {
CostChecker { tcx, param_env, callee_body, instance, penalty: 0, bonus: 0 }
}
pub fn add_function_level_costs(&mut self) {
fn is_call_like(bbd: &BasicBlockData<'_>) -> bool {
use TerminatorKind::*;
match bbd.terminator().kind {
Call { .. } | TailCall { .. } | Drop { .. } | Assert { .. } | InlineAsm { .. } => {
true
}
Goto { .. }
| SwitchInt { .. }
| UnwindResume
| UnwindTerminate(_)
| Return
| Unreachable => false,
Yield { .. } | CoroutineDrop | FalseEdge { .. } | FalseUnwind { .. } => {
unreachable!()
}
}
}
if self.callee_body.basic_blocks.iter().filter(|bbd| is_call_like(bbd)).count() == 1 {
self.bonus += CALL_PENALTY;
}
}
pub fn cost(&self) -> usize {
usize::saturating_sub(self.penalty, self.bonus)
}
fn instantiate_ty(&self, v: Ty<'tcx>) -> Ty<'tcx> {
if let Some(instance) = self.instance {
instance.instantiate_mir(self.tcx, ty::EarlyBinder::bind(&v))
} else {
v
}
}
}
impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
match statement.kind {
StatementKind::Intrinsic(ref ndi) => {
self.penalty += match **ndi {
NonDivergingIntrinsic::Assume(..) => INSTR_COST,
NonDivergingIntrinsic::CopyNonOverlapping(..) => CALL_PENALTY,
};
}
_ => self.super_statement(statement, location),
}
}
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) {
match rvalue {
Rvalue::NullaryOp(NullOp::UbChecks, ..)
if !self
.tcx
.sess
.opts
.unstable_opts
.inline_mir_preserve_debug
.unwrap_or(self.tcx.sess.ub_checks()) =>
{
self.bonus += CALL_PENALTY;
}
Rvalue::NullaryOp(..) => {}
_ => self.penalty += INSTR_COST,
}
}
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) {
match &terminator.kind {
TerminatorKind::Drop { place, unwind, .. } => {
let ty = self.instantiate_ty(place.ty(self.callee_body, self.tcx).ty);
if ty.needs_drop(self.tcx, self.param_env) {
self.penalty += CALL_PENALTY;
if let UnwindAction::Cleanup(_) = unwind {
self.penalty += LANDINGPAD_PENALTY;
}
}
}
TerminatorKind::Call { func, unwind, .. } => {
self.penalty += if let Some((def_id, ..)) = func.const_fn_def()
&& self.tcx.intrinsic(def_id).is_some()
{
INSTR_COST
} else {
CALL_PENALTY
};
if let UnwindAction::Cleanup(_) = unwind {
self.penalty += LANDINGPAD_PENALTY;
}
}
TerminatorKind::TailCall { .. } => {
self.penalty += CALL_PENALTY;
}
TerminatorKind::SwitchInt { discr, targets } => {
if discr.constant().is_some() {
self.bonus += CONST_SWITCH_BONUS;
} else if targets.all_targets().len() > 3 {
self.penalty += LARGE_SWITCH_PENALTY;
} else {
self.penalty += INSTR_COST;
}
}
TerminatorKind::Assert { unwind, msg, .. } => {
self.penalty += if msg.is_optional_overflow_check()
&& !self
.tcx
.sess
.opts
.unstable_opts
.inline_mir_preserve_debug
.unwrap_or(self.tcx.sess.overflow_checks())
{
INSTR_COST
} else {
CALL_PENALTY
};
if let UnwindAction::Cleanup(_) = unwind {
self.penalty += LANDINGPAD_PENALTY;
}
}
TerminatorKind::UnwindResume => self.penalty += RESUME_PENALTY,
TerminatorKind::InlineAsm { unwind, .. } => {
self.penalty += INSTR_COST;
if let UnwindAction::Cleanup(_) = unwind {
self.penalty += LANDINGPAD_PENALTY;
}
}
TerminatorKind::Unreachable => {
self.bonus += INSTR_COST;
}
TerminatorKind::Goto { .. } | TerminatorKind::Return => {}
TerminatorKind::UnwindTerminate(..) => {}
kind @ (TerminatorKind::FalseUnwind { .. }
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::Yield { .. }
| TerminatorKind::CoroutineDrop) => {
bug!("{kind:?} should not be in runtime MIR");
}
}
}
}