derive_more/
add_like.rs

1use crate::add_helpers::{struct_exprs, tuple_exprs};
2use crate::utils::{
3    add_extra_type_param_bound_op_output, field_idents, named_to_vec, numbered_vars,
4    unnamed_to_vec,
5};
6use proc_macro2::{Span, TokenStream};
7use quote::{quote, ToTokens};
8use std::iter;
9use syn::{Data, DataEnum, DeriveInput, Field, Fields, Ident};
10
11pub fn expand(input: &DeriveInput, trait_name: &str) -> TokenStream {
12    let trait_name = trait_name.trim_end_matches("Self");
13    let trait_ident = Ident::new(trait_name, Span::call_site());
14    let method_name = trait_name.to_lowercase();
15    let method_ident = Ident::new(&method_name, Span::call_site());
16    let input_type = &input.ident;
17
18    let generics = add_extra_type_param_bound_op_output(&input.generics, &trait_ident);
19    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
20
21    let (output_type, block) = match input.data {
22        Data::Struct(ref data_struct) => match data_struct.fields {
23            Fields::Unnamed(ref fields) => (
24                quote!(#input_type#ty_generics),
25                tuple_content(input_type, &unnamed_to_vec(fields), &method_ident),
26            ),
27            Fields::Named(ref fields) => (
28                quote!(#input_type#ty_generics),
29                struct_content(input_type, &named_to_vec(fields), &method_ident),
30            ),
31            _ => panic!("Unit structs cannot use derive({})", trait_name),
32        },
33        Data::Enum(ref data_enum) => (
34            quote!(::core::result::Result<#input_type#ty_generics, &'static str>),
35            enum_content(input_type, data_enum, &method_ident),
36        ),
37
38        _ => panic!("Only structs and enums can use derive({})", trait_name),
39    };
40
41    quote!(
42        impl#impl_generics ::core::ops::#trait_ident for #input_type#ty_generics #where_clause {
43            type Output = #output_type;
44            #[inline]
45            fn #method_ident(self, rhs: #input_type#ty_generics) -> #output_type {
46                #block
47            }
48        }
49    )
50}
51
52fn tuple_content<T: ToTokens>(
53    input_type: &T,
54    fields: &[&Field],
55    method_ident: &Ident,
56) -> TokenStream {
57    let exprs = tuple_exprs(fields, method_ident);
58    quote!(#input_type(#(#exprs),*))
59}
60
61fn struct_content(
62    input_type: &Ident,
63    fields: &[&Field],
64    method_ident: &Ident,
65) -> TokenStream {
66    // It's safe to unwrap because struct fields always have an identifier
67    let exprs = struct_exprs(fields, method_ident);
68    let field_names = field_idents(fields);
69
70    quote!(#input_type{#(#field_names: #exprs),*})
71}
72
73#[allow(clippy::cognitive_complexity)]
74fn enum_content(
75    input_type: &Ident,
76    data_enum: &DataEnum,
77    method_ident: &Ident,
78) -> TokenStream {
79    let mut matches = vec![];
80    let mut method_iter = iter::repeat(method_ident);
81
82    for variant in &data_enum.variants {
83        let subtype = &variant.ident;
84        let subtype = quote!(#input_type::#subtype);
85
86        match variant.fields {
87            Fields::Unnamed(ref fields) => {
88                // The patern that is outputted should look like this:
89                // (Subtype(left_vars), TypePath(right_vars)) => Ok(TypePath(exprs))
90                let size = unnamed_to_vec(fields).len();
91                let l_vars = &numbered_vars(size, "l_");
92                let r_vars = &numbered_vars(size, "r_");
93                let method_iter = method_iter.by_ref();
94                let matcher = quote! {
95                    (#subtype(#(#l_vars),*),
96                     #subtype(#(#r_vars),*)) => {
97                        ::core::result::Result::Ok(#subtype(#(#l_vars.#method_iter(#r_vars)),*))
98                    }
99                };
100                matches.push(matcher);
101            }
102            Fields::Named(ref fields) => {
103                // The patern that is outputted should look like this:
104                // (Subtype{a: __l_a, ...}, Subtype{a: __r_a, ...} => {
105                //     Ok(Subtype{a: __l_a.add(__r_a), ...})
106                // }
107                let field_vec = named_to_vec(fields);
108                let size = field_vec.len();
109                let field_names = &field_idents(&field_vec);
110                let l_vars = &numbered_vars(size, "l_");
111                let r_vars = &numbered_vars(size, "r_");
112                let method_iter = method_iter.by_ref();
113                let matcher = quote! {
114                    (#subtype{#(#field_names: #l_vars),*},
115                     #subtype{#(#field_names: #r_vars),*}) => {
116                        ::core::result::Result::Ok(#subtype{#(#field_names: #l_vars.#method_iter(#r_vars)),*})
117                    }
118                };
119                matches.push(matcher);
120            }
121            Fields::Unit => {
122                let message = format!("Cannot {}() unit variants", method_ident);
123                matches.push(quote!((#subtype, #subtype) => ::core::result::Result::Err(#message)));
124            }
125        }
126    }
127
128    if data_enum.variants.len() > 1 {
129        // In the strange case where there's only one enum variant this is would be an unreachable
130        // match.
131        let message = format!("Trying to {} mismatched enum variants", method_ident);
132        matches.push(quote!(_ => ::core::result::Result::Err(#message)));
133    }
134    quote!(
135        match (self, rhs) {
136            #(#matches),*
137        }
138    )
139}