Skip to main content

rustc_expand/
proc_macro.rs

1use rustc_ast as ast;
2use rustc_ast::tokenstream::TokenStream;
3use rustc_errors::ErrorGuaranteed;
4use rustc_middle::ty::{self, TyCtxt};
5use rustc_parse::parser::{AllowConstBlockItems, ForceCollect, Parser};
6use rustc_proc_macro as pm;
7use rustc_session::Session;
8use rustc_session::config::ProcMacroExecutionStrategy;
9use rustc_span::profiling::SpannedEventArgRecorder;
10use rustc_span::{LocalExpnId, Span};
11
12use crate::base::{self, *};
13use crate::{errors, proc_macro_server};
14
15fn exec_strategy(sess: &Session) -> impl pm::bridge::server::ExecutionStrategy + 'static {
16    pm::bridge::server::MaybeCrossThread {
17        cross_thread: sess.opts.unstable_opts.proc_macro_execution_strategy
18            == ProcMacroExecutionStrategy::CrossThread,
19    }
20}
21
22pub struct BangProcMacro {
23    pub client: pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>,
24}
25
26impl base::BangProcMacro for BangProcMacro {
27    fn expand(
28        &self,
29        ecx: &mut ExtCtxt<'_>,
30        span: Span,
31        input: TokenStream,
32    ) -> Result<TokenStream, ErrorGuaranteed> {
33        let _timer =
34            ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
35                recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
36            });
37
38        let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
39        let strategy = exec_strategy(ecx.sess);
40        let server = proc_macro_server::Rustc::new(ecx);
41        self.client.run(&strategy, server, input, proc_macro_backtrace).map_err(|e| {
42            ecx.dcx().emit_err(errors::ProcMacroPanicked {
43                span,
44                message: e
45                    .as_str()
46                    .map(|message| errors::ProcMacroPanickedHelp { message: message.into() }),
47            })
48        })
49    }
50}
51
52pub struct AttrProcMacro {
53    pub client: pm::bridge::client::Client<(pm::TokenStream, pm::TokenStream), pm::TokenStream>,
54}
55
56impl base::AttrProcMacro for AttrProcMacro {
57    fn expand(
58        &self,
59        ecx: &mut ExtCtxt<'_>,
60        span: Span,
61        annotation: TokenStream,
62        annotated: TokenStream,
63    ) -> Result<TokenStream, ErrorGuaranteed> {
64        let _timer =
65            ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
66                recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
67            });
68
69        let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
70        let strategy = exec_strategy(ecx.sess);
71        let server = proc_macro_server::Rustc::new(ecx);
72        self.client.run(&strategy, server, annotation, annotated, proc_macro_backtrace).map_err(
73            |e| {
74                ecx.dcx().emit_err(errors::CustomAttributePanicked {
75                    span,
76                    message: e.as_str().map(|message| errors::CustomAttributePanickedHelp {
77                        message: message.into(),
78                    }),
79                })
80            },
81        )
82    }
83}
84
85pub struct DeriveProcMacro {
86    pub client: DeriveClient,
87}
88
89impl MultiItemModifier for DeriveProcMacro {
90    fn expand(
91        &self,
92        ecx: &mut ExtCtxt<'_>,
93        span: Span,
94        _meta_item: &ast::MetaItem,
95        item: Annotatable,
96        _is_derive_const: bool,
97    ) -> ExpandResult<Vec<Annotatable>, Annotatable> {
98        let _timer = ecx.sess.prof.generic_activity_with_arg_recorder(
99            "expand_derive_proc_macro_outer",
100            |recorder| {
101                recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
102            },
103        );
104
105        // We need special handling for statement items
106        // (e.g. `fn foo() { #[derive(Debug)] struct Bar; }`)
107        let is_stmt = #[allow(non_exhaustive_omitted_patterns)] match item {
    Annotatable::Stmt(..) => true,
    _ => false,
}matches!(item, Annotatable::Stmt(..));
108
109        let input = item.to_tokens();
110
111        let invoc_id = ecx.current_expansion.id;
112
113        let res = if ecx.sess.opts.incremental.is_some()
114            && ecx.sess.opts.unstable_opts.cache_proc_macros
115        {
116            ty::tls::with(|tcx| {
117                let input = &*tcx.arena.alloc(input);
118                let key: (LocalExpnId, &TokenStream) = (invoc_id, input);
119
120                QueryDeriveExpandCtx::enter(ecx, self.client, move || {
121                    tcx.derive_macro_expansion(key).cloned()
122                })
123            })
124        } else {
125            expand_derive_macro(invoc_id, input, ecx, self.client)
126        };
127
128        let Ok(output) = res else {
129            // error will already have been emitted
130            return ExpandResult::Ready(::alloc::vec::Vec::new()vec![]);
131        };
132
133        let error_count_before = ecx.dcx().err_count();
134        let mut parser = Parser::new(&ecx.sess.psess, output, Some("proc-macro derive"));
135        let mut items = ::alloc::vec::Vec::new()vec![];
136
137        loop {
138            match parser.parse_item(
139                ForceCollect::No,
140                if is_stmt { AllowConstBlockItems::No } else { AllowConstBlockItems::Yes },
141            ) {
142                Ok(None) => break,
143                Ok(Some(item)) => {
144                    if is_stmt {
145                        items.push(Annotatable::Stmt(Box::new(ecx.stmt_item(span, item))));
146                    } else {
147                        items.push(Annotatable::Item(item));
148                    }
149                }
150                Err(err) => {
151                    err.emit();
152                    break;
153                }
154            }
155        }
156
157        // fail if there have been errors emitted
158        if ecx.dcx().err_count() > error_count_before {
159            ecx.dcx().emit_err(errors::ProcMacroDeriveTokens { span });
160        }
161
162        ExpandResult::Ready(items)
163    }
164}
165
166/// Provide a query for computing the output of a derive macro.
167pub(super) fn provide_derive_macro_expansion<'tcx>(
168    tcx: TyCtxt<'tcx>,
169    key: (LocalExpnId, &'tcx TokenStream),
170) -> Result<&'tcx TokenStream, ()> {
171    let (invoc_id, input) = key;
172
173    // Make sure that we invalidate the query when the crate defining the proc macro changes
174    let _ = tcx.crate_hash(invoc_id.expn_data().macro_def_id.unwrap().krate);
175
176    QueryDeriveExpandCtx::with(|ecx, client| {
177        expand_derive_macro(invoc_id, input.clone(), ecx, client).map(|ts| &*tcx.arena.alloc(ts))
178    })
179}
180
181type DeriveClient = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;
182
183fn expand_derive_macro(
184    invoc_id: LocalExpnId,
185    input: TokenStream,
186    ecx: &mut ExtCtxt<'_>,
187    client: DeriveClient,
188) -> Result<TokenStream, ()> {
189    let _timer =
190        ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
191            let invoc_expn_data = invoc_id.expn_data();
192            let span = invoc_expn_data.call_site;
193            let event_arg = invoc_expn_data.kind.descr();
194            recorder.record_arg_with_span(ecx.sess.source_map(), event_arg.clone(), span);
195        });
196
197    let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
198    let strategy = exec_strategy(ecx.sess);
199    let server = proc_macro_server::Rustc::new(ecx);
200
201    match client.run(&strategy, server, input, proc_macro_backtrace) {
202        Ok(stream) => Ok(stream),
203        Err(e) => {
204            let invoc_expn_data = invoc_id.expn_data();
205            let span = invoc_expn_data.call_site;
206            ecx.dcx().emit_err({
207                errors::ProcMacroDerivePanicked {
208                    span,
209                    message: e.as_str().map(|message| errors::ProcMacroDerivePanickedHelp {
210                        message: message.into(),
211                    }),
212                }
213            });
214            Err(())
215        }
216    }
217}
218
219/// Stores the context necessary to expand a derive proc macro via a query.
220struct QueryDeriveExpandCtx {
221    /// Type-erased version of `&mut ExtCtxt`
222    expansion_ctx: *mut (),
223    client: DeriveClient,
224}
225
226impl QueryDeriveExpandCtx {
227    /// Store the extension context and the client into the thread local value.
228    /// It will be accessible via the `with` method while `f` is active.
229    fn enter<F, R>(ecx: &mut ExtCtxt<'_>, client: DeriveClient, f: F) -> R
230    where
231        F: FnOnce() -> R,
232    {
233        // We need erasure to get rid of the lifetime
234        let ctx = Self { expansion_ctx: ecx as *mut _ as *mut (), client };
235        DERIVE_EXPAND_CTX.set(&ctx, || f())
236    }
237
238    /// Accesses the thread local value of the derive expansion context.
239    /// Must be called while the `enter` function is active.
240    fn with<F, R>(f: F) -> R
241    where
242        F: for<'a, 'b> FnOnce(&'b mut ExtCtxt<'a>, DeriveClient) -> R,
243    {
244        DERIVE_EXPAND_CTX.with(|ctx| {
245            let ectx = {
246                let casted = ctx.expansion_ctx.cast::<ExtCtxt<'_>>();
247                // SAFETY: We can only get the value from `with` while the `enter` function
248                // is active (on the callstack), and that function's signature ensures that the
249                // lifetime is valid.
250                // If `with` is called at some other time, it will panic due to usage of
251                // `scoped_tls::with`.
252                unsafe { casted.as_mut().unwrap() }
253            };
254
255            f(ectx, ctx.client)
256        })
257    }
258}
259
260// When we invoke a query to expand a derive proc macro, we need to provide it with the expansion
261// context and derive Client. We do that using a thread-local.
262static DERIVE_EXPAND_CTX: ::scoped_tls::ScopedKey<QueryDeriveExpandCtx> =
    ::scoped_tls::ScopedKey {
        inner: {
            const FOO: ::std::thread::LocalKey<::std::cell::Cell<*const ()>> =
                {
                    const __RUST_STD_INTERNAL_INIT: ::std::cell::Cell<*const ()>
                        =
                        { ::std::cell::Cell::new(::std::ptr::null()) };
                    unsafe {
                        ::std::thread::LocalKey::new(const {
                                    if ::std::mem::needs_drop::<::std::cell::Cell<*const ()>>()
                                        {
                                        |_|
                                            {
                                                #[thread_local]
                                                static __RUST_STD_INTERNAL_VAL:
                                                    ::std::thread::local_impl::EagerStorage<::std::cell::Cell<*const ()>>
                                                    =
                                                    ::std::thread::local_impl::EagerStorage::new(__RUST_STD_INTERNAL_INIT);
                                                __RUST_STD_INTERNAL_VAL.get()
                                            }
                                    } else {
                                        |_|
                                            {
                                                #[thread_local]
                                                static __RUST_STD_INTERNAL_VAL: ::std::cell::Cell<*const ()>
                                                    =
                                                    __RUST_STD_INTERNAL_INIT;
                                                &__RUST_STD_INTERNAL_VAL
                                            }
                                    }
                                })
                    }
                };
            &FOO
        },
        _marker: ::std::marker::PhantomData,
    };scoped_tls::scoped_thread_local!(static DERIVE_EXPAND_CTX: QueryDeriveExpandCtx);