derive_more/
from.rs

1use std::iter;
2
3use proc_macro2::{Span, TokenStream};
4use quote::{quote, ToTokens};
5use syn::{parse::Result, DeriveInput, Ident, Index};
6
7use crate::utils::{
8    add_where_clauses_for_new_ident, AttrParams, DeriveType, HashMap, MultiFieldData,
9    RefType, State,
10};
11
12/// Provides the hook to expand `#[derive(From)]` into an implementation of `From`
13pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
14    let state = State::with_attr_params(
15        input,
16        trait_name,
17        quote!(::core::convert),
18        trait_name.to_lowercase(),
19        AttrParams {
20            enum_: vec!["forward", "ignore"],
21            variant: vec!["forward", "ignore", "types"],
22            struct_: vec!["forward", "types"],
23            field: vec!["forward"],
24        },
25    )?;
26    if state.derive_type == DeriveType::Enum {
27        Ok(enum_from(input, state))
28    } else {
29        Ok(struct_from(input, &state))
30    }
31}
32
33pub fn struct_from(input: &DeriveInput, state: &State) -> TokenStream {
34    let multi_field_data = state.enabled_fields_data();
35    let MultiFieldData {
36        fields,
37        variant_info,
38        infos,
39        input_type,
40        trait_path,
41        ..
42    } = multi_field_data.clone();
43
44    let additional_types = variant_info.additional_types(RefType::No);
45    let mut impls = Vec::with_capacity(additional_types.len() + 1);
46    for explicit_type in iter::once(None).chain(additional_types.iter().map(Some)) {
47        let mut new_generics = input.generics.clone();
48
49        let mut initializers = Vec::with_capacity(infos.len());
50        let mut from_types = Vec::with_capacity(infos.len());
51        for (i, (info, field)) in infos.iter().zip(fields.iter()).enumerate() {
52            let field_type = &field.ty;
53            let variable = if fields.len() == 1 {
54                quote! { original }
55            } else {
56                let tuple_index = Index::from(i);
57                quote! { original.#tuple_index }
58            };
59            if let Some(type_) = explicit_type {
60                initializers.push(quote! {
61                    <#field_type as #trait_path<#type_>>::from(#variable)
62                });
63                from_types.push(quote! { #type_ });
64            } else if info.forward {
65                let type_param =
66                    &Ident::new(&format!("__FromT{}", i), Span::call_site());
67                let sub_trait_path = quote! { #trait_path<#type_param> };
68                let type_where_clauses = quote! {
69                    where #field_type: #sub_trait_path
70                };
71                new_generics = add_where_clauses_for_new_ident(
72                    &new_generics,
73                    &[field],
74                    type_param,
75                    type_where_clauses,
76                    true,
77                );
78                let casted_trait = quote! { <#field_type as #sub_trait_path> };
79                initializers.push(quote! { #casted_trait::from(#variable) });
80                from_types.push(quote! { #type_param });
81            } else {
82                initializers.push(variable);
83                from_types.push(quote! { #field_type });
84            }
85        }
86
87        let body = multi_field_data.initializer(&initializers);
88        let (impl_generics, _, where_clause) = new_generics.split_for_impl();
89        let (_, ty_generics, _) = input.generics.split_for_impl();
90
91        impls.push(quote! {
92            #[automatically_derived]
93            impl#impl_generics #trait_path<(#(#from_types),*)> for
94                #input_type#ty_generics #where_clause {
95
96                #[inline]
97                fn from(original: (#(#from_types),*)) -> #input_type#ty_generics {
98                    #body
99                }
100            }
101        });
102    }
103
104    quote! { #( #impls )* }
105}
106
107fn enum_from(input: &DeriveInput, state: State) -> TokenStream {
108    let mut tokens = TokenStream::new();
109
110    let mut variants_per_types = HashMap::default();
111    for variant_state in state.enabled_variant_data().variant_states {
112        let multi_field_data = variant_state.enabled_fields_data();
113        let MultiFieldData { field_types, .. } = multi_field_data.clone();
114        variants_per_types
115            .entry(field_types.clone())
116            .or_insert_with(Vec::new)
117            .push(variant_state);
118    }
119    for (ref field_types, ref variant_states) in variants_per_types {
120        for variant_state in variant_states {
121            let multi_field_data = variant_state.enabled_fields_data();
122            let MultiFieldData {
123                variant_info,
124                infos,
125                ..
126            } = multi_field_data.clone();
127            // If there would be a conflict on a empty tuple derive, ignore the
128            // variants that are not explicitly enabled or have explicitly enabled
129            // or disabled fields
130            if field_types.is_empty()
131                && variant_states.len() > 1
132                && !std::iter::once(variant_info)
133                    .chain(infos)
134                    .any(|info| info.info.enabled.is_some())
135            {
136                continue;
137            }
138            struct_from(input, variant_state).to_tokens(&mut tokens);
139        }
140    }
141    tokens
142}