rustc_type_ir_macros/
lib.rs

1use quote::{ToTokens, quote};
2use syn::visit_mut::VisitMut;
3use syn::{Attribute, parse_quote};
4use synstructure::decl_derive;
5
6decl_derive!(
7    [TypeVisitable_Generic, attributes(type_visitable)] => type_visitable_derive
8);
9decl_derive!(
10    [TypeFoldable_Generic, attributes(type_foldable)] => type_foldable_derive
11);
12decl_derive!(
13    [Lift_Generic] => lift_derive
14);
15
16fn has_ignore_attr(attrs: &[Attribute], name: &'static str, meta: &'static str) -> bool {
17    let mut ignored = false;
18    attrs.iter().for_each(|attr| {
19        if !attr.path().is_ident(name) {
20            return;
21        }
22        let _ = attr.parse_nested_meta(|nested| {
23            if nested.path.is_ident(meta) {
24                ignored = true;
25            }
26            Ok(())
27        });
28    });
29
30    ignored
31}
32
33fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
34    if let syn::Data::Union(_) = s.ast().data {
35        panic!("cannot derive on union")
36    }
37
38    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
39        s.add_impl_generic(parse_quote! { I });
40    }
41
42    s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_visitable", "ignore"));
43
44    s.add_where_predicate(parse_quote! { I: Interner });
45    s.add_bounds(synstructure::AddBounds::Fields);
46    let body_visit = s.each(|bind| {
47        quote! {
48            match ::rustc_ast_ir::visit::VisitorResult::branch(
49                ::rustc_type_ir::visit::TypeVisitable::visit_with(#bind, __visitor)
50            ) {
51                ::core::ops::ControlFlow::Continue(()) => {},
52                ::core::ops::ControlFlow::Break(r) => {
53                    return ::rustc_ast_ir::visit::VisitorResult::from_residual(r);
54                },
55            }
56        }
57    });
58    s.bind_with(|_| synstructure::BindStyle::Move);
59
60    s.bound_impl(
61        quote!(::rustc_type_ir::visit::TypeVisitable<I>),
62        quote! {
63            fn visit_with<__V: ::rustc_type_ir::visit::TypeVisitor<I>>(
64                &self,
65                __visitor: &mut __V
66            ) -> __V::Result {
67                match *self { #body_visit }
68                <__V::Result as ::rustc_ast_ir::visit::VisitorResult>::output()
69            }
70        },
71    )
72}
73
74fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
75    if let syn::Data::Union(_) = s.ast().data {
76        panic!("cannot derive on union")
77    }
78
79    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
80        s.add_impl_generic(parse_quote! { I });
81    }
82
83    s.add_where_predicate(parse_quote! { I: Interner });
84    s.add_bounds(synstructure::AddBounds::Fields);
85    s.bind_with(|_| synstructure::BindStyle::Move);
86    let body_fold = s.each_variant(|vi| {
87        let bindings = vi.bindings();
88        vi.construct(|_, index| {
89            let bind = &bindings[index];
90
91            // retain value of fields with #[type_foldable(identity)]
92            if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
93                bind.to_token_stream()
94            } else {
95                quote! {
96                    ::rustc_type_ir::fold::TypeFoldable::try_fold_with(#bind, __folder)?
97                }
98            }
99        })
100    });
101
102    // We filter fields which get ignored and don't require them to implement
103    // `TypeFoldable`. We do so after generating `body_fold` as we still need
104    // to generate code for them.
105    s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_foldable", "identity"));
106    s.add_bounds(synstructure::AddBounds::Fields);
107    s.bound_impl(
108        quote!(::rustc_type_ir::fold::TypeFoldable<I>),
109        quote! {
110            fn try_fold_with<__F: ::rustc_type_ir::fold::FallibleTypeFolder<I>>(
111                self,
112                __folder: &mut __F
113            ) -> Result<Self, __F::Error> {
114                Ok(match self { #body_fold })
115            }
116        },
117    )
118}
119
120fn lift_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
121    if let syn::Data::Union(_) = s.ast().data {
122        panic!("cannot derive on union")
123    }
124
125    if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
126        s.add_impl_generic(parse_quote! { I });
127    }
128
129    s.add_bounds(synstructure::AddBounds::None);
130    s.add_where_predicate(parse_quote! { I: Interner });
131    s.add_impl_generic(parse_quote! { J });
132    s.add_where_predicate(parse_quote! { J: Interner });
133
134    let mut wc = vec![];
135    s.bind_with(|_| synstructure::BindStyle::Move);
136    let body_fold = s.each_variant(|vi| {
137        let bindings = vi.bindings();
138        vi.construct(|field, index| {
139            let ty = field.ty.clone();
140            let lifted_ty = lift(ty.clone());
141            wc.push(parse_quote! { #ty: ::rustc_type_ir::lift::Lift<J, Lifted = #lifted_ty> });
142            let bind = &bindings[index];
143            quote! {
144                #bind.lift_to_interner(interner)?
145            }
146        })
147    });
148    for wc in wc {
149        s.add_where_predicate(wc);
150    }
151
152    let (_, ty_generics, _) = s.ast().generics.split_for_impl();
153    let name = s.ast().ident.clone();
154    let self_ty: syn::Type = parse_quote! { #name #ty_generics };
155    let lifted_ty = lift(self_ty);
156
157    s.bound_impl(
158        quote!(::rustc_type_ir::lift::Lift<J>),
159        quote! {
160            type Lifted = #lifted_ty;
161
162            fn lift_to_interner(
163                self,
164                interner: J,
165            ) -> Option<Self::Lifted> {
166                Some(match self { #body_fold })
167            }
168        },
169    )
170}
171
172fn lift(mut ty: syn::Type) -> syn::Type {
173    struct ItoJ;
174    impl VisitMut for ItoJ {
175        fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) {
176            if i.qself.is_none() {
177                if let Some(first) = i.path.segments.first_mut() {
178                    if first.ident == "I" {
179                        *first = parse_quote! { J };
180                    }
181                }
182            }
183            syn::visit_mut::visit_type_path_mut(self, i);
184        }
185    }
186
187    ItoJ.visit_type_mut(&mut ty);
188
189    ty
190}