derive_more/
error.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{spanned::Spanned as _, Error, Result};
4
5use crate::utils::{
6    self, AttrParams, DeriveType, FullMetaInfo, HashSet, MetaInfo, MultiFieldData,
7    State,
8};
9
10pub fn expand(
11    input: &syn::DeriveInput,
12    trait_name: &'static str,
13) -> Result<TokenStream> {
14    let syn::DeriveInput {
15        ident, generics, ..
16    } = input;
17
18    let state = State::with_attr_params(
19        input,
20        trait_name,
21        quote!(::std::error),
22        trait_name.to_lowercase(),
23        allowed_attr_params(),
24    )?;
25
26    let type_params: HashSet<_> = generics
27        .params
28        .iter()
29        .filter_map(|generic| match generic {
30            syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
31            _ => None,
32        })
33        .collect();
34
35    let (bounds, source, backtrace) = match state.derive_type {
36        DeriveType::Named | DeriveType::Unnamed => render_struct(&type_params, &state)?,
37        DeriveType::Enum => render_enum(&type_params, &state)?,
38    };
39
40    let source = source.map(|source| {
41        quote! {
42            fn source(&self) -> Option<&(dyn ::std::error::Error + 'static)> {
43                #source
44            }
45        }
46    });
47
48    let backtrace = backtrace.map(|backtrace| {
49        quote! {
50            fn backtrace(&self) -> Option<&::std::backtrace::Backtrace> {
51                #backtrace
52            }
53        }
54    });
55
56    let mut generics = generics.clone();
57
58    if !type_params.is_empty() {
59        let generic_parameters = generics.params.iter();
60        generics = utils::add_extra_where_clauses(
61            &generics,
62            quote! {
63                where
64                    #ident<#(#generic_parameters),*>: ::std::fmt::Debug + ::std::fmt::Display
65            },
66        );
67    }
68
69    if !bounds.is_empty() {
70        let bounds = bounds.iter();
71        generics = utils::add_extra_where_clauses(
72            &generics,
73            quote! {
74                where
75                    #(#bounds: ::std::fmt::Debug + ::std::fmt::Display + ::std::error::Error + 'static),*
76            },
77        );
78    }
79
80    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
81
82    let render = quote! {
83        impl#impl_generics ::std::error::Error for #ident#ty_generics #where_clause {
84            #source
85            #backtrace
86        }
87    };
88
89    Ok(render)
90}
91
92fn render_struct(
93    type_params: &HashSet<syn::Ident>,
94    state: &State,
95) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
96    let parsed_fields = parse_fields(type_params, state)?;
97
98    let source = parsed_fields.render_source_as_struct();
99    let backtrace = parsed_fields.render_backtrace_as_struct();
100
101    Ok((parsed_fields.bounds, source, backtrace))
102}
103
104fn render_enum(
105    type_params: &HashSet<syn::Ident>,
106    state: &State,
107) -> Result<(HashSet<syn::Type>, Option<TokenStream>, Option<TokenStream>)> {
108    let mut bounds = HashSet::default();
109    let mut source_match_arms = Vec::new();
110    let mut backtrace_match_arms = Vec::new();
111
112    for variant in state.enabled_variant_data().variants {
113        let default_info = FullMetaInfo {
114            enabled: true,
115            ..FullMetaInfo::default()
116        };
117
118        let state = State::from_variant(
119            state.input,
120            state.trait_name,
121            state.trait_module.clone(),
122            state.trait_attr.clone(),
123            allowed_attr_params(),
124            variant,
125            default_info,
126        )?;
127
128        let parsed_fields = parse_fields(type_params, &state)?;
129
130        if let Some(expr) = parsed_fields.render_source_as_enum_variant_match_arm() {
131            source_match_arms.push(expr);
132        }
133
134        if let Some(expr) = parsed_fields.render_backtrace_as_enum_variant_match_arm() {
135            backtrace_match_arms.push(expr);
136        }
137
138        bounds.extend(parsed_fields.bounds);
139    }
140
141    let render = |match_arms: &mut Vec<TokenStream>| {
142        if !match_arms.is_empty() && match_arms.len() < state.variants.len() {
143            match_arms.push(quote!(_ => None));
144        }
145
146        if !match_arms.is_empty() {
147            let expr = quote! {
148                match self {
149                    #(#match_arms),*
150                }
151            };
152
153            Some(expr)
154        } else {
155            None
156        }
157    };
158
159    let source = render(&mut source_match_arms);
160    let backtrace = render(&mut backtrace_match_arms);
161
162    Ok((bounds, source, backtrace))
163}
164
165fn allowed_attr_params() -> AttrParams {
166    AttrParams {
167        enum_: vec!["ignore"],
168        struct_: vec!["ignore"],
169        variant: vec!["ignore"],
170        field: vec!["ignore", "source", "backtrace"],
171    }
172}
173
174struct ParsedFields<'input, 'state> {
175    data: MultiFieldData<'input, 'state>,
176    source: Option<usize>,
177    backtrace: Option<usize>,
178    bounds: HashSet<syn::Type>,
179}
180
181impl<'input, 'state> ParsedFields<'input, 'state> {
182    fn new(data: MultiFieldData<'input, 'state>) -> Self {
183        Self {
184            data,
185            source: None,
186            backtrace: None,
187            bounds: HashSet::default(),
188        }
189    }
190}
191
192impl<'input, 'state> ParsedFields<'input, 'state> {
193    fn render_source_as_struct(&self) -> Option<TokenStream> {
194        let source = self.source?;
195        let ident = &self.data.members[source];
196        Some(render_some(quote!(&#ident)))
197    }
198
199    fn render_source_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
200        let source = self.source?;
201        let pattern = self.data.matcher(&[source], &[quote!(source)]);
202        let expr = render_some(quote!(source));
203        Some(quote!(#pattern => #expr))
204    }
205
206    fn render_backtrace_as_struct(&self) -> Option<TokenStream> {
207        let backtrace = self.backtrace?;
208        let backtrace_expr = &self.data.members[backtrace];
209        Some(quote!(Some(&#backtrace_expr)))
210    }
211
212    fn render_backtrace_as_enum_variant_match_arm(&self) -> Option<TokenStream> {
213        let backtrace = self.backtrace?;
214        let pattern = self.data.matcher(&[backtrace], &[quote!(backtrace)]);
215        Some(quote!(#pattern => Some(backtrace)))
216    }
217}
218
219fn render_some<T>(expr: T) -> TokenStream
220where
221    T: quote::ToTokens,
222{
223    quote!(Some(#expr as &(dyn ::std::error::Error + 'static)))
224}
225
226fn parse_fields<'input, 'state>(
227    type_params: &HashSet<syn::Ident>,
228    state: &'state State<'input>,
229) -> Result<ParsedFields<'input, 'state>> {
230    let mut parsed_fields = match state.derive_type {
231        DeriveType::Named => {
232            parse_fields_impl(state, |attr, field, _| {
233                // Unwrapping is safe, cause fields in named struct
234                // always have an ident
235                let ident = field.ident.as_ref().unwrap();
236
237                match attr {
238                    "source" => ident == "source",
239                    "backtrace" => {
240                        ident == "backtrace"
241                            || is_type_path_ends_with_segment(&field.ty, "Backtrace")
242                    }
243                    _ => unreachable!(),
244                }
245            })
246        }
247
248        DeriveType::Unnamed => {
249            let mut parsed_fields =
250                parse_fields_impl(state, |attr, field, len| match attr {
251                    "source" => {
252                        len == 1
253                            && !is_type_path_ends_with_segment(&field.ty, "Backtrace")
254                    }
255                    "backtrace" => {
256                        is_type_path_ends_with_segment(&field.ty, "Backtrace")
257                    }
258                    _ => unreachable!(),
259                })?;
260
261            parsed_fields.source = parsed_fields
262                .source
263                .or_else(|| infer_source_field(&state.fields, &parsed_fields));
264
265            Ok(parsed_fields)
266        }
267
268        _ => unreachable!(),
269    }?;
270
271    if let Some(source) = parsed_fields.source {
272        add_bound_if_type_parameter_used_in_type(
273            &mut parsed_fields.bounds,
274            type_params,
275            &state.fields[source].ty,
276        );
277    }
278
279    Ok(parsed_fields)
280}
281
282/// Checks if `ty` is [`syn::Type::Path`] and ends with segment matching `tail`
283/// and doesn't contain any generic parameters.
284fn is_type_path_ends_with_segment(ty: &syn::Type, tail: &str) -> bool {
285    let ty = match ty {
286        syn::Type::Path(ty) => ty,
287        _ => return false,
288    };
289
290    // Unwrapping is safe, cause 'syn::TypePath.path.segments'
291    // have to have at least one segment
292    let segment = ty.path.segments.last().unwrap();
293
294    match segment.arguments {
295        syn::PathArguments::None => (),
296        _ => return false,
297    };
298
299    segment.ident == tail
300}
301
302fn infer_source_field(
303    fields: &[&syn::Field],
304    parsed_fields: &ParsedFields,
305) -> Option<usize> {
306    // if we have exactly two fields
307    if fields.len() != 2 {
308        return None;
309    }
310
311    // no source field was specified/inferred
312    if parsed_fields.source.is_some() {
313        return None;
314    }
315
316    // but one of the fields was specified/inferred as backtrace field
317    if let Some(backtrace) = parsed_fields.backtrace {
318        // then infer *other field* as source field
319        let source = (backtrace + 1) % 2;
320        // unless it was explicitly marked as non-source
321        if parsed_fields.data.infos[source].info.source != Some(false) {
322            return Some(source);
323        }
324    }
325
326    None
327}
328
329fn parse_fields_impl<'input, 'state, P>(
330    state: &'state State<'input>,
331    is_valid_default_field_for_attr: P,
332) -> Result<ParsedFields<'input, 'state>>
333where
334    P: Fn(&str, &syn::Field, usize) -> bool,
335{
336    let MultiFieldData { fields, infos, .. } = state.enabled_fields_data();
337
338    let iter = fields
339        .iter()
340        .zip(infos.iter().map(|info| &info.info))
341        .enumerate()
342        .map(|(index, (field, info))| (index, *field, info));
343
344    let source = parse_field_impl(
345        &is_valid_default_field_for_attr,
346        state.fields.len(),
347        iter.clone(),
348        "source",
349        |info| info.source,
350    )?;
351
352    let backtrace = parse_field_impl(
353        &is_valid_default_field_for_attr,
354        state.fields.len(),
355        iter.clone(),
356        "backtrace",
357        |info| info.backtrace,
358    )?;
359
360    let mut parsed_fields = ParsedFields::new(state.enabled_fields_data());
361
362    if let Some((index, _, _)) = source {
363        parsed_fields.source = Some(index);
364    }
365
366    if let Some((index, _, _)) = backtrace {
367        parsed_fields.backtrace = Some(index);
368    }
369
370    Ok(parsed_fields)
371}
372
373fn parse_field_impl<'a, P, V>(
374    is_valid_default_field_for_attr: &P,
375    len: usize,
376    iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)> + Clone,
377    attr: &str,
378    value: V,
379) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>>
380where
381    P: Fn(&str, &syn::Field, usize) -> bool,
382    V: Fn(&MetaInfo) -> Option<bool>,
383{
384    let explicit_fields = iter.clone().filter(|(_, _, info)| match value(info) {
385        Some(true) => true,
386        _ => false,
387    });
388
389    let inferred_fields = iter.filter(|(_, field, info)| match value(info) {
390        None => is_valid_default_field_for_attr(attr, field, len),
391        _ => false,
392    });
393
394    let field = assert_iter_contains_zero_or_one_item(
395        explicit_fields,
396        &format!(
397            "Multiple `{}` attributes specified. \
398             Single attribute per struct/enum variant allowed.",
399            attr
400        ),
401    )?;
402
403    let field = match field {
404        field @ Some(_) => field,
405        None => assert_iter_contains_zero_or_one_item(
406            inferred_fields,
407            "Conflicting fields found. Consider specifying some \
408             `#[error(...)]` attributes to resolve conflict.",
409        )?,
410    };
411
412    Ok(field)
413}
414
415fn assert_iter_contains_zero_or_one_item<'a>(
416    mut iter: impl Iterator<Item = (usize, &'a syn::Field, &'a MetaInfo)>,
417    error_msg: &str,
418) -> Result<Option<(usize, &'a syn::Field, &'a MetaInfo)>> {
419    let item = match iter.next() {
420        Some(item) => item,
421        None => return Ok(None),
422    };
423
424    if let Some((_, field, _)) = iter.next() {
425        return Err(Error::new(field.span(), error_msg));
426    }
427
428    Ok(Some(item))
429}
430
431fn add_bound_if_type_parameter_used_in_type(
432    bounds: &mut HashSet<syn::Type>,
433    type_params: &HashSet<syn::Ident>,
434    ty: &syn::Type,
435) {
436    if let Some(ty) = utils::get_if_type_parameter_used_in_type(type_params, ty) {
437        bounds.insert(ty);
438    }
439}