rustc_macros/
serialize.rs

1use proc_macro2::TokenStream;
2use quote::{quote, quote_spanned};
3use syn::parse_quote;
4use syn::spanned::Spanned;
5
6pub(super) fn type_decodable_derive(
7    mut s: synstructure::Structure<'_>,
8) -> proc_macro2::TokenStream {
9    if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
10        s.add_impl_generic(parse_quote! { 'tcx });
11    }
12    let decoder_ty = quote! { __D };
13    s.add_impl_generic(parse_quote! { #decoder_ty: ::rustc_middle::ty::codec::TyDecoder<'tcx> });
14    s.add_bounds(synstructure::AddBounds::Fields);
15
16    decodable_body(s, decoder_ty)
17}
18
19pub(super) fn blob_decodable_derive(
20    mut s: synstructure::Structure<'_>,
21) -> proc_macro2::TokenStream {
22    let decoder_ty = quote! { __D };
23    s.add_impl_generic(parse_quote! { #decoder_ty: ::rustc_span::BlobDecoder });
24    s.add_bounds(synstructure::AddBounds::Generics);
25
26    decodable_body(s, decoder_ty)
27}
28
29pub(super) fn lazy_decodable_derive(
30    mut s: synstructure::Structure<'_>,
31) -> proc_macro2::TokenStream {
32    let decoder_ty = quote! { __D };
33    s.add_impl_generic(parse_quote! { #decoder_ty: LazyDecoder });
34    s.add_bounds(synstructure::AddBounds::Generics);
35
36    decodable_body(s, decoder_ty)
37}
38
39pub(super) fn decodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
40    let decoder_ty = quote! { __D };
41    s.add_impl_generic(parse_quote! { #decoder_ty: ::rustc_span::SpanDecoder });
42    s.add_bounds(synstructure::AddBounds::Generics);
43
44    decodable_body(s, decoder_ty)
45}
46
47pub(super) fn decodable_nocontext_derive(
48    mut s: synstructure::Structure<'_>,
49) -> proc_macro2::TokenStream {
50    let decoder_ty = quote! { __D };
51    s.add_impl_generic(parse_quote! { #decoder_ty: ::rustc_serialize::Decoder });
52    s.add_bounds(synstructure::AddBounds::Fields);
53
54    decodable_body(s, decoder_ty)
55}
56
57fn decodable_body(
58    s: synstructure::Structure<'_>,
59    decoder_ty: TokenStream,
60) -> proc_macro2::TokenStream {
61    if let syn::Data::Union(_) = s.ast().data {
62        panic!("cannot derive on union")
63    }
64    let ty_name = s.ast().ident.to_string();
65    let decode_body = match s.variants() {
66        [] => {
67            let message = format!("`{ty_name}` has no variants to decode");
68            quote! {
69                panic!(#message)
70            }
71        }
72        [vi] => vi.construct(|field, _index| decode_field(field)),
73        variants => {
74            let match_inner: TokenStream = variants
75                .iter()
76                .enumerate()
77                .map(|(idx, vi)| {
78                    let construct = vi.construct(|field, _index| decode_field(field));
79                    quote! { #idx => { #construct } }
80                })
81                .collect();
82            let message = format!(
83                "invalid enum variant tag while decoding `{}`, expected 0..{}, actual {{}}",
84                ty_name,
85                variants.len()
86            );
87            let tag = if variants.len() < u8::MAX as usize {
88                quote! {
89                    ::rustc_serialize::Decoder::read_u8(__decoder) as usize
90                }
91            } else {
92                quote! {
93                    ::rustc_serialize::Decoder::read_usize(__decoder)
94                }
95            };
96            quote! {
97                match #tag {
98                    #match_inner
99                    n => panic!(#message, n),
100                }
101            }
102        }
103    };
104
105    s.bound_impl(
106        quote!(::rustc_serialize::Decodable<#decoder_ty>),
107        quote! {
108            fn decode(__decoder: &mut #decoder_ty) -> Self {
109                #decode_body
110            }
111        },
112    )
113}
114
115fn decode_field(field: &syn::Field) -> proc_macro2::TokenStream {
116    let field_span = field.ident.as_ref().map_or(field.ty.span(), |ident| ident.span());
117
118    let decode_inner_method = if let syn::Type::Reference(_) = field.ty {
119        quote! { ::rustc_middle::ty::codec::RefDecodable::decode }
120    } else {
121        quote! { ::rustc_serialize::Decodable::decode }
122    };
123    let __decoder = quote! { __decoder };
124    // Use the span of the field for the method call, so
125    // that backtraces will point to the field.
126    quote_spanned! { field_span=> #decode_inner_method(#__decoder) }
127}
128
129pub(super) fn type_encodable_derive(
130    mut s: synstructure::Structure<'_>,
131) -> proc_macro2::TokenStream {
132    let encoder_ty = quote! { __E };
133    if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
134        s.add_impl_generic(parse_quote! { 'tcx });
135    }
136    s.add_impl_generic(parse_quote! { #encoder_ty: ::rustc_middle::ty::codec::TyEncoder<'tcx> });
137    s.add_bounds(synstructure::AddBounds::Fields);
138
139    encodable_body(s, encoder_ty, false)
140}
141
142pub(super) fn meta_encodable_derive(
143    mut s: synstructure::Structure<'_>,
144) -> proc_macro2::TokenStream {
145    if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
146        s.add_impl_generic(parse_quote! { 'tcx });
147    }
148    s.add_impl_generic(parse_quote! { '__a });
149    let encoder_ty = quote! { EncodeContext<'__a, 'tcx> };
150    s.add_bounds(synstructure::AddBounds::Generics);
151
152    encodable_body(s, encoder_ty, true)
153}
154
155pub(super) fn encodable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
156    let encoder_ty = quote! { __E };
157    s.add_impl_generic(parse_quote! { #encoder_ty: ::rustc_span::SpanEncoder });
158    s.add_bounds(synstructure::AddBounds::Generics);
159
160    encodable_body(s, encoder_ty, false)
161}
162
163pub(super) fn encodable_nocontext_derive(
164    mut s: synstructure::Structure<'_>,
165) -> proc_macro2::TokenStream {
166    let encoder_ty = quote! { __E };
167    s.add_impl_generic(parse_quote! { #encoder_ty: ::rustc_serialize::Encoder });
168    s.add_bounds(synstructure::AddBounds::Fields);
169
170    encodable_body(s, encoder_ty, false)
171}
172
173fn encodable_body(
174    mut s: synstructure::Structure<'_>,
175    encoder_ty: TokenStream,
176    allow_unreachable_code: bool,
177) -> proc_macro2::TokenStream {
178    if let syn::Data::Union(_) = s.ast().data {
179        panic!("cannot derive on union")
180    }
181
182    s.bind_with(|binding| {
183        // Handle the lack of a blanket reference impl.
184        if let syn::Type::Reference(_) = binding.ast().ty {
185            synstructure::BindStyle::Move
186        } else {
187            synstructure::BindStyle::Ref
188        }
189    });
190
191    let encode_body = match s.variants() {
192        [] => {
193            quote! {
194                match *self {}
195            }
196        }
197        [_] => {
198            let encode_inner = s.each_variant(|vi| {
199                vi.bindings()
200                    .iter()
201                    .map(|binding| {
202                        let bind_ident = &binding.binding;
203                        let result = quote! {
204                            ::rustc_serialize::Encodable::<#encoder_ty>::encode(
205                                #bind_ident,
206                                __encoder,
207                            );
208                        };
209                        result
210                    })
211                    .collect::<TokenStream>()
212            });
213            quote! {
214                match *self { #encode_inner }
215            }
216        }
217        _ => {
218            let disc = {
219                let mut variant_idx = 0usize;
220                let encode_inner = s.each_variant(|_| {
221                    let result = quote! {
222                        #variant_idx
223                    };
224                    variant_idx += 1;
225                    result
226                });
227                if variant_idx < u8::MAX as usize {
228                    quote! {
229                        let disc = match *self {
230                            #encode_inner
231                        };
232                        ::rustc_serialize::Encoder::emit_u8(__encoder, disc as u8);
233                    }
234                } else {
235                    quote! {
236                        let disc = match *self {
237                            #encode_inner
238                        };
239                        ::rustc_serialize::Encoder::emit_usize(__encoder, disc);
240                    }
241                }
242            };
243
244            let mut variant_idx = 0usize;
245            let encode_inner = s.each_variant(|vi| {
246                let encode_fields: TokenStream = vi
247                    .bindings()
248                    .iter()
249                    .map(|binding| {
250                        let bind_ident = &binding.binding;
251                        let result = quote! {
252                            ::rustc_serialize::Encodable::<#encoder_ty>::encode(
253                                #bind_ident,
254                                __encoder,
255                            );
256                        };
257                        result
258                    })
259                    .collect();
260                variant_idx += 1;
261                encode_fields
262            });
263            quote! {
264                #disc
265                match *self {
266                    #encode_inner
267                }
268            }
269        }
270    };
271
272    let lints = if allow_unreachable_code {
273        quote! { #![allow(unreachable_code)] }
274    } else {
275        quote! {}
276    };
277
278    s.bound_impl(
279        quote!(::rustc_serialize::Encodable<#encoder_ty>),
280        quote! {
281            fn encode(
282                &self,
283                __encoder: &mut #encoder_ty,
284            ) {
285                #lints
286                #encode_body
287            }
288        },
289    )
290}