prost_derive/field/
map.rs

1use anyhow::{bail, Error};
2use proc_macro2::{Span, TokenStream};
3use quote::quote;
4use syn::punctuated::Punctuated;
5use syn::{Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Path, Token};
6
7use crate::field::{scalar, set_option, tag_attr};
8
9#[derive(Clone, Debug)]
10pub enum MapTy {
11    HashMap,
12    BTreeMap,
13}
14
15impl MapTy {
16    fn from_str(s: &str) -> Option<MapTy> {
17        match s {
18            "map" | "hash_map" => Some(MapTy::HashMap),
19            "btree_map" => Some(MapTy::BTreeMap),
20            _ => None,
21        }
22    }
23
24    fn module(&self) -> Ident {
25        match *self {
26            MapTy::HashMap => Ident::new("hash_map", Span::call_site()),
27            MapTy::BTreeMap => Ident::new("btree_map", Span::call_site()),
28        }
29    }
30
31    fn lib(&self) -> TokenStream {
32        match self {
33            MapTy::HashMap => quote! { std },
34            MapTy::BTreeMap => quote! { prost::alloc },
35        }
36    }
37}
38
39fn fake_scalar(ty: scalar::Ty) -> scalar::Field {
40    let kind = scalar::Kind::Plain(scalar::DefaultValue::new(&ty));
41    scalar::Field {
42        ty,
43        kind,
44        tag: 0, // Not used here
45    }
46}
47
48#[derive(Clone)]
49pub struct Field {
50    pub map_ty: MapTy,
51    pub key_ty: scalar::Ty,
52    pub value_ty: ValueTy,
53    pub tag: u32,
54}
55
56impl Field {
57    pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
58        let mut types = None;
59        let mut tag = None;
60
61        for attr in attrs {
62            if let Some(t) = tag_attr(attr)? {
63                set_option(&mut tag, t, "duplicate tag attributes")?;
64            } else if let Some(map_ty) = attr
65                .path()
66                .get_ident()
67                .and_then(|i| MapTy::from_str(&i.to_string()))
68            {
69                let (k, v): (String, String) = match attr {
70                    Meta::NameValue(MetaNameValue {
71                        value:
72                            Expr::Lit(ExprLit {
73                                lit: Lit::Str(lit), ..
74                            }),
75                        ..
76                    }) => {
77                        let items = lit.value();
78                        let mut items = items.split(',').map(ToString::to_string);
79                        let k = items.next().unwrap();
80                        let v = match items.next() {
81                            Some(k) => k,
82                            None => bail!("invalid map attribute: must have key and value types"),
83                        };
84                        if items.next().is_some() {
85                            bail!("invalid map attribute: {attr:?}");
86                        }
87                        (k, v)
88                    }
89                    Meta::List(meta_list) => {
90                        let nested = meta_list
91                            .parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?
92                            .into_iter()
93                            .collect::<Vec<_>>();
94                        if nested.len() != 2 {
95                            bail!("invalid map attribute: must contain key and value types");
96                        }
97                        (nested[0].to_string(), nested[1].to_string())
98                    }
99                    _ => return Ok(None),
100                };
101                set_option(
102                    &mut types,
103                    (map_ty, key_ty_from_str(&k)?, ValueTy::from_str(&v)?),
104                    "duplicate map type attribute",
105                )?;
106            } else {
107                return Ok(None);
108            }
109        }
110
111        Ok(match (types, tag.or(inferred_tag)) {
112            (Some((map_ty, key_ty, value_ty)), Some(tag)) => Some(Field {
113                map_ty,
114                key_ty,
115                value_ty,
116                tag,
117            }),
118            _ => None,
119        })
120    }
121
122    pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> {
123        Field::new(attrs, None)
124    }
125
126    /// Returns a statement which encodes the map field.
127    pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
128        let tag = self.tag;
129        let key_mod = self.key_ty.module();
130        let ke = quote!(#prost_path::encoding::#key_mod::encode);
131        let kl = quote!(#prost_path::encoding::#key_mod::encoded_len);
132        let module = self.map_ty.module();
133        match &self.value_ty {
134            ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
135                let default = quote!(#ty::default() as i32);
136                quote! {
137                    #prost_path::encoding::#module::encode_with_default(
138                        #ke,
139                        #kl,
140                        #prost_path::encoding::int32::encode,
141                        #prost_path::encoding::int32::encoded_len,
142                        &(#default),
143                        #tag,
144                        &#ident,
145                        buf,
146                    );
147                }
148            }
149            ValueTy::Scalar(value_ty) => {
150                let val_mod = value_ty.module();
151                let ve = quote!(#prost_path::encoding::#val_mod::encode);
152                let vl = quote!(#prost_path::encoding::#val_mod::encoded_len);
153                quote! {
154                    #prost_path::encoding::#module::encode(
155                        #ke,
156                        #kl,
157                        #ve,
158                        #vl,
159                        #tag,
160                        &#ident,
161                        buf,
162                    );
163                }
164            }
165            ValueTy::Message => quote! {
166                #prost_path::encoding::#module::encode(
167                    #ke,
168                    #kl,
169                    #prost_path::encoding::message::encode,
170                    #prost_path::encoding::message::encoded_len,
171                    #tag,
172                    &#ident,
173                    buf,
174                );
175            },
176        }
177    }
178
179    /// Returns an expression which evaluates to the result of merging a decoded key value pair
180    /// into the map.
181    pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
182        let key_mod = self.key_ty.module();
183        let km = quote!(#prost_path::encoding::#key_mod::merge);
184        let module = self.map_ty.module();
185        match &self.value_ty {
186            ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
187                let default = quote!(#ty::default() as i32);
188                quote! {
189                    #prost_path::encoding::#module::merge_with_default(
190                        #km,
191                        #prost_path::encoding::int32::merge,
192                        #default,
193                        &mut #ident,
194                        buf,
195                        ctx,
196                    )
197                }
198            }
199            ValueTy::Scalar(value_ty) => {
200                let val_mod = value_ty.module();
201                let vm = quote!(#prost_path::encoding::#val_mod::merge);
202                quote!(#prost_path::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx))
203            }
204            ValueTy::Message => quote! {
205                #prost_path::encoding::#module::merge(
206                    #km,
207                    #prost_path::encoding::message::merge,
208                    &mut #ident,
209                    buf,
210                    ctx,
211                )
212            },
213        }
214    }
215
216    /// Returns an expression which evaluates to the encoded length of the map.
217    pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
218        let tag = self.tag;
219        let key_mod = self.key_ty.module();
220        let kl = quote!(#prost_path::encoding::#key_mod::encoded_len);
221        let module = self.map_ty.module();
222        match &self.value_ty {
223            ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
224                let default = quote!(#ty::default() as i32);
225                quote! {
226                    #prost_path::encoding::#module::encoded_len_with_default(
227                        #kl,
228                        #prost_path::encoding::int32::encoded_len,
229                        &(#default),
230                        #tag,
231                        &#ident,
232                    )
233                }
234            }
235            ValueTy::Scalar(value_ty) => {
236                let val_mod = value_ty.module();
237                let vl = quote!(#prost_path::encoding::#val_mod::encoded_len);
238                quote!(#prost_path::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident))
239            }
240            ValueTy::Message => quote! {
241                #prost_path::encoding::#module::encoded_len(
242                    #kl,
243                    #prost_path::encoding::message::encoded_len,
244                    #tag,
245                    &#ident,
246                )
247            },
248        }
249    }
250
251    pub fn clear(&self, ident: TokenStream) -> TokenStream {
252        quote!(#ident.clear())
253    }
254
255    /// Returns methods to embed in the message.
256    pub fn methods(&self, prost_path: &Path, ident: &TokenStream) -> Option<TokenStream> {
257        if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty {
258            let key_ty = self.key_ty.rust_type(prost_path);
259            let key_ref_ty = self.key_ty.rust_ref_type();
260
261            let get = Ident::new(&format!("get_{ident}"), Span::call_site());
262            let insert = Ident::new(&format!("insert_{ident}"), Span::call_site());
263            let take_ref = if self.key_ty.is_numeric() {
264                quote!(&)
265            } else {
266                quote!()
267            };
268
269            let get_doc = format!(
270                "Returns the enum value for the corresponding key in `{ident}`, \
271                 or `None` if the entry does not exist or it is not a valid enum value."
272            );
273            let insert_doc = format!("Inserts a key value pair into `{ident}`.");
274            Some(quote! {
275                #[doc=#get_doc]
276                pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> {
277                    self.#ident.get(#take_ref key).cloned().and_then(|x| {
278                        let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
279                        result.ok()
280                    })
281                }
282                #[doc=#insert_doc]
283                pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> {
284                    self.#ident.insert(key, value as i32).and_then(|x| {
285                        let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
286                        result.ok()
287                    })
288                }
289            })
290        } else {
291            None
292        }
293    }
294
295    /// Returns a newtype wrapper around the map, implementing nicer Debug
296    ///
297    /// The Debug tries to convert any enumerations met into the variants if possible, instead of
298    /// outputting the raw numbers.
299    pub fn debug(&self, prost_path: &Path, wrapper_name: TokenStream) -> TokenStream {
300        let type_name = match self.map_ty {
301            MapTy::HashMap => Ident::new("HashMap", Span::call_site()),
302            MapTy::BTreeMap => Ident::new("BTreeMap", Span::call_site()),
303        };
304
305        // A fake field for generating the debug wrapper
306        let key_wrapper = fake_scalar(self.key_ty.clone()).debug(prost_path, quote!(KeyWrapper));
307        let key = self.key_ty.rust_type(prost_path);
308        let value_wrapper = self.value_ty.debug(prost_path);
309        let libname = self.map_ty.lib();
310        let fmt = quote! {
311            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
312                #key_wrapper
313                #value_wrapper
314                let mut builder = f.debug_map();
315                for (k, v) in self.0 {
316                    builder.entry(&KeyWrapper(k), &ValueWrapper(v));
317                }
318                builder.finish()
319            }
320        };
321        match &self.value_ty {
322            ValueTy::Scalar(ty) => {
323                if let scalar::Ty::Bytes(_) = *ty {
324                    return quote! {
325                        struct #wrapper_name<'a>(&'a dyn ::core::fmt::Debug);
326                        impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
327                            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
328                                self.0.fmt(f)
329                            }
330                        }
331                    };
332                }
333
334                let value = ty.rust_type(prost_path);
335                quote! {
336                    struct #wrapper_name<'a>(&'a ::#libname::collections::#type_name<#key, #value>);
337                    impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
338                        #fmt
339                    }
340                }
341            }
342            ValueTy::Message => quote! {
343                struct #wrapper_name<'a, V: 'a>(&'a ::#libname::collections::#type_name<#key, V>);
344                impl<'a, V> ::core::fmt::Debug for #wrapper_name<'a, V>
345                where
346                    V: ::core::fmt::Debug + 'a,
347                {
348                    #fmt
349                }
350            },
351        }
352    }
353}
354
355fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
356    let ty = scalar::Ty::from_str(s)?;
357    match ty {
358        scalar::Ty::Int32
359        | scalar::Ty::Int64
360        | scalar::Ty::Uint32
361        | scalar::Ty::Uint64
362        | scalar::Ty::Sint32
363        | scalar::Ty::Sint64
364        | scalar::Ty::Fixed32
365        | scalar::Ty::Fixed64
366        | scalar::Ty::Sfixed32
367        | scalar::Ty::Sfixed64
368        | scalar::Ty::Bool
369        | scalar::Ty::String => Ok(ty),
370        _ => bail!("invalid map key type: {s}"),
371    }
372}
373
374/// A map value type.
375#[derive(Clone, Debug, PartialEq, Eq)]
376pub enum ValueTy {
377    Scalar(scalar::Ty),
378    Message,
379}
380
381impl ValueTy {
382    fn from_str(s: &str) -> Result<ValueTy, Error> {
383        if let Ok(ty) = scalar::Ty::from_str(s) {
384            Ok(ValueTy::Scalar(ty))
385        } else if s.trim() == "message" {
386            Ok(ValueTy::Message)
387        } else {
388            bail!("invalid map value type: {s}");
389        }
390    }
391
392    /// Returns a newtype wrapper around the ValueTy for nicer debug.
393    ///
394    /// If the contained value is enumeration, it tries to convert it to the variant. If not, it
395    /// just forwards the implementation.
396    fn debug(&self, prost_path: &Path) -> TokenStream {
397        match self {
398            ValueTy::Scalar(ty) => fake_scalar(ty.clone()).debug(prost_path, quote!(ValueWrapper)),
399            ValueTy::Message => quote!(
400                fn ValueWrapper<T>(v: T) -> T {
401                    v
402                }
403            ),
404        }
405    }
406}