Skip to main content

proc_macro/bridge/
rpc.rs

1//! Serialization for client-server communication.
2
3use std::any::Any;
4use std::io::Write;
5use std::num::NonZero;
6
7use super::buffer::Buffer;
8
9pub(super) trait Encode<S>: Sized {
10    fn encode(self, w: &mut Buffer, s: &mut S);
11}
12
13pub(super) trait Decode<'a, 's, S>: Sized {
14    fn decode(r: &mut &'a [u8], s: &'s mut S) -> Self;
15}
16
17macro_rules! rpc_encode_decode {
18    (le $ty:ident $size:literal) => {
19        impl<S> Encode<S> for $ty {
20            fn encode(self, w: &mut Buffer, _: &mut S) {
21                const N: usize = size_of::<$ty>();
22
23                // We can pad with zeros without changing the value because of
24                // little endian encoding.
25                let mut bytes = [0; $size];
26                bytes[..N].copy_from_slice(&self.to_le_bytes());
27
28                w.extend_from_array(&bytes);
29            }
30        }
31
32        impl<S> Decode<'_, '_, S> for $ty {
33            fn decode(r: &mut &[u8], _: &mut S) -> Self {
34                const N: usize = size_of::<$ty>();
35                const {
36                    assert!(N <= $size);
37                }
38
39                let mut bytes = [0; N];
40                bytes.copy_from_slice(&r[..N]);
41                *r = &r[$size..];
42
43                Self::from_le_bytes(bytes)
44            }
45        }
46    };
47    (struct $name:ident $(<$($T:ident),+>)? { $($field:ident),* $(,)? }) => {
48        impl<S, $($($T: Encode<S>),+)?> Encode<S> for $name $(<$($T),+>)? {
49            fn encode(self, w: &mut Buffer, s: &mut S) {
50                $(self.$field.encode(w, s);)*
51            }
52        }
53
54        impl<'a, S, $($($T: for<'s> Decode<'a, 's, S>),+)?> Decode<'a, '_, S>
55            for $name $(<$($T),+>)?
56        {
57            fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
58                $name {
59                    $($field: Decode::decode(r, s)),*
60                }
61            }
62        }
63    };
64    (enum $name:ident $(<$($T:ident),+>)? { $($variant:ident $(($field:ident))*),* $(,)? }) => {
65        #[allow(non_upper_case_globals, non_camel_case_types)]
66        const _: () = {
67            #[repr(u8)] enum Tag { $($variant),* }
68
69            $(const $variant: u8 = Tag::$variant as u8;)*
70
71            impl<S, $($($T: Encode<S>),+)?> Encode<S> for $name $(<$($T),+>)? {
72                fn encode(self, w: &mut Buffer, s: &mut S) {
73                    match self {
74                        $($name::$variant $(($field))* => {
75                            $variant.encode(w, s);
76                            $($field.encode(w, s);)*
77                        })*
78                    }
79                }
80            }
81
82            impl<'a, S, $($($T: for<'s> Decode<'a, 's, S>),+)?> Decode<'a, '_, S>
83                for $name $(<$($T),+>)?
84            {
85                fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
86                    match u8::decode(r, s) {
87                        $($variant => {
88                            $(let $field = Decode::decode(r, s);)*
89                            $name::$variant $(($field))*
90                        })*
91                        _ => unreachable!(),
92                    }
93                }
94            }
95        };
96    }
97}
98
99impl<S> Encode<S> for () {
100    fn encode(self, _: &mut Buffer, _: &mut S) {}
101}
102
103impl<S> Decode<'_, '_, S> for () {
104    fn decode(_: &mut &[u8], _: &mut S) -> Self {}
105}
106
107impl<S> Encode<S> for u8 {
108    fn encode(self, w: &mut Buffer, _: &mut S) {
109        w.push(self);
110    }
111}
112
113impl<S> Decode<'_, '_, S> for u8 {
114    fn decode(r: &mut &[u8], _: &mut S) -> Self {
115        let x = r[0];
116        *r = &r[1..];
117        x
118    }
119}
120
121rpc_encode_decode!(le u32 4);
122rpc_encode_decode!(le usize 8);
123
124impl<S> Encode<S> for bool {
125    fn encode(self, w: &mut Buffer, s: &mut S) {
126        (self as u8).encode(w, s);
127    }
128}
129
130impl<S> Decode<'_, '_, S> for bool {
131    fn decode(r: &mut &[u8], s: &mut S) -> Self {
132        match u8::decode(r, s) {
133            0 => false,
134            1 => true,
135            _ => unreachable!(),
136        }
137    }
138}
139
140impl<S> Encode<S> for NonZero<u32> {
141    fn encode(self, w: &mut Buffer, s: &mut S) {
142        self.get().encode(w, s);
143    }
144}
145
146impl<S> Decode<'_, '_, S> for NonZero<u32> {
147    fn decode(r: &mut &[u8], s: &mut S) -> Self {
148        Self::new(u32::decode(r, s)).unwrap()
149    }
150}
151
152impl<S, A: Encode<S>, B: Encode<S>> Encode<S> for (A, B) {
153    fn encode(self, w: &mut Buffer, s: &mut S) {
154        self.0.encode(w, s);
155        self.1.encode(w, s);
156    }
157}
158
159impl<'a, S, A: for<'s> Decode<'a, 's, S>, B: for<'s> Decode<'a, 's, S>> Decode<'a, '_, S>
160    for (A, B)
161{
162    fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
163        (Decode::decode(r, s), Decode::decode(r, s))
164    }
165}
166
167impl<S> Encode<S> for &str {
168    fn encode(self, w: &mut Buffer, s: &mut S) {
169        let bytes = self.as_bytes();
170        bytes.len().encode(w, s);
171        w.write_all(bytes).unwrap();
172    }
173}
174
175impl<'a, S> Decode<'a, '_, S> for &'a str {
176    fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
177        let len = usize::decode(r, s);
178        let xs = &r[..len];
179        *r = &r[len..];
180        str::from_utf8(xs).unwrap()
181    }
182}
183
184impl<S> Encode<S> for String {
185    fn encode(self, w: &mut Buffer, s: &mut S) {
186        self[..].encode(w, s);
187    }
188}
189
190impl<S> Decode<'_, '_, S> for String {
191    fn decode(r: &mut &[u8], s: &mut S) -> Self {
192        <&str>::decode(r, s).to_string()
193    }
194}
195
196impl<S, T: Encode<S>> Encode<S> for Vec<T> {
197    fn encode(self, w: &mut Buffer, s: &mut S) {
198        self.len().encode(w, s);
199        for x in self {
200            x.encode(w, s);
201        }
202    }
203}
204
205impl<'a, S, T: for<'s> Decode<'a, 's, S>> Decode<'a, '_, S> for Vec<T> {
206    fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
207        let len = usize::decode(r, s);
208        let mut vec = Vec::with_capacity(len);
209        for _ in 0..len {
210            vec.push(T::decode(r, s));
211        }
212        vec
213    }
214}
215
216/// Simplified version of panic payloads, ignoring
217/// types other than `&'static str` and `String`.
218pub enum PanicMessage {
219    StaticStr(&'static str),
220    String(String),
221    Unknown,
222}
223
224impl From<Box<dyn Any + Send>> for PanicMessage {
225    fn from(payload: Box<dyn Any + Send + 'static>) -> Self {
226        if let Some(s) = payload.downcast_ref::<&'static str>() {
227            return PanicMessage::StaticStr(s);
228        }
229        if let Ok(s) = payload.downcast::<String>() {
230            return PanicMessage::String(*s);
231        }
232        PanicMessage::Unknown
233    }
234}
235
236impl From<PanicMessage> for Box<dyn Any + Send> {
237    fn from(val: PanicMessage) -> Self {
238        match val {
239            PanicMessage::StaticStr(s) => Box::new(s),
240            PanicMessage::String(s) => Box::new(s),
241            PanicMessage::Unknown => {
242                struct UnknownPanicMessage;
243                Box::new(UnknownPanicMessage)
244            }
245        }
246    }
247}
248
249impl PanicMessage {
250    pub fn as_str(&self) -> Option<&str> {
251        match self {
252            PanicMessage::StaticStr(s) => Some(s),
253            PanicMessage::String(s) => Some(s),
254            PanicMessage::Unknown => None,
255        }
256    }
257
258    pub fn into_string(self) -> Option<String> {
259        match self {
260            PanicMessage::StaticStr(s) => Some(s.into()),
261            PanicMessage::String(s) => Some(s),
262            PanicMessage::Unknown => None,
263        }
264    }
265}
266
267impl<S> Encode<S> for PanicMessage {
268    fn encode(self, w: &mut Buffer, s: &mut S) {
269        self.as_str().encode(w, s);
270    }
271}
272
273impl<S> Decode<'_, '_, S> for PanicMessage {
274    fn decode(r: &mut &[u8], s: &mut S) -> Self {
275        match Option::<String>::decode(r, s) {
276            Some(s) => PanicMessage::String(s),
277            None => PanicMessage::Unknown,
278        }
279    }
280}