rustc_infer/infer/snapshot/
fudge.rs1use std::fmt::Debug;
2use std::ops::Range;
3
4use rustc_data_structures::{snapshot_vec as sv, unify as ut};
5use rustc_middle::ty::{
6 self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid, TypeFoldable, TypeFolder,
7 TypeSuperFoldable, TypeVisitableExt,
8};
9use tracing::instrument;
10use ut::UnifyKey;
11
12use super::VariableLengths;
13use crate::infer::type_variable::TypeVariableOrigin;
14use crate::infer::unify_key::{ConstVariableValue, ConstVidKey};
15use crate::infer::{ConstVariableOrigin, InferCtxt, RegionVariableOrigin, UnificationTable};
16
17fn vars_since_snapshot<'tcx, T>(
18 table: &UnificationTable<'_, 'tcx, T>,
19 snapshot_var_len: usize,
20) -> Range<T>
21where
22 T: UnifyKey,
23 super::UndoLog<'tcx>: From<sv::UndoLog<ut::Delegate<T>>>,
24{
25 T::from_index(snapshot_var_len as u32)..T::from_index(table.len() as u32)
26}
27
28fn const_vars_since_snapshot<'tcx>(
29 table: &mut UnificationTable<'_, 'tcx, ConstVidKey<'tcx>>,
30 snapshot_var_len: usize,
31) -> (Range<ConstVid>, Vec<ConstVariableOrigin>) {
32 let range = vars_since_snapshot(table, snapshot_var_len);
33 let range = range.start.vid..range.end.vid;
34
35 (
36 range.clone(),
37 range
38 .map(|index| match table.probe_value(index) {
39 ConstVariableValue::Known { value: _ } => {
40 ConstVariableOrigin { param_def_id: None, span: rustc_span::DUMMY_SP }
41 }
42 ConstVariableValue::Unknown { origin, universe: _ } => origin,
43 })
44 .collect(),
45 )
46}
47
48impl<'tcx> InferCtxt<'tcx> {
49 #[instrument(skip(self, f), level = "debug", ret)]
89 pub fn fudge_inference_if_ok<T, E, F>(&self, f: F) -> Result<T, E>
90 where
91 F: FnOnce() -> Result<T, E>,
92 T: TypeFoldable<TyCtxt<'tcx>>,
93 E: Debug,
94 {
95 let variable_lengths = self.variable_lengths();
96 let (snapshot_vars, value) = self.probe(|_| {
97 let value = f()?;
98 let snapshot_vars = SnapshotVarData::new(self, variable_lengths);
104 Ok((snapshot_vars, self.resolve_vars_if_possible(value)))
105 })?;
106
107 Ok(self.fudge_inference(snapshot_vars, value))
112 }
113
114 fn fudge_inference<T: TypeFoldable<TyCtxt<'tcx>>>(
115 &self,
116 snapshot_vars: SnapshotVarData,
117 value: T,
118 ) -> T {
119 if snapshot_vars.is_empty() {
122 value
123 } else {
124 value.fold_with(&mut InferenceFudger { infcx: self, snapshot_vars })
125 }
126 }
127}
128
129struct SnapshotVarData {
130 region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
131 type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
132 int_vars: Range<IntVid>,
133 float_vars: Range<FloatVid>,
134 const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
135}
136
137impl SnapshotVarData {
138 fn new(infcx: &InferCtxt<'_>, vars_pre_snapshot: VariableLengths) -> SnapshotVarData {
139 let mut inner = infcx.inner.borrow_mut();
140 let region_vars = inner
141 .unwrap_region_constraints()
142 .vars_since_snapshot(vars_pre_snapshot.region_constraints_len);
143 let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
144 let int_vars =
145 vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
146 let float_vars =
147 vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);
148
149 let const_vars = const_vars_since_snapshot(
150 &mut inner.const_unification_table(),
151 vars_pre_snapshot.const_var_len,
152 );
153 SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars }
154 }
155
156 fn is_empty(&self) -> bool {
157 let SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars } = self;
158 region_vars.0.is_empty()
159 && type_vars.0.is_empty()
160 && int_vars.is_empty()
161 && float_vars.is_empty()
162 && const_vars.0.is_empty()
163 }
164}
165
166struct InferenceFudger<'a, 'tcx> {
167 infcx: &'a InferCtxt<'tcx>,
168 snapshot_vars: SnapshotVarData,
169}
170
171impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
172 fn cx(&self) -> TyCtxt<'tcx> {
173 self.infcx.tcx
174 }
175
176 fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
177 if let &ty::Infer(infer_ty) = ty.kind() {
178 match infer_ty {
179 ty::TyVar(vid) => {
180 if self.snapshot_vars.type_vars.0.contains(&vid) {
181 let idx = vid.as_usize() - self.snapshot_vars.type_vars.0.start.as_usize();
184 let origin = self.snapshot_vars.type_vars.1[idx];
185 self.infcx.next_ty_var_with_origin(origin)
186 } else {
187 debug_assert!(
193 self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
194 );
195 ty
196 }
197 }
198 ty::IntVar(vid) => {
199 if self.snapshot_vars.int_vars.contains(&vid) {
200 self.infcx.next_int_var()
201 } else {
202 ty
203 }
204 }
205 ty::FloatVar(vid) => {
206 if self.snapshot_vars.float_vars.contains(&vid) {
207 self.infcx.next_float_var()
208 } else {
209 ty
210 }
211 }
212 ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => {
213 unreachable!("unexpected fresh infcx var")
214 }
215 }
216 } else if ty.has_infer() {
217 ty.super_fold_with(self)
218 } else {
219 ty
220 }
221 }
222
223 fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
224 if let ty::ReVar(vid) = r.kind() {
225 if self.snapshot_vars.region_vars.0.contains(&vid) {
226 let idx = vid.index() - self.snapshot_vars.region_vars.0.start.index();
227 let origin = self.snapshot_vars.region_vars.1[idx];
228 self.infcx.next_region_var(origin)
229 } else {
230 r
231 }
232 } else {
233 r
234 }
235 }
236
237 fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
238 if let ty::ConstKind::Infer(infer_ct) = ct.kind() {
239 match infer_ct {
240 ty::InferConst::Var(vid) => {
241 if self.snapshot_vars.const_vars.0.contains(&vid) {
242 let idx = vid.index() - self.snapshot_vars.const_vars.0.start.index();
243 let origin = self.snapshot_vars.const_vars.1[idx];
244 self.infcx.next_const_var_with_origin(origin)
245 } else {
246 ct
247 }
248 }
249 ty::InferConst::Fresh(_) => {
250 unreachable!("unexpected fresh infcx var")
251 }
252 }
253 } else if ct.has_infer() {
254 ct.super_fold_with(self)
255 } else {
256 ct
257 }
258 }
259}