rustc_infer/infer/snapshot/
fudge.rs1use std::ops::Range;
2
3use rustc_data_structures::{snapshot_vec as sv, unify as ut};
4use rustc_middle::ty::{
5 self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid, TypeFoldable, TypeFolder,
6 TypeSuperFoldable,
7};
8use rustc_type_ir::TypeVisitableExt;
9use tracing::instrument;
10use ut::UnifyKey;
11
12use super::VariableLengths;
13use crate::infer::type_variable::TypeVariableOrigin;
14use crate::infer::unify_key::{ConstVariableValue, ConstVidKey};
15use crate::infer::{ConstVariableOrigin, InferCtxt, RegionVariableOrigin, UnificationTable};
16
17fn vars_since_snapshot<'tcx, T>(
18 table: &UnificationTable<'_, 'tcx, T>,
19 snapshot_var_len: usize,
20) -> Range<T>
21where
22 T: UnifyKey,
23 super::UndoLog<'tcx>: From<sv::UndoLog<ut::Delegate<T>>>,
24{
25 T::from_index(snapshot_var_len as u32)..T::from_index(table.len() as u32)
26}
27
28fn const_vars_since_snapshot<'tcx>(
29 table: &mut UnificationTable<'_, 'tcx, ConstVidKey<'tcx>>,
30 snapshot_var_len: usize,
31) -> (Range<ConstVid>, Vec<ConstVariableOrigin>) {
32 let range = vars_since_snapshot(table, snapshot_var_len);
33 let range = range.start.vid..range.end.vid;
34
35 (
36 range.clone(),
37 range
38 .map(|index| match table.probe_value(index) {
39 ConstVariableValue::Known { value: _ } => {
40 ConstVariableOrigin { param_def_id: None, span: rustc_span::DUMMY_SP }
41 }
42 ConstVariableValue::Unknown { origin, universe: _ } => origin,
43 })
44 .collect(),
45 )
46}
47
48impl<'tcx> InferCtxt<'tcx> {
49 #[instrument(skip(self, f), level = "debug")]
89 pub fn fudge_inference_if_ok<T, E, F>(&self, f: F) -> Result<T, E>
90 where
91 F: FnOnce() -> Result<T, E>,
92 T: TypeFoldable<TyCtxt<'tcx>>,
93 {
94 let variable_lengths = self.variable_lengths();
95 let (snapshot_vars, value) = self.probe(|_| {
96 let value = f()?;
97 let snapshot_vars = SnapshotVarData::new(self, variable_lengths);
103 Ok((snapshot_vars, self.resolve_vars_if_possible(value)))
104 })?;
105
106 Ok(self.fudge_inference(snapshot_vars, value))
111 }
112
113 fn fudge_inference<T: TypeFoldable<TyCtxt<'tcx>>>(
114 &self,
115 snapshot_vars: SnapshotVarData,
116 value: T,
117 ) -> T {
118 if snapshot_vars.is_empty() {
121 value
122 } else {
123 value.fold_with(&mut InferenceFudger { infcx: self, snapshot_vars })
124 }
125 }
126}
127
128struct SnapshotVarData {
129 region_vars: (Range<RegionVid>, Vec<RegionVariableOrigin>),
130 type_vars: (Range<TyVid>, Vec<TypeVariableOrigin>),
131 int_vars: Range<IntVid>,
132 float_vars: Range<FloatVid>,
133 const_vars: (Range<ConstVid>, Vec<ConstVariableOrigin>),
134}
135
136impl SnapshotVarData {
137 fn new(infcx: &InferCtxt<'_>, vars_pre_snapshot: VariableLengths) -> SnapshotVarData {
138 let mut inner = infcx.inner.borrow_mut();
139 let region_vars = inner
140 .unwrap_region_constraints()
141 .vars_since_snapshot(vars_pre_snapshot.region_constraints_len);
142 let type_vars = inner.type_variables().vars_since_snapshot(vars_pre_snapshot.type_var_len);
143 let int_vars =
144 vars_since_snapshot(&inner.int_unification_table(), vars_pre_snapshot.int_var_len);
145 let float_vars =
146 vars_since_snapshot(&inner.float_unification_table(), vars_pre_snapshot.float_var_len);
147
148 let const_vars = const_vars_since_snapshot(
149 &mut inner.const_unification_table(),
150 vars_pre_snapshot.const_var_len,
151 );
152 SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars }
153 }
154
155 fn is_empty(&self) -> bool {
156 let SnapshotVarData { region_vars, type_vars, int_vars, float_vars, const_vars } = self;
157 region_vars.0.is_empty()
158 && type_vars.0.is_empty()
159 && int_vars.is_empty()
160 && float_vars.is_empty()
161 && const_vars.0.is_empty()
162 }
163}
164
165struct InferenceFudger<'a, 'tcx> {
166 infcx: &'a InferCtxt<'tcx>,
167 snapshot_vars: SnapshotVarData,
168}
169
170impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for InferenceFudger<'a, 'tcx> {
171 fn cx(&self) -> TyCtxt<'tcx> {
172 self.infcx.tcx
173 }
174
175 fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
176 if let &ty::Infer(infer_ty) = ty.kind() {
177 match infer_ty {
178 ty::TyVar(vid) => {
179 if self.snapshot_vars.type_vars.0.contains(&vid) {
180 let idx = vid.as_usize() - self.snapshot_vars.type_vars.0.start.as_usize();
183 let origin = self.snapshot_vars.type_vars.1[idx];
184 self.infcx.next_ty_var_with_origin(origin)
185 } else {
186 debug_assert!(
192 self.infcx.inner.borrow_mut().type_variables().probe(vid).is_unknown()
193 );
194 ty
195 }
196 }
197 ty::IntVar(vid) => {
198 if self.snapshot_vars.int_vars.contains(&vid) {
199 self.infcx.next_int_var()
200 } else {
201 ty
202 }
203 }
204 ty::FloatVar(vid) => {
205 if self.snapshot_vars.float_vars.contains(&vid) {
206 self.infcx.next_float_var()
207 } else {
208 ty
209 }
210 }
211 ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) => {
212 unreachable!("unexpected fresh infcx var")
213 }
214 }
215 } else if ty.has_infer() {
216 ty.super_fold_with(self)
217 } else {
218 ty
219 }
220 }
221
222 fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
223 if let ty::ReVar(vid) = r.kind() {
224 if self.snapshot_vars.region_vars.0.contains(&vid) {
225 let idx = vid.index() - self.snapshot_vars.region_vars.0.start.index();
226 let origin = self.snapshot_vars.region_vars.1[idx];
227 self.infcx.next_region_var(origin)
228 } else {
229 r
230 }
231 } else {
232 r
233 }
234 }
235
236 fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
237 if let ty::ConstKind::Infer(infer_ct) = ct.kind() {
238 match infer_ct {
239 ty::InferConst::Var(vid) => {
240 if self.snapshot_vars.const_vars.0.contains(&vid) {
241 let idx = vid.index() - self.snapshot_vars.const_vars.0.start.index();
242 let origin = self.snapshot_vars.const_vars.1[idx];
243 self.infcx.next_const_var_with_origin(origin)
244 } else {
245 ct
246 }
247 }
248 ty::InferConst::Fresh(_) => {
249 unreachable!("unexpected fresh infcx var")
250 }
251 }
252 } else if ct.has_infer() {
253 ct.super_fold_with(self)
254 } else {
255 ct
256 }
257 }
258}