use hir::{Expr, Pat};
use rustc_hir::{self as hir, LangItem};
use rustc_infer::infer::TyCtxtInferExt;
use rustc_infer::traits::ObligationCause;
use rustc_middle::ty;
use rustc_session::{declare_lint, declare_lint_pass};
use rustc_span::{sym, Span};
use rustc_trait_selection::traits::ObligationCtxt;
use crate::lints::{
ForLoopsOverFalliblesDiag, ForLoopsOverFalliblesLoopSub, ForLoopsOverFalliblesQuestionMark,
ForLoopsOverFalliblesSuggestion,
};
use crate::{LateContext, LateLintPass, LintContext};
declare_lint! {
pub FOR_LOOPS_OVER_FALLIBLES,
Warn,
"for-looping over an `Option` or a `Result`, which is more clearly expressed as an `if let`"
}
declare_lint_pass!(ForLoopsOverFallibles => [FOR_LOOPS_OVER_FALLIBLES]);
impl<'tcx> LateLintPass<'tcx> for ForLoopsOverFallibles {
fn check_expr(&mut self, cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) {
let Some((pat, arg)) = extract_for_loop(expr) else { return };
let ty = cx.typeck_results().expr_ty(arg);
let (adt, args, ref_mutability) = match ty.kind() {
&ty::Adt(adt, args) => (adt, args, None),
&ty::Ref(_, ty, mutability) => match ty.kind() {
&ty::Adt(adt, args) => (adt, args, Some(mutability)),
_ => return,
},
_ => return,
};
let (article, ty, var) = match adt.did() {
did if cx.tcx.is_diagnostic_item(sym::Option, did) && ref_mutability.is_some() => {
("a", "Option", "Some")
}
did if cx.tcx.is_diagnostic_item(sym::Option, did) => ("an", "Option", "Some"),
did if cx.tcx.is_diagnostic_item(sym::Result, did) => ("a", "Result", "Ok"),
_ => return,
};
let ref_prefix = match ref_mutability {
None => "",
Some(ref_mutability) => ref_mutability.ref_prefix_str(),
};
let sub = if let Some(recv) = extract_iterator_next_call(cx, arg)
&& let Ok(recv_snip) = cx.sess().source_map().span_to_snippet(recv.span)
{
ForLoopsOverFalliblesLoopSub::RemoveNext {
suggestion: recv.span.between(arg.span.shrink_to_hi()),
recv_snip,
}
} else {
ForLoopsOverFalliblesLoopSub::UseWhileLet {
start_span: expr.span.with_hi(pat.span.lo()),
end_span: pat.span.between(arg.span),
var,
}
};
let question_mark = suggest_question_mark(cx, adt, args, expr.span)
.then(|| ForLoopsOverFalliblesQuestionMark { suggestion: arg.span.shrink_to_hi() });
let suggestion = ForLoopsOverFalliblesSuggestion {
var,
start_span: expr.span.with_hi(pat.span.lo()),
end_span: pat.span.between(arg.span),
};
cx.emit_span_lint(
FOR_LOOPS_OVER_FALLIBLES,
arg.span,
ForLoopsOverFalliblesDiag { article, ref_prefix, ty, sub, question_mark, suggestion },
);
}
}
fn extract_for_loop<'tcx>(expr: &Expr<'tcx>) -> Option<(&'tcx Pat<'tcx>, &'tcx Expr<'tcx>)> {
if let hir::ExprKind::DropTemps(e) = expr.kind
&& let hir::ExprKind::Match(iterexpr, [arm], hir::MatchSource::ForLoopDesugar) = e.kind
&& let hir::ExprKind::Call(_, [arg]) = iterexpr.kind
&& let hir::ExprKind::Loop(block, ..) = arm.body.kind
&& let [stmt] = block.stmts
&& let hir::StmtKind::Expr(e) = stmt.kind
&& let hir::ExprKind::Match(_, [_, some_arm], _) = e.kind
&& let hir::PatKind::Struct(_, [field], _) = some_arm.pat.kind
{
Some((field.pat, arg))
} else {
None
}
}
fn extract_iterator_next_call<'tcx>(
cx: &LateContext<'_>,
expr: &Expr<'tcx>,
) -> Option<&'tcx Expr<'tcx>> {
if let hir::ExprKind::MethodCall(_, recv, _, _) = expr.kind
&& cx
.typeck_results()
.type_dependent_def_id(expr.hir_id)
.is_some_and(|def_id| cx.tcx.is_lang_item(def_id, LangItem::IteratorNext))
{
Some(recv)
} else {
return None;
}
}
fn suggest_question_mark<'tcx>(
cx: &LateContext<'tcx>,
adt: ty::AdtDef<'tcx>,
args: ty::GenericArgsRef<'tcx>,
span: Span,
) -> bool {
let Some(body_id) = cx.enclosing_body else { return false };
let Some(into_iterator_did) = cx.tcx.get_diagnostic_item(sym::IntoIterator) else {
return false;
};
if !cx.tcx.is_diagnostic_item(sym::Result, adt.did()) {
return false;
}
{
let ty = cx.typeck_results().expr_ty(cx.tcx.hir().body(body_id).value);
let ty::Adt(ret_adt, ..) = ty.kind() else { return false };
if !cx.tcx.is_diagnostic_item(sym::Result, ret_adt.did()) {
return false;
}
}
let ty = args.type_at(0);
let infcx = cx.tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(&infcx);
let body_def_id = cx.tcx.hir().body_owner_def_id(body_id);
let cause =
ObligationCause::new(span, body_def_id, rustc_infer::traits::ObligationCauseCode::Misc);
ocx.register_bound(
cause,
cx.param_env,
infcx.tcx.erase_regions(ty),
into_iterator_did,
);
ocx.select_all_or_error().is_empty()
}