rustc_type_ir_macros/
lib.rs1use indexmap::IndexSet;
2use quote::{ToTokens, quote};
3use syn::visit_mut::VisitMut;
4use syn::{Attribute, parse_quote};
5use synstructure::decl_derive;
6
7decl_derive!(
8 [TypeVisitable_Generic, attributes(type_visitable)] => type_visitable_derive
9);
10decl_derive!(
11 [TypeFoldable_Generic, attributes(type_foldable)] => type_foldable_derive
12);
13decl_derive!(
14 [Lift_Generic, attributes(lift)] => lift_derive
15);
16#[cfg(not(feature = "nightly"))]
17decl_derive!(
18 [GenericTypeVisitable] => customizable_type_visitable_derive
19);
20
21struct TransformedTy {
22 ty: syn::Type,
23 generic_parameter_bounds: IndexSet<syn::Ident>,
24}
25
26enum TypeParameterPath {
27 Interner,
28 GenericParameter(syn::Ident),
29}
30
31enum TypeParameterTransform {
32 Continue,
33 Stop,
34}
35
36type TypeParameterVisitor =
37 fn(TypeParameterPath, &mut syn::TypePath, &mut IndexSet<syn::Ident>) -> TypeParameterTransform;
38
39fn has_ignore_attr(attrs: &[Attribute], name: &'static str, meta: &'static str) -> bool {
40 let mut ignored = false;
41 attrs.iter().for_each(|attr| {
42 if !attr.path().is_ident(name) {
43 return;
44 }
45 let _ = attr.parse_nested_meta(|nested| {
46 if nested.path.is_ident(meta) {
47 ignored = true;
48 }
49 Ok(())
50 });
51 });
52
53 ignored
54}
55
56fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
57 if let syn::Data::Union(_) = s.ast().data {
58 panic!("cannot derive on union")
59 }
60
61 if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
62 s.add_impl_generic(parse_quote! { I });
63 }
64
65 s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_visitable", "ignore"));
66
67 s.add_where_predicate(parse_quote! { I: Interner });
68 s.add_bounds(synstructure::AddBounds::Fields);
69 let body_visit = s.each(|bind| {
70 quote! {
71 match ::rustc_type_ir::VisitorResult::branch(
72 ::rustc_type_ir::TypeVisitable::visit_with(#bind, __visitor)
73 ) {
74 ::core::ops::ControlFlow::Continue(()) => {},
75 ::core::ops::ControlFlow::Break(r) => {
76 return ::rustc_type_ir::VisitorResult::from_residual(r);
77 },
78 }
79 }
80 });
81 s.bind_with(|_| synstructure::BindStyle::Move);
82
83 s.bound_impl(
84 quote!(::rustc_type_ir::TypeVisitable<I>),
85 quote! {
86 fn visit_with<__V: ::rustc_type_ir::TypeVisitor<I>>(
87 &self,
88 __visitor: &mut __V
89 ) -> __V::Result {
90 match *self { #body_visit }
91 <__V::Result as ::rustc_type_ir::VisitorResult>::output()
92 }
93 },
94 )
95}
96
97fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
98 if let syn::Data::Union(_) = s.ast().data {
99 panic!("cannot derive on union")
100 }
101
102 if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
103 s.add_impl_generic(parse_quote! { I });
104 }
105
106 s.add_where_predicate(parse_quote! { I: Interner });
107 s.add_bounds(synstructure::AddBounds::Fields);
108 let generic_parameters =
109 s.ast().generics.type_params().map(|ty| ty.ident.clone()).collect::<Vec<_>>();
110 let mut generic_parameter_bounds = IndexSet::new();
111 s.bind_with(|_| synstructure::BindStyle::Move);
112 let body_try_fold = s.each_variant(|vi| {
113 let bindings = vi.bindings();
114 vi.construct(|_, index| {
115 let bind = &bindings[index];
116
117 if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
119 bind.to_token_stream()
120 } else {
121 for param in
122 type_foldable_generic_parameters(bind.ast().ty.clone(), &generic_parameters)
123 {
124 generic_parameter_bounds.insert(param);
125 }
126
127 quote! {
128 ::rustc_type_ir::TypeFoldable::try_fold_with(#bind, __folder)?
129 }
130 }
131 })
132 });
133
134 let body_fold = s.each_variant(|vi| {
135 let bindings = vi.bindings();
136 vi.construct(|_, index| {
137 let bind = &bindings[index];
138
139 if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") {
141 bind.to_token_stream()
142 } else {
143 quote! {
144 ::rustc_type_ir::TypeFoldable::fold_with(#bind, __folder)
145 }
146 }
147 })
148 });
149
150 s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_foldable", "identity"));
154 s.add_bounds(synstructure::AddBounds::Fields);
155 for param in generic_parameter_bounds {
156 s.add_where_predicate(parse_quote! { #param: ::rustc_type_ir::TypeFoldable<I> });
157 }
158 s.bound_impl(
159 quote!(::rustc_type_ir::TypeFoldable<I>),
160 quote! {
161 fn try_fold_with<__F: ::rustc_type_ir::FallibleTypeFolder<I>>(
162 self,
163 __folder: &mut __F
164 ) -> Result<Self, __F::Error> {
165 Ok(match self { #body_try_fold })
166 }
167
168 fn fold_with<__F: ::rustc_type_ir::TypeFolder<I>>(
169 self,
170 __folder: &mut __F
171 ) -> Self {
172 match self { #body_fold }
173 }
174 },
175 )
176}
177
178fn type_foldable_generic_parameters(
179 ty: syn::Type,
180 generic_parameters: &[syn::Ident],
181) -> IndexSet<syn::Ident> {
182 transform_type_parameters(ty, generic_parameters, |path, _, generic_parameter_bounds| {
183 if let TypeParameterPath::GenericParameter(param) = path {
184 generic_parameter_bounds.insert(param);
185 }
186 TypeParameterTransform::Continue
187 })
188 .generic_parameter_bounds
189}
190
191fn lift_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
204 if let syn::Data::Union(_) = s.ast().data {
205 panic!("cannot derive on union")
206 }
207
208 if !s.ast().generics.type_params().any(|ty| ty.ident == "I") {
209 s.add_impl_generic(parse_quote! { I });
210 }
211
212 s.add_bounds(synstructure::AddBounds::None);
213 s.add_impl_generic(parse_quote! { J });
214 s.add_where_predicate(parse_quote! { J: Interner });
215 s.add_where_predicate(parse_quote! { I: ::rustc_type_ir::LiftInto<J> });
216
217 let generic_parameters =
218 s.ast().generics.type_params().map(|ty| ty.ident.clone()).collect::<Vec<_>>();
219
220 let mut wc = vec![];
221 s.bind_with(|_| synstructure::BindStyle::Move);
222 let body_fold = s.each_variant(|vi| {
223 let bindings = vi.bindings();
224 vi.construct(|field, index| {
225 let ty = field.ty.clone();
226 let bind = &bindings[index];
227 if has_ignore_attr(&field.attrs, "lift", "identity") {
229 return bind.to_token_stream();
230 }
231
232 let lifted = lift(ty.clone(), &generic_parameters);
233
234 for param in lifted.generic_parameter_bounds {
239 wc.push(parse_quote! { #param: ::rustc_type_ir::lift::Lift<J> });
240 }
241
242 if is_type_phantom(&ty) {
243 return quote! {
244 PhantomData
245 };
246 }
247
248 quote! {
249 #bind.lift_to_interner(interner)
250 }
251 })
252 });
253 for wc in wc {
254 s.add_where_predicate(wc);
255 }
256
257 let (_, ty_generics, _) = s.ast().generics.split_for_impl();
258 let name = s.ast().ident.clone();
259 let self_ty: syn::Type = parse_quote! { #name #ty_generics };
260 let lifted = lift(self_ty, &generic_parameters);
261 let lifted_ty = lifted.ty;
262
263 s.bound_impl(
264 quote!(::rustc_type_ir::lift::Lift<J>),
265 quote! {
266 type Lifted = #lifted_ty;
267
268 fn lift_to_interner(
269 self,
270 interner: J,
271 ) -> Self::Lifted {
272 match self { #body_fold }
273 }
274 },
275 )
276}
277
278fn get_first_path_segment(ty: &syn::Type) -> Option<&syn::PathSegment> {
279 if let syn::Type::Path(ty) = ty
280 && ty.path.segments.len() == 1
281 {
282 ty.path.segments.first()
283 } else {
284 None
285 }
286}
287
288fn is_type_phantom(ty: &syn::Type) -> bool {
290 get_first_path_segment(ty).is_some_and(|segment| segment.ident == "PhantomData")
291}
292
293fn lift(ty: syn::Type, generic_parameters: &[syn::Ident]) -> TransformedTy {
294 transform_type_parameters(ty, generic_parameters, |path, ty, generic_parameter_bounds| {
295 match path {
296 TypeParameterPath::Interner => {
297 *ty.path.segments.first_mut().unwrap() = parse_quote! { J };
298 TypeParameterTransform::Continue
299 }
300 TypeParameterPath::GenericParameter(param) => {
301 generic_parameter_bounds.insert(param.clone());
302 *ty = parse_quote! { <#param as ::rustc_type_ir::lift::Lift<J>>::Lifted };
303 TypeParameterTransform::Stop
304 }
305 }
306 })
307}
308
309fn transform_type_parameters(
310 mut ty: syn::Type,
311 generic_parameters: &[syn::Ident],
312 visit: TypeParameterVisitor,
313) -> TransformedTy {
314 struct TypeParameterTransformer<'a> {
315 generic_parameters: &'a [syn::Ident],
316 generic_parameter_bounds: IndexSet<syn::Ident>,
317 visit: TypeParameterVisitor,
318 }
319
320 impl VisitMut for TypeParameterTransformer<'_> {
321 fn visit_type_path_mut(&mut self, i: &mut syn::TypePath) {
322 let path = if i.qself.is_none() {
323 let segments_len = i.path.segments.len();
324 i.path.segments.first().and_then(|first| {
325 if first.ident == "I" {
326 Some(TypeParameterPath::Interner)
327 } else if segments_len == 1
328 && matches!(first.arguments, syn::PathArguments::None)
329 && self.generic_parameters.contains(&first.ident)
330 {
331 Some(TypeParameterPath::GenericParameter(first.ident.clone()))
332 } else {
333 None
334 }
335 })
336 } else {
337 None
338 };
339
340 if let Some(path) = path {
341 if let TypeParameterTransform::Stop =
342 (self.visit)(path, i, &mut self.generic_parameter_bounds)
343 {
344 return;
345 }
346 }
347
348 syn::visit_mut::visit_type_path_mut(self, i);
349 }
350 }
351
352 let mut visitor = TypeParameterTransformer {
353 generic_parameters,
354 generic_parameter_bounds: IndexSet::new(),
355 visit,
356 };
357 visitor.visit_type_mut(&mut ty);
358 TransformedTy { ty, generic_parameter_bounds: visitor.generic_parameter_bounds }
359}
360
361#[cfg(not(feature = "nightly"))]
362fn customizable_type_visitable_derive(
363 mut s: synstructure::Structure<'_>,
364) -> proc_macro2::TokenStream {
365 if let syn::Data::Union(_) = s.ast().data {
366 panic!("cannot derive on union")
367 }
368
369 s.add_impl_generic(parse_quote!(__V));
370 s.add_bounds(synstructure::AddBounds::Fields);
371 let body_visit = s.each(|bind| {
372 quote! {
373 ::rustc_type_ir::GenericTypeVisitable::<__V>::generic_visit_with(#bind, __visitor);
374 }
375 });
376 s.bind_with(|_| synstructure::BindStyle::Move);
377
378 s.bound_impl(
379 quote!(::rustc_type_ir::GenericTypeVisitable<__V>),
380 quote! {
381 fn generic_visit_with(
382 &self,
383 __visitor: &mut __V
384 ) {
385 match *self { #body_visit }
386 }
387 },
388 )
389}
390
391#[cfg(feature = "nightly")]
392#[proc_macro_derive(GenericTypeVisitable)]
393pub fn customizable_type_visitable_derive(_: proc_macro::TokenStream) -> proc_macro::TokenStream {
394 proc_macro::TokenStream::new()
395}