rustc_builtin_macros/deriving/
coerce_pointee.rs

1use ast::HasAttrs;
2use ast::ptr::P;
3use rustc_ast::mut_visit::MutVisitor;
4use rustc_ast::visit::BoundKind;
5use rustc_ast::{
6    self as ast, GenericArg, GenericBound, GenericParamKind, Generics, ItemKind, MetaItem,
7    TraitBoundModifiers, VariantData, WherePredicate,
8};
9use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
10use rustc_errors::E0802;
11use rustc_expand::base::{Annotatable, ExtCtxt};
12use rustc_macros::Diagnostic;
13use rustc_span::{Ident, Span, Symbol, sym};
14use thin_vec::{ThinVec, thin_vec};
15
16use crate::errors;
17
18macro_rules! path {
19    ($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] }
20}
21
22pub(crate) fn expand_deriving_coerce_pointee(
23    cx: &ExtCtxt<'_>,
24    span: Span,
25    _mitem: &MetaItem,
26    item: &Annotatable,
27    push: &mut dyn FnMut(Annotatable),
28    _is_const: bool,
29) {
30    item.visit_with(&mut DetectNonGenericPointeeAttr { cx });
31
32    let (name_ident, generics) = if let Annotatable::Item(aitem) = item
33        && let ItemKind::Struct(struct_data, g) = &aitem.kind
34    {
35        if !matches!(
36            struct_data,
37            VariantData::Struct { fields, recovered: _ } | VariantData::Tuple(fields, _)
38                if !fields.is_empty())
39        {
40            cx.dcx().emit_err(RequireOneField { span });
41            return;
42        }
43        (aitem.ident, g)
44    } else {
45        cx.dcx().emit_err(RequireTransparent { span });
46        return;
47    };
48
49    // Convert generic parameters (from the struct) into generic args.
50    let self_params: Vec<_> = generics
51        .params
52        .iter()
53        .map(|p| match p.kind {
54            GenericParamKind::Lifetime => GenericArg::Lifetime(cx.lifetime(p.span(), p.ident)),
55            GenericParamKind::Type { .. } => GenericArg::Type(cx.ty_ident(p.span(), p.ident)),
56            GenericParamKind::Const { .. } => GenericArg::Const(cx.const_ident(p.span(), p.ident)),
57        })
58        .collect();
59    let type_params: Vec<_> = generics
60        .params
61        .iter()
62        .enumerate()
63        .filter_map(|(idx, p)| {
64            if let GenericParamKind::Type { .. } = p.kind {
65                Some((idx, p.span(), p.attrs().iter().any(|attr| attr.has_name(sym::pointee))))
66            } else {
67                None
68            }
69        })
70        .collect();
71
72    let pointee_param_idx = if type_params.is_empty() {
73        // `#[derive(CoercePointee)]` requires at least one generic type on the target `struct`
74        cx.dcx().emit_err(RequireOneGeneric { span });
75        return;
76    } else if type_params.len() == 1 {
77        // Regardless of the only type param being designed as `#[pointee]` or not, we can just use it as such
78        type_params[0].0
79    } else {
80        let mut pointees = type_params
81            .iter()
82            .filter_map(|&(idx, span, is_pointee)| is_pointee.then_some((idx, span)));
83        match (pointees.next(), pointees.next()) {
84            (Some((idx, _span)), None) => idx,
85            (None, _) => {
86                cx.dcx().emit_err(RequireOnePointee { span });
87                return;
88            }
89            (Some((_, one)), Some((_, another))) => {
90                cx.dcx().emit_err(TooManyPointees { one, another });
91                return;
92            }
93        }
94    };
95
96    // Create the type of `self`.
97    let path = cx.path_all(span, false, vec![name_ident], self_params.clone());
98    let self_type = cx.ty_path(path);
99
100    // Declare helper function that adds implementation blocks.
101    // FIXME(dingxiangfei2009): Investigate the set of attributes on target struct to be propagated to impls
102    let attrs = thin_vec![cx.attr_word(sym::automatically_derived, span),];
103    // # Validity assertion which will be checked later in `rustc_hir_analysis::coherence::builtins`.
104    {
105        let trait_path =
106            cx.path_all(span, true, path!(span, core::marker::CoercePointeeValidated), vec![]);
107        let trait_ref = cx.trait_ref(trait_path);
108        push(Annotatable::Item(
109            cx.item(
110                span,
111                Ident::empty(),
112                attrs.clone(),
113                ast::ItemKind::Impl(Box::new(ast::Impl {
114                    safety: ast::Safety::Default,
115                    polarity: ast::ImplPolarity::Positive,
116                    defaultness: ast::Defaultness::Final,
117                    constness: ast::Const::No,
118                    generics: Generics {
119                        params: generics
120                            .params
121                            .iter()
122                            .map(|p| match &p.kind {
123                                GenericParamKind::Lifetime => {
124                                    cx.lifetime_param(p.span(), p.ident, p.bounds.clone())
125                                }
126                                GenericParamKind::Type { default: _ } => {
127                                    cx.typaram(p.span(), p.ident, p.bounds.clone(), None)
128                                }
129                                GenericParamKind::Const { ty, kw_span: _, default: _ } => cx
130                                    .const_param(
131                                        p.span(),
132                                        p.ident,
133                                        p.bounds.clone(),
134                                        ty.clone(),
135                                        None,
136                                    ),
137                            })
138                            .collect(),
139                        where_clause: generics.where_clause.clone(),
140                        span: generics.span,
141                    },
142                    of_trait: Some(trait_ref),
143                    self_ty: self_type.clone(),
144                    items: ThinVec::new(),
145                })),
146            ),
147        ));
148    }
149    let mut add_impl_block = |generics, trait_symbol, trait_args| {
150        let mut parts = path!(span, core::ops);
151        parts.push(Ident::new(trait_symbol, span));
152        let trait_path = cx.path_all(span, true, parts, trait_args);
153        let trait_ref = cx.trait_ref(trait_path);
154        let item = cx.item(
155            span,
156            Ident::empty(),
157            attrs.clone(),
158            ast::ItemKind::Impl(Box::new(ast::Impl {
159                safety: ast::Safety::Default,
160                polarity: ast::ImplPolarity::Positive,
161                defaultness: ast::Defaultness::Final,
162                constness: ast::Const::No,
163                generics,
164                of_trait: Some(trait_ref),
165                self_ty: self_type.clone(),
166                items: ThinVec::new(),
167            })),
168        );
169        push(Annotatable::Item(item));
170    };
171
172    // Create unsized `self`, that is, one where the `#[pointee]` type arg is replaced with `__S`. For
173    // example, instead of `MyType<'a, T>`, it will be `MyType<'a, __S>`.
174    let s_ty = cx.ty_ident(span, Ident::new(sym::__S, span));
175    let mut alt_self_params = self_params;
176    alt_self_params[pointee_param_idx] = GenericArg::Type(s_ty.clone());
177    let alt_self_type = cx.ty_path(cx.path_all(span, false, vec![name_ident], alt_self_params));
178
179    // # Add `Unsize<__S>` bound to `#[pointee]` at the generic parameter location
180    //
181    // Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
182    let mut impl_generics = generics.clone();
183    let pointee_ty_ident = generics.params[pointee_param_idx].ident;
184    let mut self_bounds;
185    {
186        let pointee = &mut impl_generics.params[pointee_param_idx];
187        self_bounds = pointee.bounds.clone();
188        if !contains_maybe_sized_bound(&self_bounds)
189            && !contains_maybe_sized_bound_on_pointee(
190                &generics.where_clause.predicates,
191                pointee_ty_ident.name,
192            )
193        {
194            cx.dcx().emit_err(RequiresMaybeSized {
195                span: pointee_ty_ident.span,
196                name: pointee_ty_ident,
197            });
198            return;
199        }
200        let arg = GenericArg::Type(s_ty.clone());
201        let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]);
202        pointee.bounds.push(cx.trait_bound(unsize, false));
203        // Drop `#[pointee]` attribute since it should not be recognized outside `derive(CoercePointee)`
204        pointee.attrs.retain(|attr| !attr.has_name(sym::pointee));
205    }
206
207    // # Rewrite generic parameter bounds
208    // For each bound `U: ..` in `struct<U: ..>`, make a new bound with `__S` in place of `#[pointee]`
209    // Example:
210    // ```
211    // struct<
212    //     U: Trait<T>,
213    //     #[pointee] T: Trait<T> + ?Sized,
214    //     V: Trait<T>> ...
215    // ```
216    // ... generates this `impl` generic parameters
217    // ```
218    // impl<
219    //     U: Trait<T> + Trait<__S>,
220    //     T: Trait<T> + ?Sized + Unsize<__S>, // (**)
221    //     __S: Trait<__S> + ?Sized, // (*)
222    //     V: Trait<T> + Trait<__S>> ...
223    // ```
224    // The new bound marked with (*) has to be done separately.
225    // See next section
226    for (idx, (params, orig_params)) in
227        impl_generics.params.iter_mut().zip(&generics.params).enumerate()
228    {
229        // Default type parameters are rejected for `impl` block.
230        // We should drop them now.
231        match &mut params.kind {
232            ast::GenericParamKind::Const { default, .. } => *default = None,
233            ast::GenericParamKind::Type { default } => *default = None,
234            ast::GenericParamKind::Lifetime => {}
235        }
236        // We CANNOT rewrite `#[pointee]` type parameter bounds.
237        // This has been set in stone. (**)
238        // So we skip over it.
239        // Otherwise, we push extra bounds involving `__S`.
240        if idx != pointee_param_idx {
241            for bound in &orig_params.bounds {
242                let mut bound = bound.clone();
243                let mut substitution = TypeSubstitution {
244                    from_name: pointee_ty_ident.name,
245                    to_ty: &s_ty,
246                    rewritten: false,
247                };
248                substitution.visit_param_bound(&mut bound, BoundKind::Bound);
249                if substitution.rewritten {
250                    // We found use of `#[pointee]` somewhere,
251                    // so we make a new bound using `__S` in place of `#[pointee]`
252                    params.bounds.push(bound);
253                }
254            }
255        }
256    }
257
258    // # Insert `__S` type parameter
259    //
260    // We now insert `__S` with the missing bounds marked with (*) above.
261    // We should also write the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`.
262    {
263        let mut substitution =
264            TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false };
265        for bound in &mut self_bounds {
266            substitution.visit_param_bound(bound, BoundKind::Bound);
267        }
268    }
269
270    // # Rewrite `where` clauses
271    //
272    // Move on to `where` clauses.
273    // Example:
274    // ```
275    // struct MyPointer<#[pointee] T, ..>
276    // where
277    //   U: Trait<V> + Trait<T>,
278    //   Companion<T>: Trait<T>,
279    //   T: Trait<T> + ?Sized,
280    // { .. }
281    // ```
282    // ... will have a impl prelude like so
283    // ```
284    // impl<..> ..
285    // where
286    //   U: Trait<V> + Trait<T>,
287    //   U: Trait<__S>,
288    //   Companion<T>: Trait<T>,
289    //   Companion<__S>: Trait<__S>,
290    //   T: Trait<T> + ?Sized,
291    //   __S: Trait<__S> + ?Sized,
292    // ```
293    //
294    // We should also write a few new `where` bounds from `#[pointee] T` to `__S`
295    // as well as any bound that indirectly involves the `#[pointee] T` type.
296    for predicate in &generics.where_clause.predicates {
297        if let ast::WherePredicateKind::BoundPredicate(bound) = &predicate.kind {
298            let mut substitution = TypeSubstitution {
299                from_name: pointee_ty_ident.name,
300                to_ty: &s_ty,
301                rewritten: false,
302            };
303            let mut kind = ast::WherePredicateKind::BoundPredicate(bound.clone());
304            substitution.visit_where_predicate_kind(&mut kind);
305            if substitution.rewritten {
306                let predicate = ast::WherePredicate {
307                    attrs: predicate.attrs.clone(),
308                    kind,
309                    span: predicate.span,
310                    id: ast::DUMMY_NODE_ID,
311                    is_placeholder: false,
312                };
313                impl_generics.where_clause.predicates.push(predicate);
314            }
315        }
316    }
317
318    let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None);
319    impl_generics.params.insert(pointee_param_idx + 1, extra_param);
320
321    // Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`.
322    let gen_args = vec![GenericArg::Type(alt_self_type)];
323    add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone());
324    add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args);
325}
326
327fn contains_maybe_sized_bound_on_pointee(predicates: &[WherePredicate], pointee: Symbol) -> bool {
328    for bound in predicates {
329        if let ast::WherePredicateKind::BoundPredicate(bound) = &bound.kind
330            && bound.bounded_ty.kind.is_simple_path().is_some_and(|name| name == pointee)
331        {
332            for bound in &bound.bounds {
333                if is_maybe_sized_bound(bound) {
334                    return true;
335                }
336            }
337        }
338    }
339    false
340}
341
342fn is_maybe_sized_bound(bound: &GenericBound) -> bool {
343    if let GenericBound::Trait(trait_ref) = bound
344        && let TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. } =
345            trait_ref.modifiers
346        && is_sized_marker(&trait_ref.trait_ref.path)
347    {
348        true
349    } else {
350        false
351    }
352}
353
354fn contains_maybe_sized_bound(bounds: &[GenericBound]) -> bool {
355    bounds.iter().any(is_maybe_sized_bound)
356}
357
358fn path_segment_is_exact_match(path_segments: &[ast::PathSegment], syms: &[Symbol]) -> bool {
359    path_segments.iter().zip(syms).all(|(segment, &symbol)| segment.ident.name == symbol)
360}
361
362fn is_sized_marker(path: &ast::Path) -> bool {
363    const CORE_UNSIZE: [Symbol; 3] = [sym::core, sym::marker, sym::Sized];
364    const STD_UNSIZE: [Symbol; 3] = [sym::std, sym::marker, sym::Sized];
365    if path.segments.len() == 4 && path.is_global() {
366        path_segment_is_exact_match(&path.segments[1..], &CORE_UNSIZE)
367            || path_segment_is_exact_match(&path.segments[1..], &STD_UNSIZE)
368    } else if path.segments.len() == 3 {
369        path_segment_is_exact_match(&path.segments, &CORE_UNSIZE)
370            || path_segment_is_exact_match(&path.segments, &STD_UNSIZE)
371    } else {
372        *path == sym::Sized
373    }
374}
375
376struct TypeSubstitution<'a> {
377    from_name: Symbol,
378    to_ty: &'a ast::Ty,
379    rewritten: bool,
380}
381
382impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> {
383    fn visit_ty(&mut self, ty: &mut P<ast::Ty>) {
384        if let Some(name) = ty.kind.is_simple_path()
385            && name == self.from_name
386        {
387            **ty = self.to_ty.clone();
388            self.rewritten = true;
389        } else {
390            ast::mut_visit::walk_ty(self, ty);
391        }
392    }
393
394    fn visit_where_predicate_kind(&mut self, kind: &mut ast::WherePredicateKind) {
395        match kind {
396            rustc_ast::WherePredicateKind::BoundPredicate(bound) => {
397                bound
398                    .bound_generic_params
399                    .flat_map_in_place(|param| self.flat_map_generic_param(param));
400                self.visit_ty(&mut bound.bounded_ty);
401                for bound in &mut bound.bounds {
402                    self.visit_param_bound(bound, BoundKind::Bound)
403                }
404            }
405            rustc_ast::WherePredicateKind::RegionPredicate(_)
406            | rustc_ast::WherePredicateKind::EqPredicate(_) => {}
407        }
408    }
409}
410
411struct DetectNonGenericPointeeAttr<'a, 'b> {
412    cx: &'a ExtCtxt<'b>,
413}
414
415impl<'a, 'b> rustc_ast::visit::Visitor<'a> for DetectNonGenericPointeeAttr<'a, 'b> {
416    fn visit_attribute(&mut self, attr: &'a rustc_ast::Attribute) -> Self::Result {
417        if attr.has_name(sym::pointee) {
418            self.cx.dcx().emit_err(errors::NonGenericPointee { span: attr.span });
419        }
420    }
421
422    fn visit_generic_param(&mut self, param: &'a rustc_ast::GenericParam) -> Self::Result {
423        let mut error_on_pointee = AlwaysErrorOnGenericParam { cx: self.cx };
424
425        match &param.kind {
426            GenericParamKind::Type { default } => {
427                // The `default` may end up containing a block expression.
428                // The problem is block expressions  may define structs with generics.
429                // A user may attach a #[pointee] attribute to one of these generics
430                // We want to catch that. The simple solution is to just
431                // always raise a `NonGenericPointee` error when this happens.
432                //
433                // This solution does reject valid rust programs but,
434                // such a code would have to, in order:
435                // - Define a smart pointer struct.
436                // - Somewhere in this struct definition use a type with a const generic argument.
437                // - Calculate this const generic in a expression block.
438                // - Define a new smart pointer type in this block.
439                // - Have this smart pointer type have more than 1 generic type.
440                // In this case, the inner smart pointer derive would be complaining that it
441                // needs a pointer attribute. Meanwhile, the outer macro would be complaining
442                // that we attached a #[pointee] to a generic type argument while helpfully
443                // informing the user that #[pointee] can only be attached to generic pointer arguments
444                rustc_ast::visit::visit_opt!(error_on_pointee, visit_ty, default);
445            }
446
447            GenericParamKind::Const { .. } | GenericParamKind::Lifetime => {
448                rustc_ast::visit::walk_generic_param(&mut error_on_pointee, param);
449            }
450        }
451    }
452
453    fn visit_ty(&mut self, t: &'a rustc_ast::Ty) -> Self::Result {
454        let mut error_on_pointee = AlwaysErrorOnGenericParam { cx: self.cx };
455        error_on_pointee.visit_ty(t)
456    }
457}
458
459struct AlwaysErrorOnGenericParam<'a, 'b> {
460    cx: &'a ExtCtxt<'b>,
461}
462
463impl<'a, 'b> rustc_ast::visit::Visitor<'a> for AlwaysErrorOnGenericParam<'a, 'b> {
464    fn visit_attribute(&mut self, attr: &'a rustc_ast::Attribute) -> Self::Result {
465        if attr.has_name(sym::pointee) {
466            self.cx.dcx().emit_err(errors::NonGenericPointee { span: attr.span });
467        }
468    }
469}
470
471#[derive(Diagnostic)]
472#[diag(builtin_macros_coerce_pointee_requires_transparent, code = E0802)]
473struct RequireTransparent {
474    #[primary_span]
475    span: Span,
476}
477
478#[derive(Diagnostic)]
479#[diag(builtin_macros_coerce_pointee_requires_one_field, code = E0802)]
480struct RequireOneField {
481    #[primary_span]
482    span: Span,
483}
484
485#[derive(Diagnostic)]
486#[diag(builtin_macros_coerce_pointee_requires_one_generic, code = E0802)]
487struct RequireOneGeneric {
488    #[primary_span]
489    span: Span,
490}
491
492#[derive(Diagnostic)]
493#[diag(builtin_macros_coerce_pointee_requires_one_pointee, code = E0802)]
494struct RequireOnePointee {
495    #[primary_span]
496    span: Span,
497}
498
499#[derive(Diagnostic)]
500#[diag(builtin_macros_coerce_pointee_too_many_pointees, code = E0802)]
501struct TooManyPointees {
502    #[primary_span]
503    one: Span,
504    #[label]
505    another: Span,
506}
507
508#[derive(Diagnostic)]
509#[diag(builtin_macros_coerce_pointee_requires_maybe_sized, code = E0802)]
510struct RequiresMaybeSized {
511    #[primary_span]
512    span: Span,
513    name: Ident,
514}