Skip to main content

rustc_expand/
proc_macro.rs

1use rustc_ast as ast;
2use rustc_ast::tokenstream::TokenStream;
3use rustc_data_structures::profiling::TimingGuard;
4use rustc_errors::ErrorGuaranteed;
5use rustc_middle::ty::{self, TyCtxt};
6use rustc_parse::parser::{AllowConstBlockItems, ForceCollect, Parser};
7use rustc_proc_macro as pm;
8use rustc_session::Session;
9use rustc_session::config::ProcMacroExecutionStrategy;
10use rustc_span::profiling::SpannedEventArgRecorder;
11use rustc_span::{LocalExpnId, Span};
12
13use crate::base::{self, *};
14use crate::{errors, proc_macro_server};
15
16fn exec_strategy(sess: &Session) -> impl pm::bridge::server::ExecutionStrategy + 'static {
17    pm::bridge::server::MaybeCrossThread {
18        cross_thread: sess.opts.unstable_opts.proc_macro_execution_strategy
19            == ProcMacroExecutionStrategy::CrossThread,
20    }
21}
22
23fn record_expand_proc_macro<'a>(
24    ecx: &ExtCtxt<'a>,
25    name: &'static str,
26    span: Span,
27) -> TimingGuard<'a> {
28    ecx.sess.prof.generic_activity_with_arg_recorder(name, |recorder| {
29        recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
30    })
31}
32
33pub struct BangProcMacro {
34    pub client: pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>,
35}
36
37impl base::BangProcMacro for BangProcMacro {
38    fn expand(
39        &self,
40        ecx: &mut ExtCtxt<'_>,
41        span: Span,
42        input: TokenStream,
43    ) -> Result<TokenStream, ErrorGuaranteed> {
44        let _timer = record_expand_proc_macro(ecx, "expand_proc_macro", span);
45
46        let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
47        let strategy = exec_strategy(ecx.sess);
48        let server = proc_macro_server::Rustc::new(ecx);
49        self.client.run(&strategy, server, input, proc_macro_backtrace).map_err(|e| {
50            ecx.dcx().emit_err(errors::ProcMacroPanicked {
51                span,
52                message: e.into_string().map(|message| errors::ProcMacroPanickedHelp { message }),
53            })
54        })
55    }
56}
57
58pub struct AttrProcMacro {
59    pub client: pm::bridge::client::Client<(pm::TokenStream, pm::TokenStream), pm::TokenStream>,
60}
61
62impl base::AttrProcMacro for AttrProcMacro {
63    fn expand(
64        &self,
65        ecx: &mut ExtCtxt<'_>,
66        span: Span,
67        annotation: TokenStream,
68        annotated: TokenStream,
69    ) -> Result<TokenStream, ErrorGuaranteed> {
70        let _timer = record_expand_proc_macro(ecx, "expand_proc_macro", span);
71
72        let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
73        let strategy = exec_strategy(ecx.sess);
74        let server = proc_macro_server::Rustc::new(ecx);
75        self.client.run(&strategy, server, annotation, annotated, proc_macro_backtrace).map_err(
76            |e| {
77                ecx.dcx().emit_err(errors::CustomAttributePanicked {
78                    span,
79                    message: e
80                        .into_string()
81                        .map(|message| errors::CustomAttributePanickedHelp { message }),
82                })
83            },
84        )
85    }
86}
87
88pub struct DeriveProcMacro {
89    pub client: DeriveClient,
90}
91
92impl MultiItemModifier for DeriveProcMacro {
93    fn expand(
94        &self,
95        ecx: &mut ExtCtxt<'_>,
96        span: Span,
97        _meta_item: &ast::MetaItem,
98        item: Annotatable,
99        _is_derive_const: bool,
100    ) -> ExpandResult<Vec<Annotatable>, Annotatable> {
101        let _timer = record_expand_proc_macro(ecx, "expand_derive_proc_macro_outer", span);
102
103        // We need special handling for statement items
104        // (e.g. `fn foo() { #[derive(Debug)] struct Bar; }`)
105        let is_stmt = #[allow(non_exhaustive_omitted_patterns)] match item {
    Annotatable::Stmt(..) => true,
    _ => false,
}matches!(item, Annotatable::Stmt(..));
106
107        let input = item.to_tokens();
108
109        let invoc_id = ecx.current_expansion.id;
110
111        let res = if ecx.sess.opts.incremental.is_some()
112            && ecx.sess.opts.unstable_opts.cache_proc_macros
113        {
114            ty::tls::with(|tcx| {
115                let input = &*tcx.arena.alloc(input);
116                let key: (LocalExpnId, &TokenStream) = (invoc_id, input);
117
118                QueryDeriveExpandCtx::enter(ecx, self.client, move || {
119                    tcx.derive_macro_expansion(key).cloned()
120                })
121            })
122        } else {
123            expand_derive_macro(invoc_id, input, ecx, self.client)
124        };
125
126        let Ok(output) = res else {
127            // error will already have been emitted
128            return ExpandResult::Ready(::alloc::vec::Vec::new()vec![]);
129        };
130
131        let error_count_before = ecx.dcx().err_count();
132        let mut parser = Parser::new(&ecx.sess.psess, output, Some("proc-macro derive"));
133        let mut items = ::alloc::vec::Vec::new()vec![];
134
135        loop {
136            match parser.parse_item(
137                ForceCollect::No,
138                if is_stmt { AllowConstBlockItems::No } else { AllowConstBlockItems::Yes },
139            ) {
140                Ok(None) => break,
141                Ok(Some(item)) => {
142                    if is_stmt {
143                        items.push(Annotatable::Stmt(Box::new(ecx.stmt_item(span, item))));
144                    } else {
145                        items.push(Annotatable::Item(item));
146                    }
147                }
148                Err(err) => {
149                    err.emit();
150                    break;
151                }
152            }
153        }
154
155        // fail if there have been errors emitted
156        if ecx.dcx().err_count() > error_count_before {
157            ecx.dcx().emit_err(errors::ProcMacroDeriveTokens { span });
158        }
159
160        ExpandResult::Ready(items)
161    }
162}
163
164/// Provide a query for computing the output of a derive macro.
165pub(super) fn provide_derive_macro_expansion<'tcx>(
166    tcx: TyCtxt<'tcx>,
167    key: (LocalExpnId, &'tcx TokenStream),
168) -> Result<&'tcx TokenStream, ()> {
169    let (invoc_id, input) = key;
170
171    // Make sure that we invalidate the query when the crate defining the proc macro changes
172    let _ = tcx.crate_hash(invoc_id.expn_data().macro_def_id.unwrap().krate);
173
174    QueryDeriveExpandCtx::with(|ecx, client| {
175        expand_derive_macro(invoc_id, input.clone(), ecx, client).map(|ts| &*tcx.arena.alloc(ts))
176    })
177}
178
179type DeriveClient = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;
180
181fn expand_derive_macro(
182    invoc_id: LocalExpnId,
183    input: TokenStream,
184    ecx: &mut ExtCtxt<'_>,
185    client: DeriveClient,
186) -> Result<TokenStream, ()> {
187    let _timer =
188        ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
189            let invoc_expn_data = invoc_id.expn_data();
190            let span = invoc_expn_data.call_site;
191            let event_arg = invoc_expn_data.kind.descr();
192            recorder.record_arg_with_span(ecx.sess.source_map(), event_arg, span);
193        });
194
195    let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
196    let strategy = exec_strategy(ecx.sess);
197    let server = proc_macro_server::Rustc::new(ecx);
198
199    match client.run(&strategy, server, input, proc_macro_backtrace) {
200        Ok(stream) => Ok(stream),
201        Err(e) => {
202            let invoc_expn_data = invoc_id.expn_data();
203            let span = invoc_expn_data.call_site;
204            ecx.dcx().emit_err({
205                errors::ProcMacroDerivePanicked {
206                    span,
207                    message: e
208                        .into_string()
209                        .map(|message| errors::ProcMacroDerivePanickedHelp { message }),
210                }
211            });
212            Err(())
213        }
214    }
215}
216
217/// Stores the context necessary to expand a derive proc macro via a query.
218struct QueryDeriveExpandCtx {
219    /// Type-erased version of `&mut ExtCtxt`
220    expansion_ctx: *mut (),
221    client: DeriveClient,
222}
223
224impl QueryDeriveExpandCtx {
225    /// Store the extension context and the client into the thread local value.
226    /// It will be accessible via the `with` method while `f` is active.
227    fn enter<F, R>(ecx: &mut ExtCtxt<'_>, client: DeriveClient, f: F) -> R
228    where
229        F: FnOnce() -> R,
230    {
231        // We need erasure to get rid of the lifetime
232        let ctx = Self { expansion_ctx: ecx as *mut _ as *mut (), client };
233        DERIVE_EXPAND_CTX.set(&ctx, f)
234    }
235
236    /// Accesses the thread local value of the derive expansion context.
237    /// Must be called while the `enter` function is active.
238    fn with<F, R>(f: F) -> R
239    where
240        F: for<'a, 'b> FnOnce(&'b mut ExtCtxt<'a>, DeriveClient) -> R,
241    {
242        DERIVE_EXPAND_CTX.with(|ctx| {
243            let ectx = {
244                let casted = ctx.expansion_ctx.cast::<ExtCtxt<'_>>();
245                // SAFETY: We can only get the value from `with` while the `enter` function
246                // is active (on the callstack), and that function's signature ensures that the
247                // lifetime is valid.
248                // If `with` is called at some other time, it will panic due to usage of
249                // `scoped_tls::with`.
250                unsafe { casted.as_mut().unwrap() }
251            };
252
253            f(ectx, ctx.client)
254        })
255    }
256}
257
258// When we invoke a query to expand a derive proc macro, we need to provide it with the expansion
259// context and derive Client. We do that using a thread-local.
260static DERIVE_EXPAND_CTX: ::scoped_tls::ScopedKey<QueryDeriveExpandCtx> =
    ::scoped_tls::ScopedKey {
        inner: {
            const FOO: ::std::thread::LocalKey<::std::cell::Cell<*const ()>> =
                {
                    const __RUST_STD_INTERNAL_INIT: ::std::cell::Cell<*const ()>
                        =
                        { ::std::cell::Cell::new(::std::ptr::null()) };
                    unsafe {
                        ::std::thread::LocalKey::new(const {
                                    if ::std::mem::needs_drop::<::std::cell::Cell<*const ()>>()
                                        {
                                        |_|
                                            {
                                                #[thread_local]
                                                static __RUST_STD_INTERNAL_VAL:
                                                    ::std::thread::local_impl::EagerStorage<::std::cell::Cell<*const ()>>
                                                    =
                                                    ::std::thread::local_impl::EagerStorage::new(__RUST_STD_INTERNAL_INIT);
                                                __RUST_STD_INTERNAL_VAL.get()
                                            }
                                    } else {
                                        |_|
                                            {
                                                #[thread_local]
                                                static __RUST_STD_INTERNAL_VAL: ::std::cell::Cell<*const ()>
                                                    =
                                                    __RUST_STD_INTERNAL_INIT;
                                                &__RUST_STD_INTERNAL_VAL
                                            }
                                    }
                                })
                    }
                };
            &FOO
        },
        _marker: ::std::marker::PhantomData,
    };scoped_tls::scoped_thread_local!(static DERIVE_EXPAND_CTX: QueryDeriveExpandCtx);