rustc_type_ir_macros/
lib.rs

1use quote::{ToTokens, quote};
2use syn::visit_mut::VisitMut;
3use syn::{Attribute, parse_quote};
4use synstructure::decl_derive;
5
6decl_derive!(
7    [TypeVisitable_Generic, attributes(type_visitable)] => type_visitable_derive
8);
9decl_derive!(
10    [TypeFoldable_Generic, attributes(type_foldable)] => type_foldable_derive
11);
12decl_derive!(
13    [Lift_Generic] => lift_derive
14);
15#[cfg(not(feature = "nightly"))]
16decl_derive!(
17    [GenericTypeVisitable] => customizable_type_visitable_derive
18);
19
20fn has_ignore_attr(attrs: &[Attribute], name: &'static str, meta: &'static str) -> bool {
21    let mut ignored = false;
22    attrs.iter().for_each(|attr| {
23        if !attr.path().is_ident(name) {
24            return;
25        }
26        let _ = attr.parse_nested_meta(|nested| {
27            if nested.path.is_ident(meta) {
28                ignored = true;
29            }
30            Ok(())
31        });
32    });
33
34    ignored
35}
36
37fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
38    if let syn::Data::Union(_) = s.ast().data {
39        panic!("cannot derive on union")
40    }
41
42    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
43        s.add_impl_generic(parse_quote! { I });
44    }
45
46    s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_visitable", "ignore"));
47
48    s.add_where_predicate(parse_quote! { I: Interner });
49    s.add_bounds(synstructure::AddBounds::Fields);
50    let body_visit = s.each(|bind| {
51        quote! {
52            match ::rustc_type_ir::VisitorResult::branch(
53                ::rustc_type_ir::TypeVisitable::visit_with(#bind, __visitor)
54            ) {
55                ::core::ops::ControlFlow::Continue(()) => {},
56                ::core::ops::ControlFlow::Break(r) => {
57                    return ::rustc_type_ir::VisitorResult::from_residual(r);
58                },
59            }
60        }
61    });
62    s.bind_with(|_| synstructure::BindStyle::Move);
63
64    s.bound_impl(
65        quote!(::rustc_type_ir::TypeVisitable<I>),
66        quote! {
67            fn visit_with<__V: ::rustc_type_ir::TypeVisitor<I>>(
68                &self,
69                __visitor: &mut __V
70            ) -> __V::Result {
71                match *self { #body_visit }
72                <__V::Result as ::rustc_type_ir::VisitorResult>::output()
73            }
74        },
75    )
76}
77
78fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
79    if let syn::Data::Union(_) = s.ast().data {
80        panic!("cannot derive on union")
81    }
82
83    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
84        s.add_impl_generic(parse_quote! { I });
85    }
86
87    s.add_where_predicate(parse_quote! { I: Interner });
88    s.add_bounds(synstructure::AddBounds::Fields);
89    s.bind_with(|_| synstructure::BindStyle::Move);
90    let body_try_fold = s.each_variant(|vi| {
91        let bindings = vi.bindings();
92        vi.construct(|_, index| {
93            let bind = &bindings[index];
94
95            // retain value of fields with #[type_foldable(identity)]
96            if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
97                bind.to_token_stream()
98            } else {
99                quote! {
100                    ::rustc_type_ir::TypeFoldable::try_fold_with(#bind, __folder)?
101                }
102            }
103        })
104    });
105
106    let body_fold = s.each_variant(|vi| {
107        let bindings = vi.bindings();
108        vi.construct(|_, index| {
109            let bind = &bindings[index];
110
111            // retain value of fields with #[type_foldable(identity)]
112            if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
113                bind.to_token_stream()
114            } else {
115                quote! {
116                    ::rustc_type_ir::TypeFoldable::fold_with(#bind, __folder)
117                }
118            }
119        })
120    });
121
122    // We filter fields which get ignored and don't require them to implement
123    // `TypeFoldable`. We do so after generating `body_fold` as we still need
124    // to generate code for them.
125    s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_foldable", "identity"));
126    s.add_bounds(synstructure::AddBounds::Fields);
127    s.bound_impl(
128        quote!(::rustc_type_ir::TypeFoldable<I>),
129        quote! {
130            fn try_fold_with<__F: ::rustc_type_ir::FallibleTypeFolder<I>>(
131                self,
132                __folder: &mut __F
133            ) -> Result<Self, __F::Error> {
134                Ok(match self { #body_try_fold })
135            }
136
137            fn fold_with<__F: ::rustc_type_ir::TypeFolder<I>>(
138                self,
139                __folder: &mut __F
140            ) -> Self {
141                match self { #body_fold }
142            }
143        },
144    )
145}
146
147fn lift_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
148    if let syn::Data::Union(_) = s.ast().data {
149        panic!("cannot derive on union")
150    }
151
152    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
153        s.add_impl_generic(parse_quote! { I });
154    }
155
156    s.add_bounds(synstructure::AddBounds::None);
157    s.add_where_predicate(parse_quote! { I: Interner });
158    s.add_impl_generic(parse_quote! { J });
159    s.add_where_predicate(parse_quote! { J: Interner });
160
161    let mut wc = vec![];
162    s.bind_with(|_| synstructure::BindStyle::Move);
163    let body_fold = s.each_variant(|vi| {
164        let bindings = vi.bindings();
165        vi.construct(|field, index| {
166            let ty = field.ty.clone();
167            let lifted_ty = lift(ty.clone());
168            wc.push(parse_quote! { #ty: ::rustc_type_ir::lift::Lift<J, Lifted = #lifted_ty> });
169            let bind = &bindings[index];
170            quote! {
171                #bind.lift_to_interner(interner)?
172            }
173        })
174    });
175    for wc in wc {
176        s.add_where_predicate(wc);
177    }
178
179    let (_, ty_generics, _) = s.ast().generics.split_for_impl();
180    let name = s.ast().ident.clone();
181    let self_ty: syn::Type = parse_quote! { #name #ty_generics };
182    let lifted_ty = lift(self_ty);
183
184    s.bound_impl(
185        quote!(::rustc_type_ir::lift::Lift<J>),
186        quote! {
187            type Lifted = #lifted_ty;
188
189            fn lift_to_interner(
190                self,
191                interner: J,
192            ) -> Option<Self::Lifted> {
193                Some(match self { #body_fold })
194            }
195        },
196    )
197}
198
199fn lift(mut ty: syn::Type) -> syn::Type {
200    struct ItoJ;
201    impl VisitMut for ItoJ {
202        fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) {
203            if i.qself.is_none() {
204                if let Some(first) = i.path.segments.first_mut()
205                    && first.ident == "I"
206                {
207                    *first = parse_quote! { J };
208                }
209            }
210            syn::visit_mut::visit_type_path_mut(self, i);
211        }
212    }
213
214    ItoJ.visit_type_mut(&mut ty);
215
216    ty
217}
218
219#[cfg(not(feature = "nightly"))]
220fn customizable_type_visitable_derive(
221    mut s: synstructure::Structure<'_>,
222) -> proc_macro2::TokenStream {
223    if let syn::Data::Union(_) = s.ast().data {
224        panic!("cannot derive on union")
225    }
226
227    s.add_impl_generic(parse_quote!(__V));
228    s.add_bounds(synstructure::AddBounds::Fields);
229    let body_visit = s.each(|bind| {
230        quote! {
231            ::rustc_type_ir::GenericTypeVisitable::<__V>::generic_visit_with(#bind, __visitor);
232        }
233    });
234    s.bind_with(|_| synstructure::BindStyle::Move);
235
236    s.bound_impl(
237        quote!(::rustc_type_ir::GenericTypeVisitable<__V>),
238        quote! {
239            fn generic_visit_with(
240                &self,
241                __visitor: &mut __V
242            ) {
243                match *self { #body_visit }
244            }
245        },
246    )
247}
248
249#[cfg(feature = "nightly")]
250#[proc_macro_derive(GenericTypeVisitable)]
251pub fn customizable_type_visitable_derive(_: proc_macro::TokenStream) -> proc_macro::TokenStream {
252    proc_macro::TokenStream::new()
253}