rustc_macros/
extension.rs

1use proc_macro2::Ident;
2use quote::quote;
3use syn::parse::{Parse, ParseStream};
4use syn::punctuated::Punctuated;
5use syn::spanned::Spanned;
6use syn::{
7    Attribute, Generics, ImplItem, Pat, PatIdent, Path, Signature, Token, TraitItem,
8    TraitItemConst, TraitItemFn, TraitItemMacro, TraitItemType, Type, Visibility, WhereClause,
9    braced, parse_macro_input,
10};
11
12pub(crate) fn extension(
13    attr: proc_macro::TokenStream,
14    input: proc_macro::TokenStream,
15) -> proc_macro::TokenStream {
16    let ExtensionAttr { vis, trait_ } = parse_macro_input!(attr as ExtensionAttr);
17    let Impl { attrs, generics, self_ty, items, wc } = parse_macro_input!(input as Impl);
18    let headers: Vec<_> = items
19        .iter()
20        .map(|item| match item {
21            ImplItem::Fn(f) => TraitItem::Fn(TraitItemFn {
22                attrs: scrub_attrs(&f.attrs),
23                sig: scrub_header(f.sig.clone()),
24                default: None,
25                semi_token: Some(Token![;](f.block.span())),
26            }),
27            ImplItem::Const(ct) => TraitItem::Const(TraitItemConst {
28                attrs: scrub_attrs(&ct.attrs),
29                const_token: ct.const_token,
30                ident: ct.ident.clone(),
31                generics: ct.generics.clone(),
32                colon_token: ct.colon_token,
33                ty: ct.ty.clone(),
34                default: None,
35                semi_token: ct.semi_token,
36            }),
37            ImplItem::Type(ty) => TraitItem::Type(TraitItemType {
38                attrs: scrub_attrs(&ty.attrs),
39                type_token: ty.type_token,
40                ident: ty.ident.clone(),
41                generics: ty.generics.clone(),
42                colon_token: None,
43                bounds: Punctuated::new(),
44                default: None,
45                semi_token: ty.semi_token,
46            }),
47            ImplItem::Macro(mac) => TraitItem::Macro(TraitItemMacro {
48                attrs: scrub_attrs(&mac.attrs),
49                mac: mac.mac.clone(),
50                semi_token: mac.semi_token,
51            }),
52            ImplItem::Verbatim(stream) => TraitItem::Verbatim(stream.clone()),
53            _ => unimplemented!(),
54        })
55        .collect();
56
57    quote! {
58        #(#attrs)*
59        #vis trait #trait_ {
60            #(#headers)*
61        }
62
63        impl #generics #trait_ for #self_ty #wc {
64            #(#items)*
65        }
66    }
67    .into()
68}
69
70/// Only keep `#[doc]` attrs.
71fn scrub_attrs(attrs: &[Attribute]) -> Vec<Attribute> {
72    attrs
73        .into_iter()
74        .cloned()
75        .filter(|attr| {
76            let ident = &attr.path().segments[0].ident;
77            ident == "doc" || ident == "must_use"
78        })
79        .collect()
80}
81
82/// Scrub arguments so that they're valid for trait signatures.
83fn scrub_header(mut sig: Signature) -> Signature {
84    for (idx, input) in sig.inputs.iter_mut().enumerate() {
85        match input {
86            syn::FnArg::Receiver(rcvr) => {
87                // `mut self` -> `self`
88                if rcvr.reference.is_none() {
89                    rcvr.mutability.take();
90                }
91            }
92            syn::FnArg::Typed(arg) => match &mut *arg.pat {
93                Pat::Ident(arg) => {
94                    // `ref mut ident @ pat` -> `ident`
95                    arg.by_ref.take();
96                    arg.mutability.take();
97                    arg.subpat.take();
98                }
99                _ => {
100                    // `pat` -> `__arg0`
101                    arg.pat = Box::new(
102                        PatIdent {
103                            attrs: vec![],
104                            by_ref: None,
105                            mutability: None,
106                            ident: Ident::new(&format!("__arg{idx}"), arg.pat.span()),
107                            subpat: None,
108                        }
109                        .into(),
110                    )
111                }
112            },
113        }
114    }
115    sig
116}
117
118struct ExtensionAttr {
119    vis: Visibility,
120    trait_: Path,
121}
122
123impl Parse for ExtensionAttr {
124    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
125        let vis = input.parse()?;
126        let _: Token![trait] = input.parse()?;
127        let trait_ = input.parse()?;
128        Ok(ExtensionAttr { vis, trait_ })
129    }
130}
131
132struct Impl {
133    attrs: Vec<Attribute>,
134    generics: Generics,
135    self_ty: Type,
136    items: Vec<ImplItem>,
137    wc: Option<WhereClause>,
138}
139
140impl Parse for Impl {
141    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
142        let attrs = input.call(Attribute::parse_outer)?;
143        let _: Token![impl] = input.parse()?;
144        let generics = input.parse()?;
145        let self_ty = input.parse()?;
146        let wc = input.parse()?;
147
148        let content;
149        let _brace_token = braced!(content in input);
150        let mut items = Vec::new();
151        while !content.is_empty() {
152            items.push(content.parse()?);
153        }
154
155        Ok(Impl { attrs, generics, self_ty, items, wc })
156    }
157}