1use std::convert::Infallible;
49use std::mem;
50use std::sync::Arc;
51
52use rustc_index::{Idx, IndexVec};
53use thin_vec::ThinVec;
54use tracing::{debug, instrument};
55
56use crate::inherent::*;
57use crate::visit::{TypeVisitable, TypeVisitableExt as _};
58use crate::{self as ty, BoundVarIndexKind, Interner, TypeFlags};
59
60pub trait TypeFoldable<I: Interner>: TypeVisitable<I> + Clone {
72    fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error>;
83
84    fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self;
98}
99
100pub trait TypeSuperFoldable<I: Interner>: TypeFoldable<I> {
102    fn try_super_fold_with<F: FallibleTypeFolder<I>>(
109        self,
110        folder: &mut F,
111    ) -> Result<Self, F::Error>;
112
113    fn super_fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self;
117}
118
119pub trait TypeFolder<I: Interner>: Sized {
129    fn cx(&self) -> I;
130
131    fn fold_binder<T>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T>
132    where
133        T: TypeFoldable<I>,
134    {
135        t.super_fold_with(self)
136    }
137
138    fn fold_ty(&mut self, t: I::Ty) -> I::Ty {
139        t.super_fold_with(self)
140    }
141
142    fn fold_region(&mut self, r: I::Region) -> I::Region {
145        r
146    }
147
148    fn fold_const(&mut self, c: I::Const) -> I::Const {
149        c.super_fold_with(self)
150    }
151
152    fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
153        p.super_fold_with(self)
154    }
155
156    fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
157        c.super_fold_with(self)
158    }
159}
160
161pub trait FallibleTypeFolder<I: Interner>: Sized {
169    type Error;
170
171    fn cx(&self) -> I;
172
173    fn try_fold_binder<T>(&mut self, t: ty::Binder<I, T>) -> Result<ty::Binder<I, T>, Self::Error>
174    where
175        T: TypeFoldable<I>,
176    {
177        t.try_super_fold_with(self)
178    }
179
180    fn try_fold_ty(&mut self, t: I::Ty) -> Result<I::Ty, Self::Error> {
181        t.try_super_fold_with(self)
182    }
183
184    fn try_fold_region(&mut self, r: I::Region) -> Result<I::Region, Self::Error> {
187        Ok(r)
188    }
189
190    fn try_fold_const(&mut self, c: I::Const) -> Result<I::Const, Self::Error> {
191        c.try_super_fold_with(self)
192    }
193
194    fn try_fold_predicate(&mut self, p: I::Predicate) -> Result<I::Predicate, Self::Error> {
195        p.try_super_fold_with(self)
196    }
197
198    fn try_fold_clauses(&mut self, c: I::Clauses) -> Result<I::Clauses, Self::Error> {
199        c.try_super_fold_with(self)
200    }
201}
202
203impl<I: Interner, T: TypeFoldable<I>, U: TypeFoldable<I>> TypeFoldable<I> for (T, U) {
207    fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<(T, U), F::Error> {
208        Ok((self.0.try_fold_with(folder)?, self.1.try_fold_with(folder)?))
209    }
210
211    fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
212        (self.0.fold_with(folder), self.1.fold_with(folder))
213    }
214}
215
216impl<I: Interner, A: TypeFoldable<I>, B: TypeFoldable<I>, C: TypeFoldable<I>> TypeFoldable<I>
217    for (A, B, C)
218{
219    fn try_fold_with<F: FallibleTypeFolder<I>>(
220        self,
221        folder: &mut F,
222    ) -> Result<(A, B, C), F::Error> {
223        Ok((
224            self.0.try_fold_with(folder)?,
225            self.1.try_fold_with(folder)?,
226            self.2.try_fold_with(folder)?,
227        ))
228    }
229
230    fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
231        (self.0.fold_with(folder), self.1.fold_with(folder), self.2.fold_with(folder))
232    }
233}
234
235impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Option<T> {
236    fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
237        Ok(match self {
238            Some(v) => Some(v.try_fold_with(folder)?),
239            None => None,
240        })
241    }
242
243    fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
244        Some(self?.fold_with(folder))
245    }
246}
247
248impl<I: Interner, T: TypeFoldable<I>, E: TypeFoldable<I>> TypeFoldable<I> for Result<T, E> {
249    fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
250        Ok(match self {
251            Ok(v) => Ok(v.try_fold_with(folder)?),
252            Err(e) => Err(e.try_fold_with(folder)?),
253        })
254    }
255
256    fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
257        match self {
258            Ok(v) => Ok(v.fold_with(folder)),
259            Err(e) => Err(e.fold_with(folder)),
260        }
261    }
262}
263
264fn fold_arc<T: Clone, E>(
265    mut arc: Arc<T>,
266    fold: impl FnOnce(T) -> Result<T, E>,
267) -> Result<Arc<T>, E> {
268    unsafe {
272        Arc::make_mut(&mut arc);
278
279        let ptr = Arc::into_raw(arc).cast::<mem::ManuallyDrop<T>>();
282        let mut unique = Arc::from_raw(ptr);
283
284        let slot = Arc::get_mut(&mut unique).unwrap_unchecked();
288
289        let owned = mem::ManuallyDrop::take(slot);
294        let folded = fold(owned)?;
295        *slot = mem::ManuallyDrop::new(folded);
296
297        Ok(Arc::from_raw(Arc::into_raw(unique).cast()))
299    }
300}
301
302impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Arc<T> {
303    fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
304        fold_arc(self, |t| t.try_fold_with(folder))
305    }
306
307    fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
308        match fold_arc::<T, Infallible>(self, |t| Ok(t.fold_with(folder))) {
309            Ok(t) => t,
310        }
311    }
312}
313
314impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Box<T> {
315    fn try_fold_with<F: FallibleTypeFolder<I>>(mut self, folder: &mut F) -> Result<Self, F::Error> {
316        *self = (*self).try_fold_with(folder)?;
317        Ok(self)
318    }
319
320    fn fold_with<F: TypeFolder<I>>(mut self, folder: &mut F) -> Self {
321        *self = (*self).fold_with(folder);
322        self
323    }
324}
325
326impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Vec<T> {
327    fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
328        self.into_iter().map(|t| t.try_fold_with(folder)).collect()
329    }
330
331    fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
332        self.into_iter().map(|t| t.fold_with(folder)).collect()
333    }
334}
335
336impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for ThinVec<T> {
337    fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
338        self.into_iter().map(|t| t.try_fold_with(folder)).collect()
339    }
340
341    fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
342        self.into_iter().map(|t| t.fold_with(folder)).collect()
343    }
344}
345
346impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Box<[T]> {
347    fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
348        Vec::from(self).try_fold_with(folder).map(Vec::into_boxed_slice)
349    }
350
351    fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
352        Vec::into_boxed_slice(Vec::from(self).fold_with(folder))
353    }
354}
355
356impl<I: Interner, T: TypeFoldable<I>, Ix: Idx> TypeFoldable<I> for IndexVec<Ix, T> {
357    fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
358        self.raw.try_fold_with(folder).map(IndexVec::from_raw)
359    }
360
361    fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
362        IndexVec::from_raw(self.raw.fold_with(folder))
363    }
364}
365
366struct Shifter<I: Interner> {
376    cx: I,
377    current_index: ty::DebruijnIndex,
378    amount: u32,
379}
380
381impl<I: Interner> Shifter<I> {
382    fn new(cx: I, amount: u32) -> Self {
383        Shifter { cx, current_index: ty::INNERMOST, amount }
384    }
385}
386
387impl<I: Interner> TypeFolder<I> for Shifter<I> {
388    fn cx(&self) -> I {
389        self.cx
390    }
391
392    fn fold_binder<T: TypeFoldable<I>>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T> {
393        self.current_index.shift_in(1);
394        let t = t.super_fold_with(self);
395        self.current_index.shift_out(1);
396        t
397    }
398
399    fn fold_region(&mut self, r: I::Region) -> I::Region {
400        match r.kind() {
401            ty::ReBound(ty::BoundVarIndexKind::Bound(debruijn), br)
402                if debruijn >= self.current_index =>
403            {
404                let debruijn = debruijn.shifted_in(self.amount);
405                Region::new_bound(self.cx, debruijn, br)
406            }
407            _ => r,
408        }
409    }
410
411    fn fold_ty(&mut self, ty: I::Ty) -> I::Ty {
412        match ty.kind() {
413            ty::Bound(BoundVarIndexKind::Bound(debruijn), bound_ty)
414                if debruijn >= self.current_index =>
415            {
416                let debruijn = debruijn.shifted_in(self.amount);
417                Ty::new_bound(self.cx, debruijn, bound_ty)
418            }
419
420            _ if ty.has_vars_bound_at_or_above(self.current_index) => ty.super_fold_with(self),
421            _ => ty,
422        }
423    }
424
425    fn fold_const(&mut self, ct: I::Const) -> I::Const {
426        match ct.kind() {
427            ty::ConstKind::Bound(ty::BoundVarIndexKind::Bound(debruijn), bound_ct)
428                if debruijn >= self.current_index =>
429            {
430                let debruijn = debruijn.shifted_in(self.amount);
431                Const::new_bound(self.cx, debruijn, bound_ct)
432            }
433            _ => ct.super_fold_with(self),
434        }
435    }
436
437    fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
438        if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
439    }
440}
441
442pub fn shift_region<I: Interner>(cx: I, region: I::Region, amount: u32) -> I::Region {
443    match region.kind() {
444        ty::ReBound(ty::BoundVarIndexKind::Bound(debruijn), br) if amount > 0 => {
445            Region::new_bound(cx, debruijn.shifted_in(amount), br)
446        }
447        _ => region,
448    }
449}
450
451#[instrument(level = "trace", skip(cx), ret)]
452pub fn shift_vars<I: Interner, T>(cx: I, value: T, amount: u32) -> T
453where
454    T: TypeFoldable<I>,
455{
456    if amount == 0 || !value.has_escaping_bound_vars() {
457        value
458    } else {
459        value.fold_with(&mut Shifter::new(cx, amount))
460    }
461}
462
463pub fn fold_regions<I: Interner, T>(
467    cx: I,
468    value: T,
469    f: impl FnMut(I::Region, ty::DebruijnIndex) -> I::Region,
470) -> T
471where
472    T: TypeFoldable<I>,
473{
474    value.fold_with(&mut RegionFolder::new(cx, f))
475}
476
477pub struct RegionFolder<I, F> {
485    cx: I,
486
487    current_index: ty::DebruijnIndex,
491
492    fold_region_fn: F,
496}
497
498impl<I, F> RegionFolder<I, F> {
499    #[inline]
500    pub fn new(cx: I, fold_region_fn: F) -> RegionFolder<I, F> {
501        RegionFolder { cx, current_index: ty::INNERMOST, fold_region_fn }
502    }
503}
504
505impl<I, F> TypeFolder<I> for RegionFolder<I, F>
506where
507    I: Interner,
508    F: FnMut(I::Region, ty::DebruijnIndex) -> I::Region,
509{
510    fn cx(&self) -> I {
511        self.cx
512    }
513
514    fn fold_binder<T: TypeFoldable<I>>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T> {
515        self.current_index.shift_in(1);
516        let t = t.super_fold_with(self);
517        self.current_index.shift_out(1);
518        t
519    }
520
521    #[instrument(skip(self), level = "debug", ret)]
522    fn fold_region(&mut self, r: I::Region) -> I::Region {
523        match r.kind() {
524            ty::ReBound(ty::BoundVarIndexKind::Bound(debruijn), _)
525                if debruijn < self.current_index =>
526            {
527                debug!(?self.current_index, "skipped bound region");
528                r
529            }
530            ty::ReBound(ty::BoundVarIndexKind::Canonical, _) => {
531                debug!(?self.current_index, "skipped bound region");
532                r
533            }
534            _ => {
535                debug!(?self.current_index, "folding free region");
536                (self.fold_region_fn)(r, self.current_index)
537            }
538        }
539    }
540
541    fn fold_ty(&mut self, t: I::Ty) -> I::Ty {
542        if t.has_type_flags(
543            TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
544        ) {
545            t.super_fold_with(self)
546        } else {
547            t
548        }
549    }
550
551    fn fold_const(&mut self, ct: I::Const) -> I::Const {
552        if ct.has_type_flags(
553            TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
554        ) {
555            ct.super_fold_with(self)
556        } else {
557            ct
558        }
559    }
560
561    fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
562        if p.has_type_flags(
563            TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
564        ) {
565            p.super_fold_with(self)
566        } else {
567            p
568        }
569    }
570}