rustc_trait_selection/traits/
misc.rs

1//! Miscellaneous type-system utilities that are too small to deserve their own modules.
2
3use std::assert_matches::assert_matches;
4
5use hir::LangItem;
6use rustc_ast::Mutability;
7use rustc_hir as hir;
8use rustc_infer::infer::{RegionResolutionError, TyCtxtInferExt};
9use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt, TypeVisitableExt, TypingMode};
10
11use crate::regions::InferCtxtRegionExt;
12use crate::traits::{self, FulfillmentError, ObligationCause};
13
14pub enum CopyImplementationError<'tcx> {
15    InfringingFields(Vec<(&'tcx ty::FieldDef, Ty<'tcx>, InfringingFieldsReason<'tcx>)>),
16    NotAnAdt,
17    HasDestructor,
18    HasUnsafeFields,
19}
20
21pub enum ConstParamTyImplementationError<'tcx> {
22    UnsizedConstParamsFeatureRequired,
23    InvalidInnerTyOfBuiltinTy(Vec<(Ty<'tcx>, InfringingFieldsReason<'tcx>)>),
24    InfrigingFields(Vec<(&'tcx ty::FieldDef, Ty<'tcx>, InfringingFieldsReason<'tcx>)>),
25    NotAnAdtOrBuiltinAllowed,
26}
27
28pub enum InfringingFieldsReason<'tcx> {
29    Fulfill(Vec<FulfillmentError<'tcx>>),
30    Regions(Vec<RegionResolutionError<'tcx>>),
31}
32
33/// Checks that the fields of the type (an ADT) all implement copy.
34///
35/// If fields don't implement copy, return an error containing a list of
36/// those violating fields.
37///
38/// If it's not an ADT, int ty, `bool`, float ty, `char`, raw pointer, `!`,
39/// a reference or an array returns `Err(NotAnAdt)`.
40///
41/// If the impl is `Safe`, `self_type` must not have unsafe fields. When used to
42/// generate suggestions in lints, `Safe` should be supplied so as to not
43/// suggest implementing `Copy` for types with unsafe fields.
44pub fn type_allowed_to_implement_copy<'tcx>(
45    tcx: TyCtxt<'tcx>,
46    param_env: ty::ParamEnv<'tcx>,
47    self_type: Ty<'tcx>,
48    parent_cause: ObligationCause<'tcx>,
49    impl_safety: hir::Safety,
50) -> Result<(), CopyImplementationError<'tcx>> {
51    let (adt, args) = match self_type.kind() {
52        // These types used to have a builtin impl.
53        // Now libcore provides that impl.
54        ty::Uint(_)
55        | ty::Int(_)
56        | ty::Bool
57        | ty::Float(_)
58        | ty::Char
59        | ty::RawPtr(..)
60        | ty::Never
61        | ty::Ref(_, _, hir::Mutability::Not)
62        | ty::Array(..) => return Ok(()),
63
64        &ty::Adt(adt, args) => (adt, args),
65
66        _ => return Err(CopyImplementationError::NotAnAdt),
67    };
68
69    all_fields_implement_trait(
70        tcx,
71        param_env,
72        self_type,
73        adt,
74        args,
75        parent_cause,
76        hir::LangItem::Copy,
77    )
78    .map_err(CopyImplementationError::InfringingFields)?;
79
80    if adt.has_dtor(tcx) {
81        return Err(CopyImplementationError::HasDestructor);
82    }
83
84    if impl_safety.is_safe() && self_type.has_unsafe_fields() {
85        return Err(CopyImplementationError::HasUnsafeFields);
86    }
87
88    Ok(())
89}
90
91/// Checks that the fields of the type (an ADT) all implement `(Unsized?)ConstParamTy`.
92///
93/// If fields don't implement `(Unsized?)ConstParamTy`, return an error containing a list of
94/// those violating fields.
95///
96/// If it's not an ADT, int ty, `bool` or `char`, returns `Err(NotAnAdtOrBuiltinAllowed)`.
97pub fn type_allowed_to_implement_const_param_ty<'tcx>(
98    tcx: TyCtxt<'tcx>,
99    param_env: ty::ParamEnv<'tcx>,
100    self_type: Ty<'tcx>,
101    lang_item: LangItem,
102    parent_cause: ObligationCause<'tcx>,
103) -> Result<(), ConstParamTyImplementationError<'tcx>> {
104    assert_matches!(lang_item, LangItem::ConstParamTy | LangItem::UnsizedConstParamTy);
105
106    let inner_tys: Vec<_> = match *self_type.kind() {
107        // Trivially okay as these types are all:
108        // - Sized
109        // - Contain no nested types
110        // - Have structural equality
111        ty::Uint(_) | ty::Int(_) | ty::Bool | ty::Char => return Ok(()),
112
113        // Handle types gated under `feature(unsized_const_params)`
114        // FIXME(unsized_const_params): Make `const N: [u8]` work then forbid references
115        ty::Slice(inner_ty) | ty::Ref(_, inner_ty, Mutability::Not)
116            if lang_item == LangItem::UnsizedConstParamTy =>
117        {
118            vec![inner_ty]
119        }
120        ty::Str if lang_item == LangItem::UnsizedConstParamTy => {
121            vec![Ty::new_slice(tcx, tcx.types.u8)]
122        }
123        ty::Str | ty::Slice(..) | ty::Ref(_, _, Mutability::Not) => {
124            return Err(ConstParamTyImplementationError::UnsizedConstParamsFeatureRequired);
125        }
126
127        ty::Array(inner_ty, _) => vec![inner_ty],
128
129        // `str` morally acts like a newtype around `[u8]`
130        ty::Tuple(inner_tys) => inner_tys.into_iter().collect(),
131
132        ty::Adt(adt, args) if adt.is_enum() || adt.is_struct() => {
133            all_fields_implement_trait(
134                tcx,
135                param_env,
136                self_type,
137                adt,
138                args,
139                parent_cause.clone(),
140                lang_item,
141            )
142            .map_err(ConstParamTyImplementationError::InfrigingFields)?;
143
144            vec![]
145        }
146
147        _ => return Err(ConstParamTyImplementationError::NotAnAdtOrBuiltinAllowed),
148    };
149
150    let mut infringing_inner_tys = vec![];
151    for inner_ty in inner_tys {
152        // We use an ocx per inner ty for better diagnostics
153        let infcx = tcx.infer_ctxt().build(TypingMode::non_body_analysis());
154        let ocx = traits::ObligationCtxt::new_with_diagnostics(&infcx);
155
156        ocx.register_bound(
157            parent_cause.clone(),
158            param_env,
159            inner_ty,
160            tcx.require_lang_item(lang_item, Some(parent_cause.span)),
161        );
162
163        let errors = ocx.select_all_or_error();
164        if !errors.is_empty() {
165            infringing_inner_tys.push((inner_ty, InfringingFieldsReason::Fulfill(errors)));
166            continue;
167        }
168
169        // Check regions assuming the self type of the impl is WF
170        let errors = infcx.resolve_regions(parent_cause.body_id, param_env, [self_type]);
171        if !errors.is_empty() {
172            infringing_inner_tys.push((inner_ty, InfringingFieldsReason::Regions(errors)));
173            continue;
174        }
175    }
176
177    if !infringing_inner_tys.is_empty() {
178        return Err(ConstParamTyImplementationError::InvalidInnerTyOfBuiltinTy(
179            infringing_inner_tys,
180        ));
181    }
182
183    Ok(())
184}
185
186/// Check that all fields of a given `adt` implement `lang_item` trait.
187pub fn all_fields_implement_trait<'tcx>(
188    tcx: TyCtxt<'tcx>,
189    param_env: ty::ParamEnv<'tcx>,
190    self_type: Ty<'tcx>,
191    adt: AdtDef<'tcx>,
192    args: ty::GenericArgsRef<'tcx>,
193    parent_cause: ObligationCause<'tcx>,
194    lang_item: LangItem,
195) -> Result<(), Vec<(&'tcx ty::FieldDef, Ty<'tcx>, InfringingFieldsReason<'tcx>)>> {
196    let trait_def_id = tcx.require_lang_item(lang_item, Some(parent_cause.span));
197
198    let mut infringing = Vec::new();
199    for variant in adt.variants() {
200        for field in &variant.fields {
201            // Do this per-field to get better error messages.
202            let infcx = tcx.infer_ctxt().build(TypingMode::non_body_analysis());
203            let ocx = traits::ObligationCtxt::new_with_diagnostics(&infcx);
204
205            let unnormalized_ty = field.ty(tcx, args);
206            if unnormalized_ty.references_error() {
207                continue;
208            }
209
210            let field_span = tcx.def_span(field.did);
211            let field_ty_span = match tcx.hir().get_if_local(field.did) {
212                Some(hir::Node::Field(field_def)) => field_def.ty.span,
213                _ => field_span,
214            };
215
216            // FIXME(compiler-errors): This gives us better spans for bad
217            // projection types like in issue-50480.
218            // If the ADT has args, point to the cause we are given.
219            // If it does not, then this field probably doesn't normalize
220            // to begin with, and point to the bad field's span instead.
221            let normalization_cause = if field
222                .ty(tcx, traits::GenericArgs::identity_for_item(tcx, adt.did()))
223                .has_non_region_param()
224            {
225                parent_cause.clone()
226            } else {
227                ObligationCause::dummy_with_span(field_ty_span)
228            };
229            let ty = ocx.normalize(&normalization_cause, param_env, unnormalized_ty);
230            let normalization_errors = ocx.select_where_possible();
231
232            // NOTE: The post-normalization type may also reference errors,
233            // such as when we project to a missing type or we have a mismatch
234            // between expected and found const-generic types. Don't report an
235            // additional copy error here, since it's not typically useful.
236            if !normalization_errors.is_empty() || ty.references_error() {
237                tcx.dcx().span_delayed_bug(field_span, format!("couldn't normalize struct field `{unnormalized_ty}` when checking {tr} implementation", tr = tcx.def_path_str(trait_def_id)));
238                continue;
239            }
240
241            ocx.register_bound(
242                ObligationCause::dummy_with_span(field_ty_span),
243                param_env,
244                ty,
245                trait_def_id,
246            );
247            let errors = ocx.select_all_or_error();
248            if !errors.is_empty() {
249                infringing.push((field, ty, InfringingFieldsReason::Fulfill(errors)));
250            }
251
252            // Check regions assuming the self type of the impl is WF
253            let errors = infcx.resolve_regions(parent_cause.body_id, param_env, [self_type]);
254            if !errors.is_empty() {
255                infringing.push((field, ty, InfringingFieldsReason::Regions(errors)));
256            }
257        }
258    }
259
260    if infringing.is_empty() { Ok(()) } else { Err(infringing) }
261}