rustc_hir_analysis/check/compare_impl_item/
refine.rs

1use itertools::Itertools as _;
2use rustc_data_structures::fx::FxIndexSet;
3use rustc_hir as hir;
4use rustc_hir::def_id::{DefId, LocalDefId};
5use rustc_infer::infer::TyCtxtInferExt;
6use rustc_lint_defs::builtin::{REFINING_IMPL_TRAIT_INTERNAL, REFINING_IMPL_TRAIT_REACHABLE};
7use rustc_middle::span_bug;
8use rustc_middle::traits::ObligationCause;
9use rustc_middle::ty::{
10    self, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperVisitable, TypeVisitable,
11    TypeVisitableExt, TypeVisitor, TypingMode,
12};
13use rustc_span::Span;
14use rustc_trait_selection::regions::InferCtxtRegionExt;
15use rustc_trait_selection::traits::{ObligationCtxt, elaborate, normalize_param_env_or_error};
16
17/// Check that an implementation does not refine an RPITIT from a trait method signature.
18pub(crate) fn check_refining_return_position_impl_trait_in_trait<'tcx>(
19    tcx: TyCtxt<'tcx>,
20    impl_m: ty::AssocItem,
21    trait_m: ty::AssocItem,
22    impl_trait_ref: ty::TraitRef<'tcx>,
23) {
24    if !tcx.impl_method_has_trait_impl_trait_tys(impl_m.def_id) {
25        return;
26    }
27
28    // unreachable traits don't have any library guarantees, there's no need to do this check.
29    let is_internal = trait_m
30        .container_id(tcx)
31        .as_local()
32        .is_some_and(|trait_def_id| !tcx.effective_visibilities(()).is_reachable(trait_def_id))
33        // If a type in the trait ref is private, then there's also no reason to do this check.
34        || impl_trait_ref.args.iter().any(|arg| {
35            if let Some(ty) = arg.as_type()
36                && let Some(self_visibility) = type_visibility(tcx, ty)
37            {
38                return !self_visibility.is_public();
39            }
40            false
41        });
42
43    let impl_def_id = impl_m.container_id(tcx);
44    let impl_m_args = ty::GenericArgs::identity_for_item(tcx, impl_m.def_id);
45    let trait_m_to_impl_m_args = impl_m_args.rebase_onto(tcx, impl_def_id, impl_trait_ref.args);
46    let bound_trait_m_sig = tcx.fn_sig(trait_m.def_id).instantiate(tcx, trait_m_to_impl_m_args);
47    let trait_m_sig = tcx.liberate_late_bound_regions(impl_m.def_id, bound_trait_m_sig);
48    // replace the self type of the trait ref with `Self` so that diagnostics render better.
49    let trait_m_sig_with_self_for_diag = tcx.liberate_late_bound_regions(
50        impl_m.def_id,
51        tcx.fn_sig(trait_m.def_id).instantiate(
52            tcx,
53            tcx.mk_args_from_iter(
54                [tcx.types.self_param.into()]
55                    .into_iter()
56                    .chain(trait_m_to_impl_m_args.iter().skip(1)),
57            ),
58        ),
59    );
60
61    let Ok(hidden_tys) = tcx.collect_return_position_impl_trait_in_trait_tys(impl_m.def_id) else {
62        // Error already emitted, no need to delay another.
63        return;
64    };
65
66    if hidden_tys.items().any(|(_, &ty)| ty.skip_binder().references_error()) {
67        return;
68    }
69
70    let mut collector = ImplTraitInTraitCollector { tcx, types: FxIndexSet::default() };
71    trait_m_sig.visit_with(&mut collector);
72
73    // Bound that we find on RPITITs in the trait signature.
74    let mut trait_bounds = vec![];
75    // Bounds that we find on the RPITITs in the impl signature.
76    let mut impl_bounds = vec![];
77    // Pairs of trait and impl opaques.
78    let mut pairs = vec![];
79
80    for trait_projection in collector.types.into_iter().rev() {
81        let impl_opaque_args = trait_projection.args.rebase_onto(tcx, trait_m.def_id, impl_m_args);
82        let hidden_ty = hidden_tys[&trait_projection.def_id].instantiate(tcx, impl_opaque_args);
83
84        // If the hidden type is not an opaque, then we have "refined" the trait signature.
85        let ty::Alias(ty::Opaque, impl_opaque) = *hidden_ty.kind() else {
86            report_mismatched_rpitit_signature(
87                tcx,
88                trait_m_sig_with_self_for_diag,
89                trait_m.def_id,
90                impl_m.def_id,
91                None,
92                is_internal,
93            );
94            return;
95        };
96
97        // This opaque also needs to be from the impl method -- otherwise,
98        // it's a refinement to a TAIT.
99        if !tcx.hir().get_if_local(impl_opaque.def_id).is_some_and(|node| {
100            matches!(
101                node.expect_opaque_ty().origin,
102                hir::OpaqueTyOrigin::AsyncFn { parent, .. }  | hir::OpaqueTyOrigin::FnReturn { parent, .. }
103                    if parent == impl_m.def_id.expect_local()
104            )
105        }) {
106            report_mismatched_rpitit_signature(
107                tcx,
108                trait_m_sig_with_self_for_diag,
109                trait_m.def_id,
110                impl_m.def_id,
111                None,
112                is_internal,
113            );
114            return;
115        }
116
117        trait_bounds.extend(
118            tcx.item_bounds(trait_projection.def_id).iter_instantiated(tcx, trait_projection.args),
119        );
120        impl_bounds.extend(elaborate(
121            tcx,
122            tcx.explicit_item_bounds(impl_opaque.def_id)
123                .iter_instantiated_copied(tcx, impl_opaque.args),
124        ));
125
126        pairs.push((trait_projection, impl_opaque));
127    }
128
129    let hybrid_preds = tcx
130        .predicates_of(impl_def_id)
131        .instantiate_identity(tcx)
132        .into_iter()
133        .chain(tcx.predicates_of(trait_m.def_id).instantiate_own(tcx, trait_m_to_impl_m_args))
134        .map(|(clause, _)| clause);
135    let param_env = ty::ParamEnv::new(tcx.mk_clauses_from_iter(hybrid_preds));
136    let param_env = normalize_param_env_or_error(tcx, param_env, ObligationCause::dummy());
137
138    let ref infcx = tcx.infer_ctxt().build(TypingMode::non_body_analysis());
139    let ocx = ObligationCtxt::new(infcx);
140
141    // Normalize the bounds. This has two purposes:
142    //
143    // 1. Project the RPITIT projections from the trait to the opaques on the impl,
144    //    which means that they don't need to be mapped manually.
145    //
146    // 2. Deeply normalize any other projections that show up in the bound. That makes sure
147    //    that we don't consider `tests/ui/async-await/in-trait/async-associated-types.rs`
148    //    or `tests/ui/impl-trait/in-trait/refine-normalize.rs` to be refining.
149    let Ok((trait_bounds, impl_bounds)) =
150        ocx.deeply_normalize(&ObligationCause::dummy(), param_env, (trait_bounds, impl_bounds))
151    else {
152        tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (selection)");
153        return;
154    };
155
156    // Since we've normalized things, we need to resolve regions, since we'll
157    // possibly have introduced region vars during projection. We don't expect
158    // this resolution to have incurred any region errors -- but if we do, then
159    // just delay a bug.
160    let mut implied_wf_types = FxIndexSet::default();
161    implied_wf_types.extend(trait_m_sig.inputs_and_output);
162    implied_wf_types.extend(ocx.normalize(
163        &ObligationCause::dummy(),
164        param_env,
165        trait_m_sig.inputs_and_output,
166    ));
167    if !ocx.select_all_or_error().is_empty() {
168        tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (selection)");
169        return;
170    }
171    let errors = infcx.resolve_regions(impl_m.def_id.expect_local(), param_env, implied_wf_types);
172    if !errors.is_empty() {
173        tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (regions)");
174        return;
175    }
176    // Resolve any lifetime variables that may have been introduced during normalization.
177    let Ok((trait_bounds, impl_bounds)) = infcx.fully_resolve((trait_bounds, impl_bounds)) else {
178        // If resolution didn't fully complete, we cannot continue checking RPITIT refinement, and
179        // delay a bug as the original code contains load-bearing errors.
180        tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (resolution)");
181        return;
182    };
183
184    if trait_bounds.references_error() || impl_bounds.references_error() {
185        return;
186    }
187
188    // For quicker lookup, use an `IndexSet` (we don't use one earlier because
189    // it's not foldable..).
190    // Also, We have to anonymize binders in these types because they may contain
191    // `BrNamed` bound vars, which contain unique `DefId`s which correspond to syntax
192    // locations that we don't care about when checking bound equality.
193    let trait_bounds = FxIndexSet::from_iter(trait_bounds.fold_with(&mut Anonymize { tcx }));
194    let impl_bounds = impl_bounds.fold_with(&mut Anonymize { tcx });
195
196    // Find any clauses that are present in the impl's RPITITs that are not
197    // present in the trait's RPITITs. This will trigger on trivial predicates,
198    // too, since we *do not* use the trait solver to prove that the RPITIT's
199    // bounds are not stronger -- we're doing a simple, syntactic compatibility
200    // check between bounds. This is strictly forwards compatible, though.
201    for (clause, span) in impl_bounds {
202        if !trait_bounds.contains(&clause) {
203            report_mismatched_rpitit_signature(
204                tcx,
205                trait_m_sig_with_self_for_diag,
206                trait_m.def_id,
207                impl_m.def_id,
208                Some(span),
209                is_internal,
210            );
211            return;
212        }
213    }
214
215    // Make sure that the RPITIT doesn't capture fewer regions than
216    // the trait definition. We hard-error if it captures *more*, since that
217    // is literally unrepresentable in the type system; however, we may be
218    // promising stronger outlives guarantees if we capture *fewer* regions.
219    for (trait_projection, impl_opaque) in pairs {
220        let impl_variances = tcx.variances_of(impl_opaque.def_id);
221        let impl_captures: FxIndexSet<_> = impl_opaque
222            .args
223            .iter()
224            .zip_eq(impl_variances)
225            .filter(|(_, v)| **v == ty::Invariant)
226            .map(|(arg, _)| arg)
227            .collect();
228
229        let trait_variances = tcx.variances_of(trait_projection.def_id);
230        let mut trait_captures = FxIndexSet::default();
231        for (arg, variance) in trait_projection.args.iter().zip_eq(trait_variances) {
232            if *variance != ty::Invariant {
233                continue;
234            }
235            arg.visit_with(&mut CollectParams { params: &mut trait_captures });
236        }
237
238        if !trait_captures.iter().all(|arg| impl_captures.contains(arg)) {
239            report_mismatched_rpitit_captures(
240                tcx,
241                impl_opaque.def_id.expect_local(),
242                trait_captures,
243                is_internal,
244            );
245        }
246    }
247}
248
249struct ImplTraitInTraitCollector<'tcx> {
250    tcx: TyCtxt<'tcx>,
251    types: FxIndexSet<ty::AliasTy<'tcx>>,
252}
253
254impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitCollector<'tcx> {
255    fn visit_ty(&mut self, ty: Ty<'tcx>) {
256        if let ty::Alias(ty::Projection, proj) = *ty.kind()
257            && self.tcx.is_impl_trait_in_trait(proj.def_id)
258        {
259            if self.types.insert(proj) {
260                for (pred, _) in self
261                    .tcx
262                    .explicit_item_bounds(proj.def_id)
263                    .iter_instantiated_copied(self.tcx, proj.args)
264                {
265                    pred.visit_with(self);
266                }
267            }
268        } else {
269            ty.super_visit_with(self);
270        }
271    }
272}
273
274fn report_mismatched_rpitit_signature<'tcx>(
275    tcx: TyCtxt<'tcx>,
276    trait_m_sig: ty::FnSig<'tcx>,
277    trait_m_def_id: DefId,
278    impl_m_def_id: DefId,
279    unmatched_bound: Option<Span>,
280    is_internal: bool,
281) {
282    let mapping = std::iter::zip(
283        tcx.fn_sig(trait_m_def_id).skip_binder().bound_vars(),
284        tcx.fn_sig(impl_m_def_id).skip_binder().bound_vars(),
285    )
286    .enumerate()
287    .filter_map(|(idx, (impl_bv, trait_bv))| {
288        if let ty::BoundVariableKind::Region(impl_bv) = impl_bv
289            && let ty::BoundVariableKind::Region(trait_bv) = trait_bv
290        {
291            let var = ty::BoundVar::from_usize(idx);
292            Some((
293                ty::LateParamRegionKind::from_bound(var, impl_bv),
294                ty::LateParamRegionKind::from_bound(var, trait_bv),
295            ))
296        } else {
297            None
298        }
299    })
300    .collect();
301
302    let mut return_ty = trait_m_sig.output().fold_with(&mut super::RemapLateParam { tcx, mapping });
303
304    if tcx.asyncness(impl_m_def_id).is_async() && tcx.asyncness(trait_m_def_id).is_async() {
305        let ty::Alias(ty::Projection, future_ty) = return_ty.kind() else {
306            span_bug!(
307                tcx.def_span(trait_m_def_id),
308                "expected return type of async fn in trait to be a AFIT projection"
309            );
310        };
311        let Some(future_output_ty) = tcx
312            .explicit_item_bounds(future_ty.def_id)
313            .iter_instantiated_copied(tcx, future_ty.args)
314            .find_map(|(clause, _)| match clause.kind().no_bound_vars()? {
315                ty::ClauseKind::Projection(proj) => proj.term.as_type(),
316                _ => None,
317            })
318        else {
319            span_bug!(tcx.def_span(trait_m_def_id), "expected `Future` projection bound in AFIT");
320        };
321        return_ty = future_output_ty;
322    }
323
324    let (span, impl_return_span, pre, post) =
325        match tcx.hir_node_by_def_id(impl_m_def_id.expect_local()).fn_decl().unwrap().output {
326            hir::FnRetTy::DefaultReturn(span) => (tcx.def_span(impl_m_def_id), span, "-> ", " "),
327            hir::FnRetTy::Return(ty) => (ty.span, ty.span, "", ""),
328        };
329    let trait_return_span =
330        tcx.hir().get_if_local(trait_m_def_id).map(|node| match node.fn_decl().unwrap().output {
331            hir::FnRetTy::DefaultReturn(_) => tcx.def_span(trait_m_def_id),
332            hir::FnRetTy::Return(ty) => ty.span,
333        });
334
335    let span = unmatched_bound.unwrap_or(span);
336    tcx.emit_node_span_lint(
337        if is_internal { REFINING_IMPL_TRAIT_INTERNAL } else { REFINING_IMPL_TRAIT_REACHABLE },
338        tcx.local_def_id_to_hir_id(impl_m_def_id.expect_local()),
339        span,
340        crate::errors::ReturnPositionImplTraitInTraitRefined {
341            impl_return_span,
342            trait_return_span,
343            pre,
344            post,
345            return_ty,
346            unmatched_bound,
347        },
348    );
349}
350
351fn type_visibility<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Option<ty::Visibility<DefId>> {
352    match *ty.kind() {
353        ty::Ref(_, ty, _) => type_visibility(tcx, ty),
354        ty::Adt(def, args) => {
355            if def.is_fundamental() {
356                type_visibility(tcx, args.type_at(0))
357            } else {
358                Some(tcx.visibility(def.did()))
359            }
360        }
361        _ => None,
362    }
363}
364
365struct Anonymize<'tcx> {
366    tcx: TyCtxt<'tcx>,
367}
368
369impl<'tcx> TypeFolder<TyCtxt<'tcx>> for Anonymize<'tcx> {
370    fn cx(&self) -> TyCtxt<'tcx> {
371        self.tcx
372    }
373
374    fn fold_binder<T>(&mut self, t: ty::Binder<'tcx, T>) -> ty::Binder<'tcx, T>
375    where
376        T: TypeFoldable<TyCtxt<'tcx>>,
377    {
378        self.tcx.anonymize_bound_vars(t)
379    }
380}
381
382struct CollectParams<'a, 'tcx> {
383    params: &'a mut FxIndexSet<ty::GenericArg<'tcx>>,
384}
385impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for CollectParams<'_, 'tcx> {
386    fn visit_ty(&mut self, ty: Ty<'tcx>) {
387        if let ty::Param(_) = ty.kind() {
388            self.params.insert(ty.into());
389        } else {
390            ty.super_visit_with(self);
391        }
392    }
393    fn visit_region(&mut self, r: ty::Region<'tcx>) {
394        match r.kind() {
395            ty::ReEarlyParam(_) | ty::ReLateParam(_) => {
396                self.params.insert(r.into());
397            }
398            _ => {}
399        }
400    }
401    fn visit_const(&mut self, ct: ty::Const<'tcx>) {
402        if let ty::ConstKind::Param(_) = ct.kind() {
403            self.params.insert(ct.into());
404        } else {
405            ct.super_visit_with(self);
406        }
407    }
408}
409
410fn report_mismatched_rpitit_captures<'tcx>(
411    tcx: TyCtxt<'tcx>,
412    impl_opaque_def_id: LocalDefId,
413    mut trait_captured_args: FxIndexSet<ty::GenericArg<'tcx>>,
414    is_internal: bool,
415) {
416    let Some(use_bound_span) =
417        tcx.hir_node_by_def_id(impl_opaque_def_id).expect_opaque_ty().bounds.iter().find_map(
418            |bound| match *bound {
419                rustc_hir::GenericBound::Use(_, span) => Some(span),
420                hir::GenericBound::Trait(_) | hir::GenericBound::Outlives(_) => None,
421            },
422        )
423    else {
424        // I have no idea when you would ever undercapture without a `use<..>`.
425        tcx.dcx().delayed_bug("expected use<..> to undercapture in an impl opaque");
426        return;
427    };
428
429    trait_captured_args
430        .sort_by_cached_key(|arg| !matches!(arg.unpack(), ty::GenericArgKind::Lifetime(_)));
431    let suggestion = format!("use<{}>", trait_captured_args.iter().join(", "));
432
433    tcx.emit_node_span_lint(
434        if is_internal { REFINING_IMPL_TRAIT_INTERNAL } else { REFINING_IMPL_TRAIT_REACHABLE },
435        tcx.local_def_id_to_hir_id(impl_opaque_def_id),
436        use_bound_span,
437        crate::errors::ReturnPositionImplTraitInTraitRefinedLifetimes {
438            suggestion_span: use_bound_span,
439            suggestion,
440        },
441    );
442}