prost_derive/field/
mod.rs

1mod group;
2mod map;
3mod message;
4mod oneof;
5mod scalar;
6
7use std::fmt;
8use std::slice;
9
10use anyhow::{bail, Error};
11use proc_macro2::TokenStream;
12use quote::quote;
13use syn::punctuated::Punctuated;
14use syn::Path;
15use syn::{Attribute, Expr, ExprLit, Lit, LitBool, LitInt, Meta, MetaNameValue, Token};
16
17#[derive(Clone)]
18pub enum Field {
19    /// A scalar field.
20    Scalar(scalar::Field),
21    /// A message field.
22    Message(message::Field),
23    /// A map field.
24    Map(map::Field),
25    /// A oneof field.
26    Oneof(oneof::Field),
27    /// A group field.
28    Group(group::Field),
29}
30
31impl Field {
32    /// Creates a new `Field` from an iterator of field attributes.
33    ///
34    /// If the meta items are invalid, an error will be returned.
35    /// If the field should be ignored, `None` is returned.
36    pub fn new(attrs: Vec<Attribute>, inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
37        let attrs = prost_attrs(attrs)?;
38
39        // TODO: check for ignore attribute.
40
41        let field = if let Some(field) = scalar::Field::new(&attrs, inferred_tag)? {
42            Field::Scalar(field)
43        } else if let Some(field) = message::Field::new(&attrs, inferred_tag)? {
44            Field::Message(field)
45        } else if let Some(field) = map::Field::new(&attrs, inferred_tag)? {
46            Field::Map(field)
47        } else if let Some(field) = oneof::Field::new(&attrs)? {
48            Field::Oneof(field)
49        } else if let Some(field) = group::Field::new(&attrs, inferred_tag)? {
50            Field::Group(field)
51        } else {
52            bail!("no type attribute");
53        };
54
55        Ok(Some(field))
56    }
57
58    /// Creates a new oneof `Field` from an iterator of field attributes.
59    ///
60    /// If the meta items are invalid, an error will be returned.
61    /// If the field should be ignored, `None` is returned.
62    pub fn new_oneof(attrs: Vec<Attribute>) -> Result<Option<Field>, Error> {
63        let attrs = prost_attrs(attrs)?;
64
65        // TODO: check for ignore attribute.
66
67        let field = if let Some(field) = scalar::Field::new_oneof(&attrs)? {
68            Field::Scalar(field)
69        } else if let Some(field) = message::Field::new_oneof(&attrs)? {
70            Field::Message(field)
71        } else if let Some(field) = map::Field::new_oneof(&attrs)? {
72            Field::Map(field)
73        } else if let Some(field) = group::Field::new_oneof(&attrs)? {
74            Field::Group(field)
75        } else {
76            bail!("no type attribute for oneof field");
77        };
78
79        Ok(Some(field))
80    }
81
82    pub fn tags(&self) -> Vec<u32> {
83        match *self {
84            Field::Scalar(ref scalar) => vec![scalar.tag],
85            Field::Message(ref message) => vec![message.tag],
86            Field::Map(ref map) => vec![map.tag],
87            Field::Oneof(ref oneof) => oneof.tags.clone(),
88            Field::Group(ref group) => vec![group.tag],
89        }
90    }
91
92    /// Returns a statement which encodes the field.
93    pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
94        match *self {
95            Field::Scalar(ref scalar) => scalar.encode(prost_path, ident),
96            Field::Message(ref message) => message.encode(prost_path, ident),
97            Field::Map(ref map) => map.encode(prost_path, ident),
98            Field::Oneof(ref oneof) => oneof.encode(ident),
99            Field::Group(ref group) => group.encode(prost_path, ident),
100        }
101    }
102
103    /// Returns an expression which evaluates to the result of merging a decoded
104    /// value into the field.
105    pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
106        match *self {
107            Field::Scalar(ref scalar) => scalar.merge(prost_path, ident),
108            Field::Message(ref message) => message.merge(prost_path, ident),
109            Field::Map(ref map) => map.merge(prost_path, ident),
110            Field::Oneof(ref oneof) => oneof.merge(ident),
111            Field::Group(ref group) => group.merge(prost_path, ident),
112        }
113    }
114
115    /// Returns an expression which evaluates to the encoded length of the field.
116    pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
117        match *self {
118            Field::Scalar(ref scalar) => scalar.encoded_len(prost_path, ident),
119            Field::Map(ref map) => map.encoded_len(prost_path, ident),
120            Field::Message(ref msg) => msg.encoded_len(prost_path, ident),
121            Field::Oneof(ref oneof) => oneof.encoded_len(ident),
122            Field::Group(ref group) => group.encoded_len(prost_path, ident),
123        }
124    }
125
126    /// Returns a statement which clears the field.
127    pub fn clear(&self, ident: TokenStream) -> TokenStream {
128        match *self {
129            Field::Scalar(ref scalar) => scalar.clear(ident),
130            Field::Message(ref message) => message.clear(ident),
131            Field::Map(ref map) => map.clear(ident),
132            Field::Oneof(ref oneof) => oneof.clear(ident),
133            Field::Group(ref group) => group.clear(ident),
134        }
135    }
136
137    pub fn default(&self, prost_path: &Path) -> TokenStream {
138        match *self {
139            Field::Scalar(ref scalar) => scalar.default(prost_path),
140            _ => quote!(::core::default::Default::default()),
141        }
142    }
143
144    /// Produces the fragment implementing debug for the given field.
145    pub fn debug(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
146        match *self {
147            Field::Scalar(ref scalar) => {
148                let wrapper = scalar.debug(prost_path, quote!(ScalarWrapper));
149                quote! {
150                    {
151                        #wrapper
152                        ScalarWrapper(&#ident)
153                    }
154                }
155            }
156            Field::Map(ref map) => {
157                let wrapper = map.debug(prost_path, quote!(MapWrapper));
158                quote! {
159                    {
160                        #wrapper
161                        MapWrapper(&#ident)
162                    }
163                }
164            }
165            _ => quote!(&#ident),
166        }
167    }
168
169    pub fn methods(&self, prost_path: &Path, ident: &TokenStream) -> Option<TokenStream> {
170        match *self {
171            Field::Scalar(ref scalar) => scalar.methods(ident),
172            Field::Map(ref map) => map.methods(prost_path, ident),
173            _ => None,
174        }
175    }
176}
177
178#[derive(Clone, Copy, PartialEq, Eq)]
179pub enum Label {
180    /// An optional field.
181    Optional,
182    /// A required field.
183    Required,
184    /// A repeated field.
185    Repeated,
186}
187
188impl Label {
189    fn as_str(self) -> &'static str {
190        match self {
191            Label::Optional => "optional",
192            Label::Required => "required",
193            Label::Repeated => "repeated",
194        }
195    }
196
197    fn variants() -> slice::Iter<'static, Label> {
198        const VARIANTS: &[Label] = &[Label::Optional, Label::Required, Label::Repeated];
199        VARIANTS.iter()
200    }
201
202    /// Parses a string into a field label.
203    /// If the string doesn't match a field label, `None` is returned.
204    fn from_attr(attr: &Meta) -> Option<Label> {
205        if let Meta::Path(ref path) = *attr {
206            for &label in Label::variants() {
207                if path.is_ident(label.as_str()) {
208                    return Some(label);
209                }
210            }
211        }
212        None
213    }
214}
215
216impl fmt::Debug for Label {
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        f.write_str(self.as_str())
219    }
220}
221
222impl fmt::Display for Label {
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        f.write_str(self.as_str())
225    }
226}
227
228/// Get the items belonging to the 'prost' list attribute, e.g. `#[prost(foo, bar="baz")]`.
229fn prost_attrs(attrs: Vec<Attribute>) -> Result<Vec<Meta>, Error> {
230    let mut result = Vec::new();
231    for attr in attrs.iter() {
232        if let Meta::List(meta_list) = &attr.meta {
233            if meta_list.path.is_ident("prost") {
234                result.extend(
235                    meta_list
236                        .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?
237                        .into_iter(),
238                )
239            }
240        }
241    }
242    Ok(result)
243}
244
245pub fn set_option<T>(option: &mut Option<T>, value: T, message: &str) -> Result<(), Error>
246where
247    T: fmt::Debug,
248{
249    if let Some(ref existing) = *option {
250        bail!("{message}: {existing:?} and {value:?}");
251    }
252    *option = Some(value);
253    Ok(())
254}
255
256pub fn set_bool(b: &mut bool, message: &str) -> Result<(), Error> {
257    if *b {
258        bail!("{message}");
259    } else {
260        *b = true;
261        Ok(())
262    }
263}
264
265/// Unpacks an attribute into a (key, boolean) pair, returning the boolean value.
266/// If the key doesn't match the attribute, `None` is returned.
267fn bool_attr(key: &str, attr: &Meta) -> Result<Option<bool>, Error> {
268    if !attr.path().is_ident(key) {
269        return Ok(None);
270    }
271    match *attr {
272        Meta::Path(..) => Ok(Some(true)),
273        Meta::List(ref meta_list) => Ok(Some(meta_list.parse_args::<LitBool>()?.value())),
274        Meta::NameValue(MetaNameValue {
275            value:
276                Expr::Lit(ExprLit {
277                    lit: Lit::Str(ref lit),
278                    ..
279                }),
280            ..
281        }) => lit
282            .value()
283            .parse::<bool>()
284            .map_err(Error::from)
285            .map(Option::Some),
286        Meta::NameValue(MetaNameValue {
287            value:
288                Expr::Lit(ExprLit {
289                    lit: Lit::Bool(LitBool { value, .. }),
290                    ..
291                }),
292            ..
293        }) => Ok(Some(value)),
294        _ => bail!("invalid {key} attribute"),
295    }
296}
297
298/// Checks if an attribute matches a word.
299fn word_attr(key: &str, attr: &Meta) -> bool {
300    if let Meta::Path(ref path) = *attr {
301        path.is_ident(key)
302    } else {
303        false
304    }
305}
306
307pub(super) fn tag_attr(attr: &Meta) -> Result<Option<u32>, Error> {
308    if !attr.path().is_ident("tag") {
309        return Ok(None);
310    }
311    match *attr {
312        Meta::List(ref meta_list) => Ok(Some(meta_list.parse_args::<LitInt>()?.base10_parse()?)),
313        Meta::NameValue(MetaNameValue {
314            value: Expr::Lit(ref expr),
315            ..
316        }) => match expr.lit {
317            Lit::Str(ref lit) => lit
318                .value()
319                .parse::<u32>()
320                .map_err(Error::from)
321                .map(Option::Some),
322            Lit::Int(ref lit) => Ok(Some(lit.base10_parse()?)),
323            _ => bail!("invalid tag attribute: {attr:?}"),
324        },
325        _ => bail!("invalid tag attribute: {attr:?}"),
326    }
327}
328
329fn tags_attr(attr: &Meta) -> Result<Option<Vec<u32>>, Error> {
330    if !attr.path().is_ident("tags") {
331        return Ok(None);
332    }
333    match *attr {
334        Meta::List(ref meta_list) => Ok(Some(
335            meta_list
336                .parse_args_with(Punctuated::<LitInt, Token![,]>::parse_terminated)?
337                .iter()
338                .map(LitInt::base10_parse)
339                .collect::<Result<Vec<_>, _>>()?,
340        )),
341        Meta::NameValue(MetaNameValue {
342            value:
343                Expr::Lit(ExprLit {
344                    lit: Lit::Str(ref lit),
345                    ..
346                }),
347            ..
348        }) => lit
349            .value()
350            .split(',')
351            .map(|s| s.trim().parse::<u32>().map_err(Error::from))
352            .collect::<Result<Vec<u32>, _>>()
353            .map(Some),
354        _ => bail!("invalid tag attribute: {attr:?}"),
355    }
356}