1use anyhow::{bail, Error};
2use proc_macro2::{Span, TokenStream};
3use quote::quote;
4use syn::punctuated::Punctuated;
5use syn::{Expr, ExprLit, Ident, Lit, Meta, MetaNameValue, Path, Token};
6
7use crate::field::{scalar, set_option, tag_attr};
8
9#[derive(Clone, Debug)]
10pub enum MapTy {
11 HashMap,
12 BTreeMap,
13}
14
15impl MapTy {
16 fn from_str(s: &str) -> Option<MapTy> {
17 match s {
18 "map" | "hash_map" => Some(MapTy::HashMap),
19 "btree_map" => Some(MapTy::BTreeMap),
20 _ => None,
21 }
22 }
23
24 fn module(&self) -> Ident {
25 match *self {
26 MapTy::HashMap => Ident::new("hash_map", Span::call_site()),
27 MapTy::BTreeMap => Ident::new("btree_map", Span::call_site()),
28 }
29 }
30
31 fn lib(&self) -> TokenStream {
32 match self {
33 MapTy::HashMap => quote! { std },
34 MapTy::BTreeMap => quote! { prost::alloc },
35 }
36 }
37}
38
39fn fake_scalar(ty: scalar::Ty) -> scalar::Field {
40 let kind = scalar::Kind::Plain(scalar::DefaultValue::new(&ty));
41 scalar::Field {
42 ty,
43 kind,
44 tag: 0, }
46}
47
48#[derive(Clone)]
49pub struct Field {
50 pub map_ty: MapTy,
51 pub key_ty: scalar::Ty,
52 pub value_ty: ValueTy,
53 pub tag: u32,
54}
55
56impl Field {
57 pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
58 let mut types = None;
59 let mut tag = None;
60
61 for attr in attrs {
62 if let Some(t) = tag_attr(attr)? {
63 set_option(&mut tag, t, "duplicate tag attributes")?;
64 } else if let Some(map_ty) = attr
65 .path()
66 .get_ident()
67 .and_then(|i| MapTy::from_str(&i.to_string()))
68 {
69 let (k, v): (String, String) = match attr {
70 Meta::NameValue(MetaNameValue {
71 value:
72 Expr::Lit(ExprLit {
73 lit: Lit::Str(lit), ..
74 }),
75 ..
76 }) => {
77 let items = lit.value();
78 let mut items = items.split(',').map(ToString::to_string);
79 let k = items.next().unwrap();
80 let v = match items.next() {
81 Some(k) => k,
82 None => bail!("invalid map attribute: must have key and value types"),
83 };
84 if items.next().is_some() {
85 bail!("invalid map attribute: {attr:?}");
86 }
87 (k, v)
88 }
89 Meta::List(meta_list) => {
90 let nested = meta_list
91 .parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)?
92 .into_iter()
93 .collect::<Vec<_>>();
94 if nested.len() != 2 {
95 bail!("invalid map attribute: must contain key and value types");
96 }
97 (nested[0].to_string(), nested[1].to_string())
98 }
99 _ => return Ok(None),
100 };
101 set_option(
102 &mut types,
103 (map_ty, key_ty_from_str(&k)?, ValueTy::from_str(&v)?),
104 "duplicate map type attribute",
105 )?;
106 } else {
107 return Ok(None);
108 }
109 }
110
111 Ok(match (types, tag.or(inferred_tag)) {
112 (Some((map_ty, key_ty, value_ty)), Some(tag)) => Some(Field {
113 map_ty,
114 key_ty,
115 value_ty,
116 tag,
117 }),
118 _ => None,
119 })
120 }
121
122 pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> {
123 Field::new(attrs, None)
124 }
125
126 pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
128 let tag = self.tag;
129 let key_mod = self.key_ty.module();
130 let ke = quote!(#prost_path::encoding::#key_mod::encode);
131 let kl = quote!(#prost_path::encoding::#key_mod::encoded_len);
132 let module = self.map_ty.module();
133 match &self.value_ty {
134 ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
135 let default = quote!(#ty::default() as i32);
136 quote! {
137 #prost_path::encoding::#module::encode_with_default(
138 #ke,
139 #kl,
140 #prost_path::encoding::int32::encode,
141 #prost_path::encoding::int32::encoded_len,
142 &(#default),
143 #tag,
144 &#ident,
145 buf,
146 );
147 }
148 }
149 ValueTy::Scalar(value_ty) => {
150 let val_mod = value_ty.module();
151 let ve = quote!(#prost_path::encoding::#val_mod::encode);
152 let vl = quote!(#prost_path::encoding::#val_mod::encoded_len);
153 quote! {
154 #prost_path::encoding::#module::encode(
155 #ke,
156 #kl,
157 #ve,
158 #vl,
159 #tag,
160 &#ident,
161 buf,
162 );
163 }
164 }
165 ValueTy::Message => quote! {
166 #prost_path::encoding::#module::encode(
167 #ke,
168 #kl,
169 #prost_path::encoding::message::encode,
170 #prost_path::encoding::message::encoded_len,
171 #tag,
172 &#ident,
173 buf,
174 );
175 },
176 }
177 }
178
179 pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
182 let key_mod = self.key_ty.module();
183 let km = quote!(#prost_path::encoding::#key_mod::merge);
184 let module = self.map_ty.module();
185 match &self.value_ty {
186 ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
187 let default = quote!(#ty::default() as i32);
188 quote! {
189 #prost_path::encoding::#module::merge_with_default(
190 #km,
191 #prost_path::encoding::int32::merge,
192 #default,
193 &mut #ident,
194 buf,
195 ctx,
196 )
197 }
198 }
199 ValueTy::Scalar(value_ty) => {
200 let val_mod = value_ty.module();
201 let vm = quote!(#prost_path::encoding::#val_mod::merge);
202 quote!(#prost_path::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx))
203 }
204 ValueTy::Message => quote! {
205 #prost_path::encoding::#module::merge(
206 #km,
207 #prost_path::encoding::message::merge,
208 &mut #ident,
209 buf,
210 ctx,
211 )
212 },
213 }
214 }
215
216 pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
218 let tag = self.tag;
219 let key_mod = self.key_ty.module();
220 let kl = quote!(#prost_path::encoding::#key_mod::encoded_len);
221 let module = self.map_ty.module();
222 match &self.value_ty {
223 ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
224 let default = quote!(#ty::default() as i32);
225 quote! {
226 #prost_path::encoding::#module::encoded_len_with_default(
227 #kl,
228 #prost_path::encoding::int32::encoded_len,
229 &(#default),
230 #tag,
231 &#ident,
232 )
233 }
234 }
235 ValueTy::Scalar(value_ty) => {
236 let val_mod = value_ty.module();
237 let vl = quote!(#prost_path::encoding::#val_mod::encoded_len);
238 quote!(#prost_path::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident))
239 }
240 ValueTy::Message => quote! {
241 #prost_path::encoding::#module::encoded_len(
242 #kl,
243 #prost_path::encoding::message::encoded_len,
244 #tag,
245 &#ident,
246 )
247 },
248 }
249 }
250
251 pub fn clear(&self, ident: TokenStream) -> TokenStream {
252 quote!(#ident.clear())
253 }
254
255 pub fn methods(&self, prost_path: &Path, ident: &TokenStream) -> Option<TokenStream> {
257 if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty {
258 let key_ty = self.key_ty.rust_type(prost_path);
259 let key_ref_ty = self.key_ty.rust_ref_type();
260
261 let get = Ident::new(&format!("get_{ident}"), Span::call_site());
262 let insert = Ident::new(&format!("insert_{ident}"), Span::call_site());
263 let take_ref = if self.key_ty.is_numeric() {
264 quote!(&)
265 } else {
266 quote!()
267 };
268
269 let get_doc = format!(
270 "Returns the enum value for the corresponding key in `{ident}`, \
271 or `None` if the entry does not exist or it is not a valid enum value."
272 );
273 let insert_doc = format!("Inserts a key value pair into `{ident}`.");
274 Some(quote! {
275 #[doc=#get_doc]
276 pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> {
277 self.#ident.get(#take_ref key).cloned().and_then(|x| {
278 let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
279 result.ok()
280 })
281 }
282 #[doc=#insert_doc]
283 pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> {
284 self.#ident.insert(key, value as i32).and_then(|x| {
285 let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
286 result.ok()
287 })
288 }
289 })
290 } else {
291 None
292 }
293 }
294
295 pub fn debug(&self, prost_path: &Path, wrapper_name: TokenStream) -> TokenStream {
300 let type_name = match self.map_ty {
301 MapTy::HashMap => Ident::new("HashMap", Span::call_site()),
302 MapTy::BTreeMap => Ident::new("BTreeMap", Span::call_site()),
303 };
304
305 let key_wrapper = fake_scalar(self.key_ty.clone()).debug(prost_path, quote!(KeyWrapper));
307 let key = self.key_ty.rust_type(prost_path);
308 let value_wrapper = self.value_ty.debug(prost_path);
309 let libname = self.map_ty.lib();
310 let fmt = quote! {
311 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
312 #key_wrapper
313 #value_wrapper
314 let mut builder = f.debug_map();
315 for (k, v) in self.0 {
316 builder.entry(&KeyWrapper(k), &ValueWrapper(v));
317 }
318 builder.finish()
319 }
320 };
321 match &self.value_ty {
322 ValueTy::Scalar(ty) => {
323 if let scalar::Ty::Bytes(_) = *ty {
324 return quote! {
325 struct #wrapper_name<'a>(&'a dyn ::core::fmt::Debug);
326 impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
327 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
328 self.0.fmt(f)
329 }
330 }
331 };
332 }
333
334 let value = ty.rust_type(prost_path);
335 quote! {
336 struct #wrapper_name<'a>(&'a ::#libname::collections::#type_name<#key, #value>);
337 impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
338 #fmt
339 }
340 }
341 }
342 ValueTy::Message => quote! {
343 struct #wrapper_name<'a, V: 'a>(&'a ::#libname::collections::#type_name<#key, V>);
344 impl<'a, V> ::core::fmt::Debug for #wrapper_name<'a, V>
345 where
346 V: ::core::fmt::Debug + 'a,
347 {
348 #fmt
349 }
350 },
351 }
352 }
353}
354
355fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
356 let ty = scalar::Ty::from_str(s)?;
357 match ty {
358 scalar::Ty::Int32
359 | scalar::Ty::Int64
360 | scalar::Ty::Uint32
361 | scalar::Ty::Uint64
362 | scalar::Ty::Sint32
363 | scalar::Ty::Sint64
364 | scalar::Ty::Fixed32
365 | scalar::Ty::Fixed64
366 | scalar::Ty::Sfixed32
367 | scalar::Ty::Sfixed64
368 | scalar::Ty::Bool
369 | scalar::Ty::String => Ok(ty),
370 _ => bail!("invalid map key type: {s}"),
371 }
372}
373
374#[derive(Clone, Debug, PartialEq, Eq)]
376pub enum ValueTy {
377 Scalar(scalar::Ty),
378 Message,
379}
380
381impl ValueTy {
382 fn from_str(s: &str) -> Result<ValueTy, Error> {
383 if let Ok(ty) = scalar::Ty::from_str(s) {
384 Ok(ValueTy::Scalar(ty))
385 } else if s.trim() == "message" {
386 Ok(ValueTy::Message)
387 } else {
388 bail!("invalid map value type: {s}");
389 }
390 }
391
392 fn debug(&self, prost_path: &Path) -> TokenStream {
397 match self {
398 ValueTy::Scalar(ty) => fake_scalar(ty.clone()).debug(prost_path, quote!(ValueWrapper)),
399 ValueTy::Message => quote!(
400 fn ValueWrapper<T>(v: T) -> T {
401 v
402 }
403 ),
404 }
405 }
406}