derive_more/
sum_like.rs

1use crate::utils::{
2    add_extra_ty_param_bound, add_extra_where_clauses, MultiFieldData, State,
3};
4use proc_macro2::{Span, TokenStream};
5use quote::quote;
6use syn::{DeriveInput, Ident, Result};
7
8pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
9    let state = State::new(
10        input,
11        trait_name,
12        quote!(::core::iter),
13        trait_name.to_lowercase(),
14    )?;
15    let multi_field_data = state.enabled_fields_data();
16    let MultiFieldData {
17        input_type,
18        field_types,
19        trait_path,
20        method_ident,
21        ..
22    } = multi_field_data.clone();
23
24    let op_trait_name = if trait_name == "Sum" { "Add" } else { "Mul" };
25    let op_trait_ident = Ident::new(op_trait_name, Span::call_site());
26    let op_path = quote!(::core::ops::#op_trait_ident);
27    let op_method_ident =
28        Ident::new(&(op_trait_name.to_lowercase()), Span::call_site());
29    let has_type_params = input.generics.type_params().next().is_none();
30    let generics = if has_type_params {
31        input.generics.clone()
32    } else {
33        let (_, ty_generics, _) = input.generics.split_for_impl();
34        let generics = add_extra_ty_param_bound(&input.generics, trait_path);
35        let operator_where_clause = quote! {
36            where #input_type#ty_generics: #op_path<Output=#input_type#ty_generics>
37        };
38        add_extra_where_clauses(&generics, operator_where_clause)
39    };
40    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
41
42    let initializers: Vec<_> = field_types
43        .iter()
44        .map(|field_type| quote!(#trait_path::#method_ident(::core::iter::empty::<#field_type>())))
45        .collect();
46    let identity = multi_field_data.initializer(&initializers);
47
48    Ok(quote!(
49        impl#impl_generics #trait_path for #input_type#ty_generics #where_clause {
50            #[inline]
51            fn #method_ident<I: ::core::iter::Iterator<Item = Self>>(iter: I) -> Self {
52                iter.fold(#identity, #op_path::#op_method_ident)
53            }
54        }
55    ))
56}