Skip to main content

rustc_hir_analysis/check/compare_impl_item/
refine.rs

1use itertools::Itertools as _;
2use rustc_data_structures::fx::FxIndexSet;
3use rustc_hir as hir;
4use rustc_hir::def_id::{DefId, LocalDefId};
5use rustc_infer::infer::TyCtxtInferExt;
6use rustc_lint_defs::builtin::{REFINING_IMPL_TRAIT_INTERNAL, REFINING_IMPL_TRAIT_REACHABLE};
7use rustc_middle::span_bug;
8use rustc_middle::traits::ObligationCause;
9use rustc_middle::ty::print::{with_no_trimmed_paths, with_types_for_signature};
10use rustc_middle::ty::{
11    self, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperVisitable, TypeVisitable,
12    TypeVisitableExt, TypeVisitor, TypingMode, Unnormalized,
13};
14use rustc_span::Span;
15use rustc_trait_selection::regions::InferCtxtRegionExt;
16use rustc_trait_selection::traits::{ObligationCtxt, elaborate, normalize_param_env_or_error};
17
18/// Check that an implementation does not refine an RPITIT from a trait method signature.
19pub(crate) fn check_refining_return_position_impl_trait_in_trait<'tcx>(
20    tcx: TyCtxt<'tcx>,
21    impl_m: ty::AssocItem,
22    trait_m: ty::AssocItem,
23    impl_trait_ref: ty::TraitRef<'tcx>,
24) {
25    if !tcx.impl_method_has_trait_impl_trait_tys(impl_m.def_id) {
26        return;
27    }
28
29    // unreachable traits don't have any library guarantees, there's no need to do this check.
30    let is_internal = trait_m
31        .container_id(tcx)
32        .as_local()
33        .is_some_and(|trait_def_id| !tcx.effective_visibilities(()).is_reachable(trait_def_id))
34        // If a type in the trait ref is private, then there's also no reason to do this check.
35        || impl_trait_ref.args.iter().any(|arg| {
36            if let Some(ty) = arg.as_type()
37                && let Some(self_visibility) = type_visibility(tcx, ty)
38            {
39                return !self_visibility.is_public();
40            }
41            false
42        });
43
44    let impl_def_id = impl_m.container_id(tcx);
45    let impl_m_args = ty::GenericArgs::identity_for_item(tcx, impl_m.def_id);
46    let trait_m_to_impl_m_args = impl_m_args.rebase_onto(tcx, impl_def_id, impl_trait_ref.args);
47    let bound_trait_m_sig =
48        tcx.fn_sig(trait_m.def_id).instantiate(tcx, trait_m_to_impl_m_args).skip_norm_wip();
49    let trait_m_sig = tcx.liberate_late_bound_regions(impl_m.def_id, bound_trait_m_sig);
50    // replace the self type of the trait ref with `Self` so that diagnostics render better.
51    let trait_m_sig_with_self_for_diag = tcx.liberate_late_bound_regions(
52        impl_m.def_id,
53        tcx.fn_sig(trait_m.def_id)
54            .instantiate(
55                tcx,
56                tcx.mk_args_from_iter(
57                    [tcx.types.self_param.into()]
58                        .into_iter()
59                        .chain(trait_m_to_impl_m_args.iter().skip(1)),
60                ),
61            )
62            .skip_norm_wip(),
63    );
64
65    let Ok(hidden_tys) = tcx.collect_return_position_impl_trait_in_trait_tys(impl_m.def_id) else {
66        // Error already emitted, no need to delay another.
67        return;
68    };
69
70    if hidden_tys.items().any(|(_, &ty)| ty.skip_binder().references_error()) {
71        return;
72    }
73
74    let mut collector = ImplTraitInTraitCollector { tcx, types: FxIndexSet::default() };
75    trait_m_sig.visit_with(&mut collector);
76
77    // Bound that we find on RPITITs in the trait signature.
78    let mut trait_bounds = ::alloc::vec::Vec::new()vec![];
79    // Bounds that we find on the RPITITs in the impl signature.
80    let mut impl_bounds = ::alloc::vec::Vec::new()vec![];
81    // Pairs of trait and impl opaques.
82    let mut pairs = ::alloc::vec::Vec::new()vec![];
83
84    for trait_projection in collector.types.into_iter().rev() {
85        let impl_opaque_args = trait_projection.args.rebase_onto(tcx, trait_m.def_id, impl_m_args);
86        let hidden_ty = hidden_tys[&trait_projection.kind.def_id()]
87            .instantiate(tcx, impl_opaque_args)
88            .skip_norm_wip();
89
90        // If the hidden type is not an opaque, then we have "refined" the trait signature.
91        let ty::Alias(
92            impl_opaque @ ty::AliasTy { kind: ty::Opaque { def_id: impl_opaque_def_id }, .. },
93        ) = *hidden_ty.kind()
94        else {
95            report_mismatched_rpitit_signature(
96                tcx,
97                trait_m_sig_with_self_for_diag,
98                trait_m.def_id,
99                impl_m.def_id,
100                None,
101                is_internal,
102            );
103            return;
104        };
105
106        // This opaque also needs to be from the impl method -- otherwise,
107        // it's a refinement to a TAIT.
108        if !tcx.hir_get_if_local(impl_opaque_def_id).is_some_and(|node| {
109            #[allow(non_exhaustive_omitted_patterns)] match node.expect_opaque_ty().origin
    {
    hir::OpaqueTyOrigin::AsyncFn { parent, .. } |
        hir::OpaqueTyOrigin::FnReturn { parent, .. } if
        parent == impl_m.def_id.expect_local() => true,
    _ => false,
}matches!(
110                node.expect_opaque_ty().origin,
111                hir::OpaqueTyOrigin::AsyncFn { parent, .. }  | hir::OpaqueTyOrigin::FnReturn { parent, .. }
112                    if parent == impl_m.def_id.expect_local()
113            )
114        }) {
115            report_mismatched_rpitit_signature(
116                tcx,
117                trait_m_sig_with_self_for_diag,
118                trait_m.def_id,
119                impl_m.def_id,
120                None,
121                is_internal,
122            );
123            return;
124        }
125
126        trait_bounds.extend(
127            tcx.item_bounds(trait_projection.kind.def_id())
128                .iter_instantiated(tcx, trait_projection.args)
129                .map(Unnormalized::skip_norm_wip),
130        );
131        impl_bounds.extend(elaborate(
132            tcx,
133            tcx.explicit_item_bounds(impl_opaque_def_id)
134                .iter_instantiated_copied(tcx, impl_opaque.args)
135                .map(Unnormalized::skip_norm_wip),
136        ));
137
138        pairs.push((trait_projection, impl_opaque));
139    }
140
141    let hybrid_preds = tcx
142        .predicates_of(impl_def_id)
143        .instantiate_identity(tcx)
144        .into_iter()
145        .chain(tcx.predicates_of(trait_m.def_id).instantiate_own(tcx, trait_m_to_impl_m_args))
146        .map(|(clause, _)| clause.skip_norm_wip());
147    let param_env = ty::ParamEnv::new(tcx.mk_clauses_from_iter(hybrid_preds));
148    let param_env = normalize_param_env_or_error(tcx, param_env, ObligationCause::dummy());
149
150    let ref infcx = tcx.infer_ctxt().build(TypingMode::non_body_analysis());
151    let ocx = ObligationCtxt::new(infcx);
152
153    // Normalize the bounds. This has two purposes:
154    //
155    // 1. Project the RPITIT projections from the trait to the opaques on the impl,
156    //    which means that they don't need to be mapped manually.
157    //
158    // 2. Deeply normalize any other projections that show up in the bound. That makes sure
159    //    that we don't consider `tests/ui/async-await/in-trait/async-associated-types.rs`
160    //    or `tests/ui/impl-trait/in-trait/refine-normalize.rs` to be refining.
161    let Ok((trait_bounds, impl_bounds)) = ocx.deeply_normalize(
162        &ObligationCause::dummy(),
163        param_env,
164        Unnormalized::new_wip((trait_bounds, impl_bounds)),
165    ) else {
166        tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (selection)");
167        return;
168    };
169
170    // Since we've normalized things, we need to resolve regions, since we'll
171    // possibly have introduced region vars during projection. We don't expect
172    // this resolution to have incurred any region errors -- but if we do, then
173    // just delay a bug.
174    let mut implied_wf_types = FxIndexSet::default();
175    implied_wf_types.extend(trait_m_sig.inputs_and_output);
176    implied_wf_types.extend(ocx.normalize(
177        &ObligationCause::dummy(),
178        param_env,
179        Unnormalized::new_wip(trait_m_sig.inputs_and_output),
180    ));
181    if !ocx.evaluate_obligations_error_on_ambiguity().is_empty() {
182        tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (selection)");
183        return;
184    }
185    let errors = infcx.resolve_regions(impl_m.def_id.expect_local(), param_env, implied_wf_types);
186    if !errors.is_empty() {
187        tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (regions)");
188        return;
189    }
190    // Resolve any lifetime variables that may have been introduced during normalization.
191    let Ok((trait_bounds, impl_bounds)) = infcx.fully_resolve((trait_bounds, impl_bounds)) else {
192        // If resolution didn't fully complete, we cannot continue checking RPITIT refinement, and
193        // delay a bug as the original code contains load-bearing errors.
194        tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (resolution)");
195        return;
196    };
197
198    if trait_bounds.references_error() || impl_bounds.references_error() {
199        return;
200    }
201
202    // For quicker lookup, use an `IndexSet` (we don't use one earlier because
203    // it's not foldable..).
204    // Also, We have to anonymize binders in these types because they may contain
205    // `BrNamed` bound vars, which contain unique `DefId`s which correspond to syntax
206    // locations that we don't care about when checking bound equality.
207    let trait_bounds = FxIndexSet::from_iter(trait_bounds.fold_with(&mut Anonymize { tcx }));
208    let impl_bounds = impl_bounds.fold_with(&mut Anonymize { tcx });
209
210    // Find any clauses that are present in the impl's RPITITs that are not
211    // present in the trait's RPITITs. This will trigger on trivial predicates,
212    // too, since we *do not* use the trait solver to prove that the RPITIT's
213    // bounds are not stronger -- we're doing a simple, syntactic compatibility
214    // check between bounds. This is strictly forwards compatible, though.
215    for (clause, span) in impl_bounds {
216        if !trait_bounds.contains(&clause) {
217            report_mismatched_rpitit_signature(
218                tcx,
219                trait_m_sig_with_self_for_diag,
220                trait_m.def_id,
221                impl_m.def_id,
222                Some(span),
223                is_internal,
224            );
225            return;
226        }
227    }
228
229    // Make sure that the RPITIT doesn't capture fewer regions than
230    // the trait definition. We hard-error if it captures *more*, since that
231    // is literally unrepresentable in the type system; however, we may be
232    // promising stronger outlives guarantees if we capture *fewer* regions.
233    for (trait_projection, impl_opaque) in pairs {
234        let impl_variances = tcx.variances_of(impl_opaque.kind.def_id());
235        let impl_captures: FxIndexSet<_> = impl_opaque
236            .args
237            .iter()
238            .zip_eq(impl_variances)
239            .filter(|(_, v)| **v == ty::Invariant)
240            .map(|(arg, _)| arg)
241            .collect();
242
243        let trait_variances = tcx.variances_of(trait_projection.kind.def_id());
244        let mut trait_captures = FxIndexSet::default();
245        for (arg, variance) in trait_projection.args.iter().zip_eq(trait_variances) {
246            if *variance != ty::Invariant {
247                continue;
248            }
249            arg.visit_with(&mut CollectParams { params: &mut trait_captures });
250        }
251
252        if !trait_captures.iter().all(|arg| impl_captures.contains(arg)) {
253            report_mismatched_rpitit_captures(
254                tcx,
255                impl_opaque.kind.def_id().expect_local(),
256                trait_captures,
257                is_internal,
258            );
259        }
260    }
261}
262
263struct ImplTraitInTraitCollector<'tcx> {
264    tcx: TyCtxt<'tcx>,
265    types: FxIndexSet<ty::AliasTy<'tcx>>,
266}
267
268impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitCollector<'tcx> {
269    fn visit_ty(&mut self, ty: Ty<'tcx>) {
270        if let ty::Alias(proj @ ty::AliasTy { kind: ty::Projection { def_id }, .. }) = *ty.kind()
271            && self.tcx.is_impl_trait_in_trait(def_id)
272        {
273            if self.types.insert(proj) {
274                for (pred, _) in self
275                    .tcx
276                    .explicit_item_bounds(def_id)
277                    .iter_instantiated_copied(self.tcx, proj.args)
278                    .map(Unnormalized::skip_norm_wip)
279                {
280                    pred.visit_with(self);
281                }
282            }
283        } else {
284            ty.super_visit_with(self);
285        }
286    }
287}
288
289fn report_mismatched_rpitit_signature<'tcx>(
290    tcx: TyCtxt<'tcx>,
291    trait_m_sig: ty::FnSig<'tcx>,
292    trait_m_def_id: DefId,
293    impl_m_def_id: DefId,
294    unmatched_bound: Option<Span>,
295    is_internal: bool,
296) {
297    let mapping = std::iter::zip(
298        tcx.fn_sig(trait_m_def_id).skip_binder().bound_vars(),
299        tcx.fn_sig(impl_m_def_id).skip_binder().bound_vars(),
300    )
301    .enumerate()
302    .filter_map(|(idx, (impl_bv, trait_bv))| {
303        if let ty::BoundVariableKind::Region(impl_bv) = impl_bv
304            && let ty::BoundVariableKind::Region(trait_bv) = trait_bv
305        {
306            let var = ty::BoundVar::from_usize(idx);
307            Some((
308                ty::LateParamRegionKind::from_bound(var, impl_bv),
309                ty::LateParamRegionKind::from_bound(var, trait_bv),
310            ))
311        } else {
312            None
313        }
314    })
315    .collect();
316
317    let mut return_ty = trait_m_sig.output().fold_with(&mut super::RemapLateParam { tcx, mapping });
318
319    if tcx.asyncness(impl_m_def_id).is_async() && tcx.asyncness(trait_m_def_id).is_async() {
320        let &ty::Alias(ty::AliasTy {
321            kind: ty::Projection { def_id: future_ty_def_id }, args, ..
322        }) = return_ty.kind()
323        else {
324            ::rustc_middle::util::bug::span_bug_fmt(tcx.def_span(trait_m_def_id),
    format_args!("expected return type of async fn in trait to be a AFIT projection"));span_bug!(
325                tcx.def_span(trait_m_def_id),
326                "expected return type of async fn in trait to be a AFIT projection"
327            );
328        };
329        let Some(future_output_ty) = tcx
330            .explicit_item_bounds(future_ty_def_id)
331            .iter_instantiated_copied(tcx, args)
332            .map(Unnormalized::skip_norm_wip)
333            .find_map(|(clause, _)| match clause.kind().no_bound_vars()? {
334                ty::ClauseKind::Projection(proj) => proj.term.as_type(),
335                _ => None,
336            })
337        else {
338            ::rustc_middle::util::bug::span_bug_fmt(tcx.def_span(trait_m_def_id),
    format_args!("expected `Future` projection bound in AFIT"));span_bug!(tcx.def_span(trait_m_def_id), "expected `Future` projection bound in AFIT");
339        };
340        return_ty = future_output_ty;
341    }
342
343    let (span, impl_return_span, pre, post) =
344        match tcx.hir_node_by_def_id(impl_m_def_id.expect_local()).fn_decl().unwrap().output {
345            hir::FnRetTy::DefaultReturn(span) => (tcx.def_span(impl_m_def_id), span, "-> ", " "),
346            hir::FnRetTy::Return(ty) => (ty.span, ty.span, "", ""),
347        };
348    let trait_return_span =
349        tcx.hir_get_if_local(trait_m_def_id).map(|node| match node.fn_decl().unwrap().output {
350            hir::FnRetTy::DefaultReturn(_) => tcx.def_span(trait_m_def_id),
351            hir::FnRetTy::Return(ty) => ty.span,
352        });
353
354    // Use ForSignature mode to ensure RPITITs are printed as `impl Trait` rather than
355    // `impl Trait { T::method(..) }` when RTN is enabled.
356    //
357    // We use `with_no_trimmed_paths!` to avoid triggering the `trimmed_def_paths` query,
358    // which requires diagnostic context (via `must_produce_diag`). Since we're formatting
359    // the type before creating the diagnostic, we need to avoid this query. This is the
360    // standard approach used elsewhere in the compiler for formatting types in suggestions
361    // (e.g., see `rustc_hir_typeck/src/demand.rs`).
362    let return_ty_suggestion =
363        {
    let _guard = NoTrimmedGuard::new();
    {
        let _guard =
            ::rustc_middle::ty::print::pretty::RtnModeHelper::with(RtnMode::ForSignature);
        ::alloc::__export::must_use({
                ::alloc::fmt::format(format_args!("{0}", return_ty))
            })
    }
}with_no_trimmed_paths!(with_types_for_signature!(format!("{return_ty}")));
364
365    let span = unmatched_bound.unwrap_or(span);
366    tcx.emit_node_span_lint(
367        if is_internal { REFINING_IMPL_TRAIT_INTERNAL } else { REFINING_IMPL_TRAIT_REACHABLE },
368        tcx.local_def_id_to_hir_id(impl_m_def_id.expect_local()),
369        span,
370        crate::errors::ReturnPositionImplTraitInTraitRefined {
371            impl_return_span,
372            trait_return_span,
373            pre,
374            post,
375            return_ty: return_ty_suggestion,
376            unmatched_bound,
377        },
378    );
379}
380
381fn type_visibility<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Option<ty::Visibility<DefId>> {
382    match *ty.kind() {
383        ty::Ref(_, ty, _) => type_visibility(tcx, ty),
384        ty::Adt(def, args) => {
385            if def.is_fundamental() {
386                type_visibility(tcx, args.type_at(0))
387            } else {
388                Some(tcx.visibility(def.did()))
389            }
390        }
391        _ => None,
392    }
393}
394
395struct Anonymize<'tcx> {
396    tcx: TyCtxt<'tcx>,
397}
398
399impl<'tcx> TypeFolder<TyCtxt<'tcx>> for Anonymize<'tcx> {
400    fn cx(&self) -> TyCtxt<'tcx> {
401        self.tcx
402    }
403
404    fn fold_binder<T>(&mut self, t: ty::Binder<'tcx, T>) -> ty::Binder<'tcx, T>
405    where
406        T: TypeFoldable<TyCtxt<'tcx>>,
407    {
408        self.tcx.anonymize_bound_vars(t)
409    }
410}
411
412struct CollectParams<'a, 'tcx> {
413    params: &'a mut FxIndexSet<ty::GenericArg<'tcx>>,
414}
415impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for CollectParams<'_, 'tcx> {
416    fn visit_ty(&mut self, ty: Ty<'tcx>) {
417        if let ty::Param(_) = ty.kind() {
418            self.params.insert(ty.into());
419        } else {
420            ty.super_visit_with(self);
421        }
422    }
423    fn visit_region(&mut self, r: ty::Region<'tcx>) {
424        match r.kind() {
425            ty::ReEarlyParam(_) | ty::ReLateParam(_) => {
426                self.params.insert(r.into());
427            }
428            _ => {}
429        }
430    }
431    fn visit_const(&mut self, ct: ty::Const<'tcx>) {
432        if let ty::ConstKind::Param(_) = ct.kind() {
433            self.params.insert(ct.into());
434        } else {
435            ct.super_visit_with(self);
436        }
437    }
438}
439
440fn report_mismatched_rpitit_captures<'tcx>(
441    tcx: TyCtxt<'tcx>,
442    impl_opaque_def_id: LocalDefId,
443    mut trait_captured_args: FxIndexSet<ty::GenericArg<'tcx>>,
444    is_internal: bool,
445) {
446    let Some(use_bound_span) =
447        tcx.hir_node_by_def_id(impl_opaque_def_id).expect_opaque_ty().bounds.iter().find_map(
448            |bound| match *bound {
449                rustc_hir::GenericBound::Use(_, span) => Some(span),
450                hir::GenericBound::Trait(_) | hir::GenericBound::Outlives(_) => None,
451            },
452        )
453    else {
454        // I have no idea when you would ever undercapture without a `use<..>`.
455        tcx.dcx().delayed_bug("expected use<..> to undercapture in an impl opaque");
456        return;
457    };
458
459    trait_captured_args
460        .sort_by_cached_key(|arg| !#[allow(non_exhaustive_omitted_patterns)] match arg.kind() {
    ty::GenericArgKind::Lifetime(_) => true,
    _ => false,
}matches!(arg.kind(), ty::GenericArgKind::Lifetime(_)));
461    let suggestion = ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("use<{0}>",
                trait_captured_args.iter().join(", ")))
    })format!("use<{}>", trait_captured_args.iter().join(", "));
462
463    tcx.emit_node_span_lint(
464        if is_internal { REFINING_IMPL_TRAIT_INTERNAL } else { REFINING_IMPL_TRAIT_REACHABLE },
465        tcx.local_def_id_to_hir_id(impl_opaque_def_id),
466        use_bound_span,
467        crate::errors::ReturnPositionImplTraitInTraitRefinedLifetimes {
468            suggestion_span: use_bound_span,
469            suggestion,
470        },
471    );
472}