1use rustc_ast as ast;
2use rustc_ast::tokenstream::TokenStream;
3use rustc_errors::ErrorGuaranteed;
4use rustc_middle::ty::{self, TyCtxt};
5use rustc_parse::parser::{AllowConstBlockItems, ForceCollect, Parser};
6use rustc_proc_macro as pm;
7use rustc_session::Session;
8use rustc_session::config::ProcMacroExecutionStrategy;
9use rustc_span::profiling::SpannedEventArgRecorder;
10use rustc_span::{LocalExpnId, Span};
11
12use crate::base::{self, *};
13use crate::{errors, proc_macro_server};
14
15fn exec_strategy(sess: &Session) -> impl pm::bridge::server::ExecutionStrategy + 'static {
16 pm::bridge::server::MaybeCrossThread {
17 cross_thread: sess.opts.unstable_opts.proc_macro_execution_strategy
18 == ProcMacroExecutionStrategy::CrossThread,
19 }
20}
21
22pub struct BangProcMacro {
23 pub client: pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>,
24}
25
26impl base::BangProcMacro for BangProcMacro {
27 fn expand(
28 &self,
29 ecx: &mut ExtCtxt<'_>,
30 span: Span,
31 input: TokenStream,
32 ) -> Result<TokenStream, ErrorGuaranteed> {
33 let _timer =
34 ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
35 recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
36 });
37
38 let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
39 let strategy = exec_strategy(ecx.sess);
40 let server = proc_macro_server::Rustc::new(ecx);
41 self.client.run(&strategy, server, input, proc_macro_backtrace).map_err(|e| {
42 ecx.dcx().emit_err(errors::ProcMacroPanicked {
43 span,
44 message: e
45 .as_str()
46 .map(|message| errors::ProcMacroPanickedHelp { message: message.into() }),
47 })
48 })
49 }
50}
51
52pub struct AttrProcMacro {
53 pub client: pm::bridge::client::Client<(pm::TokenStream, pm::TokenStream), pm::TokenStream>,
54}
55
56impl base::AttrProcMacro for AttrProcMacro {
57 fn expand(
58 &self,
59 ecx: &mut ExtCtxt<'_>,
60 span: Span,
61 annotation: TokenStream,
62 annotated: TokenStream,
63 ) -> Result<TokenStream, ErrorGuaranteed> {
64 let _timer =
65 ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
66 recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
67 });
68
69 let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
70 let strategy = exec_strategy(ecx.sess);
71 let server = proc_macro_server::Rustc::new(ecx);
72 self.client.run(&strategy, server, annotation, annotated, proc_macro_backtrace).map_err(
73 |e| {
74 ecx.dcx().emit_err(errors::CustomAttributePanicked {
75 span,
76 message: e.as_str().map(|message| errors::CustomAttributePanickedHelp {
77 message: message.into(),
78 }),
79 })
80 },
81 )
82 }
83}
84
85pub struct DeriveProcMacro {
86 pub client: DeriveClient,
87}
88
89impl MultiItemModifier for DeriveProcMacro {
90 fn expand(
91 &self,
92 ecx: &mut ExtCtxt<'_>,
93 span: Span,
94 _meta_item: &ast::MetaItem,
95 item: Annotatable,
96 _is_derive_const: bool,
97 ) -> ExpandResult<Vec<Annotatable>, Annotatable> {
98 let _timer = ecx.sess.prof.generic_activity_with_arg_recorder(
99 "expand_derive_proc_macro_outer",
100 |recorder| {
101 recorder.record_arg_with_span(ecx.sess.source_map(), ecx.expansion_descr(), span);
102 },
103 );
104
105 let is_stmt = #[allow(non_exhaustive_omitted_patterns)] match item {
Annotatable::Stmt(..) => true,
_ => false,
}matches!(item, Annotatable::Stmt(..));
108
109 let input = item.to_tokens();
110
111 let invoc_id = ecx.current_expansion.id;
112
113 let res = if ecx.sess.opts.incremental.is_some()
114 && ecx.sess.opts.unstable_opts.cache_proc_macros
115 {
116 ty::tls::with(|tcx| {
117 let input = &*tcx.arena.alloc(input);
118 let key: (LocalExpnId, &TokenStream) = (invoc_id, input);
119
120 QueryDeriveExpandCtx::enter(ecx, self.client, move || {
121 tcx.derive_macro_expansion(key).cloned()
122 })
123 })
124 } else {
125 expand_derive_macro(invoc_id, input, ecx, self.client)
126 };
127
128 let Ok(output) = res else {
129 return ExpandResult::Ready(::alloc::vec::Vec::new()vec![]);
131 };
132
133 let error_count_before = ecx.dcx().err_count();
134 let mut parser = Parser::new(&ecx.sess.psess, output, Some("proc-macro derive"));
135 let mut items = ::alloc::vec::Vec::new()vec![];
136
137 loop {
138 match parser.parse_item(
139 ForceCollect::No,
140 if is_stmt { AllowConstBlockItems::No } else { AllowConstBlockItems::Yes },
141 ) {
142 Ok(None) => break,
143 Ok(Some(item)) => {
144 if is_stmt {
145 items.push(Annotatable::Stmt(Box::new(ecx.stmt_item(span, item))));
146 } else {
147 items.push(Annotatable::Item(item));
148 }
149 }
150 Err(err) => {
151 err.emit();
152 break;
153 }
154 }
155 }
156
157 if ecx.dcx().err_count() > error_count_before {
159 ecx.dcx().emit_err(errors::ProcMacroDeriveTokens { span });
160 }
161
162 ExpandResult::Ready(items)
163 }
164}
165
166pub(super) fn provide_derive_macro_expansion<'tcx>(
168 tcx: TyCtxt<'tcx>,
169 key: (LocalExpnId, &'tcx TokenStream),
170) -> Result<&'tcx TokenStream, ()> {
171 let (invoc_id, input) = key;
172
173 let _ = tcx.crate_hash(invoc_id.expn_data().macro_def_id.unwrap().krate);
175
176 QueryDeriveExpandCtx::with(|ecx, client| {
177 expand_derive_macro(invoc_id, input.clone(), ecx, client).map(|ts| &*tcx.arena.alloc(ts))
178 })
179}
180
181type DeriveClient = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>;
182
183fn expand_derive_macro(
184 invoc_id: LocalExpnId,
185 input: TokenStream,
186 ecx: &mut ExtCtxt<'_>,
187 client: DeriveClient,
188) -> Result<TokenStream, ()> {
189 let _timer =
190 ecx.sess.prof.generic_activity_with_arg_recorder("expand_proc_macro", |recorder| {
191 let invoc_expn_data = invoc_id.expn_data();
192 let span = invoc_expn_data.call_site;
193 let event_arg = invoc_expn_data.kind.descr();
194 recorder.record_arg_with_span(ecx.sess.source_map(), event_arg.clone(), span);
195 });
196
197 let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace;
198 let strategy = exec_strategy(ecx.sess);
199 let server = proc_macro_server::Rustc::new(ecx);
200
201 match client.run(&strategy, server, input, proc_macro_backtrace) {
202 Ok(stream) => Ok(stream),
203 Err(e) => {
204 let invoc_expn_data = invoc_id.expn_data();
205 let span = invoc_expn_data.call_site;
206 ecx.dcx().emit_err({
207 errors::ProcMacroDerivePanicked {
208 span,
209 message: e.as_str().map(|message| errors::ProcMacroDerivePanickedHelp {
210 message: message.into(),
211 }),
212 }
213 });
214 Err(())
215 }
216 }
217}
218
219struct QueryDeriveExpandCtx {
221 expansion_ctx: *mut (),
223 client: DeriveClient,
224}
225
226impl QueryDeriveExpandCtx {
227 fn enter<F, R>(ecx: &mut ExtCtxt<'_>, client: DeriveClient, f: F) -> R
230 where
231 F: FnOnce() -> R,
232 {
233 let ctx = Self { expansion_ctx: ecx as *mut _ as *mut (), client };
235 DERIVE_EXPAND_CTX.set(&ctx, || f())
236 }
237
238 fn with<F, R>(f: F) -> R
241 where
242 F: for<'a, 'b> FnOnce(&'b mut ExtCtxt<'a>, DeriveClient) -> R,
243 {
244 DERIVE_EXPAND_CTX.with(|ctx| {
245 let ectx = {
246 let casted = ctx.expansion_ctx.cast::<ExtCtxt<'_>>();
247 unsafe { casted.as_mut().unwrap() }
253 };
254
255 f(ectx, ctx.client)
256 })
257 }
258}
259
260static DERIVE_EXPAND_CTX: ::scoped_tls::ScopedKey<QueryDeriveExpandCtx> =
::scoped_tls::ScopedKey {
inner: {
const FOO: ::std::thread::LocalKey<::std::cell::Cell<*const ()>> =
{
const __RUST_STD_INTERNAL_INIT: ::std::cell::Cell<*const ()>
=
{ ::std::cell::Cell::new(::std::ptr::null()) };
unsafe {
::std::thread::LocalKey::new(const {
if ::std::mem::needs_drop::<::std::cell::Cell<*const ()>>()
{
|_|
{
#[thread_local]
static __RUST_STD_INTERNAL_VAL:
::std::thread::local_impl::EagerStorage<::std::cell::Cell<*const ()>>
=
::std::thread::local_impl::EagerStorage::new(__RUST_STD_INTERNAL_INIT);
__RUST_STD_INTERNAL_VAL.get()
}
} else {
|_|
{
#[thread_local]
static __RUST_STD_INTERNAL_VAL: ::std::cell::Cell<*const ()>
=
__RUST_STD_INTERNAL_INIT;
&__RUST_STD_INTERNAL_VAL
}
}
})
}
};
&FOO
},
_marker: ::std::marker::PhantomData,
};scoped_tls::scoped_thread_local!(static DERIVE_EXPAND_CTX: QueryDeriveExpandCtx);