rustc_infer/infer/snapshot/
fudge.rs

1use std::fmt::Debug;
2use std::ops::Range;
3
4use rustc_data_structures::{snapshot_vec as sv, unify as ut};
5use rustc_middle::ty::{
6    self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid, TypeFoldable, TypeFolder,
7    TypeSuperFoldable, TypeVisitableExt,
8};
9use tracing::instrument;
10use ut::UnifyKey;
11
12use super::VariableLengths;
13use crate::infer::type_variable::TypeVariableOrigin;
14use crate::infer::unify_key::{ConstVariableValue, ConstVidKey};
15use crate::infer::{ConstVariableOrigin, InferCtxt, RegionVariableOrigin, UnificationTable};
16
17fn vars_since_snapshot<'tcx, T>(
18    table: &UnificationTable<'_, 'tcx, T>,
19    snapshot_var_len: usize,
20) -> Range<T>
21where
22    T: UnifyKey,
23    super::UndoLog<'tcx>: From<sv::UndoLog<ut::Delegate<T>>>,
24{
25    T::from_index(snapshot_var_len as u32)..T::from_index(table.len() as u32)
26}
27
28fn const_vars_since_snapshot<'tcx>(
29    table: &mut UnificationTable<'_, 'tcx, ConstVidKey<'tcx>>,
30    snapshot_var_len: usize,
31) -> (Range<ConstVid>, Vec<ConstVariableOrigin>) {
32    let range = vars_since_snapshot(table, snapshot_var_len);
33    let range = range.start.vid..range.end.vid;
34
35    (
36        range.clone(),
37        range
38            .map(|index| match table.probe_value(index) {
39                ConstVariableValue::Known { value: _ } => {
40                    ConstVariableOrigin { param_def_id: None, span: rustc_span::DUMMY_SP }
41                }
42                ConstVariableValue::Unknown { origin, universe: _ } => origin,
43            })
44            .collect(),
45    )
46}
47
48impl<'tcx> InferCtxt<'tcx> {
49    /// This rather funky routine is used while processing expected
50    /// types. What happens here is that we want to propagate a
51    /// coercion through the return type of a fn to its
52    /// argument. Consider the type of `Option::Some`, which is
53    /// basically `for<T> fn(T) -> Option<T>`. So if we have an
54    /// expression `Some(&[1, 2, 3])`, and that has the expected type
55    /// `Option<&[u32]>`, we would like to type check `&[1, 2, 3]`
56    /// with the expectation of `&[u32]`. This will cause us to coerce
57    /// from `&[u32; 3]` to `&[u32]` and make the users life more
58    /// pleasant.
59    ///
60    /// The way we do this is using `fudge_inference_if_ok`. What the
61    /// routine actually does is to start a snapshot and execute the
62    /// closure `f`. In our example above, what this closure will do
63    /// is to unify the expectation (`Option<&[u32]>`) with the actual
64    /// return type (`Option<?T>`, where `?T` represents the variable
65    /// instantiated for `T`). This will cause `?T` to be unified
66    /// with `&?a [u32]`, where `?a` is a fresh lifetime variable. The
67    /// input type (`?T`) is then returned by `f()`.
68    ///
69    /// At this point, `fudge_inference_if_ok` will normalize all type
70    /// variables, converting `?T` to `&?a [u32]` and end the
71    /// snapshot. The problem is that we can't just return this type
72    /// out, because it references the region variable `?a`, and that
73    /// region variable was popped when we popped the snapshot.
74    ///
75    /// So what we do is to keep a list (`region_vars`, in the code below)
76    /// of region variables created during the snapshot (here, `?a`). We
77    /// fold the return value and replace any such regions with a *new*
78    /// region variable (e.g., `?b`) and return the result (`&?b [u32]`).
79    /// This can then be used as the expectation for the fn argument.
80    ///
81    /// The important point here is that, for soundness purposes, the
82    /// regions in question are not particularly important. We will
83    /// use the expected types to guide coercions, but we will still
84    /// type-check the resulting types from those coercions against
85    /// the actual types (`?T`, `Option<?T>`) -- and remember that
86    /// after the snapshot is popped, the variable `?T` is no longer
87    /// unified.
88    #[instrument(skip(self, f), level = "debug", ret)]
89    pub fn fudge_inference_if_ok<T, E, F>(&self, f: F) -> Result<T, E>
90    where
91        F: FnOnce() -> Result<T, E>,
92        T: TypeFoldable<TyCtxt<'tcx>>,
93        E: Debug,
94    {
95        let variable_lengths = self.variable_lengths();
96        let (snapshot_vars, value) = self.probe(|_| {
97            let value = f()?;
98            // At this point, `value` could in principle refer
99            // to inference variables that have been created during
100            // the snapshot. Once we exit `probe()`, those are
101            // going to be popped, so we will have to
102            // eliminate any references to them.
103            let snapshot_vars = SnapshotVarData::new(self, variable_lengths);
104            Ok((snapshot_vars, self.resolve_vars_if_possible(value)))
105        })?;
106
107        // At this point, we need to replace any of the now-popped
108        // type/region variables that appear in `value` with a fresh
109        // variable of the appropriate kind. We can't do this during
110        // the probe because they would just get popped then too. =)
111        Ok(self.fudge_inference(snapshot_vars, value))
112    }
113
114    fn fudge_inference<T: TypeFoldable<TyCtxt<'tcx>>>(
115        &self,
116        snapshot_vars: SnapshotVarData,
117        value: T,
118    ) -> T {
119        // Micro-optimization: if no variables have been created, then
120        // `value` can't refer to any of them. =) So we can just return it.
121        if snapshot_vars.is_empty() {
122            value
123        } else {
124            value.fold_with(&mut InferenceFudger { infcx: self, snapshot_vars })
125        }
126    }
127}
128
129struct SnapshotVarData {
130    region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
131    type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
132    int_vars: Range<IntVid>,
133    float_vars: Range<FloatVid>,
134    const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
135}
136
137impl SnapshotVarData {
138    fn new(infcx: &InferCtxt<'_>, vars_pre_snapshot: VariableLengths) -> SnapshotVarData {
139        let mut inner = infcx.inner.borrow_mut();
140        let region_vars = inner
141            .unwrap_region_constraints()
142            .vars_since_snapshot(vars_pre_snapshot.region_constraints_len);
143        let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
144        let int_vars =
145            vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
146        let float_vars =
147            vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);
148
149        let const_vars = const_vars_since_snapshot(
150            &mut inner.const_unification_table(),
151            vars_pre_snapshot.const_var_len,
152        );
153        SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars }
154    }
155
156    fn is_empty(&self) -> bool {
157        let SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars } = self;
158        region_vars.0.is_empty()
159            && type_vars.0.is_empty()
160            && int_vars.is_empty()
161            && float_vars.is_empty()
162            && const_vars.0.is_empty()
163    }
164}
165
166struct InferenceFudger<'a, 'tcx> {
167    infcx: &'a InferCtxt<'tcx>,
168    snapshot_vars: SnapshotVarData,
169}
170
171impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
172    fn cx(&self) -> TyCtxt<'tcx> {
173        self.infcx.tcx
174    }
175
176    fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
177        if let &ty::Infer(infer_ty) = ty.kind() {
178            match infer_ty {
179                ty::TyVar(vid) => {
180                    if self.snapshot_vars.type_vars.0.contains(&vid) {
181                        // This variable was created during the fudging.
182                        // Recreate it with a fresh variable here.
183                        let idx = vid.as_usize() - self.snapshot_vars.type_vars.0.start.as_usize();
184                        let origin = self.snapshot_vars.type_vars.1[idx];
185                        self.infcx.next_ty_var_with_origin(origin)
186                    } else {
187                        // This variable was created before the
188                        // "fudging". Since we refresh all type
189                        // variables to their binding anyhow, we know
190                        // that it is unbound, so we can just return
191                        // it.
192                        debug_assert!(
193                            self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
194                        );
195                        ty
196                    }
197                }
198                ty::IntVar(vid) => {
199                    if self.snapshot_vars.int_vars.contains(&vid) {
200                        self.infcx.next_int_var()
201                    } else {
202                        ty
203                    }
204                }
205                ty::FloatVar(vid) => {
206                    if self.snapshot_vars.float_vars.contains(&vid) {
207                        self.infcx.next_float_var()
208                    } else {
209                        ty
210                    }
211                }
212                ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => {
213                    unreachable!("unexpected fresh infcx var")
214                }
215            }
216        } else if ty.has_infer() {
217            ty.super_fold_with(self)
218        } else {
219            ty
220        }
221    }
222
223    fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
224        if let ty::ReVar(vid) = r.kind() {
225            if self.snapshot_vars.region_vars.0.contains(&vid) {
226                let idx = vid.index() - self.snapshot_vars.region_vars.0.start.index();
227                let origin = self.snapshot_vars.region_vars.1[idx];
228                self.infcx.next_region_var(origin)
229            } else {
230                r
231            }
232        } else {
233            r
234        }
235    }
236
237    fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
238        if let ty::ConstKind::Infer(infer_ct) = ct.kind() {
239            match infer_ct {
240                ty::InferConst::Var(vid) => {
241                    if self.snapshot_vars.const_vars.0.contains(&vid) {
242                        let idx = vid.index() - self.snapshot_vars.const_vars.0.start.index();
243                        let origin = self.snapshot_vars.const_vars.1[idx];
244                        self.infcx.next_const_var_with_origin(origin)
245                    } else {
246                        ct
247                    }
248                }
249                ty::InferConst::Fresh(_) => {
250                    unreachable!("unexpected fresh infcx var")
251                }
252            }
253        } else if ct.has_infer() {
254            ct.super_fold_with(self)
255        } else {
256            ct
257        }
258    }
259}