use std::collections::hash_map::Entry;
use rustc_data_structures::fx::FxHashMap;
use rustc_middle::ty::error::TypeError;
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
use tracing::instrument;
use crate::infer::region_constraints::VerifyIfEq;
use crate::infer::relate::{self as relate, Relate, RelateResult, TypeRelation};
#[instrument(level = "debug", skip(tcx))]
pub fn extract_verify_if_eq<'tcx>(
tcx: TyCtxt<'tcx>,
verify_if_eq_b: &ty::Binder<'tcx, VerifyIfEq<'tcx>>,
test_ty: Ty<'tcx>,
) -> Option<ty::Region<'tcx>> {
assert!(!verify_if_eq_b.has_escaping_bound_vars());
let mut m = MatchAgainstHigherRankedOutlives::new(tcx);
let verify_if_eq = verify_if_eq_b.skip_binder();
m.relate(verify_if_eq.ty, test_ty).ok()?;
if let ty::RegionKind::ReBound(depth, br) = verify_if_eq.bound.kind() {
assert!(depth == ty::INNERMOST);
match m.map.get(&br) {
Some(&r) => Some(r),
None => {
Some(tcx.lifetimes.re_static)
}
}
} else {
Some(verify_if_eq.bound)
}
}
#[instrument(level = "debug", skip(tcx))]
pub(super) fn can_match_erased_ty<'tcx>(
tcx: TyCtxt<'tcx>,
outlives_predicate: ty::Binder<'tcx, ty::TypeOutlivesPredicate<'tcx>>,
erased_ty: Ty<'tcx>,
) -> bool {
assert!(!outlives_predicate.has_escaping_bound_vars());
let erased_outlives_predicate = tcx.erase_regions(outlives_predicate);
let outlives_ty = erased_outlives_predicate.skip_binder().0;
if outlives_ty == erased_ty {
true
} else {
MatchAgainstHigherRankedOutlives::new(tcx).relate(outlives_ty, erased_ty).is_ok()
}
}
struct MatchAgainstHigherRankedOutlives<'tcx> {
tcx: TyCtxt<'tcx>,
pattern_depth: ty::DebruijnIndex,
map: FxHashMap<ty::BoundRegion, ty::Region<'tcx>>,
}
impl<'tcx> MatchAgainstHigherRankedOutlives<'tcx> {
fn new(tcx: TyCtxt<'tcx>) -> MatchAgainstHigherRankedOutlives<'tcx> {
MatchAgainstHigherRankedOutlives {
tcx,
pattern_depth: ty::INNERMOST,
map: FxHashMap::default(),
}
}
}
impl<'tcx> MatchAgainstHigherRankedOutlives<'tcx> {
fn no_match<T>(&self) -> RelateResult<'tcx, T> {
Err(TypeError::Mismatch)
}
#[instrument(level = "trace", skip(self))]
fn bind(
&mut self,
br: ty::BoundRegion,
value: ty::Region<'tcx>,
) -> RelateResult<'tcx, ty::Region<'tcx>> {
match self.map.entry(br) {
Entry::Occupied(entry) => {
if *entry.get() == value {
Ok(value)
} else {
self.no_match()
}
}
Entry::Vacant(entry) => {
entry.insert(value);
Ok(value)
}
}
}
}
impl<'tcx> TypeRelation<TyCtxt<'tcx>> for MatchAgainstHigherRankedOutlives<'tcx> {
fn cx(&self) -> TyCtxt<'tcx> {
self.tcx
}
#[instrument(level = "trace", skip(self))]
fn relate_with_variance<T: Relate<TyCtxt<'tcx>>>(
&mut self,
variance: ty::Variance,
_: ty::VarianceDiagInfo<TyCtxt<'tcx>>,
a: T,
b: T,
) -> RelateResult<'tcx, T> {
if variance != ty::Bivariant { self.relate(a, b) } else { Ok(a) }
}
#[instrument(skip(self), level = "trace")]
fn regions(
&mut self,
pattern: ty::Region<'tcx>,
value: ty::Region<'tcx>,
) -> RelateResult<'tcx, ty::Region<'tcx>> {
if let ty::RegionKind::ReBound(depth, br) = pattern.kind()
&& depth == self.pattern_depth
{
self.bind(br, value)
} else if pattern == value {
Ok(pattern)
} else {
self.no_match()
}
}
#[instrument(skip(self), level = "trace")]
fn tys(&mut self, pattern: Ty<'tcx>, value: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
if matches!(pattern.kind(), ty::Error(_) | ty::Bound(..)) {
self.no_match()
} else if pattern == value {
Ok(pattern)
} else {
relate::structurally_relate_tys(self, pattern, value)
}
}
#[instrument(skip(self), level = "trace")]
fn consts(
&mut self,
pattern: ty::Const<'tcx>,
value: ty::Const<'tcx>,
) -> RelateResult<'tcx, ty::Const<'tcx>> {
if pattern == value {
Ok(pattern)
} else {
relate::structurally_relate_consts(self, pattern, value)
}
}
#[instrument(skip(self), level = "trace")]
fn binders<T>(
&mut self,
pattern: ty::Binder<'tcx, T>,
value: ty::Binder<'tcx, T>,
) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
where
T: Relate<TyCtxt<'tcx>>,
{
self.pattern_depth.shift_in(1);
let result = Ok(pattern.rebind(self.relate(pattern.skip_binder(), value.skip_binder())?));
self.pattern_depth.shift_out(1);
result
}
}