rustc_macros/
type_visitable.rs

1use quote::quote;
2use syn::parse_quote;
3
4pub(super) fn type_visitable_derive(
5    mut s: synstructure::Structure<'_>,
6) -> proc_macro2::TokenStream {
7    if let syn::Data::Union(_) = s.ast().data {
8        panic!("cannot derive on union")
9    }
10
11    s.underscore_const(true);
12
13    // ignore fields with #[type_visitable(ignore)]
14    s.filter(|bi| {
15        let mut ignored = false;
16
17        bi.ast().attrs.iter().for_each(|attr| {
18            if !attr.path().is_ident("type_visitable") {
19                return;
20            }
21            let _ = attr.parse_nested_meta(|nested| {
22                if nested.path.is_ident("ignore") {
23                    ignored = true;
24                }
25                Ok(())
26            });
27        });
28
29        !ignored
30    });
31
32    if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
33        s.add_impl_generic(parse_quote! { 'tcx });
34    }
35
36    s.add_bounds(synstructure::AddBounds::Generics);
37    let body_visit = s.each(|bind| {
38        quote! {
39            match ::rustc_ast_ir::visit::VisitorResult::branch(
40                ::rustc_middle::ty::visit::TypeVisitable::visit_with(#bind, __visitor)
41            ) {
42                ::core::ops::ControlFlow::Continue(()) => {},
43                ::core::ops::ControlFlow::Break(r) => {
44                    return ::rustc_ast_ir::visit::VisitorResult::from_residual(r);
45                },
46            }
47        }
48    });
49    s.bind_with(|_| synstructure::BindStyle::Move);
50
51    s.bound_impl(
52        quote!(::rustc_middle::ty::visit::TypeVisitable<::rustc_middle::ty::TyCtxt<'tcx>>),
53        quote! {
54            fn visit_with<__V: ::rustc_middle::ty::visit::TypeVisitor<::rustc_middle::ty::TyCtxt<'tcx>>>(
55                &self,
56                __visitor: &mut __V
57            ) -> __V::Result {
58                match *self { #body_visit }
59                <__V::Result as ::rustc_ast_ir::visit::VisitorResult>::output()
60            }
61        },
62    )
63}