use relate::lattice::{LatticeOp, LatticeOpKind};
use rustc_middle::bug;
use rustc_middle::ty::relate::solver_relating::RelateExt as NextSolverRelate;
use rustc_middle::ty::{Const, ImplSubject, TypingMode};
use super::*;
use crate::infer::relate::type_relating::TypeRelating;
use crate::infer::relate::{Relate, TypeRelation};
use crate::traits::Obligation;
use crate::traits::solve::Goal;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum DefineOpaqueTypes {
Yes,
No,
}
#[derive(Clone, Copy)]
pub struct At<'a, 'tcx> {
pub infcx: &'a InferCtxt<'tcx>,
pub cause: &'a ObligationCause<'tcx>,
pub param_env: ty::ParamEnv<'tcx>,
}
impl<'tcx> InferCtxt<'tcx> {
#[inline]
pub fn at<'a>(
&'a self,
cause: &'a ObligationCause<'tcx>,
param_env: ty::ParamEnv<'tcx>,
) -> At<'a, 'tcx> {
At { infcx: self, cause, param_env }
}
pub fn fork(&self) -> Self {
Self {
tcx: self.tcx,
typing_mode: self.typing_mode,
considering_regions: self.considering_regions,
skip_leak_check: self.skip_leak_check,
inner: self.inner.clone(),
lexical_region_resolutions: self.lexical_region_resolutions.clone(),
selection_cache: self.selection_cache.clone(),
evaluation_cache: self.evaluation_cache.clone(),
reported_trait_errors: self.reported_trait_errors.clone(),
reported_signature_mismatch: self.reported_signature_mismatch.clone(),
tainted_by_errors: self.tainted_by_errors.clone(),
universe: self.universe.clone(),
next_trait_solver: self.next_trait_solver,
obligation_inspector: self.obligation_inspector.clone(),
}
}
pub fn fork_with_typing_mode(&self, typing_mode: TypingMode<'tcx>) -> Self {
let forked = Self {
tcx: self.tcx,
typing_mode,
considering_regions: self.considering_regions,
skip_leak_check: self.skip_leak_check,
inner: self.inner.clone(),
lexical_region_resolutions: self.lexical_region_resolutions.clone(),
selection_cache: Default::default(),
evaluation_cache: Default::default(),
reported_trait_errors: self.reported_trait_errors.clone(),
reported_signature_mismatch: self.reported_signature_mismatch.clone(),
tainted_by_errors: self.tainted_by_errors.clone(),
universe: self.universe.clone(),
next_trait_solver: self.next_trait_solver,
obligation_inspector: self.obligation_inspector.clone(),
};
forked.inner.borrow_mut().projection_cache().clear();
forked
}
}
pub trait ToTrace<'tcx>: Relate<TyCtxt<'tcx>> + Copy {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx>;
}
impl<'a, 'tcx> At<'a, 'tcx> {
pub fn sup<T>(
self,
define_opaque_types: DefineOpaqueTypes,
expected: T,
actual: T,
) -> InferResult<'tcx, ()>
where
T: ToTrace<'tcx>,
{
if self.infcx.next_trait_solver {
NextSolverRelate::relate(
self.infcx,
self.param_env,
expected,
ty::Contravariant,
actual,
)
.map(|goals| self.goals_to_obligations(goals))
} else {
let mut op = TypeRelating::new(
self.infcx,
ToTrace::to_trace(self.cause, expected, actual),
self.param_env,
define_opaque_types,
ty::Contravariant,
);
op.relate(expected, actual)?;
Ok(InferOk { value: (), obligations: op.into_obligations() })
}
}
pub fn sub<T>(
self,
define_opaque_types: DefineOpaqueTypes,
expected: T,
actual: T,
) -> InferResult<'tcx, ()>
where
T: ToTrace<'tcx>,
{
if self.infcx.next_trait_solver {
NextSolverRelate::relate(self.infcx, self.param_env, expected, ty::Covariant, actual)
.map(|goals| self.goals_to_obligations(goals))
} else {
let mut op = TypeRelating::new(
self.infcx,
ToTrace::to_trace(self.cause, expected, actual),
self.param_env,
define_opaque_types,
ty::Covariant,
);
op.relate(expected, actual)?;
Ok(InferOk { value: (), obligations: op.into_obligations() })
}
}
pub fn eq<T>(
self,
define_opaque_types: DefineOpaqueTypes,
expected: T,
actual: T,
) -> InferResult<'tcx, ()>
where
T: ToTrace<'tcx>,
{
self.eq_trace(
define_opaque_types,
ToTrace::to_trace(self.cause, expected, actual),
expected,
actual,
)
}
pub fn eq_trace<T>(
self,
define_opaque_types: DefineOpaqueTypes,
trace: TypeTrace<'tcx>,
expected: T,
actual: T,
) -> InferResult<'tcx, ()>
where
T: Relate<TyCtxt<'tcx>>,
{
if self.infcx.next_trait_solver {
NextSolverRelate::relate(self.infcx, self.param_env, expected, ty::Invariant, actual)
.map(|goals| self.goals_to_obligations(goals))
} else {
let mut op = TypeRelating::new(
self.infcx,
trace,
self.param_env,
define_opaque_types,
ty::Invariant,
);
op.relate(expected, actual)?;
Ok(InferOk { value: (), obligations: op.into_obligations() })
}
}
pub fn relate<T>(
self,
define_opaque_types: DefineOpaqueTypes,
expected: T,
variance: ty::Variance,
actual: T,
) -> InferResult<'tcx, ()>
where
T: ToTrace<'tcx>,
{
match variance {
ty::Covariant => self.sub(define_opaque_types, expected, actual),
ty::Invariant => self.eq(define_opaque_types, expected, actual),
ty::Contravariant => self.sup(define_opaque_types, expected, actual),
ty::Bivariant => panic!("Bivariant given to `relate()`"),
}
}
pub fn lub<T>(self, expected: T, actual: T) -> InferResult<'tcx, T>
where
T: ToTrace<'tcx>,
{
let mut op = LatticeOp::new(
self.infcx,
ToTrace::to_trace(self.cause, expected, actual),
self.param_env,
LatticeOpKind::Lub,
);
let value = op.relate(expected, actual)?;
Ok(InferOk { value, obligations: op.into_obligations() })
}
fn goals_to_obligations(
&self,
goals: Vec<Goal<'tcx, ty::Predicate<'tcx>>>,
) -> InferOk<'tcx, ()> {
InferOk {
value: (),
obligations: goals
.into_iter()
.map(|goal| {
Obligation::new(
self.infcx.tcx,
self.cause.clone(),
goal.param_env,
goal.predicate,
)
})
.collect(),
}
}
}
impl<'tcx> ToTrace<'tcx> for ImplSubject<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
match (a, b) {
(ImplSubject::Trait(trait_ref_a), ImplSubject::Trait(trait_ref_b)) => {
ToTrace::to_trace(cause, trait_ref_a, trait_ref_b)
}
(ImplSubject::Inherent(ty_a), ImplSubject::Inherent(ty_b)) => {
ToTrace::to_trace(cause, ty_a, ty_b)
}
(ImplSubject::Trait(_), ImplSubject::Inherent(_))
| (ImplSubject::Inherent(_), ImplSubject::Trait(_)) => {
bug!("can not trace TraitRef and Ty");
}
}
}
}
impl<'tcx> ToTrace<'tcx> for Ty<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::Terms(ExpectedFound::new(a.into(), b.into())),
}
}
}
impl<'tcx> ToTrace<'tcx> for ty::Region<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace { cause: cause.clone(), values: ValuePairs::Regions(ExpectedFound::new(a, b)) }
}
}
impl<'tcx> ToTrace<'tcx> for Const<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::Terms(ExpectedFound::new(a.into(), b.into())),
}
}
}
impl<'tcx> ToTrace<'tcx> for ty::GenericArg<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: match (a.unpack(), b.unpack()) {
(GenericArgKind::Lifetime(a), GenericArgKind::Lifetime(b)) => {
ValuePairs::Regions(ExpectedFound::new(a, b))
}
(GenericArgKind::Type(a), GenericArgKind::Type(b)) => {
ValuePairs::Terms(ExpectedFound::new(a.into(), b.into()))
}
(GenericArgKind::Const(a), GenericArgKind::Const(b)) => {
ValuePairs::Terms(ExpectedFound::new(a.into(), b.into()))
}
_ => bug!("relating different kinds: {a:?} {b:?}"),
},
}
}
}
impl<'tcx> ToTrace<'tcx> for ty::Term<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace { cause: cause.clone(), values: ValuePairs::Terms(ExpectedFound::new(a, b)) }
}
}
impl<'tcx> ToTrace<'tcx> for ty::TraitRef<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace { cause: cause.clone(), values: ValuePairs::TraitRefs(ExpectedFound::new(a, b)) }
}
}
impl<'tcx> ToTrace<'tcx> for ty::AliasTy<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::Aliases(ExpectedFound::new(a.into(), b.into())),
}
}
}
impl<'tcx> ToTrace<'tcx> for ty::AliasTerm<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace { cause: cause.clone(), values: ValuePairs::Aliases(ExpectedFound::new(a, b)) }
}
}
impl<'tcx> ToTrace<'tcx> for ty::FnSig<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::PolySigs(ExpectedFound::new(
ty::Binder::dummy(a),
ty::Binder::dummy(b),
)),
}
}
}
impl<'tcx> ToTrace<'tcx> for ty::PolyFnSig<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace { cause: cause.clone(), values: ValuePairs::PolySigs(ExpectedFound::new(a, b)) }
}
}
impl<'tcx> ToTrace<'tcx> for ty::PolyExistentialTraitRef<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::ExistentialTraitRef(ExpectedFound::new(a, b)),
}
}
}
impl<'tcx> ToTrace<'tcx> for ty::PolyExistentialProjection<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::ExistentialProjection(ExpectedFound::new(a, b)),
}
}
}