1use rustc_abi::ExternAbi;
2use rustc_data_structures::stack::ensure_sufficient_stack;
3use rustc_errors::Applicability;
4use rustc_hir::LangItem;
5use rustc_hir::def::DefKind;
6use rustc_hir::def_id::CRATE_DEF_ID;
7use rustc_middle::span_bug;
8use rustc_middle::thir::visit::{self, Visitor};
9use rustc_middle::thir::{BodyTy, Expr, ExprId, ExprKind, Thir};
10use rustc_middle::ty::{self, Ty, TyCtxt};
11use rustc_span::def_id::{DefId, LocalDefId};
12use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span};
13
14pub(crate) fn check_tail_calls(tcx: TyCtxt<'_>, def: LocalDefId) -> Result<(), ErrorGuaranteed> {
15 let (thir, expr) = tcx.thir_body(def)?;
16 let thir = &thir.borrow();
17
18 if thir.exprs.is_empty() {
20 return Ok(());
21 }
22
23 let is_closure = matches!(tcx.def_kind(def), DefKind::Closure);
24 let caller_ty = tcx.type_of(def).skip_binder();
25
26 let mut visitor = TailCallCkVisitor {
27 tcx,
28 thir,
29 found_errors: Ok(()),
30 typing_env: ty::TypingEnv::non_body_analysis(tcx, def),
32 is_closure,
33 caller_ty,
34 };
35
36 visitor.visit_expr(&thir[expr]);
37
38 visitor.found_errors
39}
40
41struct TailCallCkVisitor<'a, 'tcx> {
42 tcx: TyCtxt<'tcx>,
43 thir: &'a Thir<'tcx>,
44 typing_env: ty::TypingEnv<'tcx>,
45 is_closure: bool,
47 found_errors: Result<(), ErrorGuaranteed>,
50 caller_ty: Ty<'tcx>,
52}
53
54impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
55 fn check_tail_call(&mut self, call: &Expr<'_>, expr: &Expr<'_>) {
56 if self.is_closure {
57 self.report_in_closure(expr);
58 return;
59 }
60
61 let BodyTy::Fn(caller_sig) = self.thir.body_type else {
62 span_bug!(
63 call.span,
64 "`become` outside of functions should have been disallowed by hir_typeck"
65 )
66 };
67 let caller_sig = self.tcx.erase_and_anonymize_regions(caller_sig);
71
72 let ExprKind::Scope { value, .. } = call.kind else {
73 span_bug!(call.span, "expected scope, found: {call:?}")
74 };
75 let value = &self.thir[value];
76
77 if matches!(
78 value.kind,
79 ExprKind::Binary { .. }
80 | ExprKind::Unary { .. }
81 | ExprKind::AssignOp { .. }
82 | ExprKind::Index { .. }
83 ) {
84 self.report_builtin_op(call, expr);
85 return;
86 }
87
88 let ExprKind::Call { ty, fun, ref args, from_hir_call, fn_span } = value.kind else {
89 self.report_non_call(value, expr);
90 return;
91 };
92
93 if !from_hir_call {
94 self.report_op(ty, args, fn_span, expr);
95 }
96
97 if let &ty::FnDef(did, args) = ty.kind() {
98 let parent = self.tcx.parent(did);
102 if self.tcx.fn_trait_kind_from_def_id(parent).is_some()
103 && let Some(this) = args.first()
104 && let Some(this) = this.as_type()
105 {
106 if this.is_closure() {
107 self.report_calling_closure(&self.thir[fun], args[1].as_type().unwrap(), expr);
108 } else {
109 self.report_nonfn_callee(fn_span, self.thir[fun].span, this);
111 }
112
113 return;
116 };
117
118 if self.tcx.intrinsic(did).is_some() {
119 self.report_calling_intrinsic(expr);
120 }
121 }
122
123 let (ty::FnDef(..) | ty::FnPtr(..)) = ty.kind() else {
124 self.report_nonfn_callee(fn_span, self.thir[fun].span, ty);
125
126 return;
128 };
129
130 let callee_sig =
132 self.tcx.normalize_erasing_late_bound_regions(self.typing_env, ty.fn_sig(self.tcx));
133
134 if caller_sig.abi != callee_sig.abi {
135 self.report_abi_mismatch(expr.span, caller_sig.abi, callee_sig.abi);
136 }
137
138 if caller_sig.inputs_and_output != callee_sig.inputs_and_output {
147 self.report_signature_mismatch(
148 expr.span,
149 self.tcx.liberate_late_bound_regions(
150 CRATE_DEF_ID.to_def_id(),
151 self.caller_ty.fn_sig(self.tcx),
152 ),
153 self.tcx.liberate_late_bound_regions(CRATE_DEF_ID.to_def_id(), ty.fn_sig(self.tcx)),
154 );
155 }
156
157 {
158 let caller_needs_location = self.needs_location(self.caller_ty);
173
174 if caller_needs_location {
175 self.report_track_caller_caller(expr.span);
176 }
177 }
178
179 if caller_sig.c_variadic {
180 self.report_c_variadic_caller(expr.span);
181 }
182
183 if callee_sig.c_variadic {
184 self.report_c_variadic_callee(expr.span);
185 }
186 }
187
188 fn needs_location(&self, ty: Ty<'tcx>) -> bool {
193 if let &ty::FnDef(did, substs) = ty.kind() {
194 let instance =
195 ty::Instance::expect_resolve(self.tcx, self.typing_env, did, substs, DUMMY_SP);
196
197 instance.def.requires_caller_location(self.tcx)
198 } else {
199 false
200 }
201 }
202
203 fn report_in_closure(&mut self, expr: &Expr<'_>) {
204 let err = self.tcx.dcx().span_err(expr.span, "`become` is not allowed in closures");
205 self.found_errors = Err(err);
206 }
207
208 fn report_builtin_op(&mut self, value: &Expr<'_>, expr: &Expr<'_>) {
209 let err = self
210 .tcx
211 .dcx()
212 .struct_span_err(value.span, "`become` does not support operators")
213 .with_note("using `become` on a builtin operator is not useful")
214 .with_span_suggestion(
215 value.span.until(expr.span),
216 "try using `return` instead",
217 "return ",
218 Applicability::MachineApplicable,
219 )
220 .emit();
221 self.found_errors = Err(err);
222 }
223
224 fn report_op(&mut self, fun_ty: Ty<'_>, args: &[ExprId], fn_span: Span, expr: &Expr<'_>) {
225 let mut err =
226 self.tcx.dcx().struct_span_err(fn_span, "`become` does not support operators");
227
228 if let &ty::FnDef(did, _substs) = fun_ty.kind()
229 && let parent = self.tcx.parent(did)
230 && matches!(self.tcx.def_kind(parent), DefKind::Trait)
231 && let Some(method) = op_trait_as_method_name(self.tcx, parent)
232 {
233 match args {
234 &[arg] => {
235 let arg = &self.thir[arg];
236
237 err.multipart_suggestion(
238 "try using the method directly",
239 vec![
240 (fn_span.shrink_to_lo().until(arg.span), "(".to_owned()),
241 (arg.span.shrink_to_hi(), format!(").{method}()")),
242 ],
243 Applicability::MaybeIncorrect,
244 );
245 }
246 &[lhs, rhs] => {
247 let lhs = &self.thir[lhs];
248 let rhs = &self.thir[rhs];
249
250 err.multipart_suggestion(
251 "try using the method directly",
252 vec![
253 (lhs.span.shrink_to_lo(), format!("(")),
254 (lhs.span.between(rhs.span), format!(").{method}(")),
255 (rhs.span.between(expr.span.shrink_to_hi()), ")".to_owned()),
256 ],
257 Applicability::MaybeIncorrect,
258 );
259 }
260 _ => span_bug!(expr.span, "operator with more than 2 args? {args:?}"),
261 }
262 }
263
264 self.found_errors = Err(err.emit());
265 }
266
267 fn report_non_call(&mut self, value: &Expr<'_>, expr: &Expr<'_>) {
268 let err = self
269 .tcx
270 .dcx()
271 .struct_span_err(value.span, "`become` requires a function call")
272 .with_span_note(value.span, "not a function call")
273 .with_span_suggestion(
274 value.span.until(expr.span),
275 "try using `return` instead",
276 "return ",
277 Applicability::MaybeIncorrect,
278 )
279 .emit();
280 self.found_errors = Err(err);
281 }
282
283 fn report_calling_closure(&mut self, fun: &Expr<'_>, tupled_args: Ty<'_>, expr: &Expr<'_>) {
284 let underscored_args = match tupled_args.kind() {
285 ty::Tuple(tys) if tys.is_empty() => "".to_owned(),
286 ty::Tuple(tys) => std::iter::repeat("_, ").take(tys.len() - 1).chain(["_"]).collect(),
287 _ => "_".to_owned(),
288 };
289
290 let err = self
291 .tcx
292 .dcx()
293 .struct_span_err(expr.span, "tail calling closures directly is not allowed")
294 .with_multipart_suggestion(
295 "try casting the closure to a function pointer type",
296 vec![
297 (fun.span.shrink_to_lo(), "(".to_owned()),
298 (fun.span.shrink_to_hi(), format!(" as fn({underscored_args}) -> _)")),
299 ],
300 Applicability::MaybeIncorrect,
301 )
302 .emit();
303 self.found_errors = Err(err);
304 }
305
306 fn report_calling_intrinsic(&mut self, expr: &Expr<'_>) {
307 let err = self
308 .tcx
309 .dcx()
310 .struct_span_err(expr.span, "tail calling intrinsics is not allowed")
311 .emit();
312
313 self.found_errors = Err(err);
314 }
315
316 fn report_nonfn_callee(&mut self, call_sp: Span, fun_sp: Span, ty: Ty<'_>) {
317 let mut err = self
318 .tcx
319 .dcx()
320 .struct_span_err(
321 call_sp,
322 "tail calls can only be performed with function definitions or pointers",
323 )
324 .with_note(format!("callee has type `{ty}`"));
325
326 let mut ty = ty;
327 let mut refs = 0;
328 while ty.is_box() || ty.is_ref() {
329 ty = ty.builtin_deref(false).unwrap();
330 refs += 1;
331 }
332
333 if refs > 0 && ty.is_fn() {
334 let thing = if ty.is_fn_ptr() { "pointer" } else { "definition" };
335
336 let derefs =
337 std::iter::once('(').chain(std::iter::repeat_n('*', refs)).collect::<String>();
338
339 err.multipart_suggestion(
340 format!("consider dereferencing the expression to get a function {thing}"),
341 vec![(fun_sp.shrink_to_lo(), derefs), (fun_sp.shrink_to_hi(), ")".to_owned())],
342 Applicability::MachineApplicable,
343 );
344 }
345
346 let err = err.emit();
347 self.found_errors = Err(err);
348 }
349
350 fn report_abi_mismatch(&mut self, sp: Span, caller_abi: ExternAbi, callee_abi: ExternAbi) {
351 let err = self
352 .tcx
353 .dcx()
354 .struct_span_err(sp, "mismatched function ABIs")
355 .with_note("`become` requires caller and callee to have the same ABI")
356 .with_note(format!("caller ABI is `{caller_abi}`, while callee ABI is `{callee_abi}`"))
357 .emit();
358 self.found_errors = Err(err);
359 }
360
361 fn report_signature_mismatch(
362 &mut self,
363 sp: Span,
364 caller_sig: ty::FnSig<'_>,
365 callee_sig: ty::FnSig<'_>,
366 ) {
367 let err = self
368 .tcx
369 .dcx()
370 .struct_span_err(sp, "mismatched signatures")
371 .with_note("`become` requires caller and callee to have matching signatures")
372 .with_note(format!("caller signature: `{caller_sig}`"))
373 .with_note(format!("callee signature: `{callee_sig}`"))
374 .emit();
375 self.found_errors = Err(err);
376 }
377
378 fn report_track_caller_caller(&mut self, sp: Span) {
379 let err = self
380 .tcx
381 .dcx()
382 .struct_span_err(
383 sp,
384 "a function marked with `#[track_caller]` cannot perform a tail-call",
385 )
386 .emit();
387
388 self.found_errors = Err(err);
389 }
390
391 fn report_c_variadic_caller(&mut self, sp: Span) {
392 let err = self
393 .tcx
394 .dcx()
395 .struct_span_err(sp, "tail-calls are not allowed in c-variadic functions")
397 .emit();
398
399 self.found_errors = Err(err);
400 }
401
402 fn report_c_variadic_callee(&mut self, sp: Span) {
403 let err = self
404 .tcx
405 .dcx()
406 .struct_span_err(sp, "c-variadic functions can't be tail-called")
408 .emit();
409
410 self.found_errors = Err(err);
411 }
412}
413
414impl<'a, 'tcx> Visitor<'a, 'tcx> for TailCallCkVisitor<'a, 'tcx> {
415 fn thir(&self) -> &'a Thir<'tcx> {
416 &self.thir
417 }
418
419 fn visit_expr(&mut self, expr: &'a Expr<'tcx>) {
420 ensure_sufficient_stack(|| {
421 if let ExprKind::Become { value } = expr.kind {
422 let call = &self.thir[value];
423 self.check_tail_call(call, expr);
424 }
425
426 visit::walk_expr(self, expr);
427 });
428 }
429}
430
431fn op_trait_as_method_name(tcx: TyCtxt<'_>, trait_did: DefId) -> Option<&'static str> {
432 let m = match tcx.as_lang_item(trait_did)? {
433 LangItem::Add => "add",
434 LangItem::Sub => "sub",
435 LangItem::Mul => "mul",
436 LangItem::Div => "div",
437 LangItem::Rem => "rem",
438 LangItem::Neg => "neg",
439 LangItem::Not => "not",
440 LangItem::BitXor => "bitxor",
441 LangItem::BitAnd => "bitand",
442 LangItem::BitOr => "bitor",
443 LangItem::Shl => "shl",
444 LangItem::Shr => "shr",
445 LangItem::AddAssign => "add_assign",
446 LangItem::SubAssign => "sub_assign",
447 LangItem::MulAssign => "mul_assign",
448 LangItem::DivAssign => "div_assign",
449 LangItem::RemAssign => "rem_assign",
450 LangItem::BitXorAssign => "bitxor_assign",
451 LangItem::BitAndAssign => "bitand_assign",
452 LangItem::BitOrAssign => "bitor_assign",
453 LangItem::ShlAssign => "shl_assign",
454 LangItem::ShrAssign => "shr_assign",
455 LangItem::Index => "index",
456 LangItem::IndexMut => "index_mut",
457 _ => return None,
458 };
459
460 Some(m)
461}