1#![doc(html_root_url = "https://docs.rs/prost-derive/0.14.1")]
2#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro2::{Span, TokenStream};
11use quote::quote;
12use syn::{
13 punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
14 FieldsUnnamed, Ident, Index, Variant,
15};
16
17mod field;
18use crate::field::Field;
19
20fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
21 let input: DeriveInput = syn::parse2(input)?;
22
23 let ident = input.ident;
24
25 syn::custom_keyword!(skip_debug);
26 let skip_debug = input
27 .attrs
28 .into_iter()
29 .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
30
31 let variant_data = match input.data {
32 Data::Struct(variant_data) => variant_data,
33 Data::Enum(..) => bail!("Message can not be derived for an enum"),
34 Data::Union(..) => bail!("Message can not be derived for a union"),
35 };
36
37 let generics = &input.generics;
38 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
39
40 let (is_struct, fields) = match variant_data {
41 DataStruct {
42 fields: Fields::Named(FieldsNamed { named: fields, .. }),
43 ..
44 } => (true, fields.into_iter().collect()),
45 DataStruct {
46 fields:
47 Fields::Unnamed(FieldsUnnamed {
48 unnamed: fields, ..
49 }),
50 ..
51 } => (false, fields.into_iter().collect()),
52 DataStruct {
53 fields: Fields::Unit,
54 ..
55 } => (false, Vec::new()),
56 };
57
58 let mut next_tag: u32 = 1;
59 let mut fields = fields
60 .into_iter()
61 .enumerate()
62 .flat_map(|(i, field)| {
63 let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
64 let index = Index {
65 index: i as u32,
66 span: Span::call_site(),
67 };
68 quote!(#index)
69 });
70 match Field::new(field.attrs, Some(next_tag)) {
71 Ok(Some(field)) => {
72 next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
73 Some(Ok((field_ident, field)))
74 }
75 Ok(None) => None,
76 Err(err) => Some(Err(
77 err.context(format!("invalid message field {}.{}", ident, field_ident))
78 )),
79 }
80 })
81 .collect::<Result<Vec<_>, _>>()?;
82
83 let unsorted_fields = fields.clone();
85
86 fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
91 let fields = fields;
92
93 if let Some(duplicate_tag) = fields
94 .iter()
95 .flat_map(|(_, field)| field.tags())
96 .duplicates()
97 .next()
98 {
99 bail!(
100 "message {} has multiple fields with tag {}",
101 ident,
102 duplicate_tag
103 )
104 };
105
106 let encoded_len = fields
107 .iter()
108 .map(|(field_ident, field)| field.encoded_len(quote!(self.#field_ident)));
109
110 let encode = fields
111 .iter()
112 .map(|(field_ident, field)| field.encode(quote!(self.#field_ident)));
113
114 let merge = fields.iter().map(|(field_ident, field)| {
115 let merge = field.merge(quote!(value));
116 let tags = field.tags().into_iter().map(|tag| quote!(#tag));
117 let tags = Itertools::intersperse(tags, quote!(|));
118
119 quote! {
120 #(#tags)* => {
121 let mut value = &mut self.#field_ident;
122 #merge.map_err(|mut error| {
123 error.push(STRUCT_NAME, stringify!(#field_ident));
124 error
125 })
126 },
127 }
128 });
129
130 let struct_name = if fields.is_empty() {
131 quote!()
132 } else {
133 quote!(
134 const STRUCT_NAME: &'static str = stringify!(#ident);
135 )
136 };
137
138 let clear = fields
139 .iter()
140 .map(|(field_ident, field)| field.clear(quote!(self.#field_ident)));
141
142 let default = if is_struct {
143 let default = fields.iter().map(|(field_ident, field)| {
144 let value = field.default();
145 quote!(#field_ident: #value,)
146 });
147 quote! {#ident {
148 #(#default)*
149 }}
150 } else {
151 let default = fields.iter().map(|(_, field)| {
152 let value = field.default();
153 quote!(#value,)
154 });
155 quote! {#ident (
156 #(#default)*
157 )}
158 };
159
160 let methods = fields
161 .iter()
162 .flat_map(|(field_ident, field)| field.methods(field_ident))
163 .collect::<Vec<_>>();
164 let methods = if methods.is_empty() {
165 quote!()
166 } else {
167 quote! {
168 #[allow(dead_code)]
169 impl #impl_generics #ident #ty_generics #where_clause {
170 #(#methods)*
171 }
172 }
173 };
174
175 let expanded = quote! {
176 impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
177 #[allow(unused_variables)]
178 fn encode_raw(&self, buf: &mut impl ::prost::bytes::BufMut) {
179 #(#encode)*
180 }
181
182 #[allow(unused_variables)]
183 fn merge_field(
184 &mut self,
185 tag: u32,
186 wire_type: ::prost::encoding::wire_type::WireType,
187 buf: &mut impl ::prost::bytes::Buf,
188 ctx: ::prost::encoding::DecodeContext,
189 ) -> ::core::result::Result<(), ::prost::DecodeError>
190 {
191 #struct_name
192 match tag {
193 #(#merge)*
194 _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
195 }
196 }
197
198 #[inline]
199 fn encoded_len(&self) -> usize {
200 0 #(+ #encoded_len)*
201 }
202
203 fn clear(&mut self) {
204 #(#clear;)*
205 }
206 }
207
208 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
209 fn default() -> Self {
210 #default
211 }
212 }
213 };
214 let expanded = if skip_debug {
215 expanded
216 } else {
217 let debugs = unsorted_fields.iter().map(|(field_ident, field)| {
218 let wrapper = field.debug(quote!(self.#field_ident));
219 let call = if is_struct {
220 quote!(builder.field(stringify!(#field_ident), &wrapper))
221 } else {
222 quote!(builder.field(&wrapper))
223 };
224 quote! {
225 let builder = {
226 let wrapper = #wrapper;
227 #call
228 };
229 }
230 });
231 let debug_builder = if is_struct {
232 quote!(f.debug_struct(stringify!(#ident)))
233 } else {
234 quote!(f.debug_tuple(stringify!(#ident)))
235 };
236 quote! {
237 #expanded
238
239 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
240 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
241 let mut builder = #debug_builder;
242 #(#debugs;)*
243 builder.finish()
244 }
245 }
246 }
247 };
248
249 let expanded = quote! {
250 #expanded
251
252 #methods
253 };
254
255 Ok(expanded)
256}
257
258#[proc_macro_derive(Message, attributes(prost))]
259pub fn message(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
260 try_message(input.into()).unwrap().into()
261}
262
263fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
264 let input: DeriveInput = syn::parse2(input)?;
265 let ident = input.ident;
266
267 let generics = &input.generics;
268 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
269
270 let punctuated_variants = match input.data {
271 Data::Enum(DataEnum { variants, .. }) => variants,
272 Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
273 Data::Union(..) => bail!("Enumeration can not be derived for a union"),
274 };
275
276 let mut variants: Vec<(Ident, Expr)> = Vec::new();
278 for Variant {
279 ident,
280 fields,
281 discriminant,
282 ..
283 } in punctuated_variants
284 {
285 match fields {
286 Fields::Unit => (),
287 Fields::Named(_) | Fields::Unnamed(_) => {
288 bail!("Enumeration variants may not have fields")
289 }
290 }
291
292 match discriminant {
293 Some((_, expr)) => variants.push((ident, expr)),
294 None => bail!("Enumeration variants must have a discriminant"),
295 }
296 }
297
298 if variants.is_empty() {
299 panic!("Enumeration must have at least one variant");
300 }
301
302 let default = variants[0].0.clone();
303
304 let is_valid = variants.iter().map(|(_, value)| quote!(#value => true));
305 let from = variants
306 .iter()
307 .map(|(variant, value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)));
308
309 let try_from = variants
310 .iter()
311 .map(|(variant, value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant)));
312
313 let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
314 let from_i32_doc = format!(
315 "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
316 ident
317 );
318
319 let expanded = quote! {
320 impl #impl_generics #ident #ty_generics #where_clause {
321 #[doc=#is_valid_doc]
322 pub fn is_valid(value: i32) -> bool {
323 match value {
324 #(#is_valid,)*
325 _ => false,
326 }
327 }
328
329 #[deprecated = "Use the TryFrom<i32> implementation instead"]
330 #[doc=#from_i32_doc]
331 pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
332 match value {
333 #(#from,)*
334 _ => ::core::option::Option::None,
335 }
336 }
337 }
338
339 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
340 fn default() -> #ident {
341 #ident::#default
342 }
343 }
344
345 impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
346 fn from(value: #ident) -> i32 {
347 value as i32
348 }
349 }
350
351 impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause {
352 type Error = ::prost::UnknownEnumValue;
353
354 fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::UnknownEnumValue> {
355 match value {
356 #(#try_from,)*
357 _ => ::core::result::Result::Err(::prost::UnknownEnumValue(value)),
358 }
359 }
360 }
361 };
362
363 Ok(expanded)
364}
365
366#[proc_macro_derive(Enumeration, attributes(prost))]
367pub fn enumeration(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
368 try_enumeration(input.into()).unwrap().into()
369}
370
371fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
372 let input: DeriveInput = syn::parse2(input)?;
373
374 let ident = input.ident;
375
376 syn::custom_keyword!(skip_debug);
377 let skip_debug = input
378 .attrs
379 .into_iter()
380 .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
381
382 let variants = match input.data {
383 Data::Enum(DataEnum { variants, .. }) => variants,
384 Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
385 Data::Union(..) => bail!("Oneof can not be derived for a union"),
386 };
387
388 let generics = &input.generics;
389 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
390
391 let mut fields: Vec<(Ident, Field)> = Vec::new();
393 for Variant {
394 attrs,
395 ident: variant_ident,
396 fields: variant_fields,
397 ..
398 } in variants
399 {
400 let variant_fields = match variant_fields {
401 Fields::Unit => Punctuated::new(),
402 Fields::Named(FieldsNamed { named: fields, .. })
403 | Fields::Unnamed(FieldsUnnamed {
404 unnamed: fields, ..
405 }) => fields,
406 };
407 if variant_fields.len() != 1 {
408 bail!("Oneof enum variants must have a single field");
409 }
410 match Field::new_oneof(attrs)? {
411 Some(field) => fields.push((variant_ident, field)),
412 None => bail!("invalid oneof variant: oneof variants may not be ignored"),
413 }
414 }
415
416 assert!(fields.iter().all(|(_, field)| field.tags().len() == 1));
419
420 if let Some(duplicate_tag) = fields
421 .iter()
422 .flat_map(|(_, field)| field.tags())
423 .duplicates()
424 .next()
425 {
426 bail!(
427 "invalid oneof {}: multiple variants have tag {}",
428 ident,
429 duplicate_tag
430 );
431 }
432
433 let encode = fields.iter().map(|(variant_ident, field)| {
434 let encode = field.encode(quote!(*value));
435 quote!(#ident::#variant_ident(ref value) => { #encode })
436 });
437
438 let merge = fields.iter().map(|(variant_ident, field)| {
439 let tag = field.tags()[0];
440 let merge = field.merge(quote!(value));
441 quote! {
442 #tag => if let ::core::option::Option::Some(#ident::#variant_ident(value)) = field {
443 #merge
444 } else {
445 let mut owned_value = ::core::default::Default::default();
446 let value = &mut owned_value;
447 #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
448 }
449 }
450 });
451
452 let encoded_len = fields.iter().map(|(variant_ident, field)| {
453 let encoded_len = field.encoded_len(quote!(*value));
454 quote!(#ident::#variant_ident(ref value) => #encoded_len)
455 });
456
457 let expanded = quote! {
458 impl #impl_generics #ident #ty_generics #where_clause {
459 pub fn encode(&self, buf: &mut impl ::prost::bytes::BufMut) {
461 match *self {
462 #(#encode,)*
463 }
464 }
465
466 pub fn merge(
468 field: &mut ::core::option::Option<#ident #ty_generics>,
469 tag: u32,
470 wire_type: ::prost::encoding::wire_type::WireType,
471 buf: &mut impl ::prost::bytes::Buf,
472 ctx: ::prost::encoding::DecodeContext,
473 ) -> ::core::result::Result<(), ::prost::DecodeError>
474 {
475 match tag {
476 #(#merge,)*
477 _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
478 }
479 }
480
481 #[inline]
483 pub fn encoded_len(&self) -> usize {
484 match *self {
485 #(#encoded_len,)*
486 }
487 }
488 }
489
490 };
491 let expanded = if skip_debug {
492 expanded
493 } else {
494 let debug = fields.iter().map(|(variant_ident, field)| {
495 let wrapper = field.debug(quote!(*value));
496 quote!(#ident::#variant_ident(ref value) => {
497 let wrapper = #wrapper;
498 f.debug_tuple(stringify!(#variant_ident))
499 .field(&wrapper)
500 .finish()
501 })
502 });
503 quote! {
504 #expanded
505
506 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
507 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
508 match *self {
509 #(#debug,)*
510 }
511 }
512 }
513 }
514 };
515
516 Ok(expanded)
517}
518
519#[proc_macro_derive(Oneof, attributes(prost))]
520pub fn oneof(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
521 try_oneof(input.into()).unwrap().into()
522}
523
524#[cfg(test)]
525mod test {
526 use crate::{try_message, try_oneof};
527 use quote::quote;
528
529 #[test]
530 fn test_rejects_colliding_message_fields() {
531 let output = try_message(quote!(
532 struct Invalid {
533 #[prost(bool, tag = "1")]
534 a: bool,
535 #[prost(oneof = "super::Whatever", tags = "4, 5, 1")]
536 b: Option<super::Whatever>,
537 }
538 ));
539 assert_eq!(
540 output
541 .expect_err("did not reject colliding message fields")
542 .to_string(),
543 "message Invalid has multiple fields with tag 1"
544 );
545 }
546
547 #[test]
548 fn test_rejects_colliding_oneof_variants() {
549 let output = try_oneof(quote!(
550 pub enum Invalid {
551 #[prost(bool, tag = "1")]
552 A(bool),
553 #[prost(bool, tag = "3")]
554 B(bool),
555 #[prost(bool, tag = "1")]
556 C(bool),
557 }
558 ));
559 assert_eq!(
560 output
561 .expect_err("did not reject colliding oneof variants")
562 .to_string(),
563 "invalid oneof Invalid: multiple variants have tag 1"
564 );
565 }
566
567 #[test]
568 fn test_rejects_multiple_tags_oneof_variant() {
569 let output = try_oneof(quote!(
570 enum What {
571 #[prost(bool, tag = "1", tag = "2")]
572 A(bool),
573 }
574 ));
575 assert_eq!(
576 output
577 .expect_err("did not reject multiple tags on oneof variant")
578 .to_string(),
579 "duplicate tag attributes: 1 and 2"
580 );
581
582 let output = try_oneof(quote!(
583 enum What {
584 #[prost(bool, tag = "3")]
585 #[prost(tag = "4")]
586 A(bool),
587 }
588 ));
589 assert!(output.is_err());
590 assert_eq!(
591 output
592 .expect_err("did not reject multiple tags on oneof variant")
593 .to_string(),
594 "duplicate tag attributes: 3 and 4"
595 );
596
597 let output = try_oneof(quote!(
598 enum What {
599 #[prost(bool, tags = "5,6")]
600 A(bool),
601 }
602 ));
603 assert!(output.is_err());
604 assert_eq!(
605 output
606 .expect_err("did not reject multiple tags on oneof variant")
607 .to_string(),
608 "unknown attribute(s): #[prost(tags = \"5,6\")]"
609 );
610 }
611}