Skip to main content

proc_macro/bridge/
server.rs

1//! Server-side traits.
2
3use std::cell::Cell;
4use std::marker::PhantomData;
5
6use super::*;
7
8pub(super) struct HandleStore<S: Server> {
9    token_stream: handle::OwnedStore<MarkedTokenStream<S>>,
10    span: handle::InternedStore<MarkedSpan<S>>,
11}
12
13impl<S: Server> HandleStore<S> {
14    fn new(handle_counters: &'static client::HandleCounters) -> Self {
15        HandleStore {
16            token_stream: handle::OwnedStore::new(&handle_counters.token_stream),
17            span: handle::InternedStore::new(&handle_counters.span),
18        }
19    }
20}
21
22pub(super) type MarkedTokenStream<S> = Marked<<S as Server>::TokenStream, client::TokenStream>;
23pub(super) type MarkedSpan<S> = Marked<<S as Server>::Span, client::Span>;
24pub(super) type MarkedSymbol<S> = Marked<<S as Server>::Symbol, client::Symbol>;
25
26impl<S: Server> Encode<HandleStore<S>> for MarkedTokenStream<S> {
27    fn encode(self, w: &mut Buffer, s: &mut HandleStore<S>) {
28        s.token_stream.alloc(self).encode(w, s);
29    }
30}
31
32impl<S: Server> Decode<'_, '_, HandleStore<S>> for MarkedTokenStream<S> {
33    fn decode(r: &mut &[u8], s: &mut HandleStore<S>) -> Self {
34        s.token_stream.take(handle::Handle::decode(r, &mut ()))
35    }
36}
37
38impl<'s, S: Server> Decode<'_, 's, HandleStore<S>> for &'s MarkedTokenStream<S> {
39    fn decode(r: &mut &[u8], s: &'s mut HandleStore<S>) -> Self {
40        &s.token_stream[handle::Handle::decode(r, &mut ())]
41    }
42}
43
44impl<S: Server> Encode<HandleStore<S>> for MarkedSpan<S> {
45    fn encode(self, w: &mut Buffer, s: &mut HandleStore<S>) {
46        s.span.alloc(self).encode(w, s);
47    }
48}
49
50impl<S: Server> Decode<'_, '_, HandleStore<S>> for MarkedSpan<S> {
51    fn decode(r: &mut &[u8], s: &mut HandleStore<S>) -> Self {
52        s.span.copy(handle::Handle::decode(r, &mut ()))
53    }
54}
55
56struct Dispatcher<S: Server> {
57    handle_store: HandleStore<S>,
58    server: S,
59}
60
61macro_rules! define_server {
62    (
63        $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)*;)*
64    ) => {
65        pub trait Server {
66            type TokenStream: 'static + Clone + Default;
67            type Span: 'static + Copy + Eq + Hash;
68            type Symbol: 'static;
69
70            fn globals(&mut self) -> ExpnGlobals<Self::Span>;
71
72            /// Intern a symbol received from RPC
73            fn intern_symbol(ident: &str) -> Self::Symbol;
74
75            /// Recover the string value of a symbol, and invoke a callback with it.
76            fn with_symbol_string(symbol: &Self::Symbol, f: impl FnOnce(&str));
77
78            $(fn $method(&mut self, $($arg: $arg_ty),*) $(-> $ret_ty)?;)*
79        }
80    }
81}
82with_api!(define_server, Self::TokenStream, Self::Span, Self::Symbol);
83
84macro_rules! define_dispatcher {
85    (
86        $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)*;)*
87    ) => {
88        // FIXME(eddyb) `pub` only for `ExecutionStrategy` below.
89        pub trait DispatcherTrait {
90            fn dispatch(&mut self, buf: Buffer) -> Buffer;
91        }
92
93        impl<S: Server> DispatcherTrait for Dispatcher<S> {
94            fn dispatch(&mut self, mut buf: Buffer) -> Buffer {
95                let Dispatcher { handle_store, server } = self;
96
97                let mut reader = &buf[..];
98                match ApiTags::decode(&mut reader, &mut ()) {
99                    $(ApiTags::$method => {
100                        let mut call_method = || {
101                            $(let $arg = <$arg_ty>::decode(&mut reader, handle_store).unmark();)*
102                            let r = server.$method($($arg),*);
103                            $(
104                                let r: $ret_ty = Mark::mark(r);
105                            )*
106                            r
107                        };
108                        // HACK(eddyb) don't use `panic::catch_unwind` in a panic.
109                        // If client and server happen to use the same `std`,
110                        // `catch_unwind` asserts that the panic counter was 0,
111                        // even when the closure passed to it didn't panic.
112                        let r = if thread::panicking() {
113                            Ok(call_method())
114                        } else {
115                            panic::catch_unwind(panic::AssertUnwindSafe(call_method))
116                                .map_err(PanicMessage::from)
117                        };
118
119                        buf.clear();
120                        r.encode(&mut buf, handle_store);
121                    })*
122                }
123                buf
124            }
125        }
126    }
127}
128with_api!(define_dispatcher, MarkedTokenStream<S>, MarkedSpan<S>, MarkedSymbol<S>);
129
130pub trait ExecutionStrategy {
131    fn run_bridge_and_client(
132        &self,
133        dispatcher: &mut impl DispatcherTrait,
134        input: Buffer,
135        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
136        force_show_panics: bool,
137    ) -> Buffer;
138}
139
140thread_local! {
141    /// While running a proc-macro with the same-thread executor, this flag will
142    /// be set, forcing nested proc-macro invocations (e.g. due to
143    /// `TokenStream::expand_expr`) to be run using a cross-thread executor.
144    ///
145    /// This is required as the thread-local state in the proc_macro client does
146    /// not handle being re-entered, and will invalidate all `Symbol`s when
147    /// entering a nested macro.
148    static ALREADY_RUNNING_SAME_THREAD: Cell<bool> = const { Cell::new(false) };
149}
150
151/// Keep `ALREADY_RUNNING_SAME_THREAD` (see also its documentation)
152/// set to `true`, preventing same-thread reentrance.
153struct RunningSameThreadGuard(());
154
155impl RunningSameThreadGuard {
156    fn new() -> Self {
157        let already_running = ALREADY_RUNNING_SAME_THREAD.replace(true);
158        assert!(
159            !already_running,
160            "same-thread nesting (\"reentrance\") of proc macro executions is not supported"
161        );
162        RunningSameThreadGuard(())
163    }
164}
165
166impl Drop for RunningSameThreadGuard {
167    fn drop(&mut self) {
168        ALREADY_RUNNING_SAME_THREAD.set(false);
169    }
170}
171
172pub struct MaybeCrossThread<P> {
173    cross_thread: bool,
174    marker: PhantomData<P>,
175}
176
177impl<P> MaybeCrossThread<P> {
178    pub const fn new(cross_thread: bool) -> Self {
179        MaybeCrossThread { cross_thread, marker: PhantomData }
180    }
181}
182
183impl<P> ExecutionStrategy for MaybeCrossThread<P>
184where
185    P: MessagePipe<Buffer> + Send + 'static,
186{
187    fn run_bridge_and_client(
188        &self,
189        dispatcher: &mut impl DispatcherTrait,
190        input: Buffer,
191        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
192        force_show_panics: bool,
193    ) -> Buffer {
194        if self.cross_thread || ALREADY_RUNNING_SAME_THREAD.get() {
195            <CrossThread<P>>::new().run_bridge_and_client(
196                dispatcher,
197                input,
198                run_client,
199                force_show_panics,
200            )
201        } else {
202            SameThread.run_bridge_and_client(dispatcher, input, run_client, force_show_panics)
203        }
204    }
205}
206
207pub struct SameThread;
208
209impl ExecutionStrategy for SameThread {
210    fn run_bridge_and_client(
211        &self,
212        dispatcher: &mut impl DispatcherTrait,
213        input: Buffer,
214        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
215        force_show_panics: bool,
216    ) -> Buffer {
217        let _guard = RunningSameThreadGuard::new();
218
219        let mut dispatch = |buf| dispatcher.dispatch(buf);
220
221        run_client(BridgeConfig { input, dispatch: (&mut dispatch).into(), force_show_panics })
222    }
223}
224
225pub struct CrossThread<P>(PhantomData<P>);
226
227impl<P> CrossThread<P> {
228    pub const fn new() -> Self {
229        CrossThread(PhantomData)
230    }
231}
232
233impl<P> ExecutionStrategy for CrossThread<P>
234where
235    P: MessagePipe<Buffer> + Send + 'static,
236{
237    fn run_bridge_and_client(
238        &self,
239        dispatcher: &mut impl DispatcherTrait,
240        input: Buffer,
241        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
242        force_show_panics: bool,
243    ) -> Buffer {
244        let (mut server, mut client) = P::new();
245
246        let join_handle = thread::spawn(move || {
247            let mut dispatch = |b: Buffer| -> Buffer {
248                client.send(b);
249                client.recv().expect("server died while client waiting for reply")
250            };
251
252            run_client(BridgeConfig { input, dispatch: (&mut dispatch).into(), force_show_panics })
253        });
254
255        while let Some(b) = server.recv() {
256            server.send(dispatcher.dispatch(b));
257        }
258
259        join_handle.join().unwrap()
260    }
261}
262
263/// A message pipe used for communicating between server and client threads.
264pub trait MessagePipe<T>: Sized {
265    /// Creates a new pair of endpoints for the message pipe.
266    fn new() -> (Self, Self);
267
268    /// Send a message to the other endpoint of this pipe.
269    fn send(&mut self, value: T);
270
271    /// Receive a message from the other endpoint of this pipe.
272    ///
273    /// Returns `None` if the other end of the pipe has been destroyed, and no
274    /// message was received.
275    fn recv(&mut self) -> Option<T>;
276}
277
278fn run_server<
279    S: Server,
280    I: Encode<HandleStore<S>>,
281    O: for<'a, 's> Decode<'a, 's, HandleStore<S>>,
282>(
283    strategy: &impl ExecutionStrategy,
284    handle_counters: &'static client::HandleCounters,
285    server: S,
286    input: I,
287    run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
288    force_show_panics: bool,
289) -> Result<O, PanicMessage> {
290    let mut dispatcher = Dispatcher { handle_store: HandleStore::new(handle_counters), server };
291
292    let globals = dispatcher.server.globals();
293
294    let mut buf = Buffer::new();
295    (<ExpnGlobals<MarkedSpan<S>> as Mark>::mark(globals), input)
296        .encode(&mut buf, &mut dispatcher.handle_store);
297
298    buf = strategy.run_bridge_and_client(&mut dispatcher, buf, run_client, force_show_panics);
299
300    Result::decode(&mut &buf[..], &mut dispatcher.handle_store)
301}
302
303impl client::Client<crate::TokenStream, crate::TokenStream> {
304    pub fn run<S>(
305        &self,
306        strategy: &impl ExecutionStrategy,
307        server: S,
308        input: S::TokenStream,
309        force_show_panics: bool,
310    ) -> Result<S::TokenStream, PanicMessage>
311    where
312        S: Server,
313    {
314        let client::Client { handle_counters, run, _marker } = *self;
315        run_server(
316            strategy,
317            handle_counters,
318            server,
319            <MarkedTokenStream<S>>::mark(input),
320            run,
321            force_show_panics,
322        )
323        .map(|s| <Option<MarkedTokenStream<S>>>::unmark(s).unwrap_or_default())
324    }
325}
326
327impl client::Client<(crate::TokenStream, crate::TokenStream), crate::TokenStream> {
328    pub fn run<S>(
329        &self,
330        strategy: &impl ExecutionStrategy,
331        server: S,
332        input: S::TokenStream,
333        input2: S::TokenStream,
334        force_show_panics: bool,
335    ) -> Result<S::TokenStream, PanicMessage>
336    where
337        S: Server,
338    {
339        let client::Client { handle_counters, run, _marker } = *self;
340        run_server(
341            strategy,
342            handle_counters,
343            server,
344            (<MarkedTokenStream<S>>::mark(input), <MarkedTokenStream<S>>::mark(input2)),
345            run,
346            force_show_panics,
347        )
348        .map(|s| <Option<MarkedTokenStream<S>>>::unmark(s).unwrap_or_default())
349    }
350}