derive_more/
is_variant.rs

1use crate::utils::{AttrParams, DeriveType, State};
2use convert_case::{Case, Casing};
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{DeriveInput, Fields, Ident, Result};
6
7pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
8    let state = State::with_attr_params(
9        input,
10        trait_name,
11        quote!(),
12        String::from("is_variant"),
13        AttrParams {
14            enum_: vec!["ignore"],
15            variant: vec!["ignore"],
16            struct_: vec!["ignore"],
17            field: vec!["ignore"],
18        },
19    )?;
20    assert!(
21        state.derive_type == DeriveType::Enum,
22        "IsVariant can only be derived for enums"
23    );
24
25    let enum_name = &input.ident;
26    let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl();
27
28    let mut funcs = vec![];
29    for variant_state in state.enabled_variant_data().variant_states {
30        let variant = variant_state.variant.unwrap();
31        let fn_name = Ident::new(
32            &format_ident!("is_{}", variant.ident)
33                .to_string()
34                .to_case(Case::Snake),
35            variant.ident.span(),
36        );
37        let variant_ident = &variant.ident;
38
39        let data_pattern = match variant.fields {
40            Fields::Named(_) => quote! { {..} },
41            Fields::Unnamed(_) => quote! { (..) },
42            Fields::Unit => quote! {},
43        };
44        let func = quote! {
45            pub fn #fn_name(&self) -> bool {
46                match self {
47                    #enum_name ::#variant_ident #data_pattern => true,
48                    _ => false
49                }
50            }
51        };
52        funcs.push(func);
53    }
54
55    let imp = quote! {
56        impl #imp_generics #enum_name #type_generics #where_clause{
57            #(#funcs)*
58        }
59    };
60
61    Ok(imp)
62}