rustc_infer/infer/snapshot/
fudge.rs
1use std::ops::Range;
2
3use rustc_data_structures::{snapshot_vec as sv, unify as ut};
4use rustc_middle::ty::{
5 self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid, TypeFoldable, TypeFolder,
6 TypeSuperFoldable,
7};
8use rustc_type_ir::TypeVisitableExt;
9use tracing::instrument;
10use ut::UnifyKey;
11
12use super::VariableLengths;
13use crate::infer::type_variable::TypeVariableOrigin;
14use crate::infer::unify_key::{ConstVariableValue, ConstVidKey};
15use crate::infer::{ConstVariableOrigin, InferCtxt, RegionVariableOrigin, UnificationTable};
16
17fn vars_since_snapshot<'tcx, T>(
18 table: &UnificationTable<'_, 'tcx, T>,
19 snapshot_var_len: usize,
20) -> Range<T>
21where
22 T: UnifyKey,
23 super::UndoLog<'tcx>: From<sv::UndoLog<ut::Delegate<T>>>,
24{
25 T::from_index(snapshot_var_len as u32)..T::from_index(table.len() as u32)
26}
27
28fn const_vars_since_snapshot<'tcx>(
29 table: &mut UnificationTable<'_, 'tcx, ConstVidKey<'tcx>>,
30 snapshot_var_len: usize,
31) -> (Range<ConstVid>, Vec<ConstVariableOrigin>) {
32 let range = vars_since_snapshot(table, snapshot_var_len);
33
34 (
35 range.start.vid..range.end.vid,
36 (range.start.index()..range.end.index())
37 .map(|index| match table.probe_value(ConstVid::from_u32(index)) {
38 ConstVariableValue::Known { value: _ } => {
39 ConstVariableOrigin { param_def_id: None, span: rustc_span::DUMMY_SP }
40 }
41 ConstVariableValue::Unknown { origin, universe: _ } => origin,
42 })
43 .collect(),
44 )
45}
46
47impl<'tcx> InferCtxt<'tcx> {
48 #[instrument(skip(self, f), level = "debug")]
88 pub fn fudge_inference_if_ok<T, E, F>(&self, f: F) -> Result<T, E>
89 where
90 F: FnOnce() -> Result<T, E>,
91 T: TypeFoldable<TyCtxt<'tcx>>,
92 {
93 let variable_lengths = self.variable_lengths();
94 let (snapshot_vars, value) = self.probe(|_| {
95 let value = f()?;
96 let snapshot_vars = SnapshotVarData::new(self, variable_lengths);
102 Ok((snapshot_vars, self.resolve_vars_if_possible(value)))
103 })?;
104
105 Ok(self.fudge_inference(snapshot_vars, value))
110 }
111
112 fn fudge_inference<T: TypeFoldable<TyCtxt<'tcx>>>(
113 &self,
114 snapshot_vars: SnapshotVarData,
115 value: T,
116 ) -> T {
117 if snapshot_vars.is_empty() {
120 value
121 } else {
122 value.fold_with(&mut InferenceFudger { infcx: self, snapshot_vars })
123 }
124 }
125}
126
127struct SnapshotVarData {
128 region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
129 type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
130 int_vars: Range<IntVid>,
131 float_vars: Range<FloatVid>,
132 const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
133}
134
135impl SnapshotVarData {
136 fn new(infcx: &InferCtxt<'_>, vars_pre_snapshot: VariableLengths) -> SnapshotVarData {
137 let mut inner = infcx.inner.borrow_mut();
138 let region_vars = inner
139 .unwrap_region_constraints()
140 .vars_since_snapshot(vars_pre_snapshot.region_constraints_len);
141 let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
142 let int_vars =
143 vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
144 let float_vars =
145 vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);
146
147 let const_vars = const_vars_since_snapshot(
148 &mut inner.const_unification_table(),
149 vars_pre_snapshot.const_var_len,
150 );
151 SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars }
152 }
153
154 fn is_empty(&self) -> bool {
155 let SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars } = self;
156 region_vars.0.is_empty()
157 && type_vars.0.is_empty()
158 && int_vars.is_empty()
159 && float_vars.is_empty()
160 && const_vars.0.is_empty()
161 }
162}
163
164struct InferenceFudger<'a, 'tcx> {
165 infcx: &'a InferCtxt<'tcx>,
166 snapshot_vars: SnapshotVarData,
167}
168
169impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
170 fn cx(&self) -> TyCtxt<'tcx> {
171 self.infcx.tcx
172 }
173
174 fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
175 if let &ty::Infer(infer_ty) = ty.kind() {
176 match infer_ty {
177 ty::TyVar(vid) => {
178 if self.snapshot_vars.type_vars.0.contains(&vid) {
179 let idx = vid.as_usize() - self.snapshot_vars.type_vars.0.start.as_usize();
182 let origin = self.snapshot_vars.type_vars.1[idx];
183 self.infcx.next_ty_var_with_origin(origin)
184 } else {
185 debug_assert!(
191 self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
192 );
193 ty
194 }
195 }
196 ty::IntVar(vid) => {
197 if self.snapshot_vars.int_vars.contains(&vid) {
198 self.infcx.next_int_var()
199 } else {
200 ty
201 }
202 }
203 ty::FloatVar(vid) => {
204 if self.snapshot_vars.float_vars.contains(&vid) {
205 self.infcx.next_float_var()
206 } else {
207 ty
208 }
209 }
210 ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => {
211 unreachable!("unexpected fresh infcx var")
212 }
213 }
214 } else if ty.has_infer() {
215 ty.super_fold_with(self)
216 } else {
217 ty
218 }
219 }
220
221 fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
222 if let ty::ReVar(vid) = r.kind() {
223 if self.snapshot_vars.region_vars.0.contains(&vid) {
224 let idx = vid.index() - self.snapshot_vars.region_vars.0.start.index();
225 let origin = self.snapshot_vars.region_vars.1[idx];
226 self.infcx.next_region_var(origin)
227 } else {
228 r
229 }
230 } else {
231 r
232 }
233 }
234
235 fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
236 if let ty::ConstKind::Infer(infer_ct) = ct.kind() {
237 match infer_ct {
238 ty::InferConst::Var(vid) => {
239 if self.snapshot_vars.const_vars.0.contains(&vid) {
240 let idx = vid.index() - self.snapshot_vars.const_vars.0.start.index();
241 let origin = self.snapshot_vars.const_vars.1[idx];
242 self.infcx.next_const_var_with_origin(origin)
243 } else {
244 ct
245 }
246 }
247 ty::InferConst::Fresh(_) => {
248 unreachable!("unexpected fresh infcx var")
249 }
250 }
251 } else if ct.has_infer() {
252 ct.super_fold_with(self)
253 } else {
254 ct
255 }
256 }
257}