Skip to main content

rustc_mir_build/
check_tail_calls.rs

1use rustc_abi::ExternAbi;
2use rustc_data_structures::stack::ensure_sufficient_stack;
3use rustc_errors::Applicability;
4use rustc_hir::LangItem;
5use rustc_hir::def::DefKind;
6use rustc_hir::def_id::CRATE_DEF_ID;
7use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags;
8use rustc_middle::span_bug;
9use rustc_middle::thir::visit::{self, Visitor};
10use rustc_middle::thir::{BodyTy, Expr, ExprId, ExprKind, Thir};
11use rustc_middle::ty::{self, Ty, TyCtxt};
12use rustc_span::def_id::{DefId, LocalDefId};
13use rustc_span::{ErrorGuaranteed, Span};
14
15pub(crate) fn check_tail_calls(tcx: TyCtxt<'_>, def: LocalDefId) -> Result<(), ErrorGuaranteed> {
16    let (thir, expr) = tcx.thir_body(def)?;
17    let thir = &thir.borrow();
18
19    // If `thir` is empty, a type error occurred, skip this body.
20    if thir.exprs.is_empty() {
21        return Ok(());
22    }
23
24    let is_closure = #[allow(non_exhaustive_omitted_patterns)] match tcx.def_kind(def) {
    DefKind::Closure => true,
    _ => false,
}matches!(tcx.def_kind(def), DefKind::Closure);
25
26    let mut visitor = TailCallCkVisitor {
27        tcx,
28        thir,
29        found_errors: Ok(()),
30        // FIXME(#132279): we're clearly in a body here.
31        typing_env: ty::TypingEnv::non_body_analysis(tcx, def),
32        is_closure,
33        caller_def_id: def,
34    };
35
36    visitor.visit_expr(&thir[expr]);
37
38    visitor.found_errors
39}
40
41struct TailCallCkVisitor<'a, 'tcx> {
42    tcx: TyCtxt<'tcx>,
43    thir: &'a Thir<'tcx>,
44    typing_env: ty::TypingEnv<'tcx>,
45    /// Whatever the currently checked body is one of a closure
46    is_closure: bool,
47    /// The result of the checks, `Err(_)` if there was a problem with some
48    /// tail call, `Ok(())` if all of them were fine.
49    found_errors: Result<(), ErrorGuaranteed>,
50    /// `LocalDefId` of the caller function.
51    caller_def_id: LocalDefId,
52}
53
54impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
55    fn check_tail_call(&mut self, call: &Expr<'_>, expr: &Expr<'_>) {
56        if self.is_closure {
57            self.report_in_closure(expr);
58            return;
59        }
60
61        let BodyTy::Fn(caller_sig) = self.thir.body_type else {
62            ::rustc_middle::util::bug::span_bug_fmt(call.span,
    format_args!("`become` outside of functions should have been disallowed by hir_typeck"))span_bug!(
63                call.span,
64                "`become` outside of functions should have been disallowed by hir_typeck"
65            )
66        };
67        // While the `caller_sig` does have its free regions erased, it does not have its
68        // binders anonymized. We call `erase_and_anonymize_regions` once again to anonymize any binders
69        // within the signature, such as in function pointer or `dyn Trait` args.
70        let caller_sig = self.tcx.erase_and_anonymize_regions(caller_sig);
71
72        let ExprKind::Scope { value, .. } = call.kind else {
73            ::rustc_middle::util::bug::span_bug_fmt(call.span,
    format_args!("expected scope, found: {0:?}", call))span_bug!(call.span, "expected scope, found: {call:?}")
74        };
75        let value = &self.thir[value];
76
77        if #[allow(non_exhaustive_omitted_patterns)] match value.kind {
    ExprKind::Binary { .. } | ExprKind::Unary { .. } | ExprKind::AssignOp { ..
        } | ExprKind::Index { .. } => true,
    _ => false,
}matches!(
78            value.kind,
79            ExprKind::Binary { .. }
80                | ExprKind::Unary { .. }
81                | ExprKind::AssignOp { .. }
82                | ExprKind::Index { .. }
83        ) {
84            self.report_builtin_op(call, expr);
85            return;
86        }
87
88        let ExprKind::Call { ty, fun, ref args, from_hir_call, fn_span } = value.kind else {
89            self.report_non_call(value, expr);
90            return;
91        };
92
93        if !from_hir_call {
94            self.report_op(ty, args, fn_span, expr);
95        }
96
97        if let &ty::FnDef(did, args) = ty.kind() {
98            // Closures in thir look something akin to
99            // `for<'a> extern "rust-call" fn(&'a [closure@...], ()) -> <[closure@...] as FnOnce<()>>::Output {<[closure@...] as Fn<()>>::call}`
100            // So we have to check for them in this weird way...
101            let parent = self.tcx.parent(did);
102            if self.tcx.fn_trait_kind_from_def_id(parent).is_some()
103                && let Some(this) = args.first()
104                && let Some(this) = this.as_type()
105            {
106                if this.is_closure() {
107                    self.report_calling_closure(&self.thir[fun], args[1].as_type().unwrap(), expr);
108                } else {
109                    // This can happen when tail calling `Box` that wraps a function
110                    self.report_nonfn_callee(fn_span, self.thir[fun].span, this);
111                }
112
113                // Tail calling is likely to cause unrelated errors (ABI, argument mismatches),
114                // skip them, producing an error about calling a closure is enough.
115                return;
116            };
117
118            if self.tcx.intrinsic(did).is_some() {
119                self.report_calling_intrinsic(expr);
120            }
121        }
122
123        let (ty::FnDef(..) | ty::FnPtr(..)) = ty.kind() else {
124            self.report_nonfn_callee(fn_span, self.thir[fun].span, ty);
125
126            // `fn_sig` below panics otherwise
127            return;
128        };
129
130        // Erase regions since tail calls don't care about lifetimes
131        let callee_sig =
132            self.tcx.normalize_erasing_late_bound_regions(self.typing_env, ty.fn_sig(self.tcx));
133
134        if caller_sig.abi() != callee_sig.abi() {
135            self.report_abi_mismatch(expr.span, caller_sig.abi(), callee_sig.abi());
136        }
137
138        if !callee_sig.abi().supports_guaranteed_tail_call() {
139            self.report_unsupported_abi(expr.span, callee_sig.abi());
140        }
141
142        // FIXME(explicit_tail_calls): this currently fails for cases where opaques are used.
143        // e.g.
144        // ```
145        // fn a() -> impl Sized { become b() } // ICE
146        // fn b() -> u8 { 0 }
147        // ```
148        // we should think what is the expected behavior here.
149        // (we should probably just accept this by revealing opaques?)
150        if caller_sig.inputs_and_output != callee_sig.inputs_and_output
151            && !#[allow(non_exhaustive_omitted_patterns)] match callee_sig.abi() {
    ExternAbi::RustTail => true,
    _ => false,
}matches!(callee_sig.abi(), ExternAbi::RustTail)
152        {
153            let caller_ty = self.tcx.type_of(self.caller_def_id).skip_binder();
154
155            self.report_signature_mismatch(
156                expr.span,
157                self.tcx.liberate_late_bound_regions(
158                    CRATE_DEF_ID.to_def_id(),
159                    caller_ty.fn_sig(self.tcx),
160                ),
161                self.tcx.liberate_late_bound_regions(CRATE_DEF_ID.to_def_id(), ty.fn_sig(self.tcx)),
162            );
163        }
164
165        {
166            // `#[track_caller]` affects the ABI of a function (by adding a location argument),
167            // so a `track_caller` can only tail call other `track_caller` functions.
168            //
169            // The issue is however that we can't know if a function is `track_caller` or not at
170            // this point (THIR can be polymorphic, we may have an unresolved trait function).
171            // We could only allow functions that we *can* resolve and *are* `track_caller`,
172            // but that would turn changing `track_caller`-ness into a breaking change,
173            // which is probably undesirable.
174            //
175            // Also note that we don't check callee's `track_caller`-ness at all, mostly for the
176            // reasons above, but also because we can always tailcall the shim we'd generate for
177            // coercing the function to an `fn()` pointer. (although in that case the tailcall is
178            // basically useless -- the shim calls the actual function, so tailcalling the shim is
179            // equivalent to calling the function)
180            let caller_needs_location = self.caller_needs_location();
181
182            if caller_needs_location {
183                self.report_track_caller_caller(expr.span);
184            }
185        }
186
187        if caller_sig.c_variadic() {
188            self.report_c_variadic_caller(expr.span);
189        }
190
191        if callee_sig.c_variadic() {
192            self.report_c_variadic_callee(expr.span);
193        }
194
195        for &arg_ty in callee_sig.inputs() {
196            if !arg_ty.is_sized(self.tcx, self.typing_env) {
197                self.report_unsized_argument(expr.span, arg_ty);
198            }
199        }
200    }
201
202    /// Returns true if the caller function needs a location argument
203    /// (i.e. if a function is marked as `#[track_caller]`)
204    fn caller_needs_location(&self) -> bool {
205        let flags = self.tcx.codegen_fn_attrs(self.caller_def_id).flags;
206        flags.contains(CodegenFnAttrFlags::TRACK_CALLER)
207    }
208
209    fn report_in_closure(&mut self, expr: &Expr<'_>) {
210        let err = self.tcx.dcx().span_err(expr.span, "`become` is not allowed in closures");
211        self.found_errors = Err(err);
212    }
213
214    fn report_builtin_op(&mut self, value: &Expr<'_>, expr: &Expr<'_>) {
215        let err = self
216            .tcx
217            .dcx()
218            .struct_span_err(value.span, "`become` does not support operators")
219            .with_note("using `become` on a builtin operator is not useful")
220            .with_span_suggestion(
221                value.span.until(expr.span),
222                "try using `return` instead",
223                "return ",
224                Applicability::MachineApplicable,
225            )
226            .emit();
227        self.found_errors = Err(err);
228    }
229
230    fn report_op(&mut self, fun_ty: Ty<'_>, args: &[ExprId], fn_span: Span, expr: &Expr<'_>) {
231        let mut err =
232            self.tcx.dcx().struct_span_err(fn_span, "`become` does not support operators");
233
234        if let &ty::FnDef(did, _substs) = fun_ty.kind()
235            && let parent = self.tcx.parent(did)
236            && #[allow(non_exhaustive_omitted_patterns)] match self.tcx.def_kind(parent) {
    DefKind::Trait => true,
    _ => false,
}matches!(self.tcx.def_kind(parent), DefKind::Trait)
237            && let Some(method) = op_trait_as_method_name(self.tcx, parent)
238        {
239            match args {
240                &[arg] => {
241                    let arg = &self.thir[arg];
242
243                    err.multipart_suggestion(
244                        "try using the method directly",
245                        ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [(fn_span.shrink_to_lo().until(arg.span), "(".to_owned()),
                (arg.span.shrink_to_hi(),
                    ::alloc::__export::must_use({
                            ::alloc::fmt::format(format_args!(").{0}()", method))
                        }))]))vec![
246                            (fn_span.shrink_to_lo().until(arg.span), "(".to_owned()),
247                            (arg.span.shrink_to_hi(), format!(").{method}()")),
248                        ],
249                        Applicability::MaybeIncorrect,
250                    );
251                }
252                &[lhs, rhs] => {
253                    let lhs = &self.thir[lhs];
254                    let rhs = &self.thir[rhs];
255
256                    err.multipart_suggestion(
257                        "try using the method directly",
258                        ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [(lhs.span.shrink_to_lo(),
                    ::alloc::__export::must_use({
                            ::alloc::fmt::format(format_args!("("))
                        })),
                (lhs.span.between(rhs.span),
                    ::alloc::__export::must_use({
                            ::alloc::fmt::format(format_args!(").{0}(", method))
                        })),
                (rhs.span.between(expr.span.shrink_to_hi()),
                    ")".to_owned())]))vec![
259                            (lhs.span.shrink_to_lo(), format!("(")),
260                            (lhs.span.between(rhs.span), format!(").{method}(")),
261                            (rhs.span.between(expr.span.shrink_to_hi()), ")".to_owned()),
262                        ],
263                        Applicability::MaybeIncorrect,
264                    );
265                }
266                _ => ::rustc_middle::util::bug::span_bug_fmt(expr.span,
    format_args!("operator with more than 2 args? {0:?}", args))span_bug!(expr.span, "operator with more than 2 args? {args:?}"),
267            }
268        }
269
270        self.found_errors = Err(err.emit());
271    }
272
273    fn report_non_call(&mut self, value: &Expr<'_>, expr: &Expr<'_>) {
274        let err = self
275            .tcx
276            .dcx()
277            .struct_span_err(value.span, "`become` requires a function call")
278            .with_span_note(value.span, "not a function call")
279            .with_span_suggestion(
280                value.span.until(expr.span),
281                "try using `return` instead",
282                "return ",
283                Applicability::MaybeIncorrect,
284            )
285            .emit();
286        self.found_errors = Err(err);
287    }
288
289    fn report_calling_closure(&mut self, fun: &Expr<'_>, tupled_args: Ty<'_>, expr: &Expr<'_>) {
290        let underscored_args = match tupled_args.kind() {
291            ty::Tuple(tys) if tys.is_empty() => "".to_owned(),
292            ty::Tuple(tys) => std::iter::repeat_n("_, ", tys.len() - 1).chain(["_"]).collect(),
293            _ => "_".to_owned(),
294        };
295
296        let err = self
297            .tcx
298            .dcx()
299            .struct_span_err(expr.span, "tail calling closures directly is not allowed")
300            .with_multipart_suggestion(
301                "try casting the closure to a function pointer type",
302                ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [(fun.span.shrink_to_lo(), "(".to_owned()),
                (fun.span.shrink_to_hi(),
                    ::alloc::__export::must_use({
                            ::alloc::fmt::format(format_args!(" as fn({0}) -> _)",
                                    underscored_args))
                        }))]))vec![
303                    (fun.span.shrink_to_lo(), "(".to_owned()),
304                    (fun.span.shrink_to_hi(), format!(" as fn({underscored_args}) -> _)")),
305                ],
306                Applicability::MaybeIncorrect,
307            )
308            .emit();
309        self.found_errors = Err(err);
310    }
311
312    fn report_calling_intrinsic(&mut self, expr: &Expr<'_>) {
313        let err = self
314            .tcx
315            .dcx()
316            .struct_span_err(expr.span, "tail calling intrinsics is not allowed")
317            .emit();
318
319        self.found_errors = Err(err);
320    }
321
322    fn report_nonfn_callee(&mut self, call_sp: Span, fun_sp: Span, ty: Ty<'_>) {
323        let mut err = self
324            .tcx
325            .dcx()
326            .struct_span_err(
327                call_sp,
328                "tail calls can only be performed with function definitions or pointers",
329            )
330            .with_note(::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("callee has type `{0}`", ty))
    })format!("callee has type `{ty}`"));
331
332        let mut ty = ty;
333        let mut refs = 0;
334        while ty.is_box() || ty.is_ref() {
335            ty = ty.builtin_deref(false).unwrap();
336            refs += 1;
337        }
338
339        if refs > 0 && ty.is_fn() {
340            let thing = if ty.is_fn_ptr() { "pointer" } else { "definition" };
341
342            let derefs =
343                std::iter::once('(').chain(std::iter::repeat_n('*', refs)).collect::<String>();
344
345            err.multipart_suggestion(
346                ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("consider dereferencing the expression to get a function {0}",
                thing))
    })format!("consider dereferencing the expression to get a function {thing}"),
347                ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [(fun_sp.shrink_to_lo(), derefs),
                (fun_sp.shrink_to_hi(), ")".to_owned())]))vec![(fun_sp.shrink_to_lo(), derefs), (fun_sp.shrink_to_hi(), ")".to_owned())],
348                Applicability::MachineApplicable,
349            );
350        }
351
352        let err = err.emit();
353        self.found_errors = Err(err);
354    }
355
356    fn report_abi_mismatch(&mut self, sp: Span, caller_abi: ExternAbi, callee_abi: ExternAbi) {
357        let err = self
358            .tcx
359            .dcx()
360            .struct_span_err(sp, "mismatched function ABIs")
361            .with_note("`become` requires caller and callee to have the same ABI")
362            .with_note(::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("caller ABI is `{0}`, while callee ABI is `{1}`",
                caller_abi, callee_abi))
    })format!("caller ABI is `{caller_abi}`, while callee ABI is `{callee_abi}`"))
363            .emit();
364        self.found_errors = Err(err);
365    }
366
367    fn report_unsupported_abi(&mut self, sp: Span, callee_abi: ExternAbi) {
368        let err = self
369            .tcx
370            .dcx()
371            .struct_span_err(sp, "ABI does not support guaranteed tail calls")
372            .with_note(::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("`become` is not supported for `extern {0}` functions",
                callee_abi))
    })format!("`become` is not supported for `extern {callee_abi}` functions"))
373            .emit();
374        self.found_errors = Err(err);
375    }
376
377    fn report_signature_mismatch(
378        &mut self,
379        sp: Span,
380        caller_sig: ty::FnSig<'_>,
381        callee_sig: ty::FnSig<'_>,
382    ) {
383        let err = self
384            .tcx
385            .dcx()
386            .struct_span_err(sp, "mismatched signatures")
387            .with_note("`become` requires caller and callee to have matching signatures")
388            .with_note(::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("caller signature: `{0}`",
                caller_sig))
    })format!("caller signature: `{caller_sig}`"))
389            .with_note(::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("callee signature: `{0}`",
                callee_sig))
    })format!("callee signature: `{callee_sig}`"))
390            .emit();
391        self.found_errors = Err(err);
392    }
393
394    fn report_track_caller_caller(&mut self, sp: Span) {
395        let err = self
396            .tcx
397            .dcx()
398            .struct_span_err(
399                sp,
400                "a function marked with `#[track_caller]` cannot perform a tail-call",
401            )
402            .emit();
403
404        self.found_errors = Err(err);
405    }
406
407    fn report_c_variadic_caller(&mut self, sp: Span) {
408        let err = self
409            .tcx
410            .dcx()
411            // FIXME(explicit_tail_calls): highlight the `...`
412            .struct_span_err(sp, "tail-calls are not allowed in c-variadic functions")
413            .emit();
414
415        self.found_errors = Err(err);
416    }
417
418    fn report_c_variadic_callee(&mut self, sp: Span) {
419        let err = self
420            .tcx
421            .dcx()
422            // FIXME(explicit_tail_calls): highlight the function or something...
423            .struct_span_err(sp, "c-variadic functions can't be tail-called")
424            .emit();
425
426        self.found_errors = Err(err);
427    }
428
429    fn report_unsized_argument(&mut self, sp: Span, arg_ty: Ty<'tcx>) {
430        let err = self
431            .tcx
432            .dcx()
433            .struct_span_err(sp, ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("unsized arguments cannot be used in a tail call"))
    })format!("unsized arguments cannot be used in a tail call"))
434            .with_note(::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("unsized argument of type `{0}`",
                arg_ty))
    })format!("unsized argument of type `{arg_ty}`"))
435            .emit();
436
437        self.found_errors = Err(err);
438    }
439}
440
441impl<'a, 'tcx> Visitor<'a, 'tcx> for TailCallCkVisitor<'a, 'tcx> {
442    fn thir(&self) -> &'a Thir<'tcx> {
443        &self.thir
444    }
445
446    fn visit_expr(&mut self, expr: &'a Expr<'tcx>) {
447        ensure_sufficient_stack(|| {
448            if let ExprKind::Become { value } = expr.kind {
449                let call = &self.thir[value];
450                self.check_tail_call(call, expr);
451            }
452
453            visit::walk_expr(self, expr);
454        });
455    }
456}
457
458fn op_trait_as_method_name(tcx: TyCtxt<'_>, trait_did: DefId) -> Option<&'static str> {
459    let m = match tcx.as_lang_item(trait_did)? {
460        LangItem::Add => "add",
461        LangItem::Sub => "sub",
462        LangItem::Mul => "mul",
463        LangItem::Div => "div",
464        LangItem::Rem => "rem",
465        LangItem::Neg => "neg",
466        LangItem::Not => "not",
467        LangItem::BitXor => "bitxor",
468        LangItem::BitAnd => "bitand",
469        LangItem::BitOr => "bitor",
470        LangItem::Shl => "shl",
471        LangItem::Shr => "shr",
472        LangItem::AddAssign => "add_assign",
473        LangItem::SubAssign => "sub_assign",
474        LangItem::MulAssign => "mul_assign",
475        LangItem::DivAssign => "div_assign",
476        LangItem::RemAssign => "rem_assign",
477        LangItem::BitXorAssign => "bitxor_assign",
478        LangItem::BitAndAssign => "bitand_assign",
479        LangItem::BitOrAssign => "bitor_assign",
480        LangItem::ShlAssign => "shl_assign",
481        LangItem::ShrAssign => "shr_assign",
482        LangItem::Index => "index",
483        LangItem::IndexMut => "index_mut",
484        _ => return None,
485    };
486
487    Some(m)
488}