rustfmt_config_proc_macro/
item_enum.rs

1use proc_macro2::TokenStream;
2use quote::{quote, quote_spanned};
3use syn::spanned::Spanned;
4
5use crate::attrs::*;
6use crate::utils::*;
7
8type Variants = syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>;
9
10/// Defines and implements `config_type` enum.
11pub fn define_config_type_on_enum(em: &syn::ItemEnum) -> syn::Result<TokenStream> {
12    let syn::ItemEnum {
13        vis,
14        enum_token,
15        ident,
16        generics,
17        variants,
18        ..
19    } = em;
20
21    let mod_name_str = format!("__define_config_type_on_enum_{}", ident);
22    let mod_name = syn::Ident::new(&mod_name_str, ident.span());
23    let variants = fold_quote(variants.iter().map(process_variant), |meta| quote!(#meta,));
24
25    let impl_doc_hint = impl_doc_hint(&em.ident, &em.variants);
26    let impl_from_str = impl_from_str(&em.ident, &em.variants);
27    let impl_display = impl_display(&em.ident, &em.variants);
28    let impl_serde = impl_serde(&em.ident, &em.variants);
29    let impl_deserialize = impl_deserialize(&em.ident, &em.variants);
30
31    Ok(quote! {
32        #[allow(non_snake_case)]
33        mod #mod_name {
34            #[derive(Debug, Copy, Clone, Eq, PartialEq)]
35            pub #enum_token #ident #generics { #variants }
36            #impl_display
37            #impl_doc_hint
38            #impl_from_str
39            #impl_serde
40            #impl_deserialize
41        }
42        #vis use #mod_name::#ident;
43    })
44}
45
46/// Remove attributes specific to `config_proc_macro` from enum variant fields.
47fn process_variant(variant: &syn::Variant) -> TokenStream {
48    let metas = variant
49        .attrs
50        .iter()
51        .filter(|attr| !is_doc_hint(attr) && !is_config_value(attr) && !is_unstable_variant(attr));
52    let attrs = fold_quote(metas, |meta| quote!(#meta));
53    let syn::Variant { ident, fields, .. } = variant;
54    quote!(#attrs #ident #fields)
55}
56
57/// Return the correct syntax to pattern match on the enum variant, discarding all
58/// internal field data.
59fn fields_in_variant(variant: &syn::Variant) -> TokenStream {
60    // With thanks to https://stackoverflow.com/a/65182902
61    match &variant.fields {
62        syn::Fields::Unnamed(_) => quote_spanned! { variant.span() => (..) },
63        syn::Fields::Unit => quote_spanned! { variant.span() => },
64        syn::Fields::Named(_) => quote_spanned! { variant.span() => {..} },
65    }
66}
67
68fn impl_doc_hint(ident: &syn::Ident, variants: &Variants) -> TokenStream {
69    let doc_hint = variants
70        .iter()
71        .map(doc_hint_of_variant)
72        .collect::<Vec<_>>()
73        .join("|");
74    let doc_hint = format!("[{}]", doc_hint);
75
76    let variant_stables = variants
77        .iter()
78        .map(|v| (&v.ident, fields_in_variant(&v), !unstable_of_variant(v)));
79    let match_patterns = fold_quote(variant_stables, |(v, fields, stable)| {
80        quote! {
81            #ident::#v #fields => #stable,
82        }
83    });
84    quote! {
85        use crate::config::ConfigType;
86        impl ConfigType for #ident {
87            fn doc_hint() -> String {
88                #doc_hint.to_owned()
89            }
90            fn stable_variant(&self) -> bool {
91                match self {
92                    #match_patterns
93                }
94            }
95        }
96    }
97}
98
99fn impl_display(ident: &syn::Ident, variants: &Variants) -> TokenStream {
100    let vs = variants
101        .iter()
102        .filter(|v| is_unit(v))
103        .map(|v| (config_value_of_variant(v), &v.ident));
104    let match_patterns = fold_quote(vs, |(s, v)| {
105        quote! {
106            #ident::#v => write!(f, "{}", #s),
107        }
108    });
109    quote! {
110        use std::fmt;
111        impl fmt::Display for #ident {
112            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113                match self {
114                    #match_patterns
115                    _ => unimplemented!(),
116                }
117            }
118        }
119    }
120}
121
122fn impl_from_str(ident: &syn::Ident, variants: &Variants) -> TokenStream {
123    let vs = variants
124        .iter()
125        .filter(|v| is_unit(v))
126        .map(|v| (config_value_of_variant(v), &v.ident));
127    let if_patterns = fold_quote(vs, |(s, v)| {
128        quote! {
129            if #s.eq_ignore_ascii_case(s) {
130                return Ok(#ident::#v);
131            }
132        }
133    });
134    let mut err_msg = String::from("Bad variant, expected one of:");
135    for v in variants.iter().filter(|v| is_unit(v)) {
136        err_msg.push_str(&format!(" `{}`", v.ident));
137    }
138
139    quote! {
140        impl ::std::str::FromStr for #ident {
141            type Err = &'static str;
142
143            fn from_str(s: &str) -> Result<Self, Self::Err> {
144                #if_patterns
145                return Err(#err_msg);
146            }
147        }
148    }
149}
150
151fn doc_hint_of_variant(variant: &syn::Variant) -> String {
152    let mut text = find_doc_hint(&variant.attrs).unwrap_or(variant.ident.to_string());
153    if unstable_of_variant(&variant) {
154        text.push_str(" (unstable)")
155    };
156    text
157}
158
159fn config_value_of_variant(variant: &syn::Variant) -> String {
160    find_config_value(&variant.attrs).unwrap_or(variant.ident.to_string())
161}
162
163fn unstable_of_variant(variant: &syn::Variant) -> bool {
164    any_unstable_variant(&variant.attrs)
165}
166
167fn impl_serde(ident: &syn::Ident, variants: &Variants) -> TokenStream {
168    let arms = fold_quote(variants.iter(), |v| {
169        let v_ident = &v.ident;
170        let pattern = match v.fields {
171            syn::Fields::Named(..) => quote!(#ident::v_ident{..}),
172            syn::Fields::Unnamed(..) => quote!(#ident::#v_ident(..)),
173            syn::Fields::Unit => quote!(#ident::#v_ident),
174        };
175        let option_value = config_value_of_variant(v);
176        quote! {
177            #pattern => serializer.serialize_str(&#option_value),
178        }
179    });
180
181    quote! {
182        impl ::serde::ser::Serialize for #ident {
183            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
184            where
185                S: ::serde::ser::Serializer,
186            {
187                use serde::ser::Error;
188                match self {
189                    #arms
190                    _ => Err(S::Error::custom(format!("Cannot serialize {:?}", self))),
191                }
192            }
193        }
194    }
195}
196
197// Currently only unit variants are supported.
198fn impl_deserialize(ident: &syn::Ident, variants: &Variants) -> TokenStream {
199    let supported_vs = variants.iter().filter(|v| is_unit(v));
200    let if_patterns = fold_quote(supported_vs, |v| {
201        let config_value = config_value_of_variant(v);
202        let variant_ident = &v.ident;
203        quote! {
204            if #config_value.eq_ignore_ascii_case(s) {
205                return Ok(#ident::#variant_ident);
206            }
207        }
208    });
209
210    let supported_vs = variants.iter().filter(|v| is_unit(v));
211    let allowed = fold_quote(supported_vs.map(config_value_of_variant), |s| quote!(#s,));
212
213    quote! {
214        impl<'de> serde::de::Deserialize<'de> for #ident {
215            fn deserialize<D>(d: D) -> Result<Self, D::Error>
216            where
217                D: serde::Deserializer<'de>,
218            {
219                use serde::de::{Error, Visitor};
220                use std::marker::PhantomData;
221                use std::fmt;
222                struct StringOnly<T>(PhantomData<T>);
223                impl<'de, T> Visitor<'de> for StringOnly<T>
224                where T: serde::Deserializer<'de> {
225                    type Value = String;
226                    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
227                        formatter.write_str("string")
228                    }
229                    fn visit_str<E>(self, value: &str) -> Result<String, E> {
230                        Ok(String::from(value))
231                    }
232                }
233                let s = &d.deserialize_string(StringOnly::<D>(PhantomData))?;
234
235                #if_patterns
236
237                static ALLOWED: &'static[&str] = &[#allowed];
238                Err(D::Error::unknown_variant(&s, ALLOWED))
239            }
240        }
241    }
242}