rustc_mir_build/
check_tail_calls.rs

1use rustc_abi::ExternAbi;
2use rustc_data_structures::stack::ensure_sufficient_stack;
3use rustc_errors::Applicability;
4use rustc_hir::LangItem;
5use rustc_hir::def::DefKind;
6use rustc_middle::span_bug;
7use rustc_middle::thir::visit::{self, Visitor};
8use rustc_middle::thir::{BodyTy, Expr, ExprId, ExprKind, Thir};
9use rustc_middle::ty::{self, Ty, TyCtxt};
10use rustc_span::def_id::{DefId, LocalDefId};
11use rustc_span::{DUMMY_SP, ErrorGuaranteed, Span};
12
13pub(crate) fn check_tail_calls(tcx: TyCtxt<'_>, def: LocalDefId) -> Result<(), ErrorGuaranteed> {
14    let (thir, expr) = tcx.thir_body(def)?;
15    let thir = &thir.borrow();
16
17    // If `thir` is empty, a type error occurred, skip this body.
18    if thir.exprs.is_empty() {
19        return Ok(());
20    }
21
22    let is_closure = matches!(tcx.def_kind(def), DefKind::Closure);
23    let caller_ty = tcx.type_of(def).skip_binder();
24
25    let mut visitor = TailCallCkVisitor {
26        tcx,
27        thir,
28        found_errors: Ok(()),
29        // FIXME(#132279): we're clearly in a body here.
30        typing_env: ty::TypingEnv::non_body_analysis(tcx, def),
31        is_closure,
32        caller_ty,
33    };
34
35    visitor.visit_expr(&thir[expr]);
36
37    visitor.found_errors
38}
39
40struct TailCallCkVisitor<'a, 'tcx> {
41    tcx: TyCtxt<'tcx>,
42    thir: &'a Thir<'tcx>,
43    typing_env: ty::TypingEnv<'tcx>,
44    /// Whatever the currently checked body is one of a closure
45    is_closure: bool,
46    /// The result of the checks, `Err(_)` if there was a problem with some
47    /// tail call, `Ok(())` if all of them were fine.
48    found_errors: Result<(), ErrorGuaranteed>,
49    /// Type of the caller function.
50    caller_ty: Ty<'tcx>,
51}
52
53impl<'tcx> TailCallCkVisitor<'_, 'tcx> {
54    fn check_tail_call(&mut self, call: &Expr<'_>, expr: &Expr<'_>) {
55        if self.is_closure {
56            self.report_in_closure(expr);
57            return;
58        }
59
60        let BodyTy::Fn(caller_sig) = self.thir.body_type else {
61            span_bug!(
62                call.span,
63                "`become` outside of functions should have been disallowed by hit_typeck"
64            )
65        };
66
67        let ExprKind::Scope { value, .. } = call.kind else {
68            span_bug!(call.span, "expected scope, found: {call:?}")
69        };
70        let value = &self.thir[value];
71
72        if matches!(
73            value.kind,
74            ExprKind::Binary { .. }
75                | ExprKind::Unary { .. }
76                | ExprKind::AssignOp { .. }
77                | ExprKind::Index { .. }
78        ) {
79            self.report_builtin_op(call, expr);
80            return;
81        }
82
83        let ExprKind::Call { ty, fun, ref args, from_hir_call, fn_span } = value.kind else {
84            self.report_non_call(value, expr);
85            return;
86        };
87
88        if !from_hir_call {
89            self.report_op(ty, args, fn_span, expr);
90        }
91
92        // Closures in thir look something akin to
93        // `for<'a> extern "rust-call" fn(&'a [closure@...], ()) -> <[closure@...] as FnOnce<()>>::Output {<[closure@...] as Fn<()>>::call}`
94        // So we have to check for them in this weird way...
95        if let &ty::FnDef(did, args) = ty.kind() {
96            let parent = self.tcx.parent(did);
97            if self.tcx.fn_trait_kind_from_def_id(parent).is_some()
98                && args.first().and_then(|arg| arg.as_type()).is_some_and(Ty::is_closure)
99            {
100                self.report_calling_closure(&self.thir[fun], args[1].as_type().unwrap(), expr);
101
102                // Tail calling is likely to cause unrelated errors (ABI, argument mismatches),
103                // skip them, producing an error about calling a closure is enough.
104                return;
105            };
106        }
107
108        // Erase regions since tail calls don't care about lifetimes
109        let callee_sig =
110            self.tcx.normalize_erasing_late_bound_regions(self.typing_env, ty.fn_sig(self.tcx));
111
112        if caller_sig.abi != callee_sig.abi {
113            self.report_abi_mismatch(expr.span, caller_sig.abi, callee_sig.abi);
114        }
115
116        if caller_sig.inputs_and_output != callee_sig.inputs_and_output {
117            if caller_sig.inputs() != callee_sig.inputs() {
118                self.report_arguments_mismatch(expr.span, caller_sig, callee_sig);
119            }
120
121            // FIXME(explicit_tail_calls): this currently fails for cases where opaques are used.
122            // e.g.
123            // ```
124            // fn a() -> impl Sized { become b() } // ICE
125            // fn b() -> u8 { 0 }
126            // ```
127            // we should think what is the expected behavior here.
128            // (we should probably just accept this by revealing opaques?)
129            if caller_sig.output() != callee_sig.output() {
130                span_bug!(expr.span, "hir typeck should have checked the return type already");
131            }
132        }
133
134        {
135            // `#[track_caller]` affects the ABI of a function (by adding a location argument),
136            // so a `track_caller` can only tail call other `track_caller` functions.
137            //
138            // The issue is however that we can't know if a function is `track_caller` or not at
139            // this point (THIR can be polymorphic, we may have an unresolved trait function).
140            // We could only allow functions that we *can* resolve and *are* `track_caller`,
141            // but that would turn changing `track_caller`-ness into a breaking change,
142            // which is probably undesirable.
143            //
144            // Also note that we don't check callee's `track_caller`-ness at all, mostly for the
145            // reasons above, but also because we can always tailcall the shim we'd generate for
146            // coercing the function to an `fn()` pointer. (although in that case the tailcall is
147            // basically useless -- the shim calls the actual function, so tailcalling the shim is
148            // equivalent to calling the function)
149            let caller_needs_location = self.needs_location(self.caller_ty);
150
151            if caller_needs_location {
152                self.report_track_caller_caller(expr.span);
153            }
154        }
155
156        if caller_sig.c_variadic {
157            self.report_c_variadic_caller(expr.span);
158        }
159
160        if callee_sig.c_variadic {
161            self.report_c_variadic_callee(expr.span);
162        }
163    }
164
165    /// Returns true if function of type `ty` needs location argument
166    /// (i.e. if a function is marked as `#[track_caller]`).
167    ///
168    /// Panics if the function's instance can't be immediately resolved.
169    fn needs_location(&self, ty: Ty<'tcx>) -> bool {
170        if let &ty::FnDef(did, substs) = ty.kind() {
171            let instance =
172                ty::Instance::expect_resolve(self.tcx, self.typing_env, did, substs, DUMMY_SP);
173
174            instance.def.requires_caller_location(self.tcx)
175        } else {
176            false
177        }
178    }
179
180    fn report_in_closure(&mut self, expr: &Expr<'_>) {
181        let err = self.tcx.dcx().span_err(expr.span, "`become` is not allowed in closures");
182        self.found_errors = Err(err);
183    }
184
185    fn report_builtin_op(&mut self, value: &Expr<'_>, expr: &Expr<'_>) {
186        let err = self
187            .tcx
188            .dcx()
189            .struct_span_err(value.span, "`become` does not support operators")
190            .with_note("using `become` on a builtin operator is not useful")
191            .with_span_suggestion(
192                value.span.until(expr.span),
193                "try using `return` instead",
194                "return ",
195                Applicability::MachineApplicable,
196            )
197            .emit();
198        self.found_errors = Err(err);
199    }
200
201    fn report_op(&mut self, fun_ty: Ty<'_>, args: &[ExprId], fn_span: Span, expr: &Expr<'_>) {
202        let mut err =
203            self.tcx.dcx().struct_span_err(fn_span, "`become` does not support operators");
204
205        if let &ty::FnDef(did, _substs) = fun_ty.kind()
206            && let parent = self.tcx.parent(did)
207            && matches!(self.tcx.def_kind(parent), DefKind::Trait)
208            && let Some(method) = op_trait_as_method_name(self.tcx, parent)
209        {
210            match args {
211                &[arg] => {
212                    let arg = &self.thir[arg];
213
214                    err.multipart_suggestion(
215                        "try using the method directly",
216                        vec![
217                            (fn_span.shrink_to_lo().until(arg.span), "(".to_owned()),
218                            (arg.span.shrink_to_hi(), format!(").{method}()")),
219                        ],
220                        Applicability::MaybeIncorrect,
221                    );
222                }
223                &[lhs, rhs] => {
224                    let lhs = &self.thir[lhs];
225                    let rhs = &self.thir[rhs];
226
227                    err.multipart_suggestion(
228                        "try using the method directly",
229                        vec![
230                            (lhs.span.shrink_to_lo(), format!("(")),
231                            (lhs.span.between(rhs.span), format!(").{method}(")),
232                            (rhs.span.between(expr.span.shrink_to_hi()), ")".to_owned()),
233                        ],
234                        Applicability::MaybeIncorrect,
235                    );
236                }
237                _ => span_bug!(expr.span, "operator with more than 2 args? {args:?}"),
238            }
239        }
240
241        self.found_errors = Err(err.emit());
242    }
243
244    fn report_non_call(&mut self, value: &Expr<'_>, expr: &Expr<'_>) {
245        let err = self
246            .tcx
247            .dcx()
248            .struct_span_err(value.span, "`become` requires a function call")
249            .with_span_note(value.span, "not a function call")
250            .with_span_suggestion(
251                value.span.until(expr.span),
252                "try using `return` instead",
253                "return ",
254                Applicability::MaybeIncorrect,
255            )
256            .emit();
257        self.found_errors = Err(err);
258    }
259
260    fn report_calling_closure(&mut self, fun: &Expr<'_>, tupled_args: Ty<'_>, expr: &Expr<'_>) {
261        let underscored_args = match tupled_args.kind() {
262            ty::Tuple(tys) if tys.is_empty() => "".to_owned(),
263            ty::Tuple(tys) => std::iter::repeat("_, ").take(tys.len() - 1).chain(["_"]).collect(),
264            _ => "_".to_owned(),
265        };
266
267        let err = self
268            .tcx
269            .dcx()
270            .struct_span_err(expr.span, "tail calling closures directly is not allowed")
271            .with_multipart_suggestion(
272                "try casting the closure to a function pointer type",
273                vec![
274                    (fun.span.shrink_to_lo(), "(".to_owned()),
275                    (fun.span.shrink_to_hi(), format!(" as fn({underscored_args}) -> _)")),
276                ],
277                Applicability::MaybeIncorrect,
278            )
279            .emit();
280        self.found_errors = Err(err);
281    }
282
283    fn report_abi_mismatch(&mut self, sp: Span, caller_abi: ExternAbi, callee_abi: ExternAbi) {
284        let err = self
285            .tcx
286            .dcx()
287            .struct_span_err(sp, "mismatched function ABIs")
288            .with_note("`become` requires caller and callee to have the same ABI")
289            .with_note(format!("caller ABI is `{caller_abi}`, while callee ABI is `{callee_abi}`"))
290            .emit();
291        self.found_errors = Err(err);
292    }
293
294    fn report_arguments_mismatch(
295        &mut self,
296        sp: Span,
297        caller_sig: ty::FnSig<'_>,
298        callee_sig: ty::FnSig<'_>,
299    ) {
300        let err = self
301            .tcx
302            .dcx()
303            .struct_span_err(sp, "mismatched signatures")
304            .with_note("`become` requires caller and callee to have matching signatures")
305            .with_note(format!("caller signature: `{caller_sig}`"))
306            .with_note(format!("callee signature: `{callee_sig}`"))
307            .emit();
308        self.found_errors = Err(err);
309    }
310
311    fn report_track_caller_caller(&mut self, sp: Span) {
312        let err = self
313            .tcx
314            .dcx()
315            .struct_span_err(
316                sp,
317                "a function marked with `#[track_caller]` cannot perform a tail-call",
318            )
319            .emit();
320
321        self.found_errors = Err(err);
322    }
323
324    fn report_c_variadic_caller(&mut self, sp: Span) {
325        let err = self
326            .tcx
327            .dcx()
328            // FIXME(explicit_tail_calls): highlight the `...`
329            .struct_span_err(sp, "tail-calls are not allowed in c-variadic functions")
330            .emit();
331
332        self.found_errors = Err(err);
333    }
334
335    fn report_c_variadic_callee(&mut self, sp: Span) {
336        let err = self
337            .tcx
338            .dcx()
339            // FIXME(explicit_tail_calls): highlight the function or something...
340            .struct_span_err(sp, "c-variadic functions can't be tail-called")
341            .emit();
342
343        self.found_errors = Err(err);
344    }
345}
346
347impl<'a, 'tcx> Visitor<'a, 'tcx> for TailCallCkVisitor<'a, 'tcx> {
348    fn thir(&self) -> &'a Thir<'tcx> {
349        &self.thir
350    }
351
352    fn visit_expr(&mut self, expr: &'a Expr<'tcx>) {
353        ensure_sufficient_stack(|| {
354            if let ExprKind::Become { value } = expr.kind {
355                let call = &self.thir[value];
356                self.check_tail_call(call, expr);
357            }
358
359            visit::walk_expr(self, expr);
360        });
361    }
362}
363
364fn op_trait_as_method_name(tcx: TyCtxt<'_>, trait_did: DefId) -> Option<&'static str> {
365    let m = match tcx.as_lang_item(trait_did)? {
366        LangItem::Add => "add",
367        LangItem::Sub => "sub",
368        LangItem::Mul => "mul",
369        LangItem::Div => "div",
370        LangItem::Rem => "rem",
371        LangItem::Neg => "neg",
372        LangItem::Not => "not",
373        LangItem::BitXor => "bitxor",
374        LangItem::BitAnd => "bitand",
375        LangItem::BitOr => "bitor",
376        LangItem::Shl => "shl",
377        LangItem::Shr => "shr",
378        LangItem::AddAssign => "add_assign",
379        LangItem::SubAssign => "sub_assign",
380        LangItem::MulAssign => "mul_assign",
381        LangItem::DivAssign => "div_assign",
382        LangItem::RemAssign => "rem_assign",
383        LangItem::BitXorAssign => "bitxor_assign",
384        LangItem::BitAndAssign => "bitand_assign",
385        LangItem::BitOrAssign => "bitor_assign",
386        LangItem::ShlAssign => "shl_assign",
387        LangItem::ShrAssign => "shr_assign",
388        LangItem::Index => "index",
389        LangItem::IndexMut => "index_mut",
390        _ => return None,
391    };
392
393    Some(m)
394}