derive_more/
unwrap.rs

1use crate::utils::{AttrParams, DeriveType, State};
2use convert_case::{Case, Casing};
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{DeriveInput, Fields, Ident, Result};
6
7pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
8    let state = State::with_attr_params(
9        input,
10        trait_name,
11        quote!(),
12        String::from("unwrap"),
13        AttrParams {
14            enum_: vec!["ignore"],
15            variant: vec!["ignore"],
16            struct_: vec!["ignore"],
17            field: vec!["ignore"],
18        },
19    )?;
20    assert!(
21        state.derive_type == DeriveType::Enum,
22        "Unwrap can only be derived for enums"
23    );
24
25    let enum_name = &input.ident;
26    let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl();
27
28    let mut funcs = vec![];
29    for variant_state in state.enabled_variant_data().variant_states {
30        let variant = variant_state.variant.unwrap();
31        let fn_name = Ident::new(
32            &format_ident!("unwrap_{}", variant.ident)
33                .to_string()
34                .to_case(Case::Snake),
35            variant.ident.span(),
36        );
37        let variant_ident = &variant.ident;
38
39        let (data_pattern, ret_value, ret_type) = match variant.fields {
40            Fields::Named(_) => panic!("cannot unwrap anonymous records"),
41            Fields::Unnamed(ref fields) => {
42                let data_pattern =
43                    (0..fields.unnamed.len()).fold(vec![], |mut a, n| {
44                        a.push(format_ident!("field_{}", n));
45                        a
46                    });
47                let ret_type = &fields.unnamed;
48                (
49                    quote! { (#(#data_pattern),*) },
50                    quote! { (#(#data_pattern),*) },
51                    quote! { (#ret_type) },
52                )
53            }
54            Fields::Unit => (quote! {}, quote! { () }, quote! { () }),
55        };
56
57        let other_arms = state.variant_states.iter().map(|variant| {
58            variant.variant.unwrap()
59        }).filter(|variant| {
60            &variant.ident != variant_ident
61        }).map(|variant| {
62            let data_pattern = match variant.fields {
63                Fields::Named(_) => quote! { {..} },
64                Fields::Unnamed(_) => quote! { (..) },
65                Fields::Unit => quote! {},
66            };
67            let variant_ident = &variant.ident;
68            quote! { #enum_name :: #variant_ident #data_pattern =>
69                      panic!(concat!("called `", stringify!(#enum_name), "::", stringify!(#fn_name),
70                                     "()` on a `", stringify!(#variant_ident), "` value"))
71            }
72        });
73
74        let func = quote! {
75            #[track_caller]
76            pub fn #fn_name(self) -> #ret_type {
77                match self {
78                    #enum_name ::#variant_ident #data_pattern => #ret_value,
79                    #(#other_arms),*
80                }
81            }
82        };
83        funcs.push(func);
84    }
85
86    let imp = quote! {
87        impl #imp_generics #enum_name #type_generics #where_clause{
88            #(#funcs)*
89        }
90    };
91
92    Ok(imp)
93}