1use std::assert_matches::debug_assert_matches;
6
7use rustc_data_structures::fx::FxHashMap;
8use rustc_hir::def::DefKind;
9use rustc_hir::def_id::{DefId, LocalDefId};
10use rustc_middle::ty::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
11use rustc_middle::ty::{self, Ty, TyCtxt};
12use rustc_span::{ErrorGuaranteed, Span};
13use rustc_type_ir::visit::TypeVisitableExt;
14
15type RemapTable = FxHashMap<u32, u32>;
16
17struct ParamIndexRemapper<'tcx> {
18 tcx: TyCtxt<'tcx>,
19 remap_table: RemapTable,
20}
21
22impl<'tcx> TypeFolder<TyCtxt<'tcx>> for ParamIndexRemapper<'tcx> {
23 fn cx(&self) -> TyCtxt<'tcx> {
24 self.tcx
25 }
26
27 fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
28 if !ty.has_param() {
29 return ty;
30 }
31
32 if let ty::Param(param) = ty.kind()
33 && let Some(index) = self.remap_table.get(¶m.index)
34 {
35 return Ty::new_param(self.tcx, *index, param.name);
36 }
37 ty.super_fold_with(self)
38 }
39
40 fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
41 if let ty::ReEarlyParam(param) = r.kind()
42 && let Some(index) = self.remap_table.get(¶m.index).copied()
43 {
44 return ty::Region::new_early_param(
45 self.tcx,
46 ty::EarlyParamRegion { index, name: param.name },
47 );
48 }
49 r
50 }
51
52 fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
53 if let ty::ConstKind::Param(param) = ct.kind()
54 && let Some(idx) = self.remap_table.get(¶m.index)
55 {
56 let param = ty::ParamConst::new(*idx, param.name);
57 return ty::Const::new_param(self.tcx, param);
58 }
59 ct.super_fold_with(self)
60 }
61}
62
63#[derive(Clone, Copy, Debug, PartialEq)]
64enum FnKind {
65 Free,
66 AssocInherentImpl,
67 AssocTrait,
68 AssocTraitImpl,
69}
70
71fn fn_kind<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> FnKind {
72 debug_assert_matches!(tcx.def_kind(def_id), DefKind::Fn | DefKind::AssocFn);
73
74 let parent = tcx.parent(def_id);
75 match tcx.def_kind(parent) {
76 DefKind::Trait => FnKind::AssocTrait,
77 DefKind::Impl { of_trait: true } => FnKind::AssocTraitImpl,
78 DefKind::Impl { of_trait: false } => FnKind::AssocInherentImpl,
79 _ => FnKind::Free,
80 }
81}
82
83#[derive(Clone, Copy, Debug, PartialEq)]
86enum InheritanceKind {
87 WithParent(bool),
97 Own,
101}
102
103fn build_generics<'tcx>(
104 tcx: TyCtxt<'tcx>,
105 sig_id: DefId,
106 parent: Option<DefId>,
107 inh_kind: InheritanceKind,
108) -> ty::Generics {
109 let mut own_params = vec![];
110
111 let sig_generics = tcx.generics_of(sig_id);
112 if let InheritanceKind::WithParent(has_self) = inh_kind
113 && let Some(parent_def_id) = sig_generics.parent
114 {
115 let sig_parent_generics = tcx.generics_of(parent_def_id);
116 own_params.append(&mut sig_parent_generics.own_params.clone());
117 if !has_self {
118 own_params.remove(0);
119 }
120 }
121 own_params.append(&mut sig_generics.own_params.clone());
122
123 own_params.sort_by_key(|key| key.kind.is_ty_or_const());
137
138 let param_def_id_to_index =
139 own_params.iter().map(|param| (param.def_id, param.index)).collect();
140
141 let (parent_count, has_self) = if let Some(def_id) = parent {
142 let parent_generics = tcx.generics_of(def_id);
143 let parent_kind = tcx.def_kind(def_id);
144 (parent_generics.count(), parent_kind == DefKind::Trait)
145 } else {
146 (0, false)
147 };
148
149 for (idx, param) in own_params.iter_mut().enumerate() {
150 param.index = (idx + parent_count) as u32;
151 if let ty::GenericParamDefKind::Type { has_default, .. }
159 | ty::GenericParamDefKind::Const { has_default, .. } = &mut param.kind
160 {
161 *has_default = false;
162 }
163 }
164
165 ty::Generics {
166 parent,
167 parent_count,
168 own_params,
169 param_def_id_to_index,
170 has_self,
171 has_late_bound_regions: sig_generics.has_late_bound_regions,
172 }
173}
174
175fn build_predicates<'tcx>(
176 tcx: TyCtxt<'tcx>,
177 sig_id: DefId,
178 parent: Option<DefId>,
179 inh_kind: InheritanceKind,
180 args: ty::GenericArgsRef<'tcx>,
181) -> ty::GenericPredicates<'tcx> {
182 struct PredicatesCollector<'tcx> {
183 tcx: TyCtxt<'tcx>,
184 preds: Vec<(ty::Clause<'tcx>, Span)>,
185 args: ty::GenericArgsRef<'tcx>,
186 }
187
188 impl<'tcx> PredicatesCollector<'tcx> {
189 fn new(tcx: TyCtxt<'tcx>, args: ty::GenericArgsRef<'tcx>) -> PredicatesCollector<'tcx> {
190 PredicatesCollector { tcx, preds: vec![], args }
191 }
192
193 fn with_own_preds(
194 mut self,
195 f: impl Fn(DefId) -> ty::GenericPredicates<'tcx>,
196 def_id: DefId,
197 ) -> Self {
198 let preds = f(def_id).instantiate_own(self.tcx, self.args);
199 self.preds.extend(preds);
200 self
201 }
202
203 fn with_preds(
204 mut self,
205 f: impl Fn(DefId) -> ty::GenericPredicates<'tcx> + Copy,
206 def_id: DefId,
207 ) -> Self {
208 let preds = f(def_id);
209 if let Some(parent_def_id) = preds.parent {
210 self = self.with_own_preds(f, parent_def_id);
211 }
212 self.with_own_preds(f, def_id)
213 }
214 }
215 let collector = PredicatesCollector::new(tcx, args);
216
217 let preds = match inh_kind {
221 InheritanceKind::WithParent(false) => {
222 collector.with_preds(|def_id| tcx.explicit_predicates_of(def_id), sig_id)
223 }
224 InheritanceKind::WithParent(true) => {
225 collector.with_preds(|def_id| tcx.predicates_of(def_id), sig_id)
226 }
227 InheritanceKind::Own => {
228 collector.with_own_preds(|def_id| tcx.predicates_of(def_id), sig_id)
229 }
230 }
231 .preds;
232
233 ty::GenericPredicates { parent, predicates: tcx.arena.alloc_from_iter(preds) }
234}
235
236fn build_generic_args<'tcx>(
237 tcx: TyCtxt<'tcx>,
238 sig_id: DefId,
239 def_id: LocalDefId,
240 args: ty::GenericArgsRef<'tcx>,
241) -> ty::GenericArgsRef<'tcx> {
242 let caller_generics = tcx.generics_of(def_id);
243 let callee_generics = tcx.generics_of(sig_id);
244
245 let mut remap_table = FxHashMap::default();
246 for caller_param in &caller_generics.own_params {
247 let callee_index = callee_generics.param_def_id_to_index(tcx, caller_param.def_id).unwrap();
248 remap_table.insert(callee_index, caller_param.index);
249 }
250
251 let mut folder = ParamIndexRemapper { tcx, remap_table };
252 args.fold_with(&mut folder)
253}
254
255fn create_generic_args<'tcx>(
256 tcx: TyCtxt<'tcx>,
257 def_id: LocalDefId,
258 sig_id: DefId,
259) -> ty::GenericArgsRef<'tcx> {
260 let caller_kind = fn_kind(tcx, def_id.into());
261 let callee_kind = fn_kind(tcx, sig_id);
262 match (caller_kind, callee_kind) {
263 (FnKind::Free, FnKind::Free)
264 | (FnKind::Free, FnKind::AssocTrait)
265 | (FnKind::AssocInherentImpl, FnKind::Free)
266 | (FnKind::AssocTrait, FnKind::Free)
267 | (FnKind::AssocTrait, FnKind::AssocTrait) => {
268 let args = ty::GenericArgs::identity_for_item(tcx, sig_id);
269 build_generic_args(tcx, sig_id, def_id, args)
270 }
271
272 (FnKind::AssocTraitImpl, FnKind::AssocTrait) => {
273 let callee_generics = tcx.generics_of(sig_id);
274 let parent = tcx.parent(def_id.into());
275 let parent_args =
276 tcx.impl_trait_header(parent).unwrap().trait_ref.instantiate_identity().args;
277
278 let trait_args = ty::GenericArgs::identity_for_item(tcx, sig_id);
279 let method_args =
280 tcx.mk_args_from_iter(trait_args.iter().skip(callee_generics.parent_count));
281 let method_args = build_generic_args(tcx, sig_id, def_id, method_args);
282
283 tcx.mk_args_from_iter(parent_args.iter().chain(method_args))
284 }
285
286 (FnKind::AssocInherentImpl, FnKind::AssocTrait) => {
287 let parent = tcx.parent(def_id.into());
288 let self_ty = tcx.type_of(parent).instantiate_identity();
289 let generic_self_ty = ty::GenericArg::from(self_ty);
290
291 let trait_args = ty::GenericArgs::identity_for_item(tcx, sig_id);
292 let trait_args = build_generic_args(tcx, sig_id, def_id, trait_args);
293
294 let args = std::iter::once(generic_self_ty).chain(trait_args.iter().skip(1));
295 tcx.mk_args_from_iter(args)
296 }
297
298 (FnKind::AssocTraitImpl, _)
301 | (_, FnKind::AssocTraitImpl)
302 | (_, FnKind::AssocInherentImpl) => unreachable!(),
303 }
304}
305
306pub(crate) fn inherit_generics_for_delegation_item<'tcx>(
318 tcx: TyCtxt<'tcx>,
319 def_id: LocalDefId,
320 sig_id: DefId,
321) -> ty::Generics {
322 let caller_kind = fn_kind(tcx, def_id.into());
323 let callee_kind = fn_kind(tcx, sig_id);
324 match (caller_kind, callee_kind) {
325 (FnKind::Free, FnKind::Free) | (FnKind::Free, FnKind::AssocTrait) => {
326 build_generics(tcx, sig_id, None, InheritanceKind::WithParent(true))
327 }
328
329 (FnKind::AssocTraitImpl, FnKind::AssocTrait) => {
330 build_generics(tcx, sig_id, Some(tcx.parent(def_id.into())), InheritanceKind::Own)
331 }
332
333 (FnKind::AssocInherentImpl, FnKind::AssocTrait)
334 | (FnKind::AssocTrait, FnKind::AssocTrait)
335 | (FnKind::AssocInherentImpl, FnKind::Free)
336 | (FnKind::AssocTrait, FnKind::Free) => build_generics(
337 tcx,
338 sig_id,
339 Some(tcx.parent(def_id.into())),
340 InheritanceKind::WithParent(false),
341 ),
342
343 (FnKind::AssocTraitImpl, _)
346 | (_, FnKind::AssocTraitImpl)
347 | (_, FnKind::AssocInherentImpl) => unreachable!(),
348 }
349}
350
351pub(crate) fn inherit_predicates_for_delegation_item<'tcx>(
352 tcx: TyCtxt<'tcx>,
353 def_id: LocalDefId,
354 sig_id: DefId,
355) -> ty::GenericPredicates<'tcx> {
356 let args = create_generic_args(tcx, def_id, sig_id);
357 let caller_kind = fn_kind(tcx, def_id.into());
358 let callee_kind = fn_kind(tcx, sig_id);
359 match (caller_kind, callee_kind) {
360 (FnKind::Free, FnKind::Free) | (FnKind::Free, FnKind::AssocTrait) => {
361 build_predicates(tcx, sig_id, None, InheritanceKind::WithParent(true), args)
362 }
363
364 (FnKind::AssocTraitImpl, FnKind::AssocTrait) => build_predicates(
365 tcx,
366 sig_id,
367 Some(tcx.parent(def_id.into())),
368 InheritanceKind::Own,
369 args,
370 ),
371
372 (FnKind::AssocInherentImpl, FnKind::AssocTrait)
373 | (FnKind::AssocTrait, FnKind::AssocTrait)
374 | (FnKind::AssocInherentImpl, FnKind::Free)
375 | (FnKind::AssocTrait, FnKind::Free) => build_predicates(
376 tcx,
377 sig_id,
378 Some(tcx.parent(def_id.into())),
379 InheritanceKind::WithParent(false),
380 args,
381 ),
382
383 (FnKind::AssocTraitImpl, _)
386 | (_, FnKind::AssocTraitImpl)
387 | (_, FnKind::AssocInherentImpl) => unreachable!(),
388 }
389}
390
391fn check_constraints<'tcx>(
392 tcx: TyCtxt<'tcx>,
393 def_id: LocalDefId,
394 sig_id: DefId,
395) -> Result<(), ErrorGuaranteed> {
396 let mut ret = Ok(());
397
398 let mut emit = |descr| {
399 ret = Err(tcx.dcx().emit_err(crate::errors::UnsupportedDelegation {
400 span: tcx.def_span(def_id),
401 descr,
402 callee_span: tcx.def_span(sig_id),
403 }));
404 };
405
406 if let Some(local_sig_id) = sig_id.as_local()
407 && tcx.hir().opt_delegation_sig_id(local_sig_id).is_some()
408 {
409 emit("recursive delegation is not supported yet");
410 }
411
412 ret
413}
414
415pub(crate) fn inherit_sig_for_delegation_item<'tcx>(
416 tcx: TyCtxt<'tcx>,
417 def_id: LocalDefId,
418) -> &'tcx [Ty<'tcx>] {
419 let sig_id = tcx.hir().opt_delegation_sig_id(def_id).unwrap();
420 let caller_sig = tcx.fn_sig(sig_id);
421 if let Err(err) = check_constraints(tcx, def_id, sig_id) {
422 let sig_len = caller_sig.instantiate_identity().skip_binder().inputs().len() + 1;
423 let err_type = Ty::new_error(tcx, err);
424 return tcx.arena.alloc_from_iter((0..sig_len).map(|_| err_type));
425 }
426 let args = create_generic_args(tcx, def_id, sig_id);
427
428 let sig = caller_sig.instantiate(tcx, args).skip_binder();
431 let sig_iter = sig.inputs().iter().cloned().chain(std::iter::once(sig.output()));
432 tcx.arena.alloc_from_iter(sig_iter)
433}