derive_more/
not_like.rs

1use crate::utils::{
2    add_extra_type_param_bound_op_output, named_to_vec, unnamed_to_vec,
3};
4use proc_macro2::{Span, TokenStream};
5use quote::{quote, ToTokens};
6use std::iter;
7use syn::{Data, DataEnum, DeriveInput, Field, Fields, Ident, Index};
8
9pub fn expand(input: &DeriveInput, trait_name: &str) -> TokenStream {
10    let trait_ident = Ident::new(trait_name, Span::call_site());
11    let method_name = trait_name.to_lowercase();
12    let method_ident = &Ident::new(&method_name, Span::call_site());
13    let input_type = &input.ident;
14
15    let generics = add_extra_type_param_bound_op_output(&input.generics, &trait_ident);
16    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
17
18    let (output_type, block) = match input.data {
19        Data::Struct(ref data_struct) => match data_struct.fields {
20            Fields::Unnamed(ref fields) => (
21                quote!(#input_type#ty_generics),
22                tuple_content(input_type, &unnamed_to_vec(fields), method_ident),
23            ),
24            Fields::Named(ref fields) => (
25                quote!(#input_type#ty_generics),
26                struct_content(input_type, &named_to_vec(fields), method_ident),
27            ),
28            _ => panic!("Unit structs cannot use derive({})", trait_name),
29        },
30        Data::Enum(ref data_enum) => {
31            enum_output_type_and_content(input, data_enum, method_ident)
32        }
33
34        _ => panic!("Only structs and enums can use derive({})", trait_name),
35    };
36
37    quote!(
38        impl#impl_generics ::core::ops::#trait_ident for #input_type#ty_generics #where_clause {
39            type Output = #output_type;
40            #[inline]
41            fn #method_ident(self) -> #output_type {
42                #block
43            }
44        }
45    )
46}
47
48fn tuple_content<T: ToTokens>(
49    input_type: &T,
50    fields: &[&Field],
51    method_ident: &Ident,
52) -> TokenStream {
53    let mut exprs = vec![];
54
55    for i in 0..fields.len() {
56        let i = Index::from(i);
57        // generates `self.0.add()`
58        let expr = quote!(self.#i.#method_ident());
59        exprs.push(expr);
60    }
61
62    quote!(#input_type(#(#exprs),*))
63}
64
65fn struct_content(
66    input_type: &Ident,
67    fields: &[&Field],
68    method_ident: &Ident,
69) -> TokenStream {
70    let mut exprs = vec![];
71
72    for field in fields {
73        // It's safe to unwrap because struct fields always have an identifier
74        let field_id = field.ident.as_ref();
75        // generates `x: self.x.not()`
76        let expr = quote!(#field_id: self.#field_id.#method_ident());
77        exprs.push(expr)
78    }
79
80    quote!(#input_type{#(#exprs),*})
81}
82
83fn enum_output_type_and_content(
84    input: &DeriveInput,
85    data_enum: &DataEnum,
86    method_ident: &Ident,
87) -> (TokenStream, TokenStream) {
88    let input_type = &input.ident;
89    let (_, ty_generics, _) = input.generics.split_for_impl();
90    let mut matches = vec![];
91    let mut method_iter = iter::repeat(method_ident);
92    // If the enum contains unit types that means it can error.
93    let has_unit_type = data_enum.variants.iter().any(|v| v.fields == Fields::Unit);
94
95    for variant in &data_enum.variants {
96        let subtype = &variant.ident;
97        let subtype = quote!(#input_type::#subtype);
98
99        match variant.fields {
100            Fields::Unnamed(ref fields) => {
101                // The patern that is outputted should look like this:
102                // (Subtype(vars)) => Ok(TypePath(exprs))
103                let size = unnamed_to_vec(fields).len();
104                let vars: &Vec<_> = &(0..size)
105                    .map(|i| Ident::new(&format!("__{}", i), Span::call_site()))
106                    .collect();
107                let method_iter = method_iter.by_ref();
108                let mut body = quote!(#subtype(#(#vars.#method_iter()),*));
109                if has_unit_type {
110                    body = quote!(::core::result::Result::Ok(#body))
111                }
112                let matcher = quote! {
113                    #subtype(#(#vars),*) => {
114                        #body
115                    }
116                };
117                matches.push(matcher);
118            }
119            Fields::Named(ref fields) => {
120                // The patern that is outputted should look like this:
121                // (Subtype{a: __l_a, ...} => {
122                //     Ok(Subtype{a: __l_a.neg(__r_a), ...})
123                // }
124                let field_vec = named_to_vec(fields);
125                let size = field_vec.len();
126                let field_names: &Vec<_> = &field_vec
127                    .iter()
128                    .map(|f| f.ident.as_ref().unwrap())
129                    .collect();
130                let vars: &Vec<_> = &(0..size)
131                    .map(|i| Ident::new(&format!("__{}", i), Span::call_site()))
132                    .collect();
133                let method_iter = method_iter.by_ref();
134                let mut body =
135                    quote!(#subtype{#(#field_names: #vars.#method_iter()),*});
136                if has_unit_type {
137                    body = quote!(::core::result::Result::Ok(#body))
138                }
139                let matcher = quote! {
140                    #subtype{#(#field_names: #vars),*} => {
141                        #body
142                    }
143                };
144                matches.push(matcher);
145            }
146            Fields::Unit => {
147                let message = format!("Cannot {}() unit variants", method_ident);
148                matches.push(quote!(#subtype => ::core::result::Result::Err(#message)));
149            }
150        }
151    }
152
153    let body = quote!(
154        match self {
155            #(#matches),*
156        }
157    );
158
159    let output_type = if has_unit_type {
160        quote!(::core::result::Result<#input_type#ty_generics, &'static str>)
161    } else {
162        quote!(#input_type#ty_generics)
163    };
164
165    (output_type, body)
166}