Skip to main content

rustc_proc_macro/bridge/
rpc.rs

1//! Serialization for client-server communication.
2
3use std::any::Any;
4use std::io::Write;
5use std::num::NonZero;
6
7use super::buffer::Buffer;
8
9pub(super) trait Encode<S>: Sized {
10    fn encode(self, w: &mut Buffer, s: &mut S);
11}
12
13pub(super) trait Decode<'a, 's, S>: Sized {
14    fn decode(r: &mut &'a [u8], s: &'s mut S) -> Self;
15}
16
17macro_rules! rpc_encode_decode {
18    (le $ty:ty) => {
19        impl<S> Encode<S> for $ty {
20            fn encode(self, w: &mut Buffer, _: &mut S) {
21                w.extend_from_array(&self.to_le_bytes());
22            }
23        }
24
25        impl<S> Decode<'_, '_, S> for $ty {
26            fn decode(r: &mut &[u8], _: &mut S) -> Self {
27                const N: usize = size_of::<$ty>();
28
29                let mut bytes = [0; N];
30                bytes.copy_from_slice(&r[..N]);
31                *r = &r[N..];
32
33                Self::from_le_bytes(bytes)
34            }
35        }
36    };
37    (struct $name:ident $(<$($T:ident),+>)? { $($field:ident),* $(,)? }) => {
38        impl<S, $($($T: Encode<S>),+)?> Encode<S> for $name $(<$($T),+>)? {
39            fn encode(self, w: &mut Buffer, s: &mut S) {
40                $(self.$field.encode(w, s);)*
41            }
42        }
43
44        impl<'a, S, $($($T: for<'s> Decode<'a, 's, S>),+)?> Decode<'a, '_, S>
45            for $name $(<$($T),+>)?
46        {
47            fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
48                $name {
49                    $($field: Decode::decode(r, s)),*
50                }
51            }
52        }
53    };
54    (enum $name:ident $(<$($T:ident),+>)? { $($variant:ident $(($field:ident))*),* $(,)? }) => {
55        #[allow(non_upper_case_globals, non_camel_case_types)]
56        const _: () = {
57            #[repr(u8)] enum Tag { $($variant),* }
58
59            $(const $variant: u8 = Tag::$variant as u8;)*
60
61            impl<S, $($($T: Encode<S>),+)?> Encode<S> for $name $(<$($T),+>)? {
62                fn encode(self, w: &mut Buffer, s: &mut S) {
63                    match self {
64                        $($name::$variant $(($field))* => {
65                            $variant.encode(w, s);
66                            $($field.encode(w, s);)*
67                        })*
68                    }
69                }
70            }
71
72            impl<'a, S, $($($T: for<'s> Decode<'a, 's, S>),+)?> Decode<'a, '_, S>
73                for $name $(<$($T),+>)?
74            {
75                fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
76                    match u8::decode(r, s) {
77                        $($variant => {
78                            $(let $field = Decode::decode(r, s);)*
79                            $name::$variant $(($field))*
80                        })*
81                        _ => unreachable!(),
82                    }
83                }
84            }
85        };
86    }
87}
88
89impl<S> Encode<S> for () {
90    fn encode(self, _: &mut Buffer, _: &mut S) {}
91}
92
93impl<S> Decode<'_, '_, S> for () {
94    fn decode(_: &mut &[u8], _: &mut S) -> Self {}
95}
96
97impl<S> Encode<S> for u8 {
98    fn encode(self, w: &mut Buffer, _: &mut S) {
99        w.push(self);
100    }
101}
102
103impl<S> Decode<'_, '_, S> for u8 {
104    fn decode(r: &mut &[u8], _: &mut S) -> Self {
105        let x = r[0];
106        *r = &r[1..];
107        x
108    }
109}
110
111impl<S> Encode<S> for u32 {
    fn encode(self, w: &mut Buffer, _: &mut S) {
        w.extend_from_array(&self.to_le_bytes());
    }
}
impl<S> Decode<'_, '_, S> for u32 {
    fn decode(r: &mut &[u8], _: &mut S) -> Self {
        const N: usize = size_of::<u32>();
        let mut bytes = [0; N];
        bytes.copy_from_slice(&r[..N]);
        *r = &r[N..];
        Self::from_le_bytes(bytes)
    }
}rpc_encode_decode!(le u32);
112#[cfg(target_pointer_width = "64")]
113impl<S> Encode<S> for usize {
    fn encode(self, w: &mut Buffer, _: &mut S) {
        w.extend_from_array(&self.to_le_bytes());
    }
}
impl<S> Decode<'_, '_, S> for usize {
    fn decode(r: &mut &[u8], _: &mut S) -> Self {
        const N: usize = size_of::<usize>();
        let mut bytes = [0; N];
        bytes.copy_from_slice(&r[..N]);
        *r = &r[N..];
        Self::from_le_bytes(bytes)
    }
}rpc_encode_decode!(le usize);
114
115#[cfg(not(target_pointer_width = "64"))]
116const MAX_USIZE_SIZE: usize = 8;
117
118#[cfg(not(target_pointer_width = "64"))]
119impl<S> Encode<S> for usize {
120    fn encode(self, w: &mut Buffer, _: &mut S) {
121        const N: usize = size_of::<usize>();
122
123        // We can pad with zeros without changing the value because of
124        // little endian encoding.
125        let mut bytes = [0; MAX_USIZE_SIZE];
126        bytes[..N].copy_from_slice(&self.to_le_bytes());
127
128        w.extend_from_array(&bytes);
129    }
130}
131
132#[cfg(not(target_pointer_width = "64"))]
133impl<S> Decode<'_, '_, S> for usize {
134    fn decode(r: &mut &[u8], _: &mut S) -> Self {
135        const N: usize = size_of::<usize>();
136        const {
137            assert!(N <= MAX_USIZE_SIZE);
138        }
139
140        let mut bytes = [0; N];
141        bytes.copy_from_slice(&r[..N]);
142        *r = &r[MAX_USIZE_SIZE..];
143
144        Self::from_le_bytes(bytes)
145    }
146}
147
148impl<S> Encode<S> for bool {
149    fn encode(self, w: &mut Buffer, s: &mut S) {
150        (self as u8).encode(w, s);
151    }
152}
153
154impl<S> Decode<'_, '_, S> for bool {
155    fn decode(r: &mut &[u8], s: &mut S) -> Self {
156        match u8::decode(r, s) {
157            0 => false,
158            1 => true,
159            _ => ::core::panicking::panic("internal error: entered unreachable code")unreachable!(),
160        }
161    }
162}
163
164impl<S> Encode<S> for NonZero<u32> {
165    fn encode(self, w: &mut Buffer, s: &mut S) {
166        self.get().encode(w, s);
167    }
168}
169
170impl<S> Decode<'_, '_, S> for NonZero<u32> {
171    fn decode(r: &mut &[u8], s: &mut S) -> Self {
172        Self::new(u32::decode(r, s)).unwrap()
173    }
174}
175
176impl<S, A: Encode<S>, B: Encode<S>> Encode<S> for (A, B) {
177    fn encode(self, w: &mut Buffer, s: &mut S) {
178        self.0.encode(w, s);
179        self.1.encode(w, s);
180    }
181}
182
183impl<'a, S, A: for<'s> Decode<'a, 's, S>, B: for<'s> Decode<'a, 's, S>> Decode<'a, '_, S>
184    for (A, B)
185{
186    fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
187        (Decode::decode(r, s), Decode::decode(r, s))
188    }
189}
190
191impl<S> Encode<S> for &str {
192    fn encode(self, w: &mut Buffer, s: &mut S) {
193        let bytes = self.as_bytes();
194        bytes.len().encode(w, s);
195        w.write_all(bytes).unwrap();
196    }
197}
198
199impl<'a, S> Decode<'a, '_, S> for &'a str {
200    fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
201        let len = usize::decode(r, s);
202        let xs = &r[..len];
203        *r = &r[len..];
204        str::from_utf8(xs).unwrap()
205    }
206}
207
208impl<S> Encode<S> for String {
209    fn encode(self, w: &mut Buffer, s: &mut S) {
210        self[..].encode(w, s);
211    }
212}
213
214impl<S> Decode<'_, '_, S> for String {
215    fn decode(r: &mut &[u8], s: &mut S) -> Self {
216        <&str>::decode(r, s).to_string()
217    }
218}
219
220impl<S, T: Encode<S>> Encode<S> for Vec<T> {
221    fn encode(self, w: &mut Buffer, s: &mut S) {
222        self.len().encode(w, s);
223        for x in self {
224            x.encode(w, s);
225        }
226    }
227}
228
229impl<'a, S, T: for<'s> Decode<'a, 's, S>> Decode<'a, '_, S> for Vec<T> {
230    fn decode(r: &mut &'a [u8], s: &mut S) -> Self {
231        let len = usize::decode(r, s);
232        let mut vec = Vec::with_capacity(len);
233        for _ in 0..len {
234            vec.push(T::decode(r, s));
235        }
236        vec
237    }
238}
239
240/// Simplified version of panic payloads, ignoring
241/// types other than `&'static str` and `String`.
242pub enum PanicMessage {
243    StaticStr(&'static str),
244    String(String),
245    Unknown,
246}
247
248impl From<Box<dyn Any + Send>> for PanicMessage {
249    fn from(payload: Box<dyn Any + Send + 'static>) -> Self {
250        if let Some(s) = payload.downcast_ref::<&'static str>() {
251            return PanicMessage::StaticStr(s);
252        }
253        if let Ok(s) = payload.downcast::<String>() {
254            return PanicMessage::String(*s);
255        }
256        PanicMessage::Unknown
257    }
258}
259
260impl From<PanicMessage> for Box<dyn Any + Send> {
261    fn from(val: PanicMessage) -> Self {
262        match val {
263            PanicMessage::StaticStr(s) => Box::new(s),
264            PanicMessage::String(s) => Box::new(s),
265            PanicMessage::Unknown => {
266                struct UnknownPanicMessage;
267                Box::new(UnknownPanicMessage)
268            }
269        }
270    }
271}
272
273impl PanicMessage {
274    pub fn as_str(&self) -> Option<&str> {
275        match self {
276            PanicMessage::StaticStr(s) => Some(s),
277            PanicMessage::String(s) => Some(s),
278            PanicMessage::Unknown => None,
279        }
280    }
281
282    pub fn into_string(self) -> Option<String> {
283        match self {
284            PanicMessage::StaticStr(s) => Some(s.into()),
285            PanicMessage::String(s) => Some(s),
286            PanicMessage::Unknown => None,
287        }
288    }
289}
290
291impl<S> Encode<S> for PanicMessage {
292    fn encode(self, w: &mut Buffer, s: &mut S) {
293        self.as_str().encode(w, s);
294    }
295}
296
297impl<S> Decode<'_, '_, S> for PanicMessage {
298    fn decode(r: &mut &[u8], s: &mut S) -> Self {
299        match Option::<String>::decode(r, s) {
300            Some(s) => PanicMessage::String(s),
301            None => PanicMessage::Unknown,
302        }
303    }
304}