Skip to main content

rustc_expand/
proc_macro.rs

1use rustc_ast::tokenstream::TokenStream;
2use rustc_errors::ErrorGuaranteed;
3use rustc_middle::ty::{self, TyCtxt};
4use rustc_parse::parser::{AllowConstBlockItems, ForceCollect, Parser};
5use rustc_session::Session;
6use rustc_session::config::ProcMacroExecutionStrategy;
7use rustc_span::profiling::SpannedEventArgRecorder;
8use rustc_span::{LocalExpnId, Span};
9use {rustc_ast as ast, rustc_proc_macro as pm};
10
11use crate::base::{self, *};
12use crate::{errors, proc_macro_server};
13
14fn exec_strategy(sess: &Session) -> impl pm::bridge::server::ExecutionStrategy + 'static {
15    pm::bridge::server::MaybeCrossThread {
16        cross_thread: sess.opts.unstable_opts.proc_macro_execution_strategy
17            == ProcMacroExecutionStrategy::CrossThread,
18    }
19}
20
21pub struct BangProcMacro {
22    pub client: pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>,
23}
24
25impl base::BangProcMacro for BangProcMacro {
26    fn expand(
27        &self,
28        ecx: &mut ExtCtxt<'_>,
29        span: Span,
30        input: TokenStream,
31    ) -> Result<TokenStream, ErrorGuaranteed> {
32        let _timer =
33            ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
34                recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
35            });
36
37        let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
38        let strategy = exec_strategy(ecx.sess);
39        let server = proc_macro_server::Rustc::new(ecx);
40        self.client.run(&strategy, server, input, proc_macro_backtrace).map_err(|e| {
41            ecx.dcx().emit_err(errors::ProcMacroPanicked {
42                span,
43                message: e
44                    .as_str()
45                    .map(|message| errors::ProcMacroPanickedHelp { message: message.into() }),
46            })
47        })
48    }
49}
50
51pub struct AttrProcMacro {
52    pub client: pm::bridge::client::Client<(pm::TokenStream, pm::TokenStream), pm::TokenStream>,
53}
54
55impl base::AttrProcMacro for AttrProcMacro {
56    fn expand(
57        &self,
58        ecx: &mut ExtCtxt<'_>,
59        span: Span,
60        annotation: TokenStream,
61        annotated: TokenStream,
62    ) -> Result<TokenStream, ErrorGuaranteed> {
63        let _timer =
64            ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
65                recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
66            });
67
68        let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
69        let strategy = exec_strategy(ecx.sess);
70        let server = proc_macro_server::Rustc::new(ecx);
71        self.client.run(&strategy, server, annotation, annotated, proc_macro_backtrace).map_err(
72            |e| {
73                ecx.dcx().emit_err(errors::CustomAttributePanicked {
74                    span,
75                    message: e.as_str().map(|message| errors::CustomAttributePanickedHelp {
76                        message: message.into(),
77                    }),
78                })
79            },
80        )
81    }
82}
83
84pub struct DeriveProcMacro {
85    pub client: DeriveClient,
86}
87
88impl MultiItemModifier for DeriveProcMacro {
89    fn expand(
90        &self,
91        ecx: &mut ExtCtxt<'_>,
92        span: Span,
93        _meta_item: &ast::MetaItem,
94        item: Annotatable,
95        _is_derive_const: bool,
96    ) -> ExpandResult<Vec<Annotatable>, Annotatable> {
97        let _timer = ecx.sess.prof.generic_activity_with_arg_recorder(
98            "expand_derive_proc_macro_outer",
99            |recorder| {
100                recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
101            },
102        );
103
104        // We need special handling for statement items
105        // (e.g. `fn foo() { #[derive(Debug)] struct Bar; }`)
106        let is_stmt = #[allow(non_exhaustive_omitted_patterns)] match item {
    Annotatable::Stmt(..) => true,
    _ => false,
}matches!(item, Annotatable::Stmt(..));
107
108        // We used to have an alternative behaviour for crates that needed it.
109        // We had a lint for a long time, but now we just emit a hard error.
110        // Eventually we might remove the special case hard error check
111        // altogether. See #73345.
112        crate::base::ann_pretty_printing_compatibility_hack(&item, &ecx.sess.psess);
113        let input = item.to_tokens();
114
115        let invoc_id = ecx.current_expansion.id;
116
117        let res = if ecx.sess.opts.incremental.is_some()
118            && ecx.sess.opts.unstable_opts.cache_proc_macros
119        {
120            ty::tls::with(|tcx| {
121                let input = &*tcx.arena.alloc(input);
122                let key: (LocalExpnId, &TokenStream) = (invoc_id, input);
123
124                QueryDeriveExpandCtx::enter(ecx, self.client, move || {
125                    tcx.derive_macro_expansion(key).cloned()
126                })
127            })
128        } else {
129            expand_derive_macro(invoc_id, input, ecx, self.client)
130        };
131
132        let Ok(output) = res else {
133            // error will already have been emitted
134            return ExpandResult::Ready(::alloc::vec::Vec::new()vec![]);
135        };
136
137        let error_count_before = ecx.dcx().err_count();
138        let mut parser = Parser::new(&ecx.sess.psess, output, Some("proc-macro derive"));
139        let mut items = ::alloc::vec::Vec::new()vec![];
140
141        loop {
142            match parser.parse_item(
143                ForceCollect::No,
144                if is_stmt { AllowConstBlockItems::No } else { AllowConstBlockItems::Yes },
145            ) {
146                Ok(None) => break,
147                Ok(Some(item)) => {
148                    if is_stmt {
149                        items.push(Annotatable::Stmt(Box::new(ecx.stmt_item(span, item))));
150                    } else {
151                        items.push(Annotatable::Item(item));
152                    }
153                }
154                Err(err) => {
155                    err.emit();
156                    break;
157                }
158            }
159        }
160
161        // fail if there have been errors emitted
162        if ecx.dcx().err_count() > error_count_before {
163            ecx.dcx().emit_err(errors::ProcMacroDeriveTokens { span });
164        }
165
166        ExpandResult::Ready(items)
167    }
168}
169
170/// Provide a query for computing the output of a derive macro.
171pub(super) fn provide_derive_macro_expansion<'tcx>(
172    tcx: TyCtxt<'tcx>,
173    key: (LocalExpnId, &'tcx TokenStream),
174) -> Result<&'tcx TokenStream, ()> {
175    let (invoc_id, input) = key;
176
177    // Make sure that we invalidate the query when the crate defining the proc macro changes
178    let _ = tcx.crate_hash(invoc_id.expn_data().macro_def_id.unwrap().krate);
179
180    QueryDeriveExpandCtx::with(|ecx, client| {
181        expand_derive_macro(invoc_id, input.clone(), ecx, client).map(|ts| &*tcx.arena.alloc(ts))
182    })
183}
184
185type DeriveClient = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;
186
187fn expand_derive_macro(
188    invoc_id: LocalExpnId,
189    input: TokenStream,
190    ecx: &mut ExtCtxt<'_>,
191    client: DeriveClient,
192) -> Result<TokenStream, ()> {
193    let _timer =
194        ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
195            let invoc_expn_data = invoc_id.expn_data();
196            let span = invoc_expn_data.call_site;
197            let event_arg = invoc_expn_data.kind.descr();
198            recorder.record_arg_with_span(ecx.sess.source_map(), event_arg.clone(), span);
199        });
200
201    let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
202    let strategy = exec_strategy(ecx.sess);
203    let server = proc_macro_server::Rustc::new(ecx);
204
205    match client.run(&strategy, server, input, proc_macro_backtrace) {
206        Ok(stream) => Ok(stream),
207        Err(e) => {
208            let invoc_expn_data = invoc_id.expn_data();
209            let span = invoc_expn_data.call_site;
210            ecx.dcx().emit_err({
211                errors::ProcMacroDerivePanicked {
212                    span,
213                    message: e.as_str().map(|message| errors::ProcMacroDerivePanickedHelp {
214                        message: message.into(),
215                    }),
216                }
217            });
218            Err(())
219        }
220    }
221}
222
223/// Stores the context necessary to expand a derive proc macro via a query.
224struct QueryDeriveExpandCtx {
225    /// Type-erased version of `&mut ExtCtxt`
226    expansion_ctx: *mut (),
227    client: DeriveClient,
228}
229
230impl QueryDeriveExpandCtx {
231    /// Store the extension context and the client into the thread local value.
232    /// It will be accessible via the `with` method while `f` is active.
233    fn enter<F, R>(ecx: &mut ExtCtxt<'_>, client: DeriveClient, f: F) -> R
234    where
235        F: FnOnce() -> R,
236    {
237        // We need erasure to get rid of the lifetime
238        let ctx = Self { expansion_ctx: ecx as *mut _ as *mut (), client };
239        DERIVE_EXPAND_CTX.set(&ctx, || f())
240    }
241
242    /// Accesses the thread local value of the derive expansion context.
243    /// Must be called while the `enter` function is active.
244    fn with<F, R>(f: F) -> R
245    where
246        F: for<'a, 'b> FnOnce(&'b mut ExtCtxt<'a>, DeriveClient) -> R,
247    {
248        DERIVE_EXPAND_CTX.with(|ctx| {
249            let ectx = {
250                let casted = ctx.expansion_ctx.cast::<ExtCtxt<'_>>();
251                // SAFETY: We can only get the value from `with` while the `enter` function
252                // is active (on the callstack), and that function's signature ensures that the
253                // lifetime is valid.
254                // If `with` is called at some other time, it will panic due to usage of
255                // `scoped_tls::with`.
256                unsafe { casted.as_mut().unwrap() }
257            };
258
259            f(ectx, ctx.client)
260        })
261    }
262}
263
264// When we invoke a query to expand a derive proc macro, we need to provide it with the expansion
265// context and derive Client. We do that using a thread-local.
266static DERIVE_EXPAND_CTX: ::scoped_tls::ScopedKey<QueryDeriveExpandCtx> =
    ::scoped_tls::ScopedKey {
        inner: {
            const FOO: ::std::thread::LocalKey<::std::cell::Cell<*const ()>> =
                {
                    const __RUST_STD_INTERNAL_INIT: ::std::cell::Cell<*const ()>
                        =
                        { ::std::cell::Cell::new(::std::ptr::null()) };
                    unsafe {
                        ::std::thread::LocalKey::new(const {
                                    if ::std::mem::needs_drop::<::std::cell::Cell<*const ()>>()
                                        {
                                        |_|
                                            {
                                                #[thread_local]
                                                static __RUST_STD_INTERNAL_VAL:
                                                    ::std::thread::local_impl::EagerStorage<::std::cell::Cell<*const ()>>
                                                    =
                                                    ::std::thread::local_impl::EagerStorage::new(__RUST_STD_INTERNAL_INIT);
                                                __RUST_STD_INTERNAL_VAL.get()
                                            }
                                    } else {
                                        |_|
                                            {
                                                #[thread_local]
                                                static __RUST_STD_INTERNAL_VAL: ::std::cell::Cell<*const ()>
                                                    =
                                                    __RUST_STD_INTERNAL_INIT;
                                                &__RUST_STD_INTERNAL_VAL
                                            }
                                    }
                                })
                    }
                };
            &FOO
        },
        _marker: ::std::marker::PhantomData,
    };scoped_tls::scoped_thread_local!(static DERIVE_EXPAND_CTX: QueryDeriveExpandCtx);