use std::ops::ControlFlow;
use rustc_data_structures::graph::iterate::{
NodeStatus, TriColorDepthFirstSearch, TriColorVisitor,
};
use rustc_hir::def::DefKind;
use rustc_middle::mir::{self, BasicBlock, BasicBlocks, Body, Terminator, TerminatorKind};
use rustc_middle::ty::{self, GenericArg, GenericArgs, Instance, Ty, TyCtxt};
use rustc_session::lint::builtin::UNCONDITIONAL_RECURSION;
use rustc_span::Span;
use crate::errors::UnconditionalRecursion;
pub(crate) fn check<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
check_call_recursion(tcx, body);
}
fn check_call_recursion<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
let def_id = body.source.def_id().expect_local();
if let DefKind::Fn | DefKind::AssocFn = tcx.def_kind(def_id) {
let trait_args = match tcx.trait_of_item(def_id.to_def_id()) {
Some(trait_def_id) => {
let trait_args_count = tcx.generics_of(trait_def_id).count();
&GenericArgs::identity_for_item(tcx, def_id)[..trait_args_count]
}
_ => &[],
};
check_recursion(tcx, body, CallRecursion { trait_args })
}
}
fn check_recursion<'tcx>(
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
classifier: impl TerminatorClassifier<'tcx>,
) {
let def_id = body.source.def_id().expect_local();
if let DefKind::Fn | DefKind::AssocFn = tcx.def_kind(def_id) {
let mut vis = Search { tcx, body, classifier, reachable_recursive_calls: vec![] };
if let Some(NonRecursive) =
TriColorDepthFirstSearch::new(&body.basic_blocks).run_from_start(&mut vis)
{
return;
}
if vis.reachable_recursive_calls.is_empty() {
return;
}
vis.reachable_recursive_calls.sort();
let sp = tcx.def_span(def_id);
let hir_id = tcx.local_def_id_to_hir_id(def_id);
tcx.emit_node_span_lint(UNCONDITIONAL_RECURSION, hir_id, sp, UnconditionalRecursion {
span: sp,
call_sites: vis.reachable_recursive_calls,
});
}
}
pub fn check_drop_recursion<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) {
let def_id = body.source.def_id().expect_local();
if let DefKind::AssocFn = tcx.def_kind(def_id)
&& let Some(trait_ref) =
tcx.impl_of_method(def_id.to_def_id()).and_then(|def_id| tcx.impl_trait_ref(def_id))
&& let Some(drop_trait) = tcx.lang_items().drop_trait()
&& drop_trait == trait_ref.instantiate_identity().def_id
&& let sig = tcx.fn_sig(def_id).instantiate_identity()
&& sig.inputs().skip_binder().len() == 1
{
if let ty::Ref(_, dropped_ty, _) =
tcx.liberate_late_bound_regions(def_id.to_def_id(), sig.input(0)).kind()
{
check_recursion(tcx, body, RecursiveDrop { drop_for: *dropped_ty });
}
}
}
trait TerminatorClassifier<'tcx> {
fn is_recursive_terminator(
&self,
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
terminator: &Terminator<'tcx>,
) -> bool;
}
struct NonRecursive;
struct Search<'mir, 'tcx, C: TerminatorClassifier<'tcx>> {
tcx: TyCtxt<'tcx>,
body: &'mir Body<'tcx>,
classifier: C,
reachable_recursive_calls: Vec<Span>,
}
struct CallRecursion<'tcx> {
trait_args: &'tcx [GenericArg<'tcx>],
}
struct RecursiveDrop<'tcx> {
drop_for: Ty<'tcx>,
}
impl<'tcx> TerminatorClassifier<'tcx> for CallRecursion<'tcx> {
fn is_recursive_terminator(
&self,
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
terminator: &Terminator<'tcx>,
) -> bool {
let TerminatorKind::Call { func, args, .. } = &terminator.kind else {
return false;
};
if args.len() != body.arg_count {
return false;
}
let caller = body.source.def_id();
let param_env = tcx.param_env(caller);
let func_ty = func.ty(body, tcx);
if let ty::FnDef(callee, args) = *func_ty.kind() {
let Ok(normalized_args) = tcx.try_normalize_erasing_regions(param_env, args) else {
return false;
};
let (callee, call_args) = if let Ok(Some(instance)) =
Instance::try_resolve(tcx, param_env, callee, normalized_args)
{
(instance.def_id(), instance.args)
} else {
(callee, normalized_args)
};
return callee == caller && &call_args[..self.trait_args.len()] == self.trait_args;
}
false
}
}
impl<'tcx> TerminatorClassifier<'tcx> for RecursiveDrop<'tcx> {
fn is_recursive_terminator(
&self,
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
terminator: &Terminator<'tcx>,
) -> bool {
let TerminatorKind::Drop { place, .. } = &terminator.kind else { return false };
let dropped_ty = place.ty(body, tcx).ty;
dropped_ty == self.drop_for
}
}
impl<'mir, 'tcx, C: TerminatorClassifier<'tcx>> TriColorVisitor<BasicBlocks<'tcx>>
for Search<'mir, 'tcx, C>
{
type BreakVal = NonRecursive;
fn node_examined(
&mut self,
bb: BasicBlock,
prior_status: Option<NodeStatus>,
) -> ControlFlow<Self::BreakVal> {
if let Some(NodeStatus::Visited) = prior_status {
return ControlFlow::Break(NonRecursive);
}
match self.body[bb].terminator().kind {
TerminatorKind::UnwindTerminate(_)
| TerminatorKind::CoroutineDrop
| TerminatorKind::UnwindResume
| TerminatorKind::Return
| TerminatorKind::Unreachable
| TerminatorKind::Yield { .. } => ControlFlow::Break(NonRecursive),
TerminatorKind::InlineAsm { ref targets, .. } => {
if !targets.is_empty() {
ControlFlow::Continue(())
} else {
ControlFlow::Break(NonRecursive)
}
}
TerminatorKind::Assert { .. }
| TerminatorKind::Call { .. }
| TerminatorKind::Drop { .. }
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. }
| TerminatorKind::Goto { .. }
| TerminatorKind::SwitchInt { .. } => ControlFlow::Continue(()),
TerminatorKind::TailCall { .. } => ControlFlow::Continue(()),
}
}
fn node_settled(&mut self, bb: BasicBlock) -> ControlFlow<Self::BreakVal> {
let terminator = self.body[bb].terminator();
if self.classifier.is_recursive_terminator(self.tcx, self.body, terminator) {
self.reachable_recursive_calls.push(terminator.source_info.span);
}
ControlFlow::Continue(())
}
fn ignore_edge(&mut self, bb: BasicBlock, target: BasicBlock) -> bool {
let terminator = self.body[bb].terminator();
let ignore_unwind = terminator.unwind() == Some(&mir::UnwindAction::Cleanup(target))
&& terminator.successors().count() > 1;
if ignore_unwind || self.classifier.is_recursive_terminator(self.tcx, self.body, terminator)
{
return true;
}
match &terminator.kind {
TerminatorKind::FalseEdge { imaginary_target, .. } => imaginary_target == &target,
_ => false,
}
}
}