use std::iter;
use derive_where::derive_where;
use rustc_ast_ir::Mutability;
use tracing::{instrument, trace};
use crate::error::{ExpectedFound, TypeError};
use crate::fold::TypeFoldable;
use crate::inherent::*;
use crate::{self as ty, Interner};
pub type RelateResult<I, T> = Result<T, TypeError<I>>;
#[derive_where(Clone, Copy, PartialEq, Eq, Debug, Default; I: Interner)]
pub enum VarianceDiagInfo<I: Interner> {
#[derive_where(default)]
None,
Invariant {
ty: I::Ty,
param_index: u32,
},
}
impl<I: Interner> VarianceDiagInfo<I> {
pub fn xform(self, other: VarianceDiagInfo<I>) -> VarianceDiagInfo<I> {
match self {
VarianceDiagInfo::None => other,
VarianceDiagInfo::Invariant { .. } => self,
}
}
}
pub trait TypeRelation<I: Interner>: Sized {
fn cx(&self) -> I;
fn relate<T: Relate<I>>(&mut self, a: T, b: T) -> RelateResult<I, T> {
Relate::relate(self, a, b)
}
#[instrument(skip(self), level = "trace")]
fn relate_item_args(
&mut self,
item_def_id: I::DefId,
a_arg: I::GenericArgs,
b_arg: I::GenericArgs,
) -> RelateResult<I, I::GenericArgs> {
let cx = self.cx();
let opt_variances = cx.variances_of(item_def_id);
relate_args_with_variances(self, item_def_id, opt_variances, a_arg, b_arg, true)
}
fn relate_with_variance<T: Relate<I>>(
&mut self,
variance: ty::Variance,
info: VarianceDiagInfo<I>,
a: T,
b: T,
) -> RelateResult<I, T>;
fn tys(&mut self, a: I::Ty, b: I::Ty) -> RelateResult<I, I::Ty>;
fn regions(&mut self, a: I::Region, b: I::Region) -> RelateResult<I, I::Region>;
fn consts(&mut self, a: I::Const, b: I::Const) -> RelateResult<I, I::Const>;
fn binders<T>(
&mut self,
a: ty::Binder<I, T>,
b: ty::Binder<I, T>,
) -> RelateResult<I, ty::Binder<I, T>>
where
T: Relate<I>;
}
pub trait Relate<I: Interner>: TypeFoldable<I> + PartialEq + Copy {
fn relate<R: TypeRelation<I>>(relation: &mut R, a: Self, b: Self) -> RelateResult<I, Self>;
}
#[inline]
pub fn relate_args_invariantly<I: Interner, R: TypeRelation<I>>(
relation: &mut R,
a_arg: I::GenericArgs,
b_arg: I::GenericArgs,
) -> RelateResult<I, I::GenericArgs> {
relation.cx().mk_args_from_iter(iter::zip(a_arg.iter(), b_arg.iter()).map(|(a, b)| {
relation.relate_with_variance(ty::Invariant, VarianceDiagInfo::default(), a, b)
}))
}
pub fn relate_args_with_variances<I: Interner, R: TypeRelation<I>>(
relation: &mut R,
ty_def_id: I::DefId,
variances: I::VariancesOf,
a_arg: I::GenericArgs,
b_arg: I::GenericArgs,
fetch_ty_for_diag: bool,
) -> RelateResult<I, I::GenericArgs> {
let cx = relation.cx();
let mut cached_ty = None;
let params = iter::zip(a_arg.iter(), b_arg.iter()).enumerate().map(|(i, (a, b))| {
let variance = variances.get(i).unwrap();
let variance_info = if variance == ty::Invariant && fetch_ty_for_diag {
let ty = *cached_ty.get_or_insert_with(|| cx.type_of(ty_def_id).instantiate(cx, a_arg));
VarianceDiagInfo::Invariant { ty, param_index: i.try_into().unwrap() }
} else {
VarianceDiagInfo::default()
};
relation.relate_with_variance(variance, variance_info, a, b)
});
cx.mk_args_from_iter(params)
}
impl<I: Interner> Relate<I> for ty::FnSig<I> {
fn relate<R: TypeRelation<I>>(
relation: &mut R,
a: ty::FnSig<I>,
b: ty::FnSig<I>,
) -> RelateResult<I, ty::FnSig<I>> {
let cx = relation.cx();
if a.c_variadic != b.c_variadic {
return Err(TypeError::VariadicMismatch({
let a = a.c_variadic;
let b = b.c_variadic;
ExpectedFound::new(true, a, b)
}));
}
let safety = relation.relate(a.safety, b.safety)?;
let abi = relation.relate(a.abi, b.abi)?;
let a_inputs = a.inputs();
let b_inputs = b.inputs();
if a_inputs.len() != b_inputs.len() {
return Err(TypeError::ArgCount);
}
let inputs_and_output = iter::zip(a_inputs.iter(), b_inputs.iter())
.map(|(a, b)| ((a, b), false))
.chain(iter::once(((a.output(), b.output()), true)))
.map(|((a, b), is_output)| {
if is_output {
relation.relate(a, b)
} else {
relation.relate_with_variance(
ty::Contravariant,
VarianceDiagInfo::default(),
a,
b,
)
}
})
.enumerate()
.map(|(i, r)| match r {
Err(TypeError::Sorts(exp_found) | TypeError::ArgumentSorts(exp_found, _)) => {
Err(TypeError::ArgumentSorts(exp_found, i))
}
Err(TypeError::Mutability | TypeError::ArgumentMutability(_)) => {
Err(TypeError::ArgumentMutability(i))
}
r => r,
});
Ok(ty::FnSig {
inputs_and_output: cx.mk_type_list_from_iter(inputs_and_output)?,
c_variadic: a.c_variadic,
safety,
abi,
})
}
}
impl<I: Interner> Relate<I> for ty::BoundConstness {
fn relate<R: TypeRelation<I>>(
_relation: &mut R,
a: ty::BoundConstness,
b: ty::BoundConstness,
) -> RelateResult<I, ty::BoundConstness> {
if a != b {
Err(TypeError::ConstnessMismatch(ExpectedFound::new(true, a, b)))
} else {
Ok(a)
}
}
}
impl<I: Interner> Relate<I> for ty::AliasTy<I> {
fn relate<R: TypeRelation<I>>(
relation: &mut R,
a: ty::AliasTy<I>,
b: ty::AliasTy<I>,
) -> RelateResult<I, ty::AliasTy<I>> {
if a.def_id != b.def_id {
Err(TypeError::ProjectionMismatched({
let a = a.def_id;
let b = b.def_id;
ExpectedFound::new(true, a, b)
}))
} else {
let args = match a.kind(relation.cx()) {
ty::Opaque => relate_args_with_variances(
relation,
a.def_id,
relation.cx().variances_of(a.def_id),
a.args,
b.args,
false, )?,
ty::Projection | ty::Weak | ty::Inherent => {
relate_args_invariantly(relation, a.args, b.args)?
}
};
Ok(ty::AliasTy::new_from_args(relation.cx(), a.def_id, args))
}
}
}
impl<I: Interner> Relate<I> for ty::AliasTerm<I> {
fn relate<R: TypeRelation<I>>(
relation: &mut R,
a: ty::AliasTerm<I>,
b: ty::AliasTerm<I>,
) -> RelateResult<I, ty::AliasTerm<I>> {
if a.def_id != b.def_id {
Err(TypeError::ProjectionMismatched({
let a = a.def_id;
let b = b.def_id;
ExpectedFound::new(true, a, b)
}))
} else {
let args = match a.kind(relation.cx()) {
ty::AliasTermKind::OpaqueTy => relate_args_with_variances(
relation,
a.def_id,
relation.cx().variances_of(a.def_id),
a.args,
b.args,
false, )?,
ty::AliasTermKind::ProjectionTy
| ty::AliasTermKind::WeakTy
| ty::AliasTermKind::InherentTy
| ty::AliasTermKind::UnevaluatedConst
| ty::AliasTermKind::ProjectionConst => {
relate_args_invariantly(relation, a.args, b.args)?
}
};
Ok(ty::AliasTerm::new_from_args(relation.cx(), a.def_id, args))
}
}
}
impl<I: Interner> Relate<I> for ty::ExistentialProjection<I> {
fn relate<R: TypeRelation<I>>(
relation: &mut R,
a: ty::ExistentialProjection<I>,
b: ty::ExistentialProjection<I>,
) -> RelateResult<I, ty::ExistentialProjection<I>> {
if a.def_id != b.def_id {
Err(TypeError::ProjectionMismatched({
let a = a.def_id;
let b = b.def_id;
ExpectedFound::new(true, a, b)
}))
} else {
let term = relation.relate_with_variance(
ty::Invariant,
VarianceDiagInfo::default(),
a.term,
b.term,
)?;
let args = relation.relate_with_variance(
ty::Invariant,
VarianceDiagInfo::default(),
a.args,
b.args,
)?;
Ok(ty::ExistentialProjection { def_id: a.def_id, args, term })
}
}
}
impl<I: Interner> Relate<I> for ty::TraitRef<I> {
fn relate<R: TypeRelation<I>>(
relation: &mut R,
a: ty::TraitRef<I>,
b: ty::TraitRef<I>,
) -> RelateResult<I, ty::TraitRef<I>> {
if a.def_id != b.def_id {
Err(TypeError::Traits({
let a = a.def_id;
let b = b.def_id;
ExpectedFound::new(true, a, b)
}))
} else {
let args = relate_args_invariantly(relation, a.args, b.args)?;
Ok(ty::TraitRef::new_from_args(relation.cx(), a.def_id, args))
}
}
}
impl<I: Interner> Relate<I> for ty::ExistentialTraitRef<I> {
fn relate<R: TypeRelation<I>>(
relation: &mut R,
a: ty::ExistentialTraitRef<I>,
b: ty::ExistentialTraitRef<I>,
) -> RelateResult<I, ty::ExistentialTraitRef<I>> {
if a.def_id != b.def_id {
Err(TypeError::Traits({
let a = a.def_id;
let b = b.def_id;
ExpectedFound::new(true, a, b)
}))
} else {
let args = relate_args_invariantly(relation, a.args, b.args)?;
Ok(ty::ExistentialTraitRef { def_id: a.def_id, args })
}
}
}
#[instrument(level = "trace", skip(relation), ret)]
pub fn structurally_relate_tys<I: Interner, R: TypeRelation<I>>(
relation: &mut R,
a: I::Ty,
b: I::Ty,
) -> RelateResult<I, I::Ty> {
let cx = relation.cx();
match (a.kind(), b.kind()) {
(ty::Infer(_), _) | (_, ty::Infer(_)) => {
panic!("var types encountered in structurally_relate_tys")
}
(ty::Bound(..), _) | (_, ty::Bound(..)) => {
panic!("bound types encountered in structurally_relate_tys")
}
(ty::Error(guar), _) | (_, ty::Error(guar)) => Ok(Ty::new_error(cx, guar)),
(ty::Never, _)
| (ty::Char, _)
| (ty::Bool, _)
| (ty::Int(_), _)
| (ty::Uint(_), _)
| (ty::Float(_), _)
| (ty::Str, _)
if a == b =>
{
Ok(a)
}
(ty::Param(a_p), ty::Param(b_p)) if a_p.index() == b_p.index() => {
Ok(a)
}
(ty::Placeholder(p1), ty::Placeholder(p2)) if p1 == p2 => Ok(a),
(ty::Adt(a_def, a_args), ty::Adt(b_def, b_args)) if a_def == b_def => {
let args = relation.relate_item_args(a_def.def_id(), a_args, b_args)?;
Ok(Ty::new_adt(cx, a_def, args))
}
(ty::Foreign(a_id), ty::Foreign(b_id)) if a_id == b_id => Ok(Ty::new_foreign(cx, a_id)),
(ty::Dynamic(a_obj, a_region, a_repr), ty::Dynamic(b_obj, b_region, b_repr))
if a_repr == b_repr =>
{
Ok(Ty::new_dynamic(
cx,
relation.relate(a_obj, b_obj)?,
relation.relate(a_region, b_region)?,
a_repr,
))
}
(ty::Coroutine(a_id, a_args), ty::Coroutine(b_id, b_args)) if a_id == b_id => {
let args = relate_args_invariantly(relation, a_args, b_args)?;
Ok(Ty::new_coroutine(cx, a_id, args))
}
(ty::CoroutineWitness(a_id, a_args), ty::CoroutineWitness(b_id, b_args))
if a_id == b_id =>
{
let args = relate_args_invariantly(relation, a_args, b_args)?;
Ok(Ty::new_coroutine_witness(cx, a_id, args))
}
(ty::Closure(a_id, a_args), ty::Closure(b_id, b_args)) if a_id == b_id => {
let args = relate_args_invariantly(relation, a_args, b_args)?;
Ok(Ty::new_closure(cx, a_id, args))
}
(ty::CoroutineClosure(a_id, a_args), ty::CoroutineClosure(b_id, b_args))
if a_id == b_id =>
{
let args = relate_args_invariantly(relation, a_args, b_args)?;
Ok(Ty::new_coroutine_closure(cx, a_id, args))
}
(ty::RawPtr(a_ty, a_mutbl), ty::RawPtr(b_ty, b_mutbl)) => {
if a_mutbl != b_mutbl {
return Err(TypeError::Mutability);
}
let (variance, info) = match a_mutbl {
Mutability::Not => (ty::Covariant, VarianceDiagInfo::None),
Mutability::Mut => {
(ty::Invariant, VarianceDiagInfo::Invariant { ty: a, param_index: 0 })
}
};
let ty = relation.relate_with_variance(variance, info, a_ty, b_ty)?;
Ok(Ty::new_ptr(cx, ty, a_mutbl))
}
(ty::Ref(a_r, a_ty, a_mutbl), ty::Ref(b_r, b_ty, b_mutbl)) => {
if a_mutbl != b_mutbl {
return Err(TypeError::Mutability);
}
let (variance, info) = match a_mutbl {
Mutability::Not => (ty::Covariant, VarianceDiagInfo::None),
Mutability::Mut => {
(ty::Invariant, VarianceDiagInfo::Invariant { ty: a, param_index: 0 })
}
};
let r = relation.relate(a_r, b_r)?;
let ty = relation.relate_with_variance(variance, info, a_ty, b_ty)?;
Ok(Ty::new_ref(cx, r, ty, a_mutbl))
}
(ty::Array(a_t, sz_a), ty::Array(b_t, sz_b)) => {
let t = relation.relate(a_t, b_t)?;
match relation.relate(sz_a, sz_b) {
Ok(sz) => Ok(Ty::new_array_with_const_len(cx, t, sz)),
Err(err) => {
let sz_a = sz_a.try_to_target_usize(cx);
let sz_b = sz_b.try_to_target_usize(cx);
match (sz_a, sz_b) {
(Some(sz_a_val), Some(sz_b_val)) if sz_a_val != sz_b_val => Err(
TypeError::FixedArraySize(ExpectedFound::new(true, sz_a_val, sz_b_val)),
),
_ => Err(err),
}
}
}
}
(ty::Slice(a_t), ty::Slice(b_t)) => {
let t = relation.relate(a_t, b_t)?;
Ok(Ty::new_slice(cx, t))
}
(ty::Tuple(as_), ty::Tuple(bs)) => {
if as_.len() == bs.len() {
Ok(Ty::new_tup_from_iter(
cx,
iter::zip(as_.iter(), bs.iter()).map(|(a, b)| relation.relate(a, b)),
)?)
} else if !(as_.is_empty() || bs.is_empty()) {
Err(TypeError::TupleSize(ExpectedFound::new(true, as_.len(), bs.len())))
} else {
Err(TypeError::Sorts(ExpectedFound::new(true, a, b)))
}
}
(ty::FnDef(a_def_id, a_args), ty::FnDef(b_def_id, b_args)) if a_def_id == b_def_id => {
let args = relation.relate_item_args(a_def_id, a_args, b_args)?;
Ok(Ty::new_fn_def(cx, a_def_id, args))
}
(ty::FnPtr(a_sig_tys, a_hdr), ty::FnPtr(b_sig_tys, b_hdr)) => {
let fty = relation.relate(a_sig_tys.with(a_hdr), b_sig_tys.with(b_hdr))?;
Ok(Ty::new_fn_ptr(cx, fty))
}
(ty::Alias(a_kind, a_data), ty::Alias(b_kind, b_data)) => {
let alias_ty = relation.relate(a_data, b_data)?;
assert_eq!(a_kind, b_kind);
Ok(Ty::new_alias(cx, a_kind, alias_ty))
}
(ty::Pat(a_ty, a_pat), ty::Pat(b_ty, b_pat)) => {
let ty = relation.relate(a_ty, b_ty)?;
let pat = relation.relate(a_pat, b_pat)?;
Ok(Ty::new_pat(cx, ty, pat))
}
_ => Err(TypeError::Sorts(ExpectedFound::new(true, a, b))),
}
}
pub fn structurally_relate_consts<I: Interner, R: TypeRelation<I>>(
relation: &mut R,
mut a: I::Const,
mut b: I::Const,
) -> RelateResult<I, I::Const> {
trace!(
"structurally_relate_consts::<{}>(a = {:?}, b = {:?})",
std::any::type_name::<R>(),
a,
b
);
let cx = relation.cx();
if cx.features().generic_const_exprs() {
a = cx.expand_abstract_consts(a);
b = cx.expand_abstract_consts(b);
}
trace!(
"structurally_relate_consts::<{}>(normed_a = {:?}, normed_b = {:?})",
std::any::type_name::<R>(),
a,
b
);
let is_match = match (a.kind(), b.kind()) {
(ty::ConstKind::Infer(_), _) | (_, ty::ConstKind::Infer(_)) => {
panic!("var types encountered in structurally_relate_consts: {:?} {:?}", a, b)
}
(ty::ConstKind::Error(_), _) => return Ok(a),
(_, ty::ConstKind::Error(_)) => return Ok(b),
(ty::ConstKind::Param(a_p), ty::ConstKind::Param(b_p)) if a_p.index() == b_p.index() => {
true
}
(ty::ConstKind::Placeholder(p1), ty::ConstKind::Placeholder(p2)) => p1 == p2,
(ty::ConstKind::Value(_, a_val), ty::ConstKind::Value(_, b_val)) => a_val == b_val,
(ty::ConstKind::Unevaluated(au), ty::ConstKind::Unevaluated(bu)) if au.def == bu.def => {
if cfg!(debug_assertions) {
let a_ty = cx.type_of(au.def).instantiate(cx, au.args);
let b_ty = cx.type_of(bu.def).instantiate(cx, bu.args);
assert_eq!(a_ty, b_ty);
}
let args = relation.relate_with_variance(
ty::Invariant,
VarianceDiagInfo::default(),
au.args,
bu.args,
)?;
return Ok(Const::new_unevaluated(cx, ty::UnevaluatedConst { def: au.def, args }));
}
(ty::ConstKind::Expr(ae), ty::ConstKind::Expr(be)) => {
let expr = relation.relate(ae, be)?;
return Ok(Const::new_expr(cx, expr));
}
_ => false,
};
if is_match { Ok(a) } else { Err(TypeError::ConstMismatch(ExpectedFound::new(true, a, b))) }
}
impl<I: Interner, T: Relate<I>> Relate<I> for ty::Binder<I, T> {
fn relate<R: TypeRelation<I>>(
relation: &mut R,
a: ty::Binder<I, T>,
b: ty::Binder<I, T>,
) -> RelateResult<I, ty::Binder<I, T>> {
relation.binders(a, b)
}
}
impl<I: Interner> Relate<I> for ty::PredicatePolarity {
fn relate<R: TypeRelation<I>>(
_relation: &mut R,
a: ty::PredicatePolarity,
b: ty::PredicatePolarity,
) -> RelateResult<I, ty::PredicatePolarity> {
if a != b {
Err(TypeError::PolarityMismatch(ExpectedFound::new(true, a, b)))
} else {
Ok(a)
}
}
}
impl<I: Interner> Relate<I> for ty::TraitPredicate<I> {
fn relate<R: TypeRelation<I>>(
relation: &mut R,
a: ty::TraitPredicate<I>,
b: ty::TraitPredicate<I>,
) -> RelateResult<I, ty::TraitPredicate<I>> {
Ok(ty::TraitPredicate {
trait_ref: relation.relate(a.trait_ref, b.trait_ref)?,
polarity: relation.relate(a.polarity, b.polarity)?,
})
}
}