Skip to main content

rustc_type_ir_macros/
lib.rs

1use indexmap::IndexSet;
2use quote::{ToTokens, quote};
3use syn::visit_mut::VisitMut;
4use syn::{Attribute, parse_quote};
5use synstructure::decl_derive;
6
7decl_derive!(
8    [TypeVisitable_Generic, attributes(type_visitable)] => type_visitable_derive
9);
10decl_derive!(
11    [TypeFoldable_Generic, attributes(type_foldable)] => type_foldable_derive
12);
13decl_derive!(
14    [Lift_Generic, attributes(lift)] => lift_derive
15);
16#[cfg(not(feature = "nightly"))]
17decl_derive!(
18    [GenericTypeVisitable] => customizable_type_visitable_derive
19);
20
21struct TransformedTy {
22    ty: syn::Type,
23    generic_parameter_bounds: IndexSet<syn::Ident>,
24}
25
26enum TypeParameterPath {
27    Interner,
28    GenericParameter(syn::Ident),
29}
30
31enum TypeParameterTransform {
32    Continue,
33    Stop,
34}
35
36type TypeParameterVisitor =
37    fn(TypeParameterPath, &mut syn::TypePath, &mut IndexSet<syn::Ident>) -> TypeParameterTransform;
38
39fn has_ignore_attr(attrs: &[Attribute], name: &'static str, meta: &'static str) -> bool {
40    let mut ignored = false;
41    attrs.iter().for_each(|attr| {
42        if !attr.path().is_ident(name) {
43            return;
44        }
45        let _ = attr.parse_nested_meta(|nested| {
46            if nested.path.is_ident(meta) {
47                ignored = true;
48            }
49            Ok(())
50        });
51    });
52
53    ignored
54}
55
56fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
57    if let syn::Data::Union(_) = s.ast().data {
58        panic!("cannot derive on union")
59    }
60
61    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
62        s.add_impl_generic(parse_quote! { I });
63    }
64
65    s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_visitable", "ignore"));
66
67    s.add_where_predicate(parse_quote! { I: Interner });
68    s.add_bounds(synstructure::AddBounds::Fields);
69    let body_visit = s.each(|bind| {
70        quote! {
71            match ::rustc_type_ir::VisitorResult::branch(
72                ::rustc_type_ir::TypeVisitable::visit_with(#bind, __visitor)
73            ) {
74                ::core::ops::ControlFlow::Continue(()) => {},
75                ::core::ops::ControlFlow::Break(r) => {
76                    return ::rustc_type_ir::VisitorResult::from_residual(r);
77                },
78            }
79        }
80    });
81    s.bind_with(|_| synstructure::BindStyle::Move);
82
83    s.bound_impl(
84        quote!(::rustc_type_ir::TypeVisitable<I>),
85        quote! {
86            fn visit_with<__V: ::rustc_type_ir::TypeVisitor<I>>(
87                &self,
88                __visitor: &mut __V
89            ) -> __V::Result {
90                match *self { #body_visit }
91                <__V::Result as ::rustc_type_ir::VisitorResult>::output()
92            }
93        },
94    )
95}
96
97fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
98    if let syn::Data::Union(_) = s.ast().data {
99        panic!("cannot derive on union")
100    }
101
102    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
103        s.add_impl_generic(parse_quote! { I });
104    }
105
106    s.add_where_predicate(parse_quote! { I: Interner });
107    s.add_bounds(synstructure::AddBounds::Fields);
108    let generic_parameters =
109        s.ast().generics.type_params().map(|ty| ty.ident.clone()).collect::<Vec<_>>();
110    let mut generic_parameter_bounds = IndexSet::new();
111    s.bind_with(|_| synstructure::BindStyle::Move);
112    let body_try_fold = s.each_variant(|vi| {
113        let bindings = vi.bindings();
114        vi.construct(|_, index| {
115            let bind = &bindings[index];
116
117            // retain value of fields with #[type_foldable(identity)]
118            if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
119                bind.to_token_stream()
120            } else {
121                for param in
122                    type_foldable_generic_parameters(bind.ast().ty.clone(), &generic_parameters)
123                {
124                    generic_parameter_bounds.insert(param);
125                }
126
127                quote! {
128                    ::rustc_type_ir::TypeFoldable::try_fold_with(#bind, __folder)?
129                }
130            }
131        })
132    });
133
134    let body_fold = s.each_variant(|vi| {
135        let bindings = vi.bindings();
136        vi.construct(|_, index| {
137            let bind = &bindings[index];
138
139            // retain value of fields with #[type_foldable(identity)]
140            if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
141                bind.to_token_stream()
142            } else {
143                quote! {
144                    ::rustc_type_ir::TypeFoldable::fold_with(#bind, __folder)
145                }
146            }
147        })
148    });
149
150    // We filter fields which get ignored and don't require them to implement
151    // `TypeFoldable`. We do so after generating `body_fold` as we still need
152    // to generate code for them.
153    s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_foldable", "identity"));
154    s.add_bounds(synstructure::AddBounds::Fields);
155    for param in generic_parameter_bounds {
156        s.add_where_predicate(parse_quote! { #param: ::rustc_type_ir::TypeFoldable<I> });
157    }
158    s.bound_impl(
159        quote!(::rustc_type_ir::TypeFoldable<I>),
160        quote! {
161            fn try_fold_with<__F: ::rustc_type_ir::FallibleTypeFolder<I>>(
162                self,
163                __folder: &mut __F
164            ) -> Result<Self, __F::Error> {
165                Ok(match self { #body_try_fold })
166            }
167
168            fn fold_with<__F: ::rustc_type_ir::TypeFolder<I>>(
169                self,
170                __folder: &mut __F
171            ) -> Self {
172                match self { #body_fold }
173            }
174        },
175    )
176}
177
178fn type_foldable_generic_parameters(
179    ty: syn::Type,
180    generic_parameters: &[syn::Ident],
181) -> IndexSet<syn::Ident> {
182    transform_type_parameters(ty, generic_parameters, |path, _, generic_parameter_bounds| {
183        if let TypeParameterPath::GenericParameter(param) = path {
184            generic_parameter_bounds.insert(param);
185        }
186        TypeParameterTransform::Continue
187    })
188    .generic_parameter_bounds
189}
190
191/// `Lift_Generic` is specialised for structs/enums parameterised by an interner
192/// `I: Interner`. It derives `Lift<J>` by rewriting interner associated types
193/// from `I::Assoc` to `J::Assoc`. The required associated type lift bounds are
194/// supplied by `I: LiftInto<J>`.
195///
196/// Ordinary generic parameters still get explicit `Lift<J>` bounds. Interner
197/// independent fields must either implement `Lift` manually or use
198/// `#[lift(identity)]`.
199///
200/// `PhantomData` is a special case that occurs enough in the code base to be
201/// handled here directly. We collect any generic bounds from the type then
202/// produce another `PhantomData`.
203fn lift_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
204    if let syn::Data::Union(_) = s.ast().data {
205        panic!("cannot derive on union")
206    }
207
208    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
209        s.add_impl_generic(parse_quote! { I });
210    }
211
212    s.add_bounds(synstructure::AddBounds::None);
213    s.add_impl_generic(parse_quote! { J });
214    s.add_where_predicate(parse_quote! { J: Interner });
215    s.add_where_predicate(parse_quote! { I: ::rustc_type_ir::LiftInto<J> });
216
217    let generic_parameters =
218        s.ast().generics.type_params().map(|ty| ty.ident.clone()).collect::<Vec<_>>();
219
220    let mut wc = vec![];
221    s.bind_with(|_| synstructure::BindStyle::Move);
222    let body_fold = s.each_variant(|vi| {
223        let bindings = vi.bindings();
224        vi.construct(|field, index| {
225            let ty = field.ty.clone();
226            let bind = &bindings[index];
227            // Allow field to be ignored from lift
228            if has_ignore_attr(&field.attrs, "lift", "identity") {
229                return bind.to_token_stream();
230            }
231
232            let lifted = lift(ty.clone(), &generic_parameters);
233
234            // Field types involving ordinary generic parameters still need
235            // explicit bounds for those parameters, e.g. `Binder<I, T>` needs
236            // `T: Lift<J>` so its own derived `Lift` impl applies. Interner
237            // associated types are covered by `I: LiftInto<J>`.
238            for param in lifted.generic_parameter_bounds {
239                wc.push(parse_quote! { #param: ::rustc_type_ir::lift::Lift<J> });
240            }
241
242            if is_type_phantom(&ty) {
243                return quote! {
244                    PhantomData
245                };
246            }
247
248            quote! {
249                #bind.lift_to_interner(interner)
250            }
251        })
252    });
253    for wc in wc {
254        s.add_where_predicate(wc);
255    }
256
257    let (_, ty_generics, _) = s.ast().generics.split_for_impl();
258    let name = s.ast().ident.clone();
259    let self_ty: syn::Type = parse_quote! { #name #ty_generics };
260    let lifted = lift(self_ty, &generic_parameters);
261    let lifted_ty = lifted.ty;
262
263    s.bound_impl(
264        quote!(::rustc_type_ir::lift::Lift<J>),
265        quote! {
266            type Lifted = #lifted_ty;
267
268            fn lift_to_interner(
269                self,
270                interner: J,
271            ) -> Self::Lifted {
272                match self { #body_fold }
273            }
274        },
275    )
276}
277
278fn get_first_path_segment(ty: &syn::Type) -> Option<&syn::PathSegment> {
279    if let syn::Type::Path(ty) = ty
280        && ty.path.segments.len() == 1
281    {
282        ty.path.segments.first()
283    } else {
284        None
285    }
286}
287
288/// Return if the type is `PhantomData`
289fn is_type_phantom(ty: &syn::Type) -> bool {
290    get_first_path_segment(ty).is_some_and(|segment| segment.ident == "PhantomData")
291}
292
293fn lift(ty: syn::Type, generic_parameters: &[syn::Ident]) -> TransformedTy {
294    transform_type_parameters(ty, generic_parameters, |path, ty, generic_parameter_bounds| {
295        match path {
296            TypeParameterPath::Interner => {
297                *ty.path.segments.first_mut().unwrap() = parse_quote! { J };
298                TypeParameterTransform::Continue
299            }
300            TypeParameterPath::GenericParameter(param) => {
301                generic_parameter_bounds.insert(param.clone());
302                *ty = parse_quote! { <#param as ::rustc_type_ir::lift::Lift<J>>::Lifted };
303                TypeParameterTransform::Stop
304            }
305        }
306    })
307}
308
309fn transform_type_parameters(
310    mut ty: syn::Type,
311    generic_parameters: &[syn::Ident],
312    visit: TypeParameterVisitor,
313) -> TransformedTy {
314    struct TypeParameterTransformer<'a> {
315        generic_parameters: &'a [syn::Ident],
316        generic_parameter_bounds: IndexSet<syn::Ident>,
317        visit: TypeParameterVisitor,
318    }
319
320    impl VisitMut for TypeParameterTransformer<'_> {
321        fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) {
322            let path = if i.qself.is_none() {
323                let segments_len = i.path.segments.len();
324                i.path.segments.first().and_then(|first| {
325                    if first.ident == "I" {
326                        Some(TypeParameterPath::Interner)
327                    } else if segments_len == 1
328                        && matches!(first.arguments, syn::PathArguments::None)
329                        && self.generic_parameters.contains(&first.ident)
330                    {
331                        Some(TypeParameterPath::GenericParameter(first.ident.clone()))
332                    } else {
333                        None
334                    }
335                })
336            } else {
337                None
338            };
339
340            if let Some(path) = path {
341                if let TypeParameterTransform::Stop =
342                    (self.visit)(path, i, &mut self.generic_parameter_bounds)
343                {
344                    return;
345                }
346            }
347
348            syn::visit_mut::visit_type_path_mut(self, i);
349        }
350    }
351
352    let mut visitor = TypeParameterTransformer {
353        generic_parameters,
354        generic_parameter_bounds: IndexSet::new(),
355        visit,
356    };
357    visitor.visit_type_mut(&mut ty);
358    TransformedTy { ty, generic_parameter_bounds: visitor.generic_parameter_bounds }
359}
360
361#[cfg(not(feature = "nightly"))]
362fn customizable_type_visitable_derive(
363    mut s: synstructure::Structure<'_>,
364) -> proc_macro2::TokenStream {
365    if let syn::Data::Union(_) = s.ast().data {
366        panic!("cannot derive on union")
367    }
368
369    s.add_impl_generic(parse_quote!(__V));
370    s.add_bounds(synstructure::AddBounds::Fields);
371    let body_visit = s.each(|bind| {
372        quote! {
373            ::rustc_type_ir::GenericTypeVisitable::<__V>::generic_visit_with(#bind, __visitor);
374        }
375    });
376    s.bind_with(|_| synstructure::BindStyle::Move);
377
378    s.bound_impl(
379        quote!(::rustc_type_ir::GenericTypeVisitable<__V>),
380        quote! {
381            fn generic_visit_with(
382                &self,
383                __visitor: &mut __V
384            ) {
385                match *self { #body_visit }
386            }
387        },
388    )
389}
390
391#[cfg(feature = "nightly")]
392#[proc_macro_derive(GenericTypeVisitable)]
393pub fn customizable_type_visitable_derive(_: proc_macro::TokenStream) -> proc_macro::TokenStream {
394    proc_macro::TokenStream::new()
395}