proc_macro/bridge/
server.rs

1//! Server-side traits.
2
3use std::cell::Cell;
4use std::marker::PhantomData;
5
6use super::*;
7
8macro_rules! define_server_handles {
9    (
10        'owned: $($oty:ident,)*
11        'interned: $($ity:ident,)*
12    ) => {
13        #[allow(non_snake_case)]
14        pub(super) struct HandleStore<S: Types> {
15            $($oty: handle::OwnedStore<S::$oty>,)*
16            $($ity: handle::InternedStore<S::$ity>,)*
17        }
18
19        impl<S: Types> HandleStore<S> {
20            fn new(handle_counters: &'static client::HandleCounters) -> Self {
21                HandleStore {
22                    $($oty: handle::OwnedStore::new(&handle_counters.$oty),)*
23                    $($ity: handle::InternedStore::new(&handle_counters.$ity),)*
24                }
25            }
26        }
27
28        $(
29            impl<S: Types> Encode<HandleStore<MarkedTypes<S>>> for Marked<S::$oty, client::$oty> {
30                fn encode(self, w: &mut Writer, s: &mut HandleStore<MarkedTypes<S>>) {
31                    s.$oty.alloc(self).encode(w, s);
32                }
33            }
34
35            impl<S: Types> DecodeMut<'_, '_, HandleStore<MarkedTypes<S>>>
36                for Marked<S::$oty, client::$oty>
37            {
38                fn decode(r: &mut Reader<'_>, s: &mut HandleStore<MarkedTypes<S>>) -> Self {
39                    s.$oty.take(handle::Handle::decode(r, &mut ()))
40                }
41            }
42
43            impl<'s, S: Types> Decode<'_, 's, HandleStore<MarkedTypes<S>>>
44                for &'s Marked<S::$oty, client::$oty>
45            {
46                fn decode(r: &mut Reader<'_>, s: &'s HandleStore<MarkedTypes<S>>) -> Self {
47                    &s.$oty[handle::Handle::decode(r, &mut ())]
48                }
49            }
50
51            impl<'s, S: Types> DecodeMut<'_, 's, HandleStore<MarkedTypes<S>>>
52                for &'s mut Marked<S::$oty, client::$oty>
53            {
54                fn decode(
55                    r: &mut Reader<'_>,
56                    s: &'s mut HandleStore<MarkedTypes<S>>
57                ) -> Self {
58                    &mut s.$oty[handle::Handle::decode(r, &mut ())]
59                }
60            }
61        )*
62
63        $(
64            impl<S: Types> Encode<HandleStore<MarkedTypes<S>>> for Marked<S::$ity, client::$ity> {
65                fn encode(self, w: &mut Writer, s: &mut HandleStore<MarkedTypes<S>>) {
66                    s.$ity.alloc(self).encode(w, s);
67                }
68            }
69
70            impl<S: Types> DecodeMut<'_, '_, HandleStore<MarkedTypes<S>>>
71                for Marked<S::$ity, client::$ity>
72            {
73                fn decode(r: &mut Reader<'_>, s: &mut HandleStore<MarkedTypes<S>>) -> Self {
74                    s.$ity.copy(handle::Handle::decode(r, &mut ()))
75                }
76            }
77        )*
78    }
79}
80with_api_handle_types!(define_server_handles);
81
82pub trait Types {
83    type FreeFunctions: 'static;
84    type TokenStream: 'static + Clone;
85    type SourceFile: 'static + Clone;
86    type Span: 'static + Copy + Eq + Hash;
87    type Symbol: 'static;
88}
89
90/// Declare an associated fn of one of the traits below, adding necessary
91/// default bodies.
92macro_rules! associated_fn {
93    (fn drop(&mut self, $arg:ident: $arg_ty:ty)) =>
94        (fn drop(&mut self, $arg: $arg_ty) { mem::drop($arg) });
95
96    (fn clone(&mut self, $arg:ident: $arg_ty:ty) -> $ret_ty:ty) =>
97        (fn clone(&mut self, $arg: $arg_ty) -> $ret_ty { $arg.clone() });
98
99    ($($item:tt)*) => ($($item)*;)
100}
101
102macro_rules! declare_server_traits {
103    ($($name:ident {
104        $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)*
105    }),* $(,)?) => {
106        $(pub trait $name: Types {
107            $(associated_fn!(fn $method(&mut self, $($arg: $arg_ty),*) $(-> $ret_ty)?);)*
108        })*
109
110        pub trait Server: Types $(+ $name)* {
111            fn globals(&mut self) -> ExpnGlobals<Self::Span>;
112
113            /// Intern a symbol received from RPC
114            fn intern_symbol(ident: &str) -> Self::Symbol;
115
116            /// Recover the string value of a symbol, and invoke a callback with it.
117            fn with_symbol_string(symbol: &Self::Symbol, f: impl FnOnce(&str));
118        }
119    }
120}
121with_api!(Self, self_, declare_server_traits);
122
123pub(super) struct MarkedTypes<S: Types>(S);
124
125impl<S: Server> Server for MarkedTypes<S> {
126    fn globals(&mut self) -> ExpnGlobals<Self::Span> {
127        <_>::mark(Server::globals(&mut self.0))
128    }
129    fn intern_symbol(ident: &str) -> Self::Symbol {
130        <_>::mark(S::intern_symbol(ident))
131    }
132    fn with_symbol_string(symbol: &Self::Symbol, f: impl FnOnce(&str)) {
133        S::with_symbol_string(symbol.unmark(), f)
134    }
135}
136
137macro_rules! define_mark_types_impls {
138    ($($name:ident {
139        $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)*
140    }),* $(,)?) => {
141        impl<S: Types> Types for MarkedTypes<S> {
142            $(type $name = Marked<S::$name, client::$name>;)*
143        }
144
145        $(impl<S: $name> $name for MarkedTypes<S> {
146            $(fn $method(&mut self, $($arg: $arg_ty),*) $(-> $ret_ty)? {
147                <_>::mark($name::$method(&mut self.0, $($arg.unmark()),*))
148            })*
149        })*
150    }
151}
152with_api!(Self, self_, define_mark_types_impls);
153
154struct Dispatcher<S: Types> {
155    handle_store: HandleStore<S>,
156    server: S,
157}
158
159macro_rules! define_dispatcher_impl {
160    ($($name:ident {
161        $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)*
162    }),* $(,)?) => {
163        // FIXME(eddyb) `pub` only for `ExecutionStrategy` below.
164        pub trait DispatcherTrait {
165            // HACK(eddyb) these are here to allow `Self::$name` to work below.
166            $(type $name;)*
167
168            fn dispatch(&mut self, buf: Buffer) -> Buffer;
169        }
170
171        impl<S: Server> DispatcherTrait for Dispatcher<MarkedTypes<S>> {
172            $(type $name = <MarkedTypes<S> as Types>::$name;)*
173
174            fn dispatch(&mut self, mut buf: Buffer) -> Buffer {
175                let Dispatcher { handle_store, server } = self;
176
177                let mut reader = &buf[..];
178                match api_tags::Method::decode(&mut reader, &mut ()) {
179                    $(api_tags::Method::$name(m) => match m {
180                        $(api_tags::$name::$method => {
181                            let mut call_method = || {
182                                reverse_decode!(reader, handle_store; $($arg: $arg_ty),*);
183                                $name::$method(server, $($arg),*)
184                            };
185                            // HACK(eddyb) don't use `panic::catch_unwind` in a panic.
186                            // If client and server happen to use the same `std`,
187                            // `catch_unwind` asserts that the panic counter was 0,
188                            // even when the closure passed to it didn't panic.
189                            let r = if thread::panicking() {
190                                Ok(call_method())
191                            } else {
192                                panic::catch_unwind(panic::AssertUnwindSafe(call_method))
193                                    .map_err(PanicMessage::from)
194                            };
195
196                            buf.clear();
197                            r.encode(&mut buf, handle_store);
198                        })*
199                    }),*
200                }
201                buf
202            }
203        }
204    }
205}
206with_api!(Self, self_, define_dispatcher_impl);
207
208pub trait ExecutionStrategy {
209    fn run_bridge_and_client(
210        &self,
211        dispatcher: &mut impl DispatcherTrait,
212        input: Buffer,
213        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
214        force_show_panics: bool,
215    ) -> Buffer;
216}
217
218thread_local! {
219    /// While running a proc-macro with the same-thread executor, this flag will
220    /// be set, forcing nested proc-macro invocations (e.g. due to
221    /// `TokenStream::expand_expr`) to be run using a cross-thread executor.
222    ///
223    /// This is required as the thread-local state in the proc_macro client does
224    /// not handle being re-entered, and will invalidate all `Symbol`s when
225    /// entering a nested macro.
226    static ALREADY_RUNNING_SAME_THREAD: Cell<bool> = const { Cell::new(false) };
227}
228
229/// Keep `ALREADY_RUNNING_SAME_THREAD` (see also its documentation)
230/// set to `true`, preventing same-thread reentrance.
231struct RunningSameThreadGuard(());
232
233impl RunningSameThreadGuard {
234    fn new() -> Self {
235        let already_running = ALREADY_RUNNING_SAME_THREAD.replace(true);
236        assert!(
237            !already_running,
238            "same-thread nesting (\"reentrance\") of proc macro executions is not supported"
239        );
240        RunningSameThreadGuard(())
241    }
242}
243
244impl Drop for RunningSameThreadGuard {
245    fn drop(&mut self) {
246        ALREADY_RUNNING_SAME_THREAD.set(false);
247    }
248}
249
250pub struct MaybeCrossThread<P> {
251    cross_thread: bool,
252    marker: PhantomData<P>,
253}
254
255impl<P> MaybeCrossThread<P> {
256    pub const fn new(cross_thread: bool) -> Self {
257        MaybeCrossThread { cross_thread, marker: PhantomData }
258    }
259}
260
261impl<P> ExecutionStrategy for MaybeCrossThread<P>
262where
263    P: MessagePipe<Buffer> + Send + 'static,
264{
265    fn run_bridge_and_client(
266        &self,
267        dispatcher: &mut impl DispatcherTrait,
268        input: Buffer,
269        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
270        force_show_panics: bool,
271    ) -> Buffer {
272        if self.cross_thread || ALREADY_RUNNING_SAME_THREAD.get() {
273            <CrossThread<P>>::new().run_bridge_and_client(
274                dispatcher,
275                input,
276                run_client,
277                force_show_panics,
278            )
279        } else {
280            SameThread.run_bridge_and_client(dispatcher, input, run_client, force_show_panics)
281        }
282    }
283}
284
285pub struct SameThread;
286
287impl ExecutionStrategy for SameThread {
288    fn run_bridge_and_client(
289        &self,
290        dispatcher: &mut impl DispatcherTrait,
291        input: Buffer,
292        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
293        force_show_panics: bool,
294    ) -> Buffer {
295        let _guard = RunningSameThreadGuard::new();
296
297        let mut dispatch = |buf| dispatcher.dispatch(buf);
298
299        run_client(BridgeConfig {
300            input,
301            dispatch: (&mut dispatch).into(),
302            force_show_panics,
303            _marker: marker::PhantomData,
304        })
305    }
306}
307
308pub struct CrossThread<P>(PhantomData<P>);
309
310impl<P> CrossThread<P> {
311    pub const fn new() -> Self {
312        CrossThread(PhantomData)
313    }
314}
315
316impl<P> ExecutionStrategy for CrossThread<P>
317where
318    P: MessagePipe<Buffer> + Send + 'static,
319{
320    fn run_bridge_and_client(
321        &self,
322        dispatcher: &mut impl DispatcherTrait,
323        input: Buffer,
324        run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
325        force_show_panics: bool,
326    ) -> Buffer {
327        let (mut server, mut client) = P::new();
328
329        let join_handle = thread::spawn(move || {
330            let mut dispatch = |b: Buffer| -> Buffer {
331                client.send(b);
332                client.recv().expect("server died while client waiting for reply")
333            };
334
335            run_client(BridgeConfig {
336                input,
337                dispatch: (&mut dispatch).into(),
338                force_show_panics,
339                _marker: marker::PhantomData,
340            })
341        });
342
343        while let Some(b) = server.recv() {
344            server.send(dispatcher.dispatch(b));
345        }
346
347        join_handle.join().unwrap()
348    }
349}
350
351/// A message pipe used for communicating between server and client threads.
352pub trait MessagePipe<T>: Sized {
353    /// Creates a new pair of endpoints for the message pipe.
354    fn new() -> (Self, Self);
355
356    /// Send a message to the other endpoint of this pipe.
357    fn send(&mut self, value: T);
358
359    /// Receive a message from the other endpoint of this pipe.
360    ///
361    /// Returns `None` if the other end of the pipe has been destroyed, and no
362    /// message was received.
363    fn recv(&mut self) -> Option<T>;
364}
365
366fn run_server<
367    S: Server,
368    I: Encode<HandleStore<MarkedTypes<S>>>,
369    O: for<'a, 's> DecodeMut<'a, 's, HandleStore<MarkedTypes<S>>>,
370>(
371    strategy: &impl ExecutionStrategy,
372    handle_counters: &'static client::HandleCounters,
373    server: S,
374    input: I,
375    run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
376    force_show_panics: bool,
377) -> Result<O, PanicMessage> {
378    let mut dispatcher =
379        Dispatcher { handle_store: HandleStore::new(handle_counters), server: MarkedTypes(server) };
380
381    let globals = dispatcher.server.globals();
382
383    let mut buf = Buffer::new();
384    (globals, input).encode(&mut buf, &mut dispatcher.handle_store);
385
386    buf = strategy.run_bridge_and_client(&mut dispatcher, buf, run_client, force_show_panics);
387
388    Result::decode(&mut &buf[..], &mut dispatcher.handle_store)
389}
390
391impl client::Client<crate::TokenStream, crate::TokenStream> {
392    pub fn run<S>(
393        &self,
394        strategy: &impl ExecutionStrategy,
395        server: S,
396        input: S::TokenStream,
397        force_show_panics: bool,
398    ) -> Result<S::TokenStream, PanicMessage>
399    where
400        S: Server,
401        S::TokenStream: Default,
402    {
403        let client::Client { handle_counters, run, _marker } = *self;
404        run_server(
405            strategy,
406            handle_counters,
407            server,
408            <MarkedTypes<S> as Types>::TokenStream::mark(input),
409            run,
410            force_show_panics,
411        )
412        .map(|s| <Option<<MarkedTypes<S> as Types>::TokenStream>>::unmark(s).unwrap_or_default())
413    }
414}
415
416impl client::Client<(crate::TokenStream, crate::TokenStream), crate::TokenStream> {
417    pub fn run<S>(
418        &self,
419        strategy: &impl ExecutionStrategy,
420        server: S,
421        input: S::TokenStream,
422        input2: S::TokenStream,
423        force_show_panics: bool,
424    ) -> Result<S::TokenStream, PanicMessage>
425    where
426        S: Server,
427        S::TokenStream: Default,
428    {
429        let client::Client { handle_counters, run, _marker } = *self;
430        run_server(
431            strategy,
432            handle_counters,
433            server,
434            (
435                <MarkedTypes<S> as Types>::TokenStream::mark(input),
436                <MarkedTypes<S> as Types>::TokenStream::mark(input2),
437            ),
438            run,
439            force_show_panics,
440        )
441        .map(|s| <Option<<MarkedTypes<S> as Types>::TokenStream>>::unmark(s).unwrap_or_default())
442    }
443}