Skip to main content

rustc_expand/
proc_macro.rs

1use rustc_ast::tokenstream::TokenStream;
2use rustc_errors::ErrorGuaranteed;
3use rustc_middle::ty::{self, TyCtxt};
4use rustc_parse::parser::{AllowConstBlockItems, ForceCollect, Parser};
5use rustc_session::Session;
6use rustc_session::config::ProcMacroExecutionStrategy;
7use rustc_span::profiling::SpannedEventArgRecorder;
8use rustc_span::{LocalExpnId, Span};
9use {rustc_ast as ast, rustc_proc_macro as pm};
10
11use crate::base::{self, *};
12use crate::{errors, proc_macro_server};
13
14fn exec_strategy(sess: &Session) -> impl pm::bridge::server::ExecutionStrategy + 'static {
15    pm::bridge::server::MaybeCrossThread {
16        cross_thread: sess.opts.unstable_opts.proc_macro_execution_strategy
17            == ProcMacroExecutionStrategy::CrossThread,
18    }
19}
20
21pub struct BangProcMacro {
22    pub client: pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>,
23}
24
25impl base::BangProcMacro for BangProcMacro {
26    fn expand(
27        &self,
28        ecx: &mut ExtCtxt<'_>,
29        span: Span,
30        input: TokenStream,
31    ) -> Result<TokenStream, ErrorGuaranteed> {
32        let _timer =
33            ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
34                recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
35            });
36
37        let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
38        let strategy = exec_strategy(ecx.sess);
39        let server = proc_macro_server::Rustc::new(ecx);
40        self.client.run(&strategy, server, input, proc_macro_backtrace).map_err(|e| {
41            ecx.dcx().emit_err(errors::ProcMacroPanicked {
42                span,
43                message: e
44                    .as_str()
45                    .map(|message| errors::ProcMacroPanickedHelp { message: message.into() }),
46            })
47        })
48    }
49}
50
51pub struct AttrProcMacro {
52    pub client: pm::bridge::client::Client<(pm::TokenStream, pm::TokenStream), pm::TokenStream>,
53}
54
55impl base::AttrProcMacro for AttrProcMacro {
56    fn expand(
57        &self,
58        ecx: &mut ExtCtxt<'_>,
59        span: Span,
60        annotation: TokenStream,
61        annotated: TokenStream,
62    ) -> Result<TokenStream, ErrorGuaranteed> {
63        let _timer =
64            ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
65                recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
66            });
67
68        let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
69        let strategy = exec_strategy(ecx.sess);
70        let server = proc_macro_server::Rustc::new(ecx);
71        self.client.run(&strategy, server, annotation, annotated, proc_macro_backtrace).map_err(
72            |e| {
73                ecx.dcx().emit_err(errors::CustomAttributePanicked {
74                    span,
75                    message: e.as_str().map(|message| errors::CustomAttributePanickedHelp {
76                        message: message.into(),
77                    }),
78                })
79            },
80        )
81    }
82}
83
84pub struct DeriveProcMacro {
85    pub client: DeriveClient,
86}
87
88impl MultiItemModifier for DeriveProcMacro {
89    fn expand(
90        &self,
91        ecx: &mut ExtCtxt<'_>,
92        span: Span,
93        _meta_item: &ast::MetaItem,
94        item: Annotatable,
95        _is_derive_const: bool,
96    ) -> ExpandResult<Vec<Annotatable>, Annotatable> {
97        let _timer = ecx.sess.prof.generic_activity_with_arg_recorder(
98            "expand_derive_proc_macro_outer",
99            |recorder| {
100                recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
101            },
102        );
103
104        // We need special handling for statement items
105        // (e.g. `fn foo() { #[derive(Debug)] struct Bar; }`)
106        let is_stmt = #[allow(non_exhaustive_omitted_patterns)] match item {
    Annotatable::Stmt(..) => true,
    _ => false,
}matches!(item, Annotatable::Stmt(..));
107
108        let input = item.to_tokens();
109
110        let invoc_id = ecx.current_expansion.id;
111
112        let res = if ecx.sess.opts.incremental.is_some()
113            && ecx.sess.opts.unstable_opts.cache_proc_macros
114        {
115            ty::tls::with(|tcx| {
116                let input = &*tcx.arena.alloc(input);
117                let key: (LocalExpnId, &TokenStream) = (invoc_id, input);
118
119                QueryDeriveExpandCtx::enter(ecx, self.client, move || {
120                    tcx.derive_macro_expansion(key).cloned()
121                })
122            })
123        } else {
124            expand_derive_macro(invoc_id, input, ecx, self.client)
125        };
126
127        let Ok(output) = res else {
128            // error will already have been emitted
129            return ExpandResult::Ready(::alloc::vec::Vec::new()vec![]);
130        };
131
132        let error_count_before = ecx.dcx().err_count();
133        let mut parser = Parser::new(&ecx.sess.psess, output, Some("proc-macro derive"));
134        let mut items = ::alloc::vec::Vec::new()vec![];
135
136        loop {
137            match parser.parse_item(
138                ForceCollect::No,
139                if is_stmt { AllowConstBlockItems::No } else { AllowConstBlockItems::Yes },
140            ) {
141                Ok(None) => break,
142                Ok(Some(item)) => {
143                    if is_stmt {
144                        items.push(Annotatable::Stmt(Box::new(ecx.stmt_item(span, item))));
145                    } else {
146                        items.push(Annotatable::Item(item));
147                    }
148                }
149                Err(err) => {
150                    err.emit();
151                    break;
152                }
153            }
154        }
155
156        // fail if there have been errors emitted
157        if ecx.dcx().err_count() > error_count_before {
158            ecx.dcx().emit_err(errors::ProcMacroDeriveTokens { span });
159        }
160
161        ExpandResult::Ready(items)
162    }
163}
164
165/// Provide a query for computing the output of a derive macro.
166pub(super) fn provide_derive_macro_expansion<'tcx>(
167    tcx: TyCtxt<'tcx>,
168    key: (LocalExpnId, &'tcx TokenStream),
169) -> Result<&'tcx TokenStream, ()> {
170    let (invoc_id, input) = key;
171
172    // Make sure that we invalidate the query when the crate defining the proc macro changes
173    let _ = tcx.crate_hash(invoc_id.expn_data().macro_def_id.unwrap().krate);
174
175    QueryDeriveExpandCtx::with(|ecx, client| {
176        expand_derive_macro(invoc_id, input.clone(), ecx, client).map(|ts| &*tcx.arena.alloc(ts))
177    })
178}
179
180type DeriveClient = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;
181
182fn expand_derive_macro(
183    invoc_id: LocalExpnId,
184    input: TokenStream,
185    ecx: &mut ExtCtxt<'_>,
186    client: DeriveClient,
187) -> Result<TokenStream, ()> {
188    let _timer =
189        ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
190            let invoc_expn_data = invoc_id.expn_data();
191            let span = invoc_expn_data.call_site;
192            let event_arg = invoc_expn_data.kind.descr();
193            recorder.record_arg_with_span(ecx.sess.source_map(), event_arg.clone(), span);
194        });
195
196    let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
197    let strategy = exec_strategy(ecx.sess);
198    let server = proc_macro_server::Rustc::new(ecx);
199
200    match client.run(&strategy, server, input, proc_macro_backtrace) {
201        Ok(stream) => Ok(stream),
202        Err(e) => {
203            let invoc_expn_data = invoc_id.expn_data();
204            let span = invoc_expn_data.call_site;
205            ecx.dcx().emit_err({
206                errors::ProcMacroDerivePanicked {
207                    span,
208                    message: e.as_str().map(|message| errors::ProcMacroDerivePanickedHelp {
209                        message: message.into(),
210                    }),
211                }
212            });
213            Err(())
214        }
215    }
216}
217
218/// Stores the context necessary to expand a derive proc macro via a query.
219struct QueryDeriveExpandCtx {
220    /// Type-erased version of `&mut ExtCtxt`
221    expansion_ctx: *mut (),
222    client: DeriveClient,
223}
224
225impl QueryDeriveExpandCtx {
226    /// Store the extension context and the client into the thread local value.
227    /// It will be accessible via the `with` method while `f` is active.
228    fn enter<F, R>(ecx: &mut ExtCtxt<'_>, client: DeriveClient, f: F) -> R
229    where
230        F: FnOnce() -> R,
231    {
232        // We need erasure to get rid of the lifetime
233        let ctx = Self { expansion_ctx: ecx as *mut _ as *mut (), client };
234        DERIVE_EXPAND_CTX.set(&ctx, || f())
235    }
236
237    /// Accesses the thread local value of the derive expansion context.
238    /// Must be called while the `enter` function is active.
239    fn with<F, R>(f: F) -> R
240    where
241        F: for<'a, 'b> FnOnce(&'b mut ExtCtxt<'a>, DeriveClient) -> R,
242    {
243        DERIVE_EXPAND_CTX.with(|ctx| {
244            let ectx = {
245                let casted = ctx.expansion_ctx.cast::<ExtCtxt<'_>>();
246                // SAFETY: We can only get the value from `with` while the `enter` function
247                // is active (on the callstack), and that function's signature ensures that the
248                // lifetime is valid.
249                // If `with` is called at some other time, it will panic due to usage of
250                // `scoped_tls::with`.
251                unsafe { casted.as_mut().unwrap() }
252            };
253
254            f(ectx, ctx.client)
255        })
256    }
257}
258
259// When we invoke a query to expand a derive proc macro, we need to provide it with the expansion
260// context and derive Client. We do that using a thread-local.
261static DERIVE_EXPAND_CTX: ::scoped_tls::ScopedKey<QueryDeriveExpandCtx> =
    ::scoped_tls::ScopedKey {
        inner: {
            const FOO: ::std::thread::LocalKey<::std::cell::Cell<*const ()>> =
                {
                    const __RUST_STD_INTERNAL_INIT: ::std::cell::Cell<*const ()>
                        =
                        { ::std::cell::Cell::new(::std::ptr::null()) };
                    unsafe {
                        ::std::thread::LocalKey::new(const {
                                    if ::std::mem::needs_drop::<::std::cell::Cell<*const ()>>()
                                        {
                                        |_|
                                            {
                                                #[thread_local]
                                                static __RUST_STD_INTERNAL_VAL:
                                                    ::std::thread::local_impl::EagerStorage<::std::cell::Cell<*const ()>>
                                                    =
                                                    ::std::thread::local_impl::EagerStorage::new(__RUST_STD_INTERNAL_INIT);
                                                __RUST_STD_INTERNAL_VAL.get()
                                            }
                                    } else {
                                        |_|
                                            {
                                                #[thread_local]
                                                static __RUST_STD_INTERNAL_VAL: ::std::cell::Cell<*const ()>
                                                    =
                                                    __RUST_STD_INTERNAL_INIT;
                                                &__RUST_STD_INTERNAL_VAL
                                            }
                                    }
                                })
                    }
                };
            &FOO
        },
        _marker: ::std::marker::PhantomData,
    };scoped_tls::scoped_thread_local!(static DERIVE_EXPAND_CTX: QueryDeriveExpandCtx);