rustc_trait_selection/traits/
util.rs

1use std::collections::{BTreeMap, VecDeque};
2
3use rustc_data_structures::fx::{FxHashSet, FxIndexMap};
4use rustc_hir::def_id::DefId;
5use rustc_infer::infer::InferCtxt;
6pub use rustc_infer::traits::util::*;
7use rustc_middle::bug;
8use rustc_middle::ty::{
9    self, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
10};
11use rustc_span::Span;
12use smallvec::{SmallVec, smallvec};
13use tracing::debug;
14
15/// Return the trait and projection predicates that come from eagerly expanding the
16/// trait aliases in the list of clauses. For each trait predicate, record a stack
17/// of spans that trace from the user-written trait alias bound. For projection predicates,
18/// just record the span of the projection itself.
19///
20/// For trait aliases, we don't deduplicte the predicates, since we currently do not
21/// consider duplicated traits as a single trait for the purposes of our "one trait principal"
22/// restriction; however, for projections we do deduplicate them.
23///
24/// ```rust,ignore (fails)
25/// trait Bar {}
26/// trait Foo = Bar + Bar;
27///
28/// let dyn_incompatible: dyn Foo; // bad, two `Bar` principals.
29/// ```
30pub fn expand_trait_aliases<'tcx>(
31    tcx: TyCtxt<'tcx>,
32    clauses: impl IntoIterator<Item = (ty::Clause<'tcx>, Span)>,
33) -> (
34    Vec<(ty::PolyTraitPredicate<'tcx>, SmallVec<[Span; 1]>)>,
35    Vec<(ty::PolyProjectionPredicate<'tcx>, Span)>,
36) {
37    let mut trait_preds = vec![];
38    let mut projection_preds = vec![];
39    let mut seen_projection_preds = FxHashSet::default();
40
41    let mut queue: VecDeque<_> = clauses.into_iter().map(|(p, s)| (p, smallvec![s])).collect();
42
43    while let Some((clause, spans)) = queue.pop_front() {
44        match clause.kind().skip_binder() {
45            ty::ClauseKind::Trait(trait_pred) => {
46                if tcx.is_trait_alias(trait_pred.def_id()) {
47                    queue.extend(
48                        tcx.explicit_super_predicates_of(trait_pred.def_id())
49                            .iter_identity_copied()
50                            .map(|(super_clause, span)| {
51                                let mut spans = spans.clone();
52                                spans.push(span);
53                                (
54                                    super_clause.instantiate_supertrait(
55                                        tcx,
56                                        clause.kind().rebind(trait_pred.trait_ref),
57                                    ),
58                                    spans,
59                                )
60                            }),
61                    );
62                } else {
63                    trait_preds.push((clause.kind().rebind(trait_pred), spans));
64                }
65            }
66            ty::ClauseKind::Projection(projection_pred) => {
67                let projection_pred = clause.kind().rebind(projection_pred);
68                if !seen_projection_preds.insert(tcx.anonymize_bound_vars(projection_pred)) {
69                    continue;
70                }
71                projection_preds.push((projection_pred, *spans.last().unwrap()));
72            }
73            ty::ClauseKind::RegionOutlives(..)
74            | ty::ClauseKind::TypeOutlives(..)
75            | ty::ClauseKind::ConstArgHasType(_, _)
76            | ty::ClauseKind::WellFormed(_)
77            | ty::ClauseKind::ConstEvaluatable(_)
78            | ty::ClauseKind::HostEffect(..) => {}
79        }
80    }
81
82    (trait_preds, projection_preds)
83}
84
85///////////////////////////////////////////////////////////////////////////
86// Other
87///////////////////////////////////////////////////////////////////////////
88
89/// Casts a trait reference into a reference to one of its super
90/// traits; returns `None` if `target_trait_def_id` is not a
91/// supertrait.
92pub fn upcast_choices<'tcx>(
93    tcx: TyCtxt<'tcx>,
94    source_trait_ref: ty::PolyTraitRef<'tcx>,
95    target_trait_def_id: DefId,
96) -> Vec<ty::PolyTraitRef<'tcx>> {
97    if source_trait_ref.def_id() == target_trait_def_id {
98        return vec![source_trait_ref]; // Shortcut the most common case.
99    }
100
101    supertraits(tcx, source_trait_ref).filter(|r| r.def_id() == target_trait_def_id).collect()
102}
103
104pub(crate) fn closure_trait_ref_and_return_type<'tcx>(
105    tcx: TyCtxt<'tcx>,
106    fn_trait_def_id: DefId,
107    self_ty: Ty<'tcx>,
108    sig: ty::PolyFnSig<'tcx>,
109    tuple_arguments: TupleArgumentsFlag,
110) -> ty::Binder<'tcx, (ty::TraitRef<'tcx>, Ty<'tcx>)> {
111    assert!(!self_ty.has_escaping_bound_vars());
112    let arguments_tuple = match tuple_arguments {
113        TupleArgumentsFlag::No => sig.skip_binder().inputs()[0],
114        TupleArgumentsFlag::Yes => Ty::new_tup(tcx, sig.skip_binder().inputs()),
115    };
116    let trait_ref = ty::TraitRef::new(tcx, fn_trait_def_id, [self_ty, arguments_tuple]);
117    sig.map_bound(|sig| (trait_ref, sig.output()))
118}
119
120pub(crate) fn coroutine_trait_ref_and_outputs<'tcx>(
121    tcx: TyCtxt<'tcx>,
122    fn_trait_def_id: DefId,
123    self_ty: Ty<'tcx>,
124    sig: ty::GenSig<TyCtxt<'tcx>>,
125) -> (ty::TraitRef<'tcx>, Ty<'tcx>, Ty<'tcx>) {
126    assert!(!self_ty.has_escaping_bound_vars());
127    let trait_ref = ty::TraitRef::new(tcx, fn_trait_def_id, [self_ty, sig.resume_ty]);
128    (trait_ref, sig.yield_ty, sig.return_ty)
129}
130
131pub(crate) fn future_trait_ref_and_outputs<'tcx>(
132    tcx: TyCtxt<'tcx>,
133    fn_trait_def_id: DefId,
134    self_ty: Ty<'tcx>,
135    sig: ty::GenSig<TyCtxt<'tcx>>,
136) -> (ty::TraitRef<'tcx>, Ty<'tcx>) {
137    assert!(!self_ty.has_escaping_bound_vars());
138    let trait_ref = ty::TraitRef::new(tcx, fn_trait_def_id, [self_ty]);
139    (trait_ref, sig.return_ty)
140}
141
142pub(crate) fn iterator_trait_ref_and_outputs<'tcx>(
143    tcx: TyCtxt<'tcx>,
144    iterator_def_id: DefId,
145    self_ty: Ty<'tcx>,
146    sig: ty::GenSig<TyCtxt<'tcx>>,
147) -> (ty::TraitRef<'tcx>, Ty<'tcx>) {
148    assert!(!self_ty.has_escaping_bound_vars());
149    let trait_ref = ty::TraitRef::new(tcx, iterator_def_id, [self_ty]);
150    (trait_ref, sig.yield_ty)
151}
152
153pub(crate) fn async_iterator_trait_ref_and_outputs<'tcx>(
154    tcx: TyCtxt<'tcx>,
155    async_iterator_def_id: DefId,
156    self_ty: Ty<'tcx>,
157    sig: ty::GenSig<TyCtxt<'tcx>>,
158) -> (ty::TraitRef<'tcx>, Ty<'tcx>) {
159    assert!(!self_ty.has_escaping_bound_vars());
160    let trait_ref = ty::TraitRef::new(tcx, async_iterator_def_id, [self_ty]);
161    (trait_ref, sig.yield_ty)
162}
163
164pub fn impl_item_is_final(tcx: TyCtxt<'_>, assoc_item: &ty::AssocItem) -> bool {
165    assoc_item.defaultness(tcx).is_final()
166        && tcx.defaultness(assoc_item.container_id(tcx)).is_final()
167}
168
169pub(crate) enum TupleArgumentsFlag {
170    Yes,
171    No,
172}
173
174/// Executes `f` on `value` after replacing all escaping bound variables with placeholders
175/// and then replaces these placeholders with the original bound variables in the result.
176///
177/// In most places, bound variables should be replaced right when entering a binder, making
178/// this function unnecessary. However, normalization currently does not do that, so we have
179/// to do this lazily.
180///
181/// You should not add any additional uses of this function, at least not without first
182/// discussing it with t-types.
183///
184/// FIXME(@lcnr): We may even consider experimenting with eagerly replacing bound vars during
185/// normalization as well, at which point this function will be unnecessary and can be removed.
186pub fn with_replaced_escaping_bound_vars<
187    'a,
188    'tcx,
189    T: TypeFoldable<TyCtxt<'tcx>>,
190    R: TypeFoldable<TyCtxt<'tcx>>,
191>(
192    infcx: &'a InferCtxt<'tcx>,
193    universe_indices: &'a mut Vec<Option<ty::UniverseIndex>>,
194    value: T,
195    f: impl FnOnce(T) -> R,
196) -> R {
197    if value.has_escaping_bound_vars() {
198        let (value, mapped_regions, mapped_types, mapped_consts) =
199            BoundVarReplacer::replace_bound_vars(infcx, universe_indices, value);
200        let result = f(value);
201        PlaceholderReplacer::replace_placeholders(
202            infcx,
203            mapped_regions,
204            mapped_types,
205            mapped_consts,
206            universe_indices,
207            result,
208        )
209    } else {
210        f(value)
211    }
212}
213
214pub struct BoundVarReplacer<'a, 'tcx> {
215    infcx: &'a InferCtxt<'tcx>,
216    // These three maps track the bound variable that were replaced by placeholders. It might be
217    // nice to remove these since we already have the `kind` in the placeholder; we really just need
218    // the `var` (but we *could* bring that into scope if we were to track them as we pass them).
219    mapped_regions: FxIndexMap<ty::PlaceholderRegion, ty::BoundRegion>,
220    mapped_types: FxIndexMap<ty::PlaceholderType, ty::BoundTy>,
221    mapped_consts: BTreeMap<ty::PlaceholderConst, ty::BoundVar>,
222    // The current depth relative to *this* folding, *not* the entire normalization. In other words,
223    // the depth of binders we've passed here.
224    current_index: ty::DebruijnIndex,
225    // The `UniverseIndex` of the binding levels above us. These are optional, since we are lazy:
226    // we don't actually create a universe until we see a bound var we have to replace.
227    universe_indices: &'a mut Vec<Option<ty::UniverseIndex>>,
228}
229
230impl<'a, 'tcx> BoundVarReplacer<'a, 'tcx> {
231    /// Returns `Some` if we *were* able to replace bound vars. If there are any bound vars that
232    /// use a binding level above `universe_indices.len()`, we fail.
233    pub fn replace_bound_vars<T: TypeFoldable<TyCtxt<'tcx>>>(
234        infcx: &'a InferCtxt<'tcx>,
235        universe_indices: &'a mut Vec<Option<ty::UniverseIndex>>,
236        value: T,
237    ) -> (
238        T,
239        FxIndexMap<ty::PlaceholderRegion, ty::BoundRegion>,
240        FxIndexMap<ty::PlaceholderType, ty::BoundTy>,
241        BTreeMap<ty::PlaceholderConst, ty::BoundVar>,
242    ) {
243        let mapped_regions: FxIndexMap<ty::PlaceholderRegion, ty::BoundRegion> =
244            FxIndexMap::default();
245        let mapped_types: FxIndexMap<ty::PlaceholderType, ty::BoundTy> = FxIndexMap::default();
246        let mapped_consts: BTreeMap<ty::PlaceholderConst, ty::BoundVar> = BTreeMap::new();
247
248        let mut replacer = BoundVarReplacer {
249            infcx,
250            mapped_regions,
251            mapped_types,
252            mapped_consts,
253            current_index: ty::INNERMOST,
254            universe_indices,
255        };
256
257        let value = value.fold_with(&mut replacer);
258
259        (value, replacer.mapped_regions, replacer.mapped_types, replacer.mapped_consts)
260    }
261
262    fn universe_for(&mut self, debruijn: ty::DebruijnIndex) -> ty::UniverseIndex {
263        let infcx = self.infcx;
264        let index =
265            self.universe_indices.len() + self.current_index.as_usize() - debruijn.as_usize() - 1;
266        let universe = self.universe_indices[index].unwrap_or_else(|| {
267            for i in self.universe_indices.iter_mut().take(index + 1) {
268                *i = i.or_else(|| Some(infcx.create_next_universe()))
269            }
270            self.universe_indices[index].unwrap()
271        });
272        universe
273    }
274}
275
276impl<'tcx> TypeFolder<TyCtxt<'tcx>> for BoundVarReplacer<'_, 'tcx> {
277    fn cx(&self) -> TyCtxt<'tcx> {
278        self.infcx.tcx
279    }
280
281    fn fold_binder<T: TypeFoldable<TyCtxt<'tcx>>>(
282        &mut self,
283        t: ty::Binder<'tcx, T>,
284    ) -> ty::Binder<'tcx, T> {
285        self.current_index.shift_in(1);
286        let t = t.super_fold_with(self);
287        self.current_index.shift_out(1);
288        t
289    }
290
291    fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
292        match *r {
293            ty::ReBound(debruijn, _)
294                if debruijn.as_usize()
295                    >= self.current_index.as_usize() + self.universe_indices.len() =>
296            {
297                bug!(
298                    "Bound vars {r:#?} outside of `self.universe_indices`: {:#?}",
299                    self.universe_indices
300                );
301            }
302            ty::ReBound(debruijn, br) if debruijn >= self.current_index => {
303                let universe = self.universe_for(debruijn);
304                let p = ty::PlaceholderRegion { universe, bound: br };
305                self.mapped_regions.insert(p, br);
306                ty::Region::new_placeholder(self.infcx.tcx, p)
307            }
308            _ => r,
309        }
310    }
311
312    fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
313        match *t.kind() {
314            ty::Bound(debruijn, _)
315                if debruijn.as_usize() + 1
316                    > self.current_index.as_usize() + self.universe_indices.len() =>
317            {
318                bug!(
319                    "Bound vars {t:#?} outside of `self.universe_indices`: {:#?}",
320                    self.universe_indices
321                );
322            }
323            ty::Bound(debruijn, bound_ty) if debruijn >= self.current_index => {
324                let universe = self.universe_for(debruijn);
325                let p = ty::PlaceholderType { universe, bound: bound_ty };
326                self.mapped_types.insert(p, bound_ty);
327                Ty::new_placeholder(self.infcx.tcx, p)
328            }
329            _ if t.has_vars_bound_at_or_above(self.current_index) => t.super_fold_with(self),
330            _ => t,
331        }
332    }
333
334    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
335        match ct.kind() {
336            ty::ConstKind::Bound(debruijn, _)
337                if debruijn.as_usize() + 1
338                    > self.current_index.as_usize() + self.universe_indices.len() =>
339            {
340                bug!(
341                    "Bound vars {ct:#?} outside of `self.universe_indices`: {:#?}",
342                    self.universe_indices
343                );
344            }
345            ty::ConstKind::Bound(debruijn, bound_const) if debruijn >= self.current_index => {
346                let universe = self.universe_for(debruijn);
347                let p = ty::PlaceholderConst { universe, bound: bound_const };
348                self.mapped_consts.insert(p, bound_const);
349                ty::Const::new_placeholder(self.infcx.tcx, p)
350            }
351            _ => ct.super_fold_with(self),
352        }
353    }
354
355    fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
356        if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
357    }
358}
359
360/// The inverse of [`BoundVarReplacer`]: replaces placeholders with the bound vars from which they came.
361pub struct PlaceholderReplacer<'a, 'tcx> {
362    infcx: &'a InferCtxt<'tcx>,
363    mapped_regions: FxIndexMap<ty::PlaceholderRegion, ty::BoundRegion>,
364    mapped_types: FxIndexMap<ty::PlaceholderType, ty::BoundTy>,
365    mapped_consts: BTreeMap<ty::PlaceholderConst, ty::BoundVar>,
366    universe_indices: &'a [Option<ty::UniverseIndex>],
367    current_index: ty::DebruijnIndex,
368}
369
370impl<'a, 'tcx> PlaceholderReplacer<'a, 'tcx> {
371    pub fn replace_placeholders<T: TypeFoldable<TyCtxt<'tcx>>>(
372        infcx: &'a InferCtxt<'tcx>,
373        mapped_regions: FxIndexMap<ty::PlaceholderRegion, ty::BoundRegion>,
374        mapped_types: FxIndexMap<ty::PlaceholderType, ty::BoundTy>,
375        mapped_consts: BTreeMap<ty::PlaceholderConst, ty::BoundVar>,
376        universe_indices: &'a [Option<ty::UniverseIndex>],
377        value: T,
378    ) -> T {
379        let mut replacer = PlaceholderReplacer {
380            infcx,
381            mapped_regions,
382            mapped_types,
383            mapped_consts,
384            universe_indices,
385            current_index: ty::INNERMOST,
386        };
387        value.fold_with(&mut replacer)
388    }
389}
390
391impl<'tcx> TypeFolder<TyCtxt<'tcx>> for PlaceholderReplacer<'_, 'tcx> {
392    fn cx(&self) -> TyCtxt<'tcx> {
393        self.infcx.tcx
394    }
395
396    fn fold_binder<T: TypeFoldable<TyCtxt<'tcx>>>(
397        &mut self,
398        t: ty::Binder<'tcx, T>,
399    ) -> ty::Binder<'tcx, T> {
400        if !t.has_placeholders() && !t.has_infer() {
401            return t;
402        }
403        self.current_index.shift_in(1);
404        let t = t.super_fold_with(self);
405        self.current_index.shift_out(1);
406        t
407    }
408
409    fn fold_region(&mut self, r0: ty::Region<'tcx>) -> ty::Region<'tcx> {
410        let r1 = match *r0 {
411            ty::ReVar(vid) => self
412                .infcx
413                .inner
414                .borrow_mut()
415                .unwrap_region_constraints()
416                .opportunistic_resolve_var(self.infcx.tcx, vid),
417            _ => r0,
418        };
419
420        let r2 = match *r1 {
421            ty::RePlaceholder(p) => {
422                let replace_var = self.mapped_regions.get(&p);
423                match replace_var {
424                    Some(replace_var) => {
425                        let index = self
426                            .universe_indices
427                            .iter()
428                            .position(|u| matches!(u, Some(pu) if *pu == p.universe))
429                            .unwrap_or_else(|| bug!("Unexpected placeholder universe."));
430                        let db = ty::DebruijnIndex::from_usize(
431                            self.universe_indices.len() - index + self.current_index.as_usize() - 1,
432                        );
433                        ty::Region::new_bound(self.cx(), db, *replace_var)
434                    }
435                    None => r1,
436                }
437            }
438            _ => r1,
439        };
440
441        debug!(?r0, ?r1, ?r2, "fold_region");
442
443        r2
444    }
445
446    fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
447        let ty = self.infcx.shallow_resolve(ty);
448        match *ty.kind() {
449            ty::Placeholder(p) => {
450                let replace_var = self.mapped_types.get(&p);
451                match replace_var {
452                    Some(replace_var) => {
453                        let index = self
454                            .universe_indices
455                            .iter()
456                            .position(|u| matches!(u, Some(pu) if *pu == p.universe))
457                            .unwrap_or_else(|| bug!("Unexpected placeholder universe."));
458                        let db = ty::DebruijnIndex::from_usize(
459                            self.universe_indices.len() - index + self.current_index.as_usize() - 1,
460                        );
461                        Ty::new_bound(self.infcx.tcx, db, *replace_var)
462                    }
463                    None => {
464                        if ty.has_infer() {
465                            ty.super_fold_with(self)
466                        } else {
467                            ty
468                        }
469                    }
470                }
471            }
472
473            _ if ty.has_placeholders() || ty.has_infer() => ty.super_fold_with(self),
474            _ => ty,
475        }
476    }
477
478    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
479        let ct = self.infcx.shallow_resolve_const(ct);
480        if let ty::ConstKind::Placeholder(p) = ct.kind() {
481            let replace_var = self.mapped_consts.get(&p);
482            match replace_var {
483                Some(replace_var) => {
484                    let index = self
485                        .universe_indices
486                        .iter()
487                        .position(|u| matches!(u, Some(pu) if *pu == p.universe))
488                        .unwrap_or_else(|| bug!("Unexpected placeholder universe."));
489                    let db = ty::DebruijnIndex::from_usize(
490                        self.universe_indices.len() - index + self.current_index.as_usize() - 1,
491                    );
492                    ty::Const::new_bound(self.infcx.tcx, db, *replace_var)
493                }
494                None => {
495                    if ct.has_infer() {
496                        ct.super_fold_with(self)
497                    } else {
498                        ct
499                    }
500                }
501            }
502        } else {
503            ct.super_fold_with(self)
504        }
505    }
506}