rustc_type_ir_macros/
lib.rs
1use quote::{ToTokens, quote};
2use syn::visit_mut::VisitMut;
3use syn::{Attribute, parse_quote};
4use synstructure::decl_derive;
5
6decl_derive!(
7 [TypeVisitable_Generic, attributes(type_visitable)] => type_visitable_derive
8);
9decl_derive!(
10 [TypeFoldable_Generic, attributes(type_foldable)] => type_foldable_derive
11);
12decl_derive!(
13 [Lift_Generic] => lift_derive
14);
15
16fn has_ignore_attr(attrs: &[Attribute], name: &'static str, meta: &'static str) -> bool {
17 let mut ignored = false;
18 attrs.iter().for_each(|attr| {
19 if !attr.path().is_ident(name) {
20 return;
21 }
22 let _ = attr.parse_nested_meta(|nested| {
23 if nested.path.is_ident(meta) {
24 ignored = true;
25 }
26 Ok(())
27 });
28 });
29
30 ignored
31}
32
33fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
34 if let syn::Data::Union(_) = s.ast().data {
35 panic!("cannot derive on union")
36 }
37
38 if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
39 s.add_impl_generic(parse_quote! { I });
40 }
41
42 s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_visitable", "ignore"));
43
44 s.add_where_predicate(parse_quote! { I: Interner });
45 s.add_bounds(synstructure::AddBounds::Fields);
46 let body_visit = s.each(|bind| {
47 quote! {
48 match ::rustc_ast_ir::visit::VisitorResult::branch(
49 ::rustc_type_ir::visit::TypeVisitable::visit_with(#bind, __visitor)
50 ) {
51 ::core::ops::ControlFlow::Continue(()) => {},
52 ::core::ops::ControlFlow::Break(r) => {
53 return ::rustc_ast_ir::visit::VisitorResult::from_residual(r);
54 },
55 }
56 }
57 });
58 s.bind_with(|_| synstructure::BindStyle::Move);
59
60 s.bound_impl(
61 quote!(::rustc_type_ir::visit::TypeVisitable<I>),
62 quote! {
63 fn visit_with<__V: ::rustc_type_ir::visit::TypeVisitor<I>>(
64 &self,
65 __visitor: &mut __V
66 ) -> __V::Result {
67 match *self { #body_visit }
68 <__V::Result as ::rustc_ast_ir::visit::VisitorResult>::output()
69 }
70 },
71 )
72}
73
74fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
75 if let syn::Data::Union(_) = s.ast().data {
76 panic!("cannot derive on union")
77 }
78
79 if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
80 s.add_impl_generic(parse_quote! { I });
81 }
82
83 s.add_where_predicate(parse_quote! { I: Interner });
84 s.add_bounds(synstructure::AddBounds::Fields);
85 s.bind_with(|_| synstructure::BindStyle::Move);
86 let body_fold = s.each_variant(|vi| {
87 let bindings = vi.bindings();
88 vi.construct(|_, index| {
89 let bind = &bindings[index];
90
91 if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
93 bind.to_token_stream()
94 } else {
95 quote! {
96 ::rustc_type_ir::fold::TypeFoldable::try_fold_with(#bind, __folder)?
97 }
98 }
99 })
100 });
101
102 s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_foldable", "identity"));
106 s.add_bounds(synstructure::AddBounds::Fields);
107 s.bound_impl(
108 quote!(::rustc_type_ir::fold::TypeFoldable<I>),
109 quote! {
110 fn try_fold_with<__F: ::rustc_type_ir::fold::FallibleTypeFolder<I>>(
111 self,
112 __folder: &mut __F
113 ) -> Result<Self, __F::Error> {
114 Ok(match self { #body_fold })
115 }
116 },
117 )
118}
119
120fn lift_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
121 if let syn::Data::Union(_) = s.ast().data {
122 panic!("cannot derive on union")
123 }
124
125 if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
126 s.add_impl_generic(parse_quote! { I });
127 }
128
129 s.add_bounds(synstructure::AddBounds::None);
130 s.add_where_predicate(parse_quote! { I: Interner });
131 s.add_impl_generic(parse_quote! { J });
132 s.add_where_predicate(parse_quote! { J: Interner });
133
134 let mut wc = vec![];
135 s.bind_with(|_| synstructure::BindStyle::Move);
136 let body_fold = s.each_variant(|vi| {
137 let bindings = vi.bindings();
138 vi.construct(|field, index| {
139 let ty = field.ty.clone();
140 let lifted_ty = lift(ty.clone());
141 wc.push(parse_quote! { #ty: ::rustc_type_ir::lift::Lift<J, Lifted = #lifted_ty> });
142 let bind = &bindings[index];
143 quote! {
144 #bind.lift_to_interner(interner)?
145 }
146 })
147 });
148 for wc in wc {
149 s.add_where_predicate(wc);
150 }
151
152 let (_, ty_generics, _) = s.ast().generics.split_for_impl();
153 let name = s.ast().ident.clone();
154 let self_ty: syn::Type = parse_quote! { #name #ty_generics };
155 let lifted_ty = lift(self_ty);
156
157 s.bound_impl(
158 quote!(::rustc_type_ir::lift::Lift<J>),
159 quote! {
160 type Lifted = #lifted_ty;
161
162 fn lift_to_interner(
163 self,
164 interner: J,
165 ) -> Option<Self::Lifted> {
166 Some(match self { #body_fold })
167 }
168 },
169 )
170}
171
172fn lift(mut ty: syn::Type) -> syn::Type {
173 struct ItoJ;
174 impl VisitMut for ItoJ {
175 fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) {
176 if i.qself.is_none() {
177 if let Some(first) = i.path.segments.first_mut() {
178 if first.ident == "I" {
179 *first = parse_quote! { J };
180 }
181 }
182 }
183 syn::visit_mut::visit_type_path_mut(self, i);
184 }
185 }
186
187 ItoJ.visit_type_mut(&mut ty);
188
189 ty
190}