Skip to main content

proc_macro/bridge/
server.rs

1//! Server-side traits.
2
3use std::cell::Cell;
4use std::marker::PhantomData;
5
6use super::*;
7
8pub(super) struct HandleStore<S: Server> {
9    token_stream: handle::OwnedStore<MarkedTokenStream<S>>,
10    span: handle::InternedStore<MarkedSpan<S>>,
11}
12
13impl<S: Server> HandleStore<S> {
14    fn new(handle_counters: &'static client::HandleCounters) -> Self {
15        HandleStore {
16            token_stream: handle::OwnedStore::new(&handle_counters.token_stream),
17            span: handle::InternedStore::new(&handle_counters.span),
18        }
19    }
20}
21
22pub(super) type MarkedTokenStream<S> = Marked<<S as Server>::TokenStream, client::TokenStream>;
23pub(super) type MarkedSpan<S> = Marked<<S as Server>::Span, client::Span>;
24pub(super) type MarkedSymbol<S> = Marked<<S as Server>::Symbol, client::Symbol>;
25
26impl<S: Server> Encode<HandleStore<S>> for MarkedTokenStream<S> {
27    fn encode(self, w: &mut Buffer, s: &mut HandleStore<S>) {
28        s.token_stream.alloc(self).encode(w, s);
29    }
30}
31
32impl<S: Server> Decode<'_, '_, HandleStore<S>> for MarkedTokenStream<S> {
33    fn decode(r: &mut &[u8], s: &mut HandleStore<S>) -> Self {
34        s.token_stream.take(handle::Handle::decode(r, &mut ()))
35    }
36}
37
38impl<'s, S: Server> Decode<'_, 's, HandleStore<S>> for &'s MarkedTokenStream<S> {
39    fn decode(r: &mut &[u8], s: &'s mut HandleStore<S>) -> Self {
40        &s.token_stream[handle::Handle::decode(r, &mut ())]
41    }
42}
43
44impl<S: Server> Encode<HandleStore<S>> for MarkedSpan<S> {
45    fn encode(self, w: &mut Buffer, s: &mut HandleStore<S>) {
46        s.span.alloc(self).encode(w, s);
47    }
48}
49
50impl<S: Server> Decode<'_, '_, HandleStore<S>> for MarkedSpan<S> {
51    fn decode(r: &mut &[u8], s: &mut HandleStore<S>) -> Self {
52        s.span.copy(handle::Handle::decode(r, &mut ()))
53    }
54}
55
56struct Dispatcher<S: Server> {
57    handle_store: HandleStore<S>,
58    server: S,
59}
60
61macro_rules! define_server_dispatcher_impl {
62    (
63        $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)*;)*
64    ) => {
65        pub trait Server {
66            type TokenStream: 'static + Clone;
67            type Span: 'static + Copy + Eq + Hash;
68            type Symbol: 'static;
69
70            fn globals(&mut self) -> ExpnGlobals<Self::Span>;
71
72            /// Intern a symbol received from RPC
73            fn intern_symbol(ident: &str) -> Self::Symbol;
74
75            /// Recover the string value of a symbol, and invoke a callback with it.
76            fn with_symbol_string(symbol: &Self::Symbol, f: impl FnOnce(&str));
77
78            $(fn $method(&mut self, $($arg: $arg_ty),*) $(-> $ret_ty)?;)*
79        }
80
81        // FIXME(eddyb) `pub` only for `ExecutionStrategy` below.
82        pub trait DispatcherTrait {
83            // HACK(eddyb) these are here to allow `Self::$name` to work below.
84            type TokenStream;
85            type Span;
86            type Symbol;
87
88            fn dispatch(&mut self, buf: Buffer) -> Buffer;
89        }
90
91        impl<S: Server> DispatcherTrait for Dispatcher<S> {
92            type TokenStream = MarkedTokenStream<S>;
93            type Span = MarkedSpan<S>;
94            type Symbol = MarkedSymbol<S>;
95
96            fn dispatch(&mut self, mut buf: Buffer) -> Buffer {
97                let Dispatcher { handle_store, server } = self;
98
99                let mut reader = &buf[..];
100                match ApiTags::decode(&mut reader, &mut ()) {
101                    $(ApiTags::$method => {
102                        let mut call_method = || {
103                            $(let $arg = <$arg_ty>::decode(&mut reader, handle_store).unmark();)*
104                            let r = server.$method($($arg),*);
105                            $(
106                                let r: $ret_ty = Mark::mark(r);
107                            )*
108                            r
109                        };
110                        // HACK(eddyb) don't use `panic::catch_unwind` in a panic.
111                        // If client and server happen to use the same `std`,
112                        // `catch_unwind` asserts that the panic counter was 0,
113                        // even when the closure passed to it didn't panic.
114                        let r = if thread::panicking() {
115                            Ok(call_method())
116                        } else {
117                            panic::catch_unwind(panic::AssertUnwindSafe(call_method))
118                                .map_err(PanicMessage::from)
119                        };
120
121                        buf.clear();
122                        r.encode(&mut buf, handle_store);
123                    })*
124                }
125                buf
126            }
127        }
128    }
129}
130with_api!(Self, define_server_dispatcher_impl);
131
132pub trait ExecutionStrategy {
133    fn run_bridge_and_client(
134        &self,
135        dispatcher: &mut impl DispatcherTrait,
136        input: Buffer,
137        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
138        force_show_panics: bool,
139    ) -> Buffer;
140}
141
142thread_local! {
143    /// While running a proc-macro with the same-thread executor, this flag will
144    /// be set, forcing nested proc-macro invocations (e.g. due to
145    /// `TokenStream::expand_expr`) to be run using a cross-thread executor.
146    ///
147    /// This is required as the thread-local state in the proc_macro client does
148    /// not handle being re-entered, and will invalidate all `Symbol`s when
149    /// entering a nested macro.
150    static ALREADY_RUNNING_SAME_THREAD: Cell<bool> = const { Cell::new(false) };
151}
152
153/// Keep `ALREADY_RUNNING_SAME_THREAD` (see also its documentation)
154/// set to `true`, preventing same-thread reentrance.
155struct RunningSameThreadGuard(());
156
157impl RunningSameThreadGuard {
158    fn new() -> Self {
159        let already_running = ALREADY_RUNNING_SAME_THREAD.replace(true);
160        assert!(
161            !already_running,
162            "same-thread nesting (\"reentrance\") of proc macro executions is not supported"
163        );
164        RunningSameThreadGuard(())
165    }
166}
167
168impl Drop for RunningSameThreadGuard {
169    fn drop(&mut self) {
170        ALREADY_RUNNING_SAME_THREAD.set(false);
171    }
172}
173
174pub struct MaybeCrossThread<P> {
175    cross_thread: bool,
176    marker: PhantomData<P>,
177}
178
179impl<P> MaybeCrossThread<P> {
180    pub const fn new(cross_thread: bool) -> Self {
181        MaybeCrossThread { cross_thread, marker: PhantomData }
182    }
183}
184
185impl<P> ExecutionStrategy for MaybeCrossThread<P>
186where
187    P: MessagePipe<Buffer> + Send + 'static,
188{
189    fn run_bridge_and_client(
190        &self,
191        dispatcher: &mut impl DispatcherTrait,
192        input: Buffer,
193        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
194        force_show_panics: bool,
195    ) -> Buffer {
196        if self.cross_thread || ALREADY_RUNNING_SAME_THREAD.get() {
197            <CrossThread<P>>::new().run_bridge_and_client(
198                dispatcher,
199                input,
200                run_client,
201                force_show_panics,
202            )
203        } else {
204            SameThread.run_bridge_and_client(dispatcher, input, run_client, force_show_panics)
205        }
206    }
207}
208
209pub struct SameThread;
210
211impl ExecutionStrategy for SameThread {
212    fn run_bridge_and_client(
213        &self,
214        dispatcher: &mut impl DispatcherTrait,
215        input: Buffer,
216        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
217        force_show_panics: bool,
218    ) -> Buffer {
219        let _guard = RunningSameThreadGuard::new();
220
221        let mut dispatch = |buf| dispatcher.dispatch(buf);
222
223        run_client(BridgeConfig { input, dispatch: (&mut dispatch).into(), force_show_panics })
224    }
225}
226
227pub struct CrossThread<P>(PhantomData<P>);
228
229impl<P> CrossThread<P> {
230    pub const fn new() -> Self {
231        CrossThread(PhantomData)
232    }
233}
234
235impl<P> ExecutionStrategy for CrossThread<P>
236where
237    P: MessagePipe<Buffer> + Send + 'static,
238{
239    fn run_bridge_and_client(
240        &self,
241        dispatcher: &mut impl DispatcherTrait,
242        input: Buffer,
243        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
244        force_show_panics: bool,
245    ) -> Buffer {
246        let (mut server, mut client) = P::new();
247
248        let join_handle = thread::spawn(move || {
249            let mut dispatch = |b: Buffer| -> Buffer {
250                client.send(b);
251                client.recv().expect("server died while client waiting for reply")
252            };
253
254            run_client(BridgeConfig { input, dispatch: (&mut dispatch).into(), force_show_panics })
255        });
256
257        while let Some(b) = server.recv() {
258            server.send(dispatcher.dispatch(b));
259        }
260
261        join_handle.join().unwrap()
262    }
263}
264
265/// A message pipe used for communicating between server and client threads.
266pub trait MessagePipe<T>: Sized {
267    /// Creates a new pair of endpoints for the message pipe.
268    fn new() -> (Self, Self);
269
270    /// Send a message to the other endpoint of this pipe.
271    fn send(&mut self, value: T);
272
273    /// Receive a message from the other endpoint of this pipe.
274    ///
275    /// Returns `None` if the other end of the pipe has been destroyed, and no
276    /// message was received.
277    fn recv(&mut self) -> Option<T>;
278}
279
280fn run_server<
281    S: Server,
282    I: Encode<HandleStore<S>>,
283    O: for<'a, 's> Decode<'a, 's, HandleStore<S>>,
284>(
285    strategy: &impl ExecutionStrategy,
286    handle_counters: &'static client::HandleCounters,
287    server: S,
288    input: I,
289    run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
290    force_show_panics: bool,
291) -> Result<O, PanicMessage> {
292    let mut dispatcher = Dispatcher { handle_store: HandleStore::new(handle_counters), server };
293
294    let globals = dispatcher.server.globals();
295
296    let mut buf = Buffer::new();
297    (<ExpnGlobals<MarkedSpan<S>> as Mark>::mark(globals), input)
298        .encode(&mut buf, &mut dispatcher.handle_store);
299
300    buf = strategy.run_bridge_and_client(&mut dispatcher, buf, run_client, force_show_panics);
301
302    Result::decode(&mut &buf[..], &mut dispatcher.handle_store)
303}
304
305impl client::Client<crate::TokenStream, crate::TokenStream> {
306    pub fn run<S>(
307        &self,
308        strategy: &impl ExecutionStrategy,
309        server: S,
310        input: S::TokenStream,
311        force_show_panics: bool,
312    ) -> Result<S::TokenStream, PanicMessage>
313    where
314        S: Server,
315        S::TokenStream: Default,
316    {
317        let client::Client { handle_counters, run, _marker } = *self;
318        run_server(
319            strategy,
320            handle_counters,
321            server,
322            <MarkedTokenStream<S>>::mark(input),
323            run,
324            force_show_panics,
325        )
326        .map(|s| <Option<MarkedTokenStream<S>>>::unmark(s).unwrap_or_default())
327    }
328}
329
330impl client::Client<(crate::TokenStream, crate::TokenStream), crate::TokenStream> {
331    pub fn run<S>(
332        &self,
333        strategy: &impl ExecutionStrategy,
334        server: S,
335        input: S::TokenStream,
336        input2: S::TokenStream,
337        force_show_panics: bool,
338    ) -> Result<S::TokenStream, PanicMessage>
339    where
340        S: Server,
341        S::TokenStream: Default,
342    {
343        let client::Client { handle_counters, run, _marker } = *self;
344        run_server(
345            strategy,
346            handle_counters,
347            server,
348            (<MarkedTokenStream<S>>::mark(input), <MarkedTokenStream<S>>::mark(input2)),
349            run,
350            force_show_panics,
351        )
352        .map(|s| <Option<MarkedTokenStream<S>>>::unmark(s).unwrap_or_default())
353    }
354}