rustc_hir_analysis/check/compare_impl_item/
refine.rs
1use itertools::Itertools as _;
2use rustc_data_structures::fx::FxIndexSet;
3use rustc_hir as hir;
4use rustc_hir::def_id::{DefId, LocalDefId};
5use rustc_infer::infer::TyCtxtInferExt;
6use rustc_lint_defs::builtin::{REFINING_IMPL_TRAIT_INTERNAL, REFINING_IMPL_TRAIT_REACHABLE};
7use rustc_middle::span_bug;
8use rustc_middle::traits::ObligationCause;
9use rustc_middle::ty::{
10 self, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperVisitable, TypeVisitable,
11 TypeVisitableExt, TypeVisitor, TypingMode,
12};
13use rustc_span::Span;
14use rustc_trait_selection::regions::InferCtxtRegionExt;
15use rustc_trait_selection::traits::{ObligationCtxt, elaborate, normalize_param_env_or_error};
16
17pub(crate) fn check_refining_return_position_impl_trait_in_trait<'tcx>(
19 tcx: TyCtxt<'tcx>,
20 impl_m: ty::AssocItem,
21 trait_m: ty::AssocItem,
22 impl_trait_ref: ty::TraitRef<'tcx>,
23) {
24 if !tcx.impl_method_has_trait_impl_trait_tys(impl_m.def_id) {
25 return;
26 }
27
28 let is_internal = trait_m
30 .container_id(tcx)
31 .as_local()
32 .is_some_and(|trait_def_id| !tcx.effective_visibilities(()).is_reachable(trait_def_id))
33 || impl_trait_ref.args.iter().any(|arg| {
35 if let Some(ty) = arg.as_type()
36 && let Some(self_visibility) = type_visibility(tcx, ty)
37 {
38 return !self_visibility.is_public();
39 }
40 false
41 });
42
43 let impl_def_id = impl_m.container_id(tcx);
44 let impl_m_args = ty::GenericArgs::identity_for_item(tcx, impl_m.def_id);
45 let trait_m_to_impl_m_args = impl_m_args.rebase_onto(tcx, impl_def_id, impl_trait_ref.args);
46 let bound_trait_m_sig = tcx.fn_sig(trait_m.def_id).instantiate(tcx, trait_m_to_impl_m_args);
47 let trait_m_sig = tcx.liberate_late_bound_regions(impl_m.def_id, bound_trait_m_sig);
48 let trait_m_sig_with_self_for_diag = tcx.liberate_late_bound_regions(
50 impl_m.def_id,
51 tcx.fn_sig(trait_m.def_id).instantiate(
52 tcx,
53 tcx.mk_args_from_iter(
54 [tcx.types.self_param.into()]
55 .into_iter()
56 .chain(trait_m_to_impl_m_args.iter().skip(1)),
57 ),
58 ),
59 );
60
61 let Ok(hidden_tys) = tcx.collect_return_position_impl_trait_in_trait_tys(impl_m.def_id) else {
62 return;
64 };
65
66 if hidden_tys.items().any(|(_, &ty)| ty.skip_binder().references_error()) {
67 return;
68 }
69
70 let mut collector = ImplTraitInTraitCollector { tcx, types: FxIndexSet::default() };
71 trait_m_sig.visit_with(&mut collector);
72
73 let mut trait_bounds = vec![];
75 let mut impl_bounds = vec![];
77 let mut pairs = vec![];
79
80 for trait_projection in collector.types.into_iter().rev() {
81 let impl_opaque_args = trait_projection.args.rebase_onto(tcx, trait_m.def_id, impl_m_args);
82 let hidden_ty = hidden_tys[&trait_projection.def_id].instantiate(tcx, impl_opaque_args);
83
84 let ty::Alias(ty::Opaque, impl_opaque) = *hidden_ty.kind() else {
86 report_mismatched_rpitit_signature(
87 tcx,
88 trait_m_sig_with_self_for_diag,
89 trait_m.def_id,
90 impl_m.def_id,
91 None,
92 is_internal,
93 );
94 return;
95 };
96
97 if !tcx.hir().get_if_local(impl_opaque.def_id).is_some_and(|node| {
100 matches!(
101 node.expect_opaque_ty().origin,
102 hir::OpaqueTyOrigin::AsyncFn { parent, .. } | hir::OpaqueTyOrigin::FnReturn { parent, .. }
103 if parent == impl_m.def_id.expect_local()
104 )
105 }) {
106 report_mismatched_rpitit_signature(
107 tcx,
108 trait_m_sig_with_self_for_diag,
109 trait_m.def_id,
110 impl_m.def_id,
111 None,
112 is_internal,
113 );
114 return;
115 }
116
117 trait_bounds.extend(
118 tcx.item_bounds(trait_projection.def_id).iter_instantiated(tcx, trait_projection.args),
119 );
120 impl_bounds.extend(elaborate(
121 tcx,
122 tcx.explicit_item_bounds(impl_opaque.def_id)
123 .iter_instantiated_copied(tcx, impl_opaque.args),
124 ));
125
126 pairs.push((trait_projection, impl_opaque));
127 }
128
129 let hybrid_preds = tcx
130 .predicates_of(impl_def_id)
131 .instantiate_identity(tcx)
132 .into_iter()
133 .chain(tcx.predicates_of(trait_m.def_id).instantiate_own(tcx, trait_m_to_impl_m_args))
134 .map(|(clause, _)| clause);
135 let param_env = ty::ParamEnv::new(tcx.mk_clauses_from_iter(hybrid_preds));
136 let param_env = normalize_param_env_or_error(tcx, param_env, ObligationCause::dummy());
137
138 let ref infcx = tcx.infer_ctxt().build(TypingMode::non_body_analysis());
139 let ocx = ObligationCtxt::new(infcx);
140
141 let Ok((trait_bounds, impl_bounds)) =
150 ocx.deeply_normalize(&ObligationCause::dummy(), param_env, (trait_bounds, impl_bounds))
151 else {
152 tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (selection)");
153 return;
154 };
155
156 let mut implied_wf_types = FxIndexSet::default();
161 implied_wf_types.extend(trait_m_sig.inputs_and_output);
162 implied_wf_types.extend(ocx.normalize(
163 &ObligationCause::dummy(),
164 param_env,
165 trait_m_sig.inputs_and_output,
166 ));
167 if !ocx.select_all_or_error().is_empty() {
168 tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (selection)");
169 return;
170 }
171 let errors = infcx.resolve_regions(impl_m.def_id.expect_local(), param_env, implied_wf_types);
172 if !errors.is_empty() {
173 tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (regions)");
174 return;
175 }
176 let Ok((trait_bounds, impl_bounds)) = infcx.fully_resolve((trait_bounds, impl_bounds)) else {
178 tcx.dcx().delayed_bug("encountered errors when checking RPITIT refinement (resolution)");
181 return;
182 };
183
184 if trait_bounds.references_error() || impl_bounds.references_error() {
185 return;
186 }
187
188 let trait_bounds = FxIndexSet::from_iter(trait_bounds.fold_with(&mut Anonymize { tcx }));
194 let impl_bounds = impl_bounds.fold_with(&mut Anonymize { tcx });
195
196 for (clause, span) in impl_bounds {
202 if !trait_bounds.contains(&clause) {
203 report_mismatched_rpitit_signature(
204 tcx,
205 trait_m_sig_with_self_for_diag,
206 trait_m.def_id,
207 impl_m.def_id,
208 Some(span),
209 is_internal,
210 );
211 return;
212 }
213 }
214
215 for (trait_projection, impl_opaque) in pairs {
220 let impl_variances = tcx.variances_of(impl_opaque.def_id);
221 let impl_captures: FxIndexSet<_> = impl_opaque
222 .args
223 .iter()
224 .zip_eq(impl_variances)
225 .filter(|(_, v)| **v == ty::Invariant)
226 .map(|(arg, _)| arg)
227 .collect();
228
229 let trait_variances = tcx.variances_of(trait_projection.def_id);
230 let mut trait_captures = FxIndexSet::default();
231 for (arg, variance) in trait_projection.args.iter().zip_eq(trait_variances) {
232 if *variance != ty::Invariant {
233 continue;
234 }
235 arg.visit_with(&mut CollectParams { params: &mut trait_captures });
236 }
237
238 if !trait_captures.iter().all(|arg| impl_captures.contains(arg)) {
239 report_mismatched_rpitit_captures(
240 tcx,
241 impl_opaque.def_id.expect_local(),
242 trait_captures,
243 is_internal,
244 );
245 }
246 }
247}
248
249struct ImplTraitInTraitCollector<'tcx> {
250 tcx: TyCtxt<'tcx>,
251 types: FxIndexSet<ty::AliasTy<'tcx>>,
252}
253
254impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitCollector<'tcx> {
255 fn visit_ty(&mut self, ty: Ty<'tcx>) {
256 if let ty::Alias(ty::Projection, proj) = *ty.kind()
257 && self.tcx.is_impl_trait_in_trait(proj.def_id)
258 {
259 if self.types.insert(proj) {
260 for (pred, _) in self
261 .tcx
262 .explicit_item_bounds(proj.def_id)
263 .iter_instantiated_copied(self.tcx, proj.args)
264 {
265 pred.visit_with(self);
266 }
267 }
268 } else {
269 ty.super_visit_with(self);
270 }
271 }
272}
273
274fn report_mismatched_rpitit_signature<'tcx>(
275 tcx: TyCtxt<'tcx>,
276 trait_m_sig: ty::FnSig<'tcx>,
277 trait_m_def_id: DefId,
278 impl_m_def_id: DefId,
279 unmatched_bound: Option<Span>,
280 is_internal: bool,
281) {
282 let mapping = std::iter::zip(
283 tcx.fn_sig(trait_m_def_id).skip_binder().bound_vars(),
284 tcx.fn_sig(impl_m_def_id).skip_binder().bound_vars(),
285 )
286 .enumerate()
287 .filter_map(|(idx, (impl_bv, trait_bv))| {
288 if let ty::BoundVariableKind::Region(impl_bv) = impl_bv
289 && let ty::BoundVariableKind::Region(trait_bv) = trait_bv
290 {
291 let var = ty::BoundVar::from_usize(idx);
292 Some((
293 ty::LateParamRegionKind::from_bound(var, impl_bv),
294 ty::LateParamRegionKind::from_bound(var, trait_bv),
295 ))
296 } else {
297 None
298 }
299 })
300 .collect();
301
302 let mut return_ty = trait_m_sig.output().fold_with(&mut super::RemapLateParam { tcx, mapping });
303
304 if tcx.asyncness(impl_m_def_id).is_async() && tcx.asyncness(trait_m_def_id).is_async() {
305 let ty::Alias(ty::Projection, future_ty) = return_ty.kind() else {
306 span_bug!(
307 tcx.def_span(trait_m_def_id),
308 "expected return type of async fn in trait to be a AFIT projection"
309 );
310 };
311 let Some(future_output_ty) = tcx
312 .explicit_item_bounds(future_ty.def_id)
313 .iter_instantiated_copied(tcx, future_ty.args)
314 .find_map(|(clause, _)| match clause.kind().no_bound_vars()? {
315 ty::ClauseKind::Projection(proj) => proj.term.as_type(),
316 _ => None,
317 })
318 else {
319 span_bug!(tcx.def_span(trait_m_def_id), "expected `Future` projection bound in AFIT");
320 };
321 return_ty = future_output_ty;
322 }
323
324 let (span, impl_return_span, pre, post) =
325 match tcx.hir_node_by_def_id(impl_m_def_id.expect_local()).fn_decl().unwrap().output {
326 hir::FnRetTy::DefaultReturn(span) => (tcx.def_span(impl_m_def_id), span, "-> ", " "),
327 hir::FnRetTy::Return(ty) => (ty.span, ty.span, "", ""),
328 };
329 let trait_return_span =
330 tcx.hir().get_if_local(trait_m_def_id).map(|node| match node.fn_decl().unwrap().output {
331 hir::FnRetTy::DefaultReturn(_) => tcx.def_span(trait_m_def_id),
332 hir::FnRetTy::Return(ty) => ty.span,
333 });
334
335 let span = unmatched_bound.unwrap_or(span);
336 tcx.emit_node_span_lint(
337 if is_internal { REFINING_IMPL_TRAIT_INTERNAL } else { REFINING_IMPL_TRAIT_REACHABLE },
338 tcx.local_def_id_to_hir_id(impl_m_def_id.expect_local()),
339 span,
340 crate::errors::ReturnPositionImplTraitInTraitRefined {
341 impl_return_span,
342 trait_return_span,
343 pre,
344 post,
345 return_ty,
346 unmatched_bound,
347 },
348 );
349}
350
351fn type_visibility<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Option<ty::Visibility<DefId>> {
352 match *ty.kind() {
353 ty::Ref(_, ty, _) => type_visibility(tcx, ty),
354 ty::Adt(def, args) => {
355 if def.is_fundamental() {
356 type_visibility(tcx, args.type_at(0))
357 } else {
358 Some(tcx.visibility(def.did()))
359 }
360 }
361 _ => None,
362 }
363}
364
365struct Anonymize<'tcx> {
366 tcx: TyCtxt<'tcx>,
367}
368
369impl<'tcx> TypeFolder<TyCtxt<'tcx>> for Anonymize<'tcx> {
370 fn cx(&self) -> TyCtxt<'tcx> {
371 self.tcx
372 }
373
374 fn fold_binder<T>(&mut self, t: ty::Binder<'tcx, T>) -> ty::Binder<'tcx, T>
375 where
376 T: TypeFoldable<TyCtxt<'tcx>>,
377 {
378 self.tcx.anonymize_bound_vars(t)
379 }
380}
381
382struct CollectParams<'a, 'tcx> {
383 params: &'a mut FxIndexSet<ty::GenericArg<'tcx>>,
384}
385impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for CollectParams<'_, 'tcx> {
386 fn visit_ty(&mut self, ty: Ty<'tcx>) {
387 if let ty::Param(_) = ty.kind() {
388 self.params.insert(ty.into());
389 } else {
390 ty.super_visit_with(self);
391 }
392 }
393 fn visit_region(&mut self, r: ty::Region<'tcx>) {
394 match r.kind() {
395 ty::ReEarlyParam(_) | ty::ReLateParam(_) => {
396 self.params.insert(r.into());
397 }
398 _ => {}
399 }
400 }
401 fn visit_const(&mut self, ct: ty::Const<'tcx>) {
402 if let ty::ConstKind::Param(_) = ct.kind() {
403 self.params.insert(ct.into());
404 } else {
405 ct.super_visit_with(self);
406 }
407 }
408}
409
410fn report_mismatched_rpitit_captures<'tcx>(
411 tcx: TyCtxt<'tcx>,
412 impl_opaque_def_id: LocalDefId,
413 mut trait_captured_args: FxIndexSet<ty::GenericArg<'tcx>>,
414 is_internal: bool,
415) {
416 let Some(use_bound_span) =
417 tcx.hir_node_by_def_id(impl_opaque_def_id).expect_opaque_ty().bounds.iter().find_map(
418 |bound| match *bound {
419 rustc_hir::GenericBound::Use(_, span) => Some(span),
420 hir::GenericBound::Trait(_) | hir::GenericBound::Outlives(_) => None,
421 },
422 )
423 else {
424 tcx.dcx().delayed_bug("expected use<..> to undercapture in an impl opaque");
426 return;
427 };
428
429 trait_captured_args
430 .sort_by_cached_key(|arg| !matches!(arg.unpack(), ty::GenericArgKind::Lifetime(_)));
431 let suggestion = format!("use<{}>", trait_captured_args.iter().join(", "));
432
433 tcx.emit_node_span_lint(
434 if is_internal { REFINING_IMPL_TRAIT_INTERNAL } else { REFINING_IMPL_TRAIT_REACHABLE },
435 tcx.local_def_id_to_hir_id(impl_opaque_def_id),
436 use_bound_span,
437 crate::errors::ReturnPositionImplTraitInTraitRefinedLifetimes {
438 suggestion_span: use_bound_span,
439 suggestion,
440 },
441 );
442}