1use crate::utils::{AttrParams, DeriveType, State};
2use convert_case::{Case, Casing};
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5use syn::{DeriveInput, Fields, Ident, Result};
6
7pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
8 let state = State::with_attr_params(
9 input,
10 trait_name,
11 quote!(),
12 String::from("unwrap"),
13 AttrParams {
14 enum_: vec!["ignore"],
15 variant: vec!["ignore"],
16 struct_: vec!["ignore"],
17 field: vec!["ignore"],
18 },
19 )?;
20 assert!(
21 state.derive_type == DeriveType::Enum,
22 "Unwrap can only be derived for enums"
23 );
24
25 let enum_name = &input.ident;
26 let (imp_generics, type_generics, where_clause) = input.generics.split_for_impl();
27
28 let mut funcs = vec![];
29 for variant_state in state.enabled_variant_data().variant_states {
30 let variant = variant_state.variant.unwrap();
31 let fn_name = Ident::new(
32 &format_ident!("unwrap_{}", variant.ident)
33 .to_string()
34 .to_case(Case::Snake),
35 variant.ident.span(),
36 );
37 let variant_ident = &variant.ident;
38
39 let (data_pattern, ret_value, ret_type) = match variant.fields {
40 Fields::Named(_) => panic!("cannot unwrap anonymous records"),
41 Fields::Unnamed(ref fields) => {
42 let data_pattern =
43 (0..fields.unnamed.len()).fold(vec![], |mut a, n| {
44 a.push(format_ident!("field_{}", n));
45 a
46 });
47 let ret_type = &fields.unnamed;
48 (
49 quote! { (#(#data_pattern),*) },
50 quote! { (#(#data_pattern),*) },
51 quote! { (#ret_type) },
52 )
53 }
54 Fields::Unit => (quote! {}, quote! { () }, quote! { () }),
55 };
56
57 let other_arms = state.variant_states.iter().map(|variant| {
58 variant.variant.unwrap()
59 }).filter(|variant| {
60 &variant.ident != variant_ident
61 }).map(|variant| {
62 let data_pattern = match variant.fields {
63 Fields::Named(_) => quote! { {..} },
64 Fields::Unnamed(_) => quote! { (..) },
65 Fields::Unit => quote! {},
66 };
67 let variant_ident = &variant.ident;
68 quote! { #enum_name :: #variant_ident #data_pattern =>
69 panic!(concat!("called `", stringify!(#enum_name), "::", stringify!(#fn_name),
70 "()` on a `", stringify!(#variant_ident), "` value"))
71 }
72 });
73
74 let func = quote! {
75 #[track_caller]
76 pub fn #fn_name(self) -> #ret_type {
77 match self {
78 #enum_name ::#variant_ident #data_pattern => #ret_value,
79 #(#other_arms),*
80 }
81 }
82 };
83 funcs.push(func);
84 }
85
86 let imp = quote! {
87 impl #imp_generics #enum_name #type_generics #where_clause{
88 #(#funcs)*
89 }
90 };
91
92 Ok(imp)
93}