rustfmt_config_proc_macro/
item_enum.rs
1use proc_macro2::TokenStream;
2use quote::{quote, quote_spanned};
3use syn::spanned::Spanned;
4
5use crate::attrs::*;
6use crate::utils::*;
7
8type Variants = syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>;
9
10pub fn define_config_type_on_enum(em: &syn::ItemEnum) -> syn::Result<TokenStream> {
12 let syn::ItemEnum {
13 vis,
14 enum_token,
15 ident,
16 generics,
17 variants,
18 ..
19 } = em;
20
21 let mod_name_str = format!("__define_config_type_on_enum_{}", ident);
22 let mod_name = syn::Ident::new(&mod_name_str, ident.span());
23 let variants = fold_quote(variants.iter().map(process_variant), |meta| quote!(#meta,));
24
25 let impl_doc_hint = impl_doc_hint(&em.ident, &em.variants);
26 let impl_from_str = impl_from_str(&em.ident, &em.variants);
27 let impl_display = impl_display(&em.ident, &em.variants);
28 let impl_serde = impl_serde(&em.ident, &em.variants);
29 let impl_deserialize = impl_deserialize(&em.ident, &em.variants);
30
31 Ok(quote! {
32 #[allow(non_snake_case)]
33 mod #mod_name {
34 #[derive(Debug, Copy, Clone, Eq, PartialEq)]
35 pub #enum_token #ident #generics { #variants }
36 #impl_display
37 #impl_doc_hint
38 #impl_from_str
39 #impl_serde
40 #impl_deserialize
41 }
42 #vis use #mod_name::#ident;
43 })
44}
45
46fn process_variant(variant: &syn::Variant) -> TokenStream {
48 let metas = variant
49 .attrs
50 .iter()
51 .filter(|attr| !is_doc_hint(attr) && !is_config_value(attr) && !is_unstable_variant(attr));
52 let attrs = fold_quote(metas, |meta| quote!(#meta));
53 let syn::Variant { ident, fields, .. } = variant;
54 quote!(#attrs #ident #fields)
55}
56
57fn fields_in_variant(variant: &syn::Variant) -> TokenStream {
60 match &variant.fields {
62 syn::Fields::Unnamed(_) => quote_spanned! { variant.span() => (..) },
63 syn::Fields::Unit => quote_spanned! { variant.span() => },
64 syn::Fields::Named(_) => quote_spanned! { variant.span() => {..} },
65 }
66}
67
68fn impl_doc_hint(ident: &syn::Ident, variants: &Variants) -> TokenStream {
69 let doc_hint = variants
70 .iter()
71 .map(doc_hint_of_variant)
72 .collect::<Vec<_>>()
73 .join("|");
74 let doc_hint = format!("[{}]", doc_hint);
75
76 let variant_stables = variants
77 .iter()
78 .map(|v| (&v.ident, fields_in_variant(&v), !unstable_of_variant(v)));
79 let match_patterns = fold_quote(variant_stables, |(v, fields, stable)| {
80 quote! {
81 #ident::#v #fields => #stable,
82 }
83 });
84 quote! {
85 use crate::config::ConfigType;
86 impl ConfigType for #ident {
87 fn doc_hint() -> String {
88 #doc_hint.to_owned()
89 }
90 fn stable_variant(&self) -> bool {
91 match self {
92 #match_patterns
93 }
94 }
95 }
96 }
97}
98
99fn impl_display(ident: &syn::Ident, variants: &Variants) -> TokenStream {
100 let vs = variants
101 .iter()
102 .filter(|v| is_unit(v))
103 .map(|v| (config_value_of_variant(v), &v.ident));
104 let match_patterns = fold_quote(vs, |(s, v)| {
105 quote! {
106 #ident::#v => write!(f, "{}", #s),
107 }
108 });
109 quote! {
110 use std::fmt;
111 impl fmt::Display for #ident {
112 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
113 match self {
114 #match_patterns
115 _ => unimplemented!(),
116 }
117 }
118 }
119 }
120}
121
122fn impl_from_str(ident: &syn::Ident, variants: &Variants) -> TokenStream {
123 let vs = variants
124 .iter()
125 .filter(|v| is_unit(v))
126 .map(|v| (config_value_of_variant(v), &v.ident));
127 let if_patterns = fold_quote(vs, |(s, v)| {
128 quote! {
129 if #s.eq_ignore_ascii_case(s) {
130 return Ok(#ident::#v);
131 }
132 }
133 });
134 let mut err_msg = String::from("Bad variant, expected one of:");
135 for v in variants.iter().filter(|v| is_unit(v)) {
136 err_msg.push_str(&format!(" `{}`", v.ident));
137 }
138
139 quote! {
140 impl ::std::str::FromStr for #ident {
141 type Err = &'static str;
142
143 fn from_str(s: &str) -> Result<Self, Self::Err> {
144 #if_patterns
145 return Err(#err_msg);
146 }
147 }
148 }
149}
150
151fn doc_hint_of_variant(variant: &syn::Variant) -> String {
152 let mut text = find_doc_hint(&variant.attrs).unwrap_or(variant.ident.to_string());
153 if unstable_of_variant(&variant) {
154 text.push_str(" (unstable)")
155 };
156 text
157}
158
159fn config_value_of_variant(variant: &syn::Variant) -> String {
160 find_config_value(&variant.attrs).unwrap_or(variant.ident.to_string())
161}
162
163fn unstable_of_variant(variant: &syn::Variant) -> bool {
164 any_unstable_variant(&variant.attrs)
165}
166
167fn impl_serde(ident: &syn::Ident, variants: &Variants) -> TokenStream {
168 let arms = fold_quote(variants.iter(), |v| {
169 let v_ident = &v.ident;
170 let pattern = match v.fields {
171 syn::Fields::Named(..) => quote!(#ident::v_ident{..}),
172 syn::Fields::Unnamed(..) => quote!(#ident::#v_ident(..)),
173 syn::Fields::Unit => quote!(#ident::#v_ident),
174 };
175 let option_value = config_value_of_variant(v);
176 quote! {
177 #pattern => serializer.serialize_str(&#option_value),
178 }
179 });
180
181 quote! {
182 impl ::serde::ser::Serialize for #ident {
183 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
184 where
185 S: ::serde::ser::Serializer,
186 {
187 use serde::ser::Error;
188 match self {
189 #arms
190 _ => Err(S::Error::custom(format!("Cannot serialize {:?}", self))),
191 }
192 }
193 }
194 }
195}
196
197fn impl_deserialize(ident: &syn::Ident, variants: &Variants) -> TokenStream {
199 let supported_vs = variants.iter().filter(|v| is_unit(v));
200 let if_patterns = fold_quote(supported_vs, |v| {
201 let config_value = config_value_of_variant(v);
202 let variant_ident = &v.ident;
203 quote! {
204 if #config_value.eq_ignore_ascii_case(s) {
205 return Ok(#ident::#variant_ident);
206 }
207 }
208 });
209
210 let supported_vs = variants.iter().filter(|v| is_unit(v));
211 let allowed = fold_quote(supported_vs.map(config_value_of_variant), |s| quote!(#s,));
212
213 quote! {
214 impl<'de> serde::de::Deserialize<'de> for #ident {
215 fn deserialize<D>(d: D) -> Result<Self, D::Error>
216 where
217 D: serde::Deserializer<'de>,
218 {
219 use serde::de::{Error, Visitor};
220 use std::marker::PhantomData;
221 use std::fmt;
222 struct StringOnly<T>(PhantomData<T>);
223 impl<'de, T> Visitor<'de> for StringOnly<T>
224 where T: serde::Deserializer<'de> {
225 type Value = String;
226 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
227 formatter.write_str("string")
228 }
229 fn visit_str<E>(self, value: &str) -> Result<String, E> {
230 Ok(String::from(value))
231 }
232 }
233 let s = &d.deserialize_string(StringOnly::<D>(PhantomData))?;
234
235 #if_patterns
236
237 static ALLOWED: &'static[&str] = &[#allowed];
238 Err(D::Error::unknown_variant(&s, ALLOWED))
239 }
240 }
241 }
242}