use std::assert_matches::assert_matches;
use hir::LangItem;
use rustc_ast::Mutability;
use rustc_data_structures::fx::FxIndexSet;
use rustc_hir as hir;
use rustc_infer::infer::outlives::env::OutlivesEnvironment;
use rustc_infer::infer::{RegionResolutionError, TyCtxtInferExt};
use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt, TypeVisitableExt};
use super::outlives_bounds::InferCtxtExt;
use crate::regions::InferCtxtRegionExt;
use crate::traits::{self, FulfillmentError, ObligationCause};
pub enum CopyImplementationError<'tcx> {
InfringingFields(Vec<(&'tcx ty::FieldDef, Ty<'tcx>, InfringingFieldsReason<'tcx>)>),
NotAnAdt,
HasDestructor,
}
pub enum ConstParamTyImplementationError<'tcx> {
UnsizedConstParamsFeatureRequired,
InvalidInnerTyOfBuiltinTy(Vec<(Ty<'tcx>, InfringingFieldsReason<'tcx>)>),
InfrigingFields(Vec<(&'tcx ty::FieldDef, Ty<'tcx>, InfringingFieldsReason<'tcx>)>),
NotAnAdtOrBuiltinAllowed,
}
pub enum InfringingFieldsReason<'tcx> {
Fulfill(Vec<FulfillmentError<'tcx>>),
Regions(Vec<RegionResolutionError<'tcx>>),
}
pub fn type_allowed_to_implement_copy<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
self_type: Ty<'tcx>,
parent_cause: ObligationCause<'tcx>,
) -> Result<(), CopyImplementationError<'tcx>> {
let (adt, args) = match self_type.kind() {
ty::Uint(_)
| ty::Int(_)
| ty::Bool
| ty::Float(_)
| ty::Char
| ty::RawPtr(..)
| ty::Never
| ty::Ref(_, _, hir::Mutability::Not)
| ty::Array(..) => return Ok(()),
&ty::Adt(adt, args) => (adt, args),
_ => return Err(CopyImplementationError::NotAnAdt),
};
all_fields_implement_trait(
tcx,
param_env,
self_type,
adt,
args,
parent_cause,
hir::LangItem::Copy,
)
.map_err(CopyImplementationError::InfringingFields)?;
if adt.has_dtor(tcx) {
return Err(CopyImplementationError::HasDestructor);
}
Ok(())
}
pub fn type_allowed_to_implement_const_param_ty<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
self_type: Ty<'tcx>,
lang_item: LangItem,
parent_cause: ObligationCause<'tcx>,
) -> Result<(), ConstParamTyImplementationError<'tcx>> {
assert_matches!(lang_item, LangItem::ConstParamTy | LangItem::UnsizedConstParamTy);
let inner_tys: Vec<_> = match *self_type.kind() {
ty::Uint(_) | ty::Int(_) | ty::Bool | ty::Char => return Ok(()),
ty::Slice(inner_ty) | ty::Ref(_, inner_ty, Mutability::Not)
if lang_item == LangItem::UnsizedConstParamTy =>
{
vec![inner_ty]
}
ty::Str if lang_item == LangItem::UnsizedConstParamTy => {
vec![Ty::new_slice(tcx, tcx.types.u8)]
}
ty::Str | ty::Slice(..) | ty::Ref(_, _, Mutability::Not) => {
return Err(ConstParamTyImplementationError::UnsizedConstParamsFeatureRequired);
}
ty::Array(inner_ty, _) => vec![inner_ty],
ty::Tuple(inner_tys) => inner_tys.into_iter().collect(),
ty::Adt(adt, args) if adt.is_enum() || adt.is_struct() => {
all_fields_implement_trait(
tcx,
param_env,
self_type,
adt,
args,
parent_cause.clone(),
lang_item,
)
.map_err(ConstParamTyImplementationError::InfrigingFields)?;
vec![]
}
_ => return Err(ConstParamTyImplementationError::NotAnAdtOrBuiltinAllowed),
};
let mut infringing_inner_tys = vec![];
for inner_ty in inner_tys {
let infcx = tcx.infer_ctxt().build();
let ocx = traits::ObligationCtxt::new_with_diagnostics(&infcx);
ocx.register_bound(
parent_cause.clone(),
param_env,
inner_ty,
tcx.require_lang_item(lang_item, Some(parent_cause.span)),
);
let errors = ocx.select_all_or_error();
if !errors.is_empty() {
infringing_inner_tys.push((inner_ty, InfringingFieldsReason::Fulfill(errors)));
continue;
}
let outlives_env = OutlivesEnvironment::with_bounds(
param_env,
infcx.implied_bounds_tys(
param_env,
parent_cause.body_id,
&FxIndexSet::from_iter([self_type]),
),
);
let errors = infcx.resolve_regions(&outlives_env);
if !errors.is_empty() {
infringing_inner_tys.push((inner_ty, InfringingFieldsReason::Regions(errors)));
continue;
}
}
if !infringing_inner_tys.is_empty() {
return Err(ConstParamTyImplementationError::InvalidInnerTyOfBuiltinTy(
infringing_inner_tys,
));
}
Ok(())
}
pub fn all_fields_implement_trait<'tcx>(
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
self_type: Ty<'tcx>,
adt: AdtDef<'tcx>,
args: ty::GenericArgsRef<'tcx>,
parent_cause: ObligationCause<'tcx>,
lang_item: LangItem,
) -> Result<(), Vec<(&'tcx ty::FieldDef, Ty<'tcx>, InfringingFieldsReason<'tcx>)>> {
let trait_def_id = tcx.require_lang_item(lang_item, Some(parent_cause.span));
let mut infringing = Vec::new();
for variant in adt.variants() {
for field in &variant.fields {
let infcx = tcx.infer_ctxt().build();
let ocx = traits::ObligationCtxt::new_with_diagnostics(&infcx);
let unnormalized_ty = field.ty(tcx, args);
if unnormalized_ty.references_error() {
continue;
}
let field_span = tcx.def_span(field.did);
let field_ty_span = match tcx.hir().get_if_local(field.did) {
Some(hir::Node::Field(field_def)) => field_def.ty.span,
_ => field_span,
};
let normalization_cause = if field
.ty(tcx, traits::GenericArgs::identity_for_item(tcx, adt.did()))
.has_non_region_param()
{
parent_cause.clone()
} else {
ObligationCause::dummy_with_span(field_ty_span)
};
let ty = ocx.normalize(&normalization_cause, param_env, unnormalized_ty);
let normalization_errors = ocx.select_where_possible();
if !normalization_errors.is_empty() || ty.references_error() {
tcx.dcx().span_delayed_bug(field_span, format!("couldn't normalize struct field `{unnormalized_ty}` when checking {tr} implementation", tr = tcx.def_path_str(trait_def_id)));
continue;
}
ocx.register_bound(
ObligationCause::dummy_with_span(field_ty_span),
param_env,
ty,
trait_def_id,
);
let errors = ocx.select_all_or_error();
if !errors.is_empty() {
infringing.push((field, ty, InfringingFieldsReason::Fulfill(errors)));
}
let outlives_env = OutlivesEnvironment::with_bounds(
param_env,
infcx.implied_bounds_tys(
param_env,
parent_cause.body_id,
&FxIndexSet::from_iter([self_type]),
),
);
let errors = infcx.resolve_regions(&outlives_env);
if !errors.is_empty() {
infringing.push((field, ty, InfringingFieldsReason::Regions(errors)));
}
}
}
if infringing.is_empty() { Ok(()) } else { Err(infringing) }
}