1use std::mem;
49use std::sync::Arc;
50
51use rustc_index::{Idx, IndexVec};
52use thin_vec::ThinVec;
53use tracing::{debug, instrument};
54
55use crate::inherent::*;
56use crate::visit::{TypeVisitable, TypeVisitableExt as _};
57use crate::{self as ty, Interner, TypeFlags};
58
59#[cfg(feature = "nightly")]
60type Never = !;
61
62#[cfg(not(feature = "nightly"))]
63type Never = std::convert::Infallible;
64
65pub trait TypeFoldable<I: Interner>: TypeVisitable<I> + Clone {
77 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error>;
88
89 fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
93 match self.try_fold_with(folder) {
94 Ok(t) => t,
95 }
96 }
97}
98
99pub trait TypeSuperFoldable<I: Interner>: TypeFoldable<I> {
101 fn try_super_fold_with<F: FallibleTypeFolder<I>>(
108 self,
109 folder: &mut F,
110 ) -> Result<Self, F::Error>;
111
112 fn super_fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
116 match self.try_super_fold_with(folder) {
117 Ok(t) => t,
118 }
119 }
120}
121
122pub trait TypeFolder<I: Interner>: FallibleTypeFolder<I, Error = Never> {
132 fn cx(&self) -> I;
133
134 fn fold_binder<T>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T>
135 where
136 T: TypeFoldable<I>,
137 {
138 t.super_fold_with(self)
139 }
140
141 fn fold_ty(&mut self, t: I::Ty) -> I::Ty {
142 t.super_fold_with(self)
143 }
144
145 fn fold_region(&mut self, r: I::Region) -> I::Region {
148 r
149 }
150
151 fn fold_const(&mut self, c: I::Const) -> I::Const {
152 c.super_fold_with(self)
153 }
154
155 fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
156 p.super_fold_with(self)
157 }
158}
159
160pub trait FallibleTypeFolder<I: Interner>: Sized {
168 type Error;
169
170 fn cx(&self) -> I;
171
172 fn try_fold_binder<T>(&mut self, t: ty::Binder<I, T>) -> Result<ty::Binder<I, T>, Self::Error>
173 where
174 T: TypeFoldable<I>,
175 {
176 t.try_super_fold_with(self)
177 }
178
179 fn try_fold_ty(&mut self, t: I::Ty) -> Result<I::Ty, Self::Error> {
180 t.try_super_fold_with(self)
181 }
182
183 fn try_fold_region(&mut self, r: I::Region) -> Result<I::Region, Self::Error> {
186 Ok(r)
187 }
188
189 fn try_fold_const(&mut self, c: I::Const) -> Result<I::Const, Self::Error> {
190 c.try_super_fold_with(self)
191 }
192
193 fn try_fold_predicate(&mut self, p: I::Predicate) -> Result<I::Predicate, Self::Error> {
194 p.try_super_fold_with(self)
195 }
196}
197
198impl<I: Interner, F> FallibleTypeFolder<I> for F
201where
202 F: TypeFolder<I>,
203{
204 type Error = Never;
205
206 fn cx(&self) -> I {
207 TypeFolder::cx(self)
208 }
209
210 fn try_fold_binder<T>(&mut self, t: ty::Binder<I, T>) -> Result<ty::Binder<I, T>, Never>
211 where
212 T: TypeFoldable<I>,
213 {
214 Ok(self.fold_binder(t))
215 }
216
217 fn try_fold_ty(&mut self, t: I::Ty) -> Result<I::Ty, Never> {
218 Ok(self.fold_ty(t))
219 }
220
221 fn try_fold_region(&mut self, r: I::Region) -> Result<I::Region, Never> {
222 Ok(self.fold_region(r))
223 }
224
225 fn try_fold_const(&mut self, c: I::Const) -> Result<I::Const, Never> {
226 Ok(self.fold_const(c))
227 }
228
229 fn try_fold_predicate(&mut self, p: I::Predicate) -> Result<I::Predicate, Never> {
230 Ok(self.fold_predicate(p))
231 }
232}
233
234impl<I: Interner, T: TypeFoldable<I>, U: TypeFoldable<I>> TypeFoldable<I> for (T, U) {
238 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<(T, U), F::Error> {
239 Ok((self.0.try_fold_with(folder)?, self.1.try_fold_with(folder)?))
240 }
241}
242
243impl<I: Interner, A: TypeFoldable<I>, B: TypeFoldable<I>, C: TypeFoldable<I>> TypeFoldable<I>
244 for (A, B, C)
245{
246 fn try_fold_with<F: FallibleTypeFolder<I>>(
247 self,
248 folder: &mut F,
249 ) -> Result<(A, B, C), F::Error> {
250 Ok((
251 self.0.try_fold_with(folder)?,
252 self.1.try_fold_with(folder)?,
253 self.2.try_fold_with(folder)?,
254 ))
255 }
256}
257
258impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Option<T> {
259 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
260 Ok(match self {
261 Some(v) => Some(v.try_fold_with(folder)?),
262 None => None,
263 })
264 }
265}
266
267impl<I: Interner, T: TypeFoldable<I>, E: TypeFoldable<I>> TypeFoldable<I> for Result<T, E> {
268 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
269 Ok(match self {
270 Ok(v) => Ok(v.try_fold_with(folder)?),
271 Err(e) => Err(e.try_fold_with(folder)?),
272 })
273 }
274}
275
276impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Arc<T> {
277 fn try_fold_with<F: FallibleTypeFolder<I>>(mut self, folder: &mut F) -> Result<Self, F::Error> {
278 unsafe {
282 Arc::make_mut(&mut self);
288
289 let ptr = Arc::into_raw(self).cast::<mem::ManuallyDrop<T>>();
292 let mut unique = Arc::from_raw(ptr);
293
294 let slot = Arc::get_mut(&mut unique).unwrap_unchecked();
298
299 let owned = mem::ManuallyDrop::take(slot);
304 let folded = owned.try_fold_with(folder)?;
305 *slot = mem::ManuallyDrop::new(folded);
306
307 Ok(Arc::from_raw(Arc::into_raw(unique).cast()))
309 }
310 }
311}
312
313impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Box<T> {
314 fn try_fold_with<F: FallibleTypeFolder<I>>(mut self, folder: &mut F) -> Result<Self, F::Error> {
315 *self = (*self).try_fold_with(folder)?;
316 Ok(self)
317 }
318}
319
320impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Vec<T> {
321 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
322 self.into_iter().map(|t| t.try_fold_with(folder)).collect()
323 }
324}
325
326impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for ThinVec<T> {
327 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
328 self.into_iter().map(|t| t.try_fold_with(folder)).collect()
329 }
330}
331
332impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Box<[T]> {
333 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
334 Vec::from(self).try_fold_with(folder).map(Vec::into_boxed_slice)
335 }
336}
337
338impl<I: Interner, T: TypeFoldable<I>, Ix: Idx> TypeFoldable<I> for IndexVec<Ix, T> {
339 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
340 self.raw.try_fold_with(folder).map(IndexVec::from_raw)
341 }
342}
343
344struct Shifter<I: Interner> {
354 cx: I,
355 current_index: ty::DebruijnIndex,
356 amount: u32,
357}
358
359impl<I: Interner> Shifter<I> {
360 fn new(cx: I, amount: u32) -> Self {
361 Shifter { cx, current_index: ty::INNERMOST, amount }
362 }
363}
364
365impl<I: Interner> TypeFolder<I> for Shifter<I> {
366 fn cx(&self) -> I {
367 self.cx
368 }
369
370 fn fold_binder<T: TypeFoldable<I>>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T> {
371 self.current_index.shift_in(1);
372 let t = t.super_fold_with(self);
373 self.current_index.shift_out(1);
374 t
375 }
376
377 fn fold_region(&mut self, r: I::Region) -> I::Region {
378 match r.kind() {
379 ty::ReBound(debruijn, br) if debruijn >= self.current_index => {
380 let debruijn = debruijn.shifted_in(self.amount);
381 Region::new_bound(self.cx, debruijn, br)
382 }
383 _ => r,
384 }
385 }
386
387 fn fold_ty(&mut self, ty: I::Ty) -> I::Ty {
388 match ty.kind() {
389 ty::Bound(debruijn, bound_ty) if debruijn >= self.current_index => {
390 let debruijn = debruijn.shifted_in(self.amount);
391 Ty::new_bound(self.cx, debruijn, bound_ty)
392 }
393
394 _ if ty.has_vars_bound_at_or_above(self.current_index) => ty.super_fold_with(self),
395 _ => ty,
396 }
397 }
398
399 fn fold_const(&mut self, ct: I::Const) -> I::Const {
400 match ct.kind() {
401 ty::ConstKind::Bound(debruijn, bound_ct) if debruijn >= self.current_index => {
402 let debruijn = debruijn.shifted_in(self.amount);
403 Const::new_bound(self.cx, debruijn, bound_ct)
404 }
405 _ => ct.super_fold_with(self),
406 }
407 }
408
409 fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
410 if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
411 }
412}
413
414pub fn shift_region<I: Interner>(cx: I, region: I::Region, amount: u32) -> I::Region {
415 match region.kind() {
416 ty::ReBound(debruijn, br) if amount > 0 => {
417 Region::new_bound(cx, debruijn.shifted_in(amount), br)
418 }
419 _ => region,
420 }
421}
422
423#[instrument(level = "trace", skip(cx), ret)]
424pub fn shift_vars<I: Interner, T>(cx: I, value: T, amount: u32) -> T
425where
426 T: TypeFoldable<I>,
427{
428 if amount == 0 || !value.has_escaping_bound_vars() {
429 value
430 } else {
431 value.fold_with(&mut Shifter::new(cx, amount))
432 }
433}
434
435pub fn fold_regions<I: Interner, T>(
439 cx: I,
440 value: T,
441 f: impl FnMut(I::Region, ty::DebruijnIndex) -> I::Region,
442) -> T
443where
444 T: TypeFoldable<I>,
445{
446 value.fold_with(&mut RegionFolder::new(cx, f))
447}
448
449pub struct RegionFolder<I, F> {
457 cx: I,
458
459 current_index: ty::DebruijnIndex,
463
464 fold_region_fn: F,
468}
469
470impl<I, F> RegionFolder<I, F> {
471 #[inline]
472 pub fn new(cx: I, fold_region_fn: F) -> RegionFolder<I, F> {
473 RegionFolder { cx, current_index: ty::INNERMOST, fold_region_fn }
474 }
475}
476
477impl<I, F> TypeFolder<I> for RegionFolder<I, F>
478where
479 I: Interner,
480 F: FnMut(I::Region, ty::DebruijnIndex) -> I::Region,
481{
482 fn cx(&self) -> I {
483 self.cx
484 }
485
486 fn fold_binder<T: TypeFoldable<I>>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T> {
487 self.current_index.shift_in(1);
488 let t = t.super_fold_with(self);
489 self.current_index.shift_out(1);
490 t
491 }
492
493 #[instrument(skip(self), level = "debug", ret)]
494 fn fold_region(&mut self, r: I::Region) -> I::Region {
495 match r.kind() {
496 ty::ReBound(debruijn, _) if debruijn < self.current_index => {
497 debug!(?self.current_index, "skipped bound region");
498 r
499 }
500 _ => {
501 debug!(?self.current_index, "folding free region");
502 (self.fold_region_fn)(r, self.current_index)
503 }
504 }
505 }
506
507 fn fold_ty(&mut self, t: I::Ty) -> I::Ty {
508 if t.has_type_flags(
509 TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
510 ) {
511 t.super_fold_with(self)
512 } else {
513 t
514 }
515 }
516
517 fn fold_const(&mut self, ct: I::Const) -> I::Const {
518 if ct.has_type_flags(
519 TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
520 ) {
521 ct.super_fold_with(self)
522 } else {
523 ct
524 }
525 }
526
527 fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
528 if p.has_type_flags(
529 TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
530 ) {
531 p.super_fold_with(self)
532 } else {
533 p
534 }
535 }
536}