prost_derive/
lib.rs

1#![doc(html_root_url = "https://docs.rs/prost-derive/0.14.1")]
2// The `quote!` macro requires deep recursion.
3#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro2::{Span, TokenStream};
11use quote::quote;
12use syn::{
13    punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
14    FieldsUnnamed, Ident, Index, Variant,
15};
16
17mod field;
18use crate::field::Field;
19
20fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
21    let input: DeriveInput = syn::parse2(input)?;
22
23    let ident = input.ident;
24
25    syn::custom_keyword!(skip_debug);
26    let skip_debug = input
27        .attrs
28        .into_iter()
29        .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
30
31    let variant_data = match input.data {
32        Data::Struct(variant_data) => variant_data,
33        Data::Enum(..) => bail!("Message can not be derived for an enum"),
34        Data::Union(..) => bail!("Message can not be derived for a union"),
35    };
36
37    let generics = &input.generics;
38    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
39
40    let (is_struct, fields) = match variant_data {
41        DataStruct {
42            fields: Fields::Named(FieldsNamed { named: fields, .. }),
43            ..
44        } => (true, fields.into_iter().collect()),
45        DataStruct {
46            fields:
47                Fields::Unnamed(FieldsUnnamed {
48                    unnamed: fields, ..
49                }),
50            ..
51        } => (false, fields.into_iter().collect()),
52        DataStruct {
53            fields: Fields::Unit,
54            ..
55        } => (false, Vec::new()),
56    };
57
58    let mut next_tag: u32 = 1;
59    let mut fields = fields
60        .into_iter()
61        .enumerate()
62        .flat_map(|(i, field)| {
63            let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
64                let index = Index {
65                    index: i as u32,
66                    span: Span::call_site(),
67                };
68                quote!(#index)
69            });
70            match Field::new(field.attrs, Some(next_tag)) {
71                Ok(Some(field)) => {
72                    next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
73                    Some(Ok((field_ident, field)))
74                }
75                Ok(None) => None,
76                Err(err) => Some(Err(
77                    err.context(format!("invalid message field {}.{}", ident, field_ident))
78                )),
79            }
80        })
81        .collect::<Result<Vec<_>, _>>()?;
82
83    // We want Debug to be in declaration order
84    let unsorted_fields = fields.clone();
85
86    // Sort the fields by tag number so that fields will be encoded in tag order.
87    // TODO: This encodes oneof fields in the position of their lowest tag,
88    // regardless of the currently occupied variant, is that consequential?
89    // See: https://developers.google.com/protocol-buffers/docs/encoding#order
90    fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
91    let fields = fields;
92
93    if let Some(duplicate_tag) = fields
94        .iter()
95        .flat_map(|(_, field)| field.tags())
96        .duplicates()
97        .next()
98    {
99        bail!(
100            "message {} has multiple fields with tag {}",
101            ident,
102            duplicate_tag
103        )
104    };
105
106    let encoded_len = fields
107        .iter()
108        .map(|(field_ident, field)| field.encoded_len(quote!(self.#field_ident)));
109
110    let encode = fields
111        .iter()
112        .map(|(field_ident, field)| field.encode(quote!(self.#field_ident)));
113
114    let merge = fields.iter().map(|(field_ident, field)| {
115        let merge = field.merge(quote!(value));
116        let tags = field.tags().into_iter().map(|tag| quote!(#tag));
117        let tags = Itertools::intersperse(tags, quote!(|));
118
119        quote! {
120            #(#tags)* => {
121                let mut value = &mut self.#field_ident;
122                #merge.map_err(|mut error| {
123                    error.push(STRUCT_NAME, stringify!(#field_ident));
124                    error
125                })
126            },
127        }
128    });
129
130    let struct_name = if fields.is_empty() {
131        quote!()
132    } else {
133        quote!(
134            const STRUCT_NAME: &'static str = stringify!(#ident);
135        )
136    };
137
138    let clear = fields
139        .iter()
140        .map(|(field_ident, field)| field.clear(quote!(self.#field_ident)));
141
142    let default = if is_struct {
143        let default = fields.iter().map(|(field_ident, field)| {
144            let value = field.default();
145            quote!(#field_ident: #value,)
146        });
147        quote! {#ident {
148            #(#default)*
149        }}
150    } else {
151        let default = fields.iter().map(|(_, field)| {
152            let value = field.default();
153            quote!(#value,)
154        });
155        quote! {#ident (
156            #(#default)*
157        )}
158    };
159
160    let methods = fields
161        .iter()
162        .flat_map(|(field_ident, field)| field.methods(field_ident))
163        .collect::<Vec<_>>();
164    let methods = if methods.is_empty() {
165        quote!()
166    } else {
167        quote! {
168            #[allow(dead_code)]
169            impl #impl_generics #ident #ty_generics #where_clause {
170                #(#methods)*
171            }
172        }
173    };
174
175    let expanded = quote! {
176        impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
177            #[allow(unused_variables)]
178            fn encode_raw(&self, buf: &mut impl ::prost::bytes::BufMut) {
179                #(#encode)*
180            }
181
182            #[allow(unused_variables)]
183            fn merge_field(
184                &mut self,
185                tag: u32,
186                wire_type: ::prost::encoding::wire_type::WireType,
187                buf: &mut impl ::prost::bytes::Buf,
188                ctx: ::prost::encoding::DecodeContext,
189            ) -> ::core::result::Result<(), ::prost::DecodeError>
190            {
191                #struct_name
192                match tag {
193                    #(#merge)*
194                    _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
195                }
196            }
197
198            #[inline]
199            fn encoded_len(&self) -> usize {
200                0 #(+ #encoded_len)*
201            }
202
203            fn clear(&mut self) {
204                #(#clear;)*
205            }
206        }
207
208        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
209            fn default() -> Self {
210                #default
211            }
212        }
213    };
214    let expanded = if skip_debug {
215        expanded
216    } else {
217        let debugs = unsorted_fields.iter().map(|(field_ident, field)| {
218            let wrapper = field.debug(quote!(self.#field_ident));
219            let call = if is_struct {
220                quote!(builder.field(stringify!(#field_ident), &wrapper))
221            } else {
222                quote!(builder.field(&wrapper))
223            };
224            quote! {
225                 let builder = {
226                     let wrapper = #wrapper;
227                     #call
228                 };
229            }
230        });
231        let debug_builder = if is_struct {
232            quote!(f.debug_struct(stringify!(#ident)))
233        } else {
234            quote!(f.debug_tuple(stringify!(#ident)))
235        };
236        quote! {
237            #expanded
238
239            impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
240                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
241                    let mut builder = #debug_builder;
242                    #(#debugs;)*
243                    builder.finish()
244                }
245            }
246        }
247    };
248
249    let expanded = quote! {
250        #expanded
251
252        #methods
253    };
254
255    Ok(expanded)
256}
257
258#[proc_macro_derive(Message, attributes(prost))]
259pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
260    try_message(input.into()).unwrap().into()
261}
262
263fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
264    let input: DeriveInput = syn::parse2(input)?;
265    let ident = input.ident;
266
267    let generics = &input.generics;
268    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
269
270    let punctuated_variants = match input.data {
271        Data::Enum(DataEnum { variants, .. }) => variants,
272        Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
273        Data::Union(..) => bail!("Enumeration can not be derived for a union"),
274    };
275
276    // Map the variants into 'fields'.
277    let mut variants: Vec<(Ident, Expr)> = Vec::new();
278    for Variant {
279        ident,
280        fields,
281        discriminant,
282        ..
283    } in punctuated_variants
284    {
285        match fields {
286            Fields::Unit => (),
287            Fields::Named(_) | Fields::Unnamed(_) => {
288                bail!("Enumeration variants may not have fields")
289            }
290        }
291
292        match discriminant {
293            Some((_, expr)) => variants.push((ident, expr)),
294            None => bail!("Enumeration variants must have a discriminant"),
295        }
296    }
297
298    if variants.is_empty() {
299        panic!("Enumeration must have at least one variant");
300    }
301
302    let default = variants[0].0.clone();
303
304    let is_valid = variants.iter().map(|(_, value)| quote!(#value => true));
305    let from = variants
306        .iter()
307        .map(|(variant, value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)));
308
309    let try_from = variants
310        .iter()
311        .map(|(variant, value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant)));
312
313    let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
314    let from_i32_doc = format!(
315        "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
316        ident
317    );
318
319    let expanded = quote! {
320        impl #impl_generics #ident #ty_generics #where_clause {
321            #[doc=#is_valid_doc]
322            pub fn is_valid(value: i32) -> bool {
323                match value {
324                    #(#is_valid,)*
325                    _ => false,
326                }
327            }
328
329            #[deprecated = "Use the TryFrom<i32> implementation instead"]
330            #[doc=#from_i32_doc]
331            pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
332                match value {
333                    #(#from,)*
334                    _ => ::core::option::Option::None,
335                }
336            }
337        }
338
339        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
340            fn default() -> #ident {
341                #ident::#default
342            }
343        }
344
345        impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
346            fn from(value: #ident) -> i32 {
347                value as i32
348            }
349        }
350
351        impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause {
352            type Error = ::prost::UnknownEnumValue;
353
354            fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::UnknownEnumValue> {
355                match value {
356                    #(#try_from,)*
357                    _ => ::core::result::Result::Err(::prost::UnknownEnumValue(value)),
358                }
359            }
360        }
361    };
362
363    Ok(expanded)
364}
365
366#[proc_macro_derive(Enumeration, attributes(prost))]
367pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
368    try_enumeration(input.into()).unwrap().into()
369}
370
371fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
372    let input: DeriveInput = syn::parse2(input)?;
373
374    let ident = input.ident;
375
376    syn::custom_keyword!(skip_debug);
377    let skip_debug = input
378        .attrs
379        .into_iter()
380        .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
381
382    let variants = match input.data {
383        Data::Enum(DataEnum { variants, .. }) => variants,
384        Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
385        Data::Union(..) => bail!("Oneof can not be derived for a union"),
386    };
387
388    let generics = &input.generics;
389    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
390
391    // Map the variants into 'fields'.
392    let mut fields: Vec<(Ident, Field)> = Vec::new();
393    for Variant {
394        attrs,
395        ident: variant_ident,
396        fields: variant_fields,
397        ..
398    } in variants
399    {
400        let variant_fields = match variant_fields {
401            Fields::Unit => Punctuated::new(),
402            Fields::Named(FieldsNamed { named: fields, .. })
403            | Fields::Unnamed(FieldsUnnamed {
404                unnamed: fields, ..
405            }) => fields,
406        };
407        if variant_fields.len() != 1 {
408            bail!("Oneof enum variants must have a single field");
409        }
410        match Field::new_oneof(attrs)? {
411            Some(field) => fields.push((variant_ident, field)),
412            None => bail!("invalid oneof variant: oneof variants may not be ignored"),
413        }
414    }
415
416    // Oneof variants cannot be oneofs themselves, so it's impossible to have a field with multiple
417    // tags.
418    assert!(fields.iter().all(|(_, field)| field.tags().len() == 1));
419
420    if let Some(duplicate_tag) = fields
421        .iter()
422        .flat_map(|(_, field)| field.tags())
423        .duplicates()
424        .next()
425    {
426        bail!(
427            "invalid oneof {}: multiple variants have tag {}",
428            ident,
429            duplicate_tag
430        );
431    }
432
433    let encode = fields.iter().map(|(variant_ident, field)| {
434        let encode = field.encode(quote!(*value));
435        quote!(#ident::#variant_ident(ref value) => { #encode })
436    });
437
438    let merge = fields.iter().map(|(variant_ident, field)| {
439        let tag = field.tags()[0];
440        let merge = field.merge(quote!(value));
441        quote! {
442            #tag => if let ::core::option::Option::Some(#ident::#variant_ident(value)) = field {
443                #merge
444            } else {
445                let mut owned_value = ::core::default::Default::default();
446                let value = &mut owned_value;
447                #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
448            }
449        }
450    });
451
452    let encoded_len = fields.iter().map(|(variant_ident, field)| {
453        let encoded_len = field.encoded_len(quote!(*value));
454        quote!(#ident::#variant_ident(ref value) => #encoded_len)
455    });
456
457    let expanded = quote! {
458        impl #impl_generics #ident #ty_generics #where_clause {
459            /// Encodes the message to a buffer.
460            pub fn encode(&self, buf: &mut impl ::prost::bytes::BufMut) {
461                match *self {
462                    #(#encode,)*
463                }
464            }
465
466            /// Decodes an instance of the message from a buffer, and merges it into self.
467            pub fn merge(
468                field: &mut ::core::option::Option<#ident #ty_generics>,
469                tag: u32,
470                wire_type: ::prost::encoding::wire_type::WireType,
471                buf: &mut impl ::prost::bytes::Buf,
472                ctx: ::prost::encoding::DecodeContext,
473            ) -> ::core::result::Result<(), ::prost::DecodeError>
474            {
475                match tag {
476                    #(#merge,)*
477                    _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
478                }
479            }
480
481            /// Returns the encoded length of the message without a length delimiter.
482            #[inline]
483            pub fn encoded_len(&self) -> usize {
484                match *self {
485                    #(#encoded_len,)*
486                }
487            }
488        }
489
490    };
491    let expanded = if skip_debug {
492        expanded
493    } else {
494        let debug = fields.iter().map(|(variant_ident, field)| {
495            let wrapper = field.debug(quote!(*value));
496            quote!(#ident::#variant_ident(ref value) => {
497                let wrapper = #wrapper;
498                f.debug_tuple(stringify!(#variant_ident))
499                    .field(&wrapper)
500                    .finish()
501            })
502        });
503        quote! {
504            #expanded
505
506            impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
507                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
508                    match *self {
509                        #(#debug,)*
510                    }
511                }
512            }
513        }
514    };
515
516    Ok(expanded)
517}
518
519#[proc_macro_derive(Oneof, attributes(prost))]
520pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
521    try_oneof(input.into()).unwrap().into()
522}
523
524#[cfg(test)]
525mod test {
526    use crate::{try_message, try_oneof};
527    use quote::quote;
528
529    #[test]
530    fn test_rejects_colliding_message_fields() {
531        let output = try_message(quote!(
532            struct Invalid {
533                #[prost(bool, tag = "1")]
534                a: bool,
535                #[prost(oneof = "super::Whatever", tags = "4, 5, 1")]
536                b: Option<super::Whatever>,
537            }
538        ));
539        assert_eq!(
540            output
541                .expect_err("did not reject colliding message fields")
542                .to_string(),
543            "message Invalid has multiple fields with tag 1"
544        );
545    }
546
547    #[test]
548    fn test_rejects_colliding_oneof_variants() {
549        let output = try_oneof(quote!(
550            pub enum Invalid {
551                #[prost(bool, tag = "1")]
552                A(bool),
553                #[prost(bool, tag = "3")]
554                B(bool),
555                #[prost(bool, tag = "1")]
556                C(bool),
557            }
558        ));
559        assert_eq!(
560            output
561                .expect_err("did not reject colliding oneof variants")
562                .to_string(),
563            "invalid oneof Invalid: multiple variants have tag 1"
564        );
565    }
566
567    #[test]
568    fn test_rejects_multiple_tags_oneof_variant() {
569        let output = try_oneof(quote!(
570            enum What {
571                #[prost(bool, tag = "1", tag = "2")]
572                A(bool),
573            }
574        ));
575        assert_eq!(
576            output
577                .expect_err("did not reject multiple tags on oneof variant")
578                .to_string(),
579            "duplicate tag attributes: 1 and 2"
580        );
581
582        let output = try_oneof(quote!(
583            enum What {
584                #[prost(bool, tag = "3")]
585                #[prost(tag = "4")]
586                A(bool),
587            }
588        ));
589        assert!(output.is_err());
590        assert_eq!(
591            output
592                .expect_err("did not reject multiple tags on oneof variant")
593                .to_string(),
594            "duplicate tag attributes: 3 and 4"
595        );
596
597        let output = try_oneof(quote!(
598            enum What {
599                #[prost(bool, tags = "5,6")]
600                A(bool),
601            }
602        ));
603        assert!(output.is_err());
604        assert_eq!(
605            output
606                .expect_err("did not reject multiple tags on oneof variant")
607                .to_string(),
608            "unknown attribute(s): #[prost(tags = \"5,6\")]"
609        );
610    }
611}