rustc_infer/infer/snapshot/
fudge.rs

1use std::ops::Range;
2
3use rustc_data_structures::{snapshot_vec as sv, unify as ut};
4use rustc_middle::ty::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
5use rustc_middle::ty::{self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid};
6use rustc_type_ir::visit::TypeVisitableExt;
7use tracing::instrument;
8use ut::UnifyKey;
9
10use super::VariableLengths;
11use crate::infer::type_variable::TypeVariableOrigin;
12use crate::infer::unify_key::{ConstVariableValue, ConstVidKey};
13use crate::infer::{ConstVariableOrigin, InferCtxt, RegionVariableOrigin, UnificationTable};
14
15fn vars_since_snapshot<'tcx, T>(
16    table: &UnificationTable<'_, 'tcx, T>,
17    snapshot_var_len: usize,
18) -> Range<T>
19where
20    T: UnifyKey,
21    super::UndoLog<'tcx>: From<sv::UndoLog<ut::Delegate<T>>>,
22{
23    T::from_index(snapshot_var_len as u32)..T::from_index(table.len() as u32)
24}
25
26fn const_vars_since_snapshot<'tcx>(
27    table: &mut UnificationTable<'_, 'tcx, ConstVidKey<'tcx>>,
28    snapshot_var_len: usize,
29) -> (Range<ConstVid>, Vec<ConstVariableOrigin>) {
30    let range = vars_since_snapshot(table, snapshot_var_len);
31
32    (
33        range.start.vid..range.end.vid,
34        (range.start.index()..range.end.index())
35            .map(|index| match table.probe_value(ConstVid::from_u32(index)) {
36                ConstVariableValue::Known { value: _ } => {
37                    ConstVariableOrigin { param_def_id: None, span: rustc_span::DUMMY_SP }
38                }
39                ConstVariableValue::Unknown { origin, universe: _ } => origin,
40            })
41            .collect(),
42    )
43}
44
45impl<'tcx> InferCtxt<'tcx> {
46    /// This rather funky routine is used while processing expected
47    /// types. What happens here is that we want to propagate a
48    /// coercion through the return type of a fn to its
49    /// argument. Consider the type of `Option::Some`, which is
50    /// basically `for<T> fn(T) -> Option<T>`. So if we have an
51    /// expression `Some(&[1, 2, 3])`, and that has the expected type
52    /// `Option<&[u32]>`, we would like to type check `&[1, 2, 3]`
53    /// with the expectation of `&[u32]`. This will cause us to coerce
54    /// from `&[u32; 3]` to `&[u32]` and make the users life more
55    /// pleasant.
56    ///
57    /// The way we do this is using `fudge_inference_if_ok`. What the
58    /// routine actually does is to start a snapshot and execute the
59    /// closure `f`. In our example above, what this closure will do
60    /// is to unify the expectation (`Option<&[u32]>`) with the actual
61    /// return type (`Option<?T>`, where `?T` represents the variable
62    /// instantiated for `T`). This will cause `?T` to be unified
63    /// with `&?a [u32]`, where `?a` is a fresh lifetime variable. The
64    /// input type (`?T`) is then returned by `f()`.
65    ///
66    /// At this point, `fudge_inference_if_ok` will normalize all type
67    /// variables, converting `?T` to `&?a [u32]` and end the
68    /// snapshot. The problem is that we can't just return this type
69    /// out, because it references the region variable `?a`, and that
70    /// region variable was popped when we popped the snapshot.
71    ///
72    /// So what we do is to keep a list (`region_vars`, in the code below)
73    /// of region variables created during the snapshot (here, `?a`). We
74    /// fold the return value and replace any such regions with a *new*
75    /// region variable (e.g., `?b`) and return the result (`&?b [u32]`).
76    /// This can then be used as the expectation for the fn argument.
77    ///
78    /// The important point here is that, for soundness purposes, the
79    /// regions in question are not particularly important. We will
80    /// use the expected types to guide coercions, but we will still
81    /// type-check the resulting types from those coercions against
82    /// the actual types (`?T`, `Option<?T>`) -- and remember that
83    /// after the snapshot is popped, the variable `?T` is no longer
84    /// unified.
85    #[instrument(skip(self, f), level = "debug")]
86    pub fn fudge_inference_if_ok<T, E, F>(&self, f: F) -> Result<T, E>
87    where
88        F: FnOnce() -> Result<T, E>,
89        T: TypeFoldable<TyCtxt<'tcx>>,
90    {
91        let variable_lengths = self.variable_lengths();
92        let (snapshot_vars, value) = self.probe(|_| {
93            let value = f()?;
94            // At this point, `value` could in principle refer
95            // to inference variables that have been created during
96            // the snapshot. Once we exit `probe()`, those are
97            // going to be popped, so we will have to
98            // eliminate any references to them.
99            let snapshot_vars = SnapshotVarData::new(self, variable_lengths);
100            Ok((snapshot_vars, self.resolve_vars_if_possible(value)))
101        })?;
102
103        // At this point, we need to replace any of the now-popped
104        // type/region variables that appear in `value` with a fresh
105        // variable of the appropriate kind. We can't do this during
106        // the probe because they would just get popped then too. =)
107        Ok(self.fudge_inference(snapshot_vars, value))
108    }
109
110    fn fudge_inference<T: TypeFoldable<TyCtxt<'tcx>>>(
111        &self,
112        snapshot_vars: SnapshotVarData,
113        value: T,
114    ) -> T {
115        // Micro-optimization: if no variables have been created, then
116        // `value` can't refer to any of them. =) So we can just return it.
117        if snapshot_vars.is_empty() {
118            value
119        } else {
120            value.fold_with(&mut InferenceFudger { infcx: self, snapshot_vars })
121        }
122    }
123}
124
125struct SnapshotVarData {
126    region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
127    type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
128    int_vars: Range<IntVid>,
129    float_vars: Range<FloatVid>,
130    const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
131}
132
133impl SnapshotVarData {
134    fn new(infcx: &InferCtxt<'_>, vars_pre_snapshot: VariableLengths) -> SnapshotVarData {
135        let mut inner = infcx.inner.borrow_mut();
136        let region_vars = inner
137            .unwrap_region_constraints()
138            .vars_since_snapshot(vars_pre_snapshot.region_constraints_len);
139        let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
140        let int_vars =
141            vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
142        let float_vars =
143            vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);
144
145        let const_vars = const_vars_since_snapshot(
146            &mut inner.const_unification_table(),
147            vars_pre_snapshot.const_var_len,
148        );
149        SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars }
150    }
151
152    fn is_empty(&self) -> bool {
153        let SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars } = self;
154        region_vars.0.is_empty()
155            && type_vars.0.is_empty()
156            && int_vars.is_empty()
157            && float_vars.is_empty()
158            && const_vars.0.is_empty()
159    }
160}
161
162struct InferenceFudger<'a, 'tcx> {
163    infcx: &'a InferCtxt<'tcx>,
164    snapshot_vars: SnapshotVarData,
165}
166
167impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
168    fn cx(&self) -> TyCtxt<'tcx> {
169        self.infcx.tcx
170    }
171
172    fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
173        if let &ty::Infer(infer_ty) = ty.kind() {
174            match infer_ty {
175                ty::TyVar(vid) => {
176                    if self.snapshot_vars.type_vars.0.contains(&vid) {
177                        // This variable was created during the fudging.
178                        // Recreate it with a fresh variable here.
179                        let idx = vid.as_usize() - self.snapshot_vars.type_vars.0.start.as_usize();
180                        let origin = self.snapshot_vars.type_vars.1[idx];
181                        self.infcx.next_ty_var_with_origin(origin)
182                    } else {
183                        // This variable was created before the
184                        // "fudging". Since we refresh all type
185                        // variables to their binding anyhow, we know
186                        // that it is unbound, so we can just return
187                        // it.
188                        debug_assert!(
189                            self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
190                        );
191                        ty
192                    }
193                }
194                ty::IntVar(vid) => {
195                    if self.snapshot_vars.int_vars.contains(&vid) {
196                        self.infcx.next_int_var()
197                    } else {
198                        ty
199                    }
200                }
201                ty::FloatVar(vid) => {
202                    if self.snapshot_vars.float_vars.contains(&vid) {
203                        self.infcx.next_float_var()
204                    } else {
205                        ty
206                    }
207                }
208                ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => {
209                    unreachable!("unexpected fresh infcx var")
210                }
211            }
212        } else if ty.has_infer() {
213            ty.super_fold_with(self)
214        } else {
215            ty
216        }
217    }
218
219    fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
220        if let ty::ReVar(vid) = r.kind() {
221            if self.snapshot_vars.region_vars.0.contains(&vid) {
222                let idx = vid.index() - self.snapshot_vars.region_vars.0.start.index();
223                let origin = self.snapshot_vars.region_vars.1[idx];
224                self.infcx.next_region_var(origin)
225            } else {
226                r
227            }
228        } else {
229            r
230        }
231    }
232
233    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
234        if let ty::ConstKind::Infer(infer_ct) = ct.kind() {
235            match infer_ct {
236                ty::InferConst::Var(vid) => {
237                    if self.snapshot_vars.const_vars.0.contains(&vid) {
238                        let idx = vid.index() - self.snapshot_vars.const_vars.0.start.index();
239                        let origin = self.snapshot_vars.const_vars.1[idx];
240                        self.infcx.next_const_var_with_origin(origin)
241                    } else {
242                        ct
243                    }
244                }
245                ty::InferConst::Fresh(_) => {
246                    unreachable!("unexpected fresh infcx var")
247                }
248            }
249        } else if ct.has_infer() {
250            ct.super_fold_with(self)
251        } else {
252            ct
253        }
254    }
255}