rustc_index_macros/
newtype.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::parse::*;
4use syn::*;
5
6// We parse the input and emit the output in a single step.
7// This field stores the final macro output
8struct Newtype(TokenStream);
9
10impl Parse for Newtype {
11    fn parse(input: ParseStream<'_>) -> Result<Self> {
12        let mut attrs = input.call(Attribute::parse_outer)?;
13        let vis: Visibility = input.parse()?;
14        input.parse::<Token![struct]>()?;
15        let name: Ident = input.parse()?;
16
17        let body;
18        braced!(body in input);
19
20        // Any additional `#[derive]` macro paths to apply
21        let mut derive_paths: Vec<Path> = Vec::new();
22        let mut debug_format: Option<Lit> = None;
23        let mut max = None;
24        let mut consts = Vec::new();
25        let mut encodable = false;
26        let mut ord = false;
27        let mut gate_rustc_only = quote! {};
28        let mut gate_rustc_only_cfg = quote! { all() };
29
30        attrs.retain(|attr| match attr.path().get_ident() {
31            Some(ident) => match &*ident.to_string() {
32                "gate_rustc_only" => {
33                    gate_rustc_only = quote! { #[cfg(feature = "nightly")] };
34                    gate_rustc_only_cfg = quote! { feature = "nightly" };
35                    false
36                }
37                "encodable" => {
38                    encodable = true;
39                    false
40                }
41                "orderable" => {
42                    ord = true;
43                    false
44                }
45                "max" => {
46                    let Meta::NameValue(MetaNameValue { value: Expr::Lit(lit), .. }) = &attr.meta
47                    else {
48                        panic!("#[max = NUMBER] attribute requires max value");
49                    };
50
51                    if let Some(old) = max.replace(lit.lit.clone()) {
52                        panic!("Specified multiple max: {old:?}");
53                    }
54
55                    false
56                }
57                "debug_format" => {
58                    let Meta::NameValue(MetaNameValue { value: Expr::Lit(lit), .. }) = &attr.meta
59                    else {
60                        panic!("#[debug_format = FMT] attribute requires a format");
61                    };
62
63                    if let Some(old) = debug_format.replace(lit.lit.clone()) {
64                        panic!("Specified multiple debug format options: {old:?}");
65                    }
66
67                    false
68                }
69                _ => true,
70            },
71            _ => true,
72        });
73
74        loop {
75            // We've parsed everything that the user provided, so we're done
76            if body.is_empty() {
77                break;
78            }
79
80            // Otherwise, we are parsing a user-defined constant
81            let const_attrs = body.call(Attribute::parse_outer)?;
82            body.parse::<Token![const]>()?;
83            let const_name: Ident = body.parse()?;
84            body.parse::<Token![=]>()?;
85            let const_val: Expr = body.parse()?;
86            body.parse::<Token![;]>()?;
87            consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
88        }
89
90        let debug_format =
91            debug_format.unwrap_or_else(|| Lit::Str(LitStr::new("{}", Span::call_site())));
92
93        // shave off 256 indices at the end to allow space for packing these indices into enums
94        let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
95
96        let encodable_impls = if encodable {
97            quote! {
98                #gate_rustc_only
99                impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
100                    fn decode(d: &mut D) -> Self {
101                        Self::from_u32(d.read_u32())
102                    }
103                }
104                #gate_rustc_only
105                impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
106                    fn encode(&self, e: &mut E) {
107                        e.emit_u32(self.as_u32());
108                    }
109                }
110            }
111        } else {
112            quote! {}
113        };
114
115        if ord {
116            derive_paths.push(parse_quote!(Ord));
117            derive_paths.push(parse_quote!(PartialOrd));
118        }
119
120        let step = if ord {
121            quote! {
122                #gate_rustc_only
123                impl ::std::iter::Step for #name {
124                    #[inline]
125                    fn steps_between(start: &Self, end: &Self) -> (usize, Option<usize>) {
126                        <usize as ::std::iter::Step>::steps_between(
127                            &Self::index(*start),
128                            &Self::index(*end),
129                        )
130                    }
131
132                    #[inline]
133                    fn forward_checked(start: Self, u: usize) -> Option<Self> {
134                        Self::index(start).checked_add(u).map(Self::from_usize)
135                    }
136
137                    #[inline]
138                    fn backward_checked(start: Self, u: usize) -> Option<Self> {
139                        Self::index(start).checked_sub(u).map(Self::from_usize)
140                    }
141                }
142            }
143        } else {
144            quote! {}
145        };
146
147        let debug_impl = quote! {
148            impl ::std::fmt::Debug for #name {
149                fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
150                    write!(fmt, #debug_format, self.as_u32())
151                }
152            }
153        };
154
155        Ok(Self(quote! {
156            #(#attrs)*
157            #[derive(Clone, Copy, PartialEq, Eq, Hash, #(#derive_paths),*)]
158            #[cfg_attr(#gate_rustc_only_cfg, rustc_layout_scalar_valid_range_end(#max))]
159            #[cfg_attr(#gate_rustc_only_cfg, rustc_pass_by_value)]
160            #vis struct #name {
161                private_use_as_methods_instead: u32,
162            }
163
164            #(#consts)*
165
166            impl #name {
167                /// Maximum value the index can take, as a `u32`.
168                #vis const MAX_AS_U32: u32  = #max;
169
170                /// Maximum value the index can take.
171                #vis const MAX: Self = Self::from_u32(#max);
172
173                /// Zero value of the index.
174                #vis const ZERO: Self = Self::from_u32(0);
175
176                /// Creates a new index from a given `usize`.
177                ///
178                /// # Panics
179                ///
180                /// Will panic if `value` exceeds `MAX`.
181                #[inline]
182                #vis const fn from_usize(value: usize) -> Self {
183                    assert!(value <= (#max as usize));
184                    // SAFETY: We just checked that `value <= max`.
185                    unsafe {
186                        Self::from_u32_unchecked(value as u32)
187                    }
188                }
189
190                /// Creates a new index from a given `u32`.
191                ///
192                /// # Panics
193                ///
194                /// Will panic if `value` exceeds `MAX`.
195                #[inline]
196                #vis const fn from_u32(value: u32) -> Self {
197                    assert!(value <= #max);
198                    // SAFETY: We just checked that `value <= max`.
199                    unsafe {
200                        Self::from_u32_unchecked(value)
201                    }
202                }
203
204                /// Creates a new index from a given `u16`.
205                ///
206                /// # Panics
207                ///
208                /// Will panic if `value` exceeds `MAX`.
209                #[inline]
210                #vis const fn from_u16(value: u16) -> Self {
211                    let value = value as u32;
212                    assert!(value <= #max);
213                    // SAFETY: We just checked that `value <= max`.
214                    unsafe {
215                        Self::from_u32_unchecked(value)
216                    }
217                }
218
219                /// Creates a new index from a given `u32`.
220                ///
221                /// # Safety
222                ///
223                /// The provided value must be less than or equal to the maximum value for the newtype.
224                /// Providing a value outside this range is undefined due to layout restrictions.
225                ///
226                /// Prefer using `from_u32`.
227                #[inline]
228                #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
229                    Self { private_use_as_methods_instead: value }
230                }
231
232                /// Extracts the value of this index as a `usize`.
233                #[inline]
234                #vis const fn index(self) -> usize {
235                    self.as_usize()
236                }
237
238                /// Extracts the value of this index as a `u32`.
239                #[inline]
240                #vis const fn as_u32(self) -> u32 {
241                    self.private_use_as_methods_instead
242                }
243
244                /// Extracts the value of this index as a `usize`.
245                #[inline]
246                #vis const fn as_usize(self) -> usize {
247                    self.as_u32() as usize
248                }
249            }
250
251            impl std::ops::Add<usize> for #name {
252                type Output = Self;
253
254                #[inline]
255                fn add(self, other: usize) -> Self {
256                    Self::from_usize(self.index() + other)
257                }
258            }
259
260            impl rustc_index::Idx for #name {
261                #[inline]
262                fn new(value: usize) -> Self {
263                    Self::from_usize(value)
264                }
265
266                #[inline]
267                fn index(self) -> usize {
268                    self.as_usize()
269                }
270            }
271
272            #step
273
274            impl From<#name> for u32 {
275                #[inline]
276                fn from(v: #name) -> u32 {
277                    v.as_u32()
278                }
279            }
280
281            impl From<#name> for usize {
282                #[inline]
283                fn from(v: #name) -> usize {
284                    v.as_usize()
285                }
286            }
287
288            impl From<usize> for #name {
289                #[inline]
290                fn from(value: usize) -> Self {
291                    Self::from_usize(value)
292                }
293            }
294
295            impl From<u32> for #name {
296                #[inline]
297                fn from(value: u32) -> Self {
298                    Self::from_u32(value)
299                }
300            }
301
302            #encodable_impls
303            #debug_impl
304        }))
305    }
306}
307
308pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
309    let input = parse_macro_input!(input as Newtype);
310    input.0.into()
311}