rustc_infer/infer/canonical/
instantiate.rs

1//! This module contains code to instantiate new values into a
2//! `Canonical<'tcx, T>`.
3//!
4//! For an overview of what canonicalization is and how it fits into
5//! rustc, check out the [chapter in the rustc dev guide][c].
6//!
7//! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html
8
9use rustc_macros::extension;
10use rustc_middle::ty::{
11    self, DelayedMap, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeSuperVisitable,
12    TypeVisitableExt, TypeVisitor,
13};
14use rustc_type_ir::{TypeFlags, TypeVisitable};
15
16use crate::infer::canonical::{Canonical, CanonicalVarValues};
17
18/// FIXME(-Znext-solver): This or public because it is shared with the
19/// new trait solver implementation. We should deduplicate canonicalization.
20#[extension(pub trait CanonicalExt<'tcx, V>)]
21impl<'tcx, V> Canonical<'tcx, V> {
22    /// Instantiate the wrapped value, replacing each canonical value
23    /// with the value given in `var_values`.
24    fn instantiate(&self, tcx: TyCtxt<'tcx>, var_values: &CanonicalVarValues<'tcx>) -> V
25    where
26        V: TypeFoldable<TyCtxt<'tcx>>,
27    {
28        self.instantiate_projected(tcx, var_values, |value| value.clone())
29    }
30
31    /// Allows one to apply a instantiation to some subset of
32    /// `self.value`. Invoke `projection_fn` with `self.value` to get
33    /// a value V that is expressed in terms of the same canonical
34    /// variables bound in `self` (usually this extracts from subset
35    /// of `self`). Apply the instantiation `var_values` to this value
36    /// V, replacing each of the canonical variables.
37    fn instantiate_projected<T>(
38        &self,
39        tcx: TyCtxt<'tcx>,
40        var_values: &CanonicalVarValues<'tcx>,
41        projection_fn: impl FnOnce(&V) -> T,
42    ) -> T
43    where
44        T: TypeFoldable<TyCtxt<'tcx>>,
45    {
46        assert_eq!(self.variables.len(), var_values.len());
47        let value = projection_fn(&self.value);
48        instantiate_value(tcx, var_values, value)
49    }
50}
51
52/// Instantiate the values from `var_values` into `value`. `var_values`
53/// must be values for the set of canonical variables that appear in
54/// `value`.
55pub(super) fn instantiate_value<'tcx, T>(
56    tcx: TyCtxt<'tcx>,
57    var_values: &CanonicalVarValues<'tcx>,
58    value: T,
59) -> T
60where
61    T: TypeFoldable<TyCtxt<'tcx>>,
62{
63    if var_values.var_values.is_empty() {
64        return value;
65    }
66
67    value.fold_with(&mut CanonicalInstantiator {
68        tcx,
69        var_values: var_values.var_values,
70        cache: Default::default(),
71    })
72}
73
74/// Replaces the bound vars in a canonical binder with var values.
75struct CanonicalInstantiator<'tcx> {
76    tcx: TyCtxt<'tcx>,
77
78    // The values that the bound vars are are being instantiated with.
79    var_values: ty::GenericArgsRef<'tcx>,
80
81    // Because we use `ty::BoundVarIndexKind::Canonical`, we can cache
82    // based only on the entire ty, not worrying about a `DebruijnIndex`
83    cache: DelayedMap<Ty<'tcx>, Ty<'tcx>>,
84}
85
86impl<'tcx> TypeFolder<TyCtxt<'tcx>> for CanonicalInstantiator<'tcx> {
87    fn cx(&self) -> TyCtxt<'tcx> {
88        self.tcx
89    }
90
91    fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
92        match *t.kind() {
93            ty::Bound(ty::BoundVarIndexKind::Canonical, bound_ty) => {
94                self.var_values[bound_ty.var.as_usize()].expect_ty()
95            }
96            _ => {
97                if !t.has_type_flags(TypeFlags::HAS_CANONICAL_BOUND) {
98                    t
99                } else if let Some(&t) = self.cache.get(&t) {
100                    t
101                } else {
102                    let res = t.super_fold_with(self);
103                    assert!(self.cache.insert(t, res));
104                    res
105                }
106            }
107        }
108    }
109
110    fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
111        match r.kind() {
112            ty::ReBound(ty::BoundVarIndexKind::Canonical, br) => {
113                self.var_values[br.var.as_usize()].expect_region()
114            }
115            _ => r,
116        }
117    }
118
119    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
120        match ct.kind() {
121            ty::ConstKind::Bound(ty::BoundVarIndexKind::Canonical, bound_const) => {
122                self.var_values[bound_const.var.as_usize()].expect_const()
123            }
124            _ => ct.super_fold_with(self),
125        }
126    }
127
128    fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
129        if p.has_type_flags(TypeFlags::HAS_CANONICAL_BOUND) { p.super_fold_with(self) } else { p }
130    }
131
132    fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
133        if !c.has_type_flags(TypeFlags::HAS_CANONICAL_BOUND) {
134            return c;
135        }
136
137        // Our cache key is `(clauses, var_values)`, but we also don't care about
138        // var values that aren't named in the clauses, since they can change without
139        // affecting the output. Since `ParamEnv`s are cached first, we compute the
140        // last var value that is mentioned in the clauses, and cut off the list so
141        // that we have more hits in the cache.
142
143        // We also cache the computation of "highest var named by clauses" since that
144        // is both expensive (depending on the size of the clauses) and a pure function.
145        let index = *self
146            .tcx
147            .highest_var_in_clauses_cache
148            .lock()
149            .entry(c)
150            .or_insert_with(|| highest_var_in_clauses(c));
151        let c_args = &self.var_values[..=index];
152
153        if let Some(c) = self.tcx.clauses_cache.lock().get(&(c, c_args)) {
154            c
155        } else {
156            let folded = c.super_fold_with(self);
157            self.tcx.clauses_cache.lock().insert((c, c_args), folded);
158            folded
159        }
160    }
161}
162
163fn highest_var_in_clauses<'tcx>(c: ty::Clauses<'tcx>) -> usize {
164    struct HighestVarInClauses {
165        max_var: usize,
166    }
167    impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for HighestVarInClauses {
168        fn visit_ty(&mut self, t: Ty<'tcx>) {
169            if let ty::Bound(ty::BoundVarIndexKind::Canonical, bound_ty) = *t.kind() {
170                self.max_var = self.max_var.max(bound_ty.var.as_usize());
171            } else if t.has_type_flags(TypeFlags::HAS_CANONICAL_BOUND) {
172                t.super_visit_with(self);
173            }
174        }
175        fn visit_region(&mut self, r: ty::Region<'tcx>) {
176            if let ty::ReBound(ty::BoundVarIndexKind::Canonical, bound_region) = r.kind() {
177                self.max_var = self.max_var.max(bound_region.var.as_usize());
178            }
179        }
180        fn visit_const(&mut self, ct: ty::Const<'tcx>) {
181            if let ty::ConstKind::Bound(ty::BoundVarIndexKind::Canonical, bound_const) = ct.kind() {
182                self.max_var = self.max_var.max(bound_const.var.as_usize());
183            } else if ct.has_type_flags(TypeFlags::HAS_CANONICAL_BOUND) {
184                ct.super_visit_with(self);
185            }
186        }
187    }
188    let mut visitor = HighestVarInClauses { max_var: 0 };
189    c.visit_with(&mut visitor);
190    visitor.max_var
191}