rustc_infer/infer/snapshot/
fudge.rs
1use std::ops::Range;
2
3use rustc_data_structures::{snapshot_vec as sv, unify as ut};
4use rustc_middle::ty::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
5use rustc_middle::ty::{self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid};
6use rustc_type_ir::visit::TypeVisitableExt;
7use tracing::instrument;
8use ut::UnifyKey;
9
10use super::VariableLengths;
11use crate::infer::type_variable::TypeVariableOrigin;
12use crate::infer::unify_key::{ConstVariableValue, ConstVidKey};
13use crate::infer::{ConstVariableOrigin, InferCtxt, RegionVariableOrigin, UnificationTable};
14
15fn vars_since_snapshot<'tcx, T>(
16 table: &UnificationTable<'_, 'tcx, T>,
17 snapshot_var_len: usize,
18) -> Range<T>
19where
20 T: UnifyKey,
21 super::UndoLog<'tcx>: From<sv::UndoLog<ut::Delegate<T>>>,
22{
23 T::from_index(snapshot_var_len as u32)..T::from_index(table.len() as u32)
24}
25
26fn const_vars_since_snapshot<'tcx>(
27 table: &mut UnificationTable<'_, 'tcx, ConstVidKey<'tcx>>,
28 snapshot_var_len: usize,
29) -> (Range<ConstVid>, Vec<ConstVariableOrigin>) {
30 let range = vars_since_snapshot(table, snapshot_var_len);
31
32 (
33 range.start.vid..range.end.vid,
34 (range.start.index()..range.end.index())
35 .map(|index| match table.probe_value(ConstVid::from_u32(index)) {
36 ConstVariableValue::Known { value: _ } => {
37 ConstVariableOrigin { param_def_id: None, span: rustc_span::DUMMY_SP }
38 }
39 ConstVariableValue::Unknown { origin, universe: _ } => origin,
40 })
41 .collect(),
42 )
43}
44
45impl<'tcx> InferCtxt<'tcx> {
46 #[instrument(skip(self, f), level = "debug")]
86 pub fn fudge_inference_if_ok<T, E, F>(&self, f: F) -> Result<T, E>
87 where
88 F: FnOnce() -> Result<T, E>,
89 T: TypeFoldable<TyCtxt<'tcx>>,
90 {
91 let variable_lengths = self.variable_lengths();
92 let (snapshot_vars, value) = self.probe(|_| {
93 let value = f()?;
94 let snapshot_vars = SnapshotVarData::new(self, variable_lengths);
100 Ok((snapshot_vars, self.resolve_vars_if_possible(value)))
101 })?;
102
103 Ok(self.fudge_inference(snapshot_vars, value))
108 }
109
110 fn fudge_inference<T: TypeFoldable<TyCtxt<'tcx>>>(
111 &self,
112 snapshot_vars: SnapshotVarData,
113 value: T,
114 ) -> T {
115 if snapshot_vars.is_empty() {
118 value
119 } else {
120 value.fold_with(&mut InferenceFudger { infcx: self, snapshot_vars })
121 }
122 }
123}
124
125struct SnapshotVarData {
126 region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
127 type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
128 int_vars: Range<IntVid>,
129 float_vars: Range<FloatVid>,
130 const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
131}
132
133impl SnapshotVarData {
134 fn new(infcx: &InferCtxt<'_>, vars_pre_snapshot: VariableLengths) -> SnapshotVarData {
135 let mut inner = infcx.inner.borrow_mut();
136 let region_vars = inner
137 .unwrap_region_constraints()
138 .vars_since_snapshot(vars_pre_snapshot.region_constraints_len);
139 let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
140 let int_vars =
141 vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
142 let float_vars =
143 vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);
144
145 let const_vars = const_vars_since_snapshot(
146 &mut inner.const_unification_table(),
147 vars_pre_snapshot.const_var_len,
148 );
149 SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars }
150 }
151
152 fn is_empty(&self) -> bool {
153 let SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars } = self;
154 region_vars.0.is_empty()
155 && type_vars.0.is_empty()
156 && int_vars.is_empty()
157 && float_vars.is_empty()
158 && const_vars.0.is_empty()
159 }
160}
161
162struct InferenceFudger<'a, 'tcx> {
163 infcx: &'a InferCtxt<'tcx>,
164 snapshot_vars: SnapshotVarData,
165}
166
167impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
168 fn cx(&self) -> TyCtxt<'tcx> {
169 self.infcx.tcx
170 }
171
172 fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
173 if let &ty::Infer(infer_ty) = ty.kind() {
174 match infer_ty {
175 ty::TyVar(vid) => {
176 if self.snapshot_vars.type_vars.0.contains(&vid) {
177 let idx = vid.as_usize() - self.snapshot_vars.type_vars.0.start.as_usize();
180 let origin = self.snapshot_vars.type_vars.1[idx];
181 self.infcx.next_ty_var_with_origin(origin)
182 } else {
183 debug_assert!(
189 self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
190 );
191 ty
192 }
193 }
194 ty::IntVar(vid) => {
195 if self.snapshot_vars.int_vars.contains(&vid) {
196 self.infcx.next_int_var()
197 } else {
198 ty
199 }
200 }
201 ty::FloatVar(vid) => {
202 if self.snapshot_vars.float_vars.contains(&vid) {
203 self.infcx.next_float_var()
204 } else {
205 ty
206 }
207 }
208 ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => {
209 unreachable!("unexpected fresh infcx var")
210 }
211 }
212 } else if ty.has_infer() {
213 ty.super_fold_with(self)
214 } else {
215 ty
216 }
217 }
218
219 fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
220 if let ty::ReVar(vid) = r.kind() {
221 if self.snapshot_vars.region_vars.0.contains(&vid) {
222 let idx = vid.index() - self.snapshot_vars.region_vars.0.start.index();
223 let origin = self.snapshot_vars.region_vars.1[idx];
224 self.infcx.next_region_var(origin)
225 } else {
226 r
227 }
228 } else {
229 r
230 }
231 }
232
233 fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
234 if let ty::ConstKind::Infer(infer_ct) = ct.kind() {
235 match infer_ct {
236 ty::InferConst::Var(vid) => {
237 if self.snapshot_vars.const_vars.0.contains(&vid) {
238 let idx = vid.index() - self.snapshot_vars.const_vars.0.start.index();
239 let origin = self.snapshot_vars.const_vars.1[idx];
240 self.infcx.next_const_var_with_origin(origin)
241 } else {
242 ct
243 }
244 }
245 ty::InferConst::Fresh(_) => {
246 unreachable!("unexpected fresh infcx var")
247 }
248 }
249 } else if ct.has_infer() {
250 ct.super_fold_with(self)
251 } else {
252 ct
253 }
254 }
255}