prost/
encoding.rs

1//! Utility functions and types for encoding and decoding Protobuf types.
2//!
3//! This module contains the encoding and decoding primitives for Protobuf as described in
4//! <https://protobuf.dev/programming-guides/encoding/>.
5//!
6//! This module is `pub`, but is only for prost internal use. The `prost-derive` crate needs access for its `Message` implementations.
7
8use alloc::collections::BTreeMap;
9use alloc::string::String;
10use alloc::vec::Vec;
11use core::mem;
12use core::str;
13
14use ::bytes::{Buf, BufMut, Bytes};
15
16use crate::error::DecodeErrorKind;
17use crate::DecodeError;
18use crate::Message;
19
20pub mod varint;
21pub use varint::{decode_varint, encode_varint, encoded_len_varint};
22
23pub mod length_delimiter;
24pub use length_delimiter::{
25    decode_length_delimiter, encode_length_delimiter, length_delimiter_len,
26};
27
28pub mod wire_type;
29pub use wire_type::{check_wire_type, WireType};
30
31/// Additional information passed to every decode/merge function.
32///
33/// The context should be passed by value and can be freely cloned. When passing
34/// to a function which is decoding a nested object, then use `enter_recursion`.
35#[derive(Clone, Debug)]
36#[cfg_attr(feature = "no-recursion-limit", derive(Default))]
37pub struct DecodeContext {
38    /// How many times we can recurse in the current decode stack before we hit
39    /// the recursion limit.
40    ///
41    /// The recursion limit is defined by `RECURSION_LIMIT` and cannot be
42    /// customized. The recursion limit can be ignored by building the Prost
43    /// crate with the `no-recursion-limit` feature.
44    #[cfg(not(feature = "no-recursion-limit"))]
45    recurse_count: u32,
46}
47
48#[cfg(not(feature = "no-recursion-limit"))]
49impl Default for DecodeContext {
50    #[inline]
51    fn default() -> DecodeContext {
52        DecodeContext {
53            recurse_count: crate::RECURSION_LIMIT,
54        }
55    }
56}
57
58impl DecodeContext {
59    /// Call this function before recursively decoding.
60    ///
61    /// There is no `exit` function since this function creates a new `DecodeContext`
62    /// to be used at the next level of recursion. Continue to use the old context
63    // at the previous level of recursion.
64    #[cfg(not(feature = "no-recursion-limit"))]
65    #[inline]
66    pub(crate) fn enter_recursion(&self) -> DecodeContext {
67        DecodeContext {
68            recurse_count: self.recurse_count - 1,
69        }
70    }
71
72    #[cfg(feature = "no-recursion-limit")]
73    #[inline]
74    pub(crate) fn enter_recursion(&self) -> DecodeContext {
75        DecodeContext {}
76    }
77
78    /// Checks whether the recursion limit has been reached in the stack of
79    /// decodes described by the `DecodeContext` at `self.ctx`.
80    ///
81    /// Returns `Ok<()>` if it is ok to continue recursing.
82    /// Returns `Err<DecodeError>` if the recursion limit has been reached.
83    #[cfg(not(feature = "no-recursion-limit"))]
84    #[inline]
85    pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
86        if self.recurse_count == 0 {
87            Err(DecodeErrorKind::RecursionLimitReached.into())
88        } else {
89            Ok(())
90        }
91    }
92
93    #[cfg(feature = "no-recursion-limit")]
94    #[inline]
95    pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
96        Ok(())
97    }
98}
99
100pub const MIN_TAG: u32 = 1;
101pub const MAX_TAG: u32 = (1 << 29) - 1;
102
103/// Encodes a Protobuf field key, which consists of a wire type designator and
104/// the field tag.
105#[inline]
106pub fn encode_key(tag: u32, wire_type: WireType, buf: &mut impl BufMut) {
107    debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
108    let key = (tag << 3) | wire_type as u32;
109    encode_varint(u64::from(key), buf);
110}
111
112/// Decodes a Protobuf field key, which consists of a wire type designator and
113/// the field tag.
114#[inline(always)]
115pub fn decode_key(buf: &mut impl Buf) -> Result<(u32, WireType), DecodeError> {
116    let key = decode_varint(buf)?;
117    if key > u64::from(u32::MAX) {
118        return Err(DecodeErrorKind::InvalidKey { key }.into());
119    }
120    let wire_type = WireType::try_from(key & 0x07)?;
121    let tag = key as u32 >> 3;
122
123    if tag < MIN_TAG {
124        return Err(DecodeErrorKind::InvalidTag.into());
125    }
126
127    Ok((tag, wire_type))
128}
129
130/// Returns the width of an encoded Protobuf field key with the given tag.
131/// The returned width will be between 1 and 5 bytes (inclusive).
132#[inline]
133pub const fn key_len(tag: u32) -> usize {
134    encoded_len_varint((tag << 3) as u64)
135}
136
137/// Helper function which abstracts reading a length delimiter prefix followed
138/// by decoding values until the length of bytes is exhausted.
139pub fn merge_loop<T, M, B>(
140    value: &mut T,
141    buf: &mut B,
142    ctx: DecodeContext,
143    mut merge: M,
144) -> Result<(), DecodeError>
145where
146    M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>,
147    B: Buf,
148{
149    let len = decode_varint(buf)?;
150    let remaining = buf.remaining();
151    if len > remaining as u64 {
152        return Err(DecodeErrorKind::BufferUnderflow.into());
153    }
154
155    let limit = remaining - len as usize;
156    while buf.remaining() > limit {
157        merge(value, buf, ctx.clone())?;
158    }
159
160    if buf.remaining() != limit {
161        return Err(DecodeErrorKind::DelimitedLengthExceeded.into());
162    }
163    Ok(())
164}
165
166pub fn skip_field(
167    wire_type: WireType,
168    tag: u32,
169    buf: &mut impl Buf,
170    ctx: DecodeContext,
171) -> Result<(), DecodeError> {
172    ctx.limit_reached()?;
173    let len = match wire_type {
174        WireType::Varint => decode_varint(buf).map(|_| 0)?,
175        WireType::ThirtyTwoBit => 4,
176        WireType::SixtyFourBit => 8,
177        WireType::LengthDelimited => decode_varint(buf)?,
178        WireType::StartGroup => loop {
179            let (inner_tag, inner_wire_type) = decode_key(buf)?;
180            match inner_wire_type {
181                WireType::EndGroup => {
182                    if inner_tag != tag {
183                        return Err(DecodeErrorKind::UnexpectedEndGroupTag.into());
184                    }
185                    break 0;
186                }
187                _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
188            }
189        },
190        WireType::EndGroup => return Err(DecodeErrorKind::UnexpectedEndGroupTag.into()),
191    };
192
193    if len > buf.remaining() as u64 {
194        return Err(DecodeErrorKind::BufferUnderflow.into());
195    }
196
197    buf.advance(len as usize);
198    Ok(())
199}
200
201/// Helper macro which emits an `encode_repeated` function for the type.
202macro_rules! encode_repeated {
203    ($ty:ty) => {
204        pub fn encode_repeated(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
205            for value in values {
206                encode(tag, value, buf);
207            }
208        }
209    };
210}
211
212/// Helper macro which emits a `merge_repeated` function for the numeric type.
213macro_rules! merge_repeated_numeric {
214    ($ty:ty,
215     $wire_type:expr,
216     $merge:ident,
217     $merge_repeated:ident) => {
218        pub fn $merge_repeated(
219            wire_type: WireType,
220            values: &mut Vec<$ty>,
221            buf: &mut impl Buf,
222            ctx: DecodeContext,
223        ) -> Result<(), DecodeError> {
224            if wire_type == WireType::LengthDelimited {
225                // Packed.
226                merge_loop(values, buf, ctx, |values, buf, ctx| {
227                    let mut value = Default::default();
228                    $merge($wire_type, &mut value, buf, ctx)?;
229                    values.push(value);
230                    Ok(())
231                })
232            } else {
233                // Unpacked.
234                check_wire_type($wire_type, wire_type)?;
235                let mut value = Default::default();
236                $merge(wire_type, &mut value, buf, ctx)?;
237                values.push(value);
238                Ok(())
239            }
240        }
241    };
242}
243
244/// Macro which emits a module containing a set of encoding functions for a
245/// variable width numeric type.
246macro_rules! varint {
247    ($ty:ty,
248     $proto_ty:ident) => (
249        varint!($ty,
250                $proto_ty,
251                to_uint64(value) { *value as u64 },
252                from_uint64(value) { value as $ty });
253    );
254
255    ($ty:ty,
256     $proto_ty:ident,
257     to_uint64($to_uint64_value:ident) $to_uint64:expr,
258     from_uint64($from_uint64_value:ident) $from_uint64:expr) => (
259
260         pub mod $proto_ty {
261            use crate::encoding::*;
262
263            pub fn encode(tag: u32, $to_uint64_value: &$ty, buf: &mut impl BufMut) {
264                encode_key(tag, WireType::Varint, buf);
265                encode_varint($to_uint64, buf);
266            }
267
268            pub fn merge(wire_type: WireType, value: &mut $ty, buf: &mut impl Buf, _ctx: DecodeContext) -> Result<(), DecodeError> {
269                check_wire_type(WireType::Varint, wire_type)?;
270                let $from_uint64_value = decode_varint(buf)?;
271                *value = $from_uint64;
272                Ok(())
273            }
274
275            encode_repeated!($ty);
276
277            pub fn encode_packed(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
278                if values.is_empty() { return; }
279
280                encode_key(tag, WireType::LengthDelimited, buf);
281                let len: usize = values.iter().map(|$to_uint64_value| {
282                    encoded_len_varint($to_uint64)
283                }).sum();
284                encode_varint(len as u64, buf);
285
286                for $to_uint64_value in values {
287                    encode_varint($to_uint64, buf);
288                }
289            }
290
291            merge_repeated_numeric!($ty, WireType::Varint, merge, merge_repeated);
292
293            #[inline]
294            pub fn encoded_len(tag: u32, $to_uint64_value: &$ty) -> usize {
295                key_len(tag) + encoded_len_varint($to_uint64)
296            }
297
298            #[inline]
299            pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
300                key_len(tag) * values.len() + values.iter().map(|$to_uint64_value| {
301                    encoded_len_varint($to_uint64)
302                }).sum::<usize>()
303            }
304
305            #[inline]
306            pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
307                if values.is_empty() {
308                    0
309                } else {
310                    let len = values.iter()
311                                    .map(|$to_uint64_value| encoded_len_varint($to_uint64))
312                                    .sum::<usize>();
313                    key_len(tag) + encoded_len_varint(len as u64) + len
314                }
315            }
316
317            #[cfg(test)]
318            mod test {
319                use proptest::prelude::*;
320
321                use crate::encoding::$proto_ty::*;
322                use crate::encoding::test::{
323                    check_collection_type,
324                    check_type,
325                };
326
327                proptest! {
328                    #[test]
329                    fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
330                        check_type(value, tag, WireType::Varint,
331                                   encode, merge, encoded_len)?;
332                    }
333                    #[test]
334                    fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
335                        check_collection_type(value, tag, WireType::Varint,
336                                              encode_repeated, merge_repeated,
337                                              encoded_len_repeated)?;
338                    }
339                    #[test]
340                    fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
341                        check_type(value, tag, WireType::LengthDelimited,
342                                   encode_packed, merge_repeated,
343                                   encoded_len_packed)?;
344                    }
345                }
346            }
347         }
348
349    );
350}
351varint!(bool, bool,
352        to_uint64(value) u64::from(*value),
353        from_uint64(value) value != 0);
354varint!(i32, int32);
355varint!(i64, int64);
356varint!(u32, uint32);
357varint!(u64, uint64);
358varint!(i32, sint32,
359to_uint64(value) {
360    ((value << 1) ^ (value >> 31)) as u32 as u64
361},
362from_uint64(value) {
363    let value = value as u32;
364    ((value >> 1) as i32) ^ (-((value & 1) as i32))
365});
366varint!(i64, sint64,
367to_uint64(value) {
368    ((value << 1) ^ (value >> 63)) as u64
369},
370from_uint64(value) {
371    ((value >> 1) as i64) ^ (-((value & 1) as i64))
372});
373
374/// Macro which emits a module containing a set of encoding functions for a
375/// fixed width numeric type.
376macro_rules! fixed_width {
377    ($ty:ty,
378     $width:expr,
379     $wire_type:expr,
380     $proto_ty:ident,
381     $put:ident,
382     $get:ident) => {
383        pub mod $proto_ty {
384            use crate::encoding::*;
385
386            pub fn encode(tag: u32, value: &$ty, buf: &mut impl BufMut) {
387                encode_key(tag, $wire_type, buf);
388                buf.$put(*value);
389            }
390
391            pub fn merge(
392                wire_type: WireType,
393                value: &mut $ty,
394                buf: &mut impl Buf,
395                _ctx: DecodeContext,
396            ) -> Result<(), DecodeError> {
397                check_wire_type($wire_type, wire_type)?;
398                if buf.remaining() < $width {
399                    return Err(DecodeErrorKind::BufferUnderflow.into());
400                }
401                *value = buf.$get();
402                Ok(())
403            }
404
405            encode_repeated!($ty);
406
407            pub fn encode_packed(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
408                if values.is_empty() {
409                    return;
410                }
411
412                encode_key(tag, WireType::LengthDelimited, buf);
413                let len = values.len() as u64 * $width;
414                encode_varint(len as u64, buf);
415
416                for value in values {
417                    buf.$put(*value);
418                }
419            }
420
421            merge_repeated_numeric!($ty, $wire_type, merge, merge_repeated);
422
423            #[inline]
424            pub fn encoded_len(tag: u32, _: &$ty) -> usize {
425                key_len(tag) + $width
426            }
427
428            #[inline]
429            pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
430                (key_len(tag) + $width) * values.len()
431            }
432
433            #[inline]
434            pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
435                if values.is_empty() {
436                    0
437                } else {
438                    let len = $width * values.len();
439                    key_len(tag) + encoded_len_varint(len as u64) + len
440                }
441            }
442
443            #[cfg(test)]
444            mod test {
445                use proptest::prelude::*;
446
447                use super::super::test::{check_collection_type, check_type};
448                use super::*;
449
450                proptest! {
451                    #[test]
452                    fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
453                        check_type(value, tag, $wire_type,
454                                   encode, merge, encoded_len)?;
455                    }
456                    #[test]
457                    fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
458                        check_collection_type(value, tag, $wire_type,
459                                              encode_repeated, merge_repeated,
460                                              encoded_len_repeated)?;
461                    }
462                    #[test]
463                    fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
464                        check_type(value, tag, WireType::LengthDelimited,
465                                   encode_packed, merge_repeated,
466                                   encoded_len_packed)?;
467                    }
468                }
469            }
470        }
471    };
472}
473fixed_width!(
474    f32,
475    4,
476    WireType::ThirtyTwoBit,
477    float,
478    put_f32_le,
479    get_f32_le
480);
481fixed_width!(
482    f64,
483    8,
484    WireType::SixtyFourBit,
485    double,
486    put_f64_le,
487    get_f64_le
488);
489fixed_width!(
490    u32,
491    4,
492    WireType::ThirtyTwoBit,
493    fixed32,
494    put_u32_le,
495    get_u32_le
496);
497fixed_width!(
498    u64,
499    8,
500    WireType::SixtyFourBit,
501    fixed64,
502    put_u64_le,
503    get_u64_le
504);
505fixed_width!(
506    i32,
507    4,
508    WireType::ThirtyTwoBit,
509    sfixed32,
510    put_i32_le,
511    get_i32_le
512);
513fixed_width!(
514    i64,
515    8,
516    WireType::SixtyFourBit,
517    sfixed64,
518    put_i64_le,
519    get_i64_le
520);
521
522/// Macro which emits encoding functions for a length-delimited type.
523macro_rules! length_delimited {
524    ($ty:ty) => {
525        encode_repeated!($ty);
526
527        pub fn merge_repeated(
528            wire_type: WireType,
529            values: &mut Vec<$ty>,
530            buf: &mut impl Buf,
531            ctx: DecodeContext,
532        ) -> Result<(), DecodeError> {
533            check_wire_type(WireType::LengthDelimited, wire_type)?;
534            let mut value = Default::default();
535            merge(wire_type, &mut value, buf, ctx)?;
536            values.push(value);
537            Ok(())
538        }
539
540        #[inline]
541        #[allow(clippy::ptr_arg)]
542        pub fn encoded_len(tag: u32, value: &$ty) -> usize {
543            key_len(tag) + encoded_len_varint(value.len() as u64) + value.len()
544        }
545
546        #[inline]
547        pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
548            key_len(tag) * values.len()
549                + values
550                    .iter()
551                    .map(|value| encoded_len_varint(value.len() as u64) + value.len())
552                    .sum::<usize>()
553        }
554    };
555}
556
557pub mod string {
558    use super::*;
559
560    pub fn encode(tag: u32, value: &String, buf: &mut impl BufMut) {
561        encode_key(tag, WireType::LengthDelimited, buf);
562        encode_varint(value.len() as u64, buf);
563        buf.put_slice(value.as_bytes());
564    }
565
566    pub fn merge(
567        wire_type: WireType,
568        value: &mut String,
569        buf: &mut impl Buf,
570        ctx: DecodeContext,
571    ) -> Result<(), DecodeError> {
572        // ## Unsafety
573        //
574        // `string::merge` reuses `bytes::merge`, with an additional check of utf-8
575        // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the
576        // string is cleared, so as to avoid leaking a string field with invalid data.
577        //
578        // This implementation uses the unsafe `String::as_mut_vec` method instead of the safe
579        // alternative of temporarily swapping an empty `String` into the field, because it results
580        // in up to 10% better performance on the protobuf message decoding benchmarks.
581        //
582        // It's required when using `String::as_mut_vec` that invalid utf-8 data not be leaked into
583        // the backing `String`. To enforce this, even in the event of a panic in `bytes::merge` or
584        // in the buf implementation, a drop guard is used.
585        unsafe {
586            struct DropGuard<'a>(&'a mut Vec<u8>);
587            impl Drop for DropGuard<'_> {
588                #[inline]
589                fn drop(&mut self) {
590                    self.0.clear();
591                }
592            }
593
594            let drop_guard = DropGuard(value.as_mut_vec());
595            bytes::merge_one_copy(wire_type, drop_guard.0, buf, ctx)?;
596            match str::from_utf8(drop_guard.0) {
597                Ok(_) => {
598                    // Success; do not clear the bytes.
599                    mem::forget(drop_guard);
600                    Ok(())
601                }
602                Err(_) => Err(DecodeErrorKind::InvalidString.into()),
603            }
604        }
605    }
606
607    length_delimited!(String);
608
609    #[cfg(test)]
610    mod test {
611        use proptest::prelude::*;
612
613        use super::super::test::{check_collection_type, check_type};
614        use super::*;
615
616        proptest! {
617            #[test]
618            fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
619                super::test::check_type(value, tag, WireType::LengthDelimited,
620                                        encode, merge, encoded_len)?;
621            }
622            #[test]
623            fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
624                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
625                                                   encode_repeated, merge_repeated,
626                                                   encoded_len_repeated)?;
627            }
628        }
629    }
630}
631
632pub trait BytesAdapter: sealed::BytesAdapter {}
633
634mod sealed {
635    use super::{Buf, BufMut};
636
637    pub trait BytesAdapter: Default + Sized + 'static {
638        fn len(&self) -> usize;
639
640        /// Replace contents of this buffer with the contents of another buffer.
641        fn replace_with(&mut self, buf: impl Buf);
642
643        /// Appends this buffer to the (contents of) other buffer.
644        fn append_to(&self, buf: &mut impl BufMut);
645
646        fn is_empty(&self) -> bool {
647            self.len() == 0
648        }
649    }
650}
651
652impl BytesAdapter for Bytes {}
653
654impl sealed::BytesAdapter for Bytes {
655    fn len(&self) -> usize {
656        Buf::remaining(self)
657    }
658
659    fn replace_with(&mut self, mut buf: impl Buf) {
660        *self = buf.copy_to_bytes(buf.remaining());
661    }
662
663    fn append_to(&self, buf: &mut impl BufMut) {
664        buf.put(self.clone())
665    }
666}
667
668impl BytesAdapter for Vec<u8> {}
669
670impl sealed::BytesAdapter for Vec<u8> {
671    fn len(&self) -> usize {
672        Vec::len(self)
673    }
674
675    fn replace_with(&mut self, buf: impl Buf) {
676        self.clear();
677        self.reserve(buf.remaining());
678        self.put(buf);
679    }
680
681    fn append_to(&self, buf: &mut impl BufMut) {
682        buf.put(self.as_slice())
683    }
684}
685
686pub mod bytes {
687    use crate::error::DecodeErrorKind;
688
689    use super::*;
690
691    pub fn encode(tag: u32, value: &impl BytesAdapter, buf: &mut impl BufMut) {
692        encode_key(tag, WireType::LengthDelimited, buf);
693        encode_varint(value.len() as u64, buf);
694        value.append_to(buf);
695    }
696
697    pub fn merge(
698        wire_type: WireType,
699        value: &mut impl BytesAdapter,
700        buf: &mut impl Buf,
701        _ctx: DecodeContext,
702    ) -> Result<(), DecodeError> {
703        check_wire_type(WireType::LengthDelimited, wire_type)?;
704        let len = decode_varint(buf)?;
705        if len > buf.remaining() as u64 {
706            return Err(DecodeErrorKind::BufferUnderflow.into());
707        }
708        let len = len as usize;
709
710        // Clear the existing value. This follows from the following rule in the encoding guide[1]:
711        //
712        // > Normally, an encoded message would never have more than one instance of a non-repeated
713        // > field. However, parsers are expected to handle the case in which they do. For numeric
714        // > types and strings, if the same field appears multiple times, the parser accepts the
715        // > last value it sees.
716        //
717        // [1]: https://protobuf.dev/programming-guides/encoding/#last-one-wins
718        //
719        // This is intended for A and B both being Bytes so it is zero-copy.
720        // Some combinations of A and B types may cause a double-copy,
721        // in which case merge_one_copy() should be used instead.
722        value.replace_with(buf.copy_to_bytes(len));
723        Ok(())
724    }
725
726    pub(super) fn merge_one_copy(
727        wire_type: WireType,
728        value: &mut impl BytesAdapter,
729        buf: &mut impl Buf,
730        _ctx: DecodeContext,
731    ) -> Result<(), DecodeError> {
732        check_wire_type(WireType::LengthDelimited, wire_type)?;
733        let len = decode_varint(buf)?;
734        if len > buf.remaining() as u64 {
735            return Err(DecodeErrorKind::BufferUnderflow.into());
736        }
737        let len = len as usize;
738
739        // If we must copy, make sure to copy only once.
740        value.replace_with(buf.take(len));
741        Ok(())
742    }
743
744    length_delimited!(impl BytesAdapter);
745
746    #[cfg(test)]
747    mod test {
748        use proptest::prelude::*;
749
750        use super::super::test::{check_collection_type, check_type};
751        use super::*;
752
753        proptest! {
754            #[test]
755            fn check_vec(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
756                super::test::check_type::<Vec<u8>, Vec<u8>>(value, tag, WireType::LengthDelimited,
757                                                            encode, merge, encoded_len)?;
758            }
759
760            #[test]
761            fn check_bytes(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
762                let value = Bytes::from(value);
763                super::test::check_type::<Bytes, Bytes>(value, tag, WireType::LengthDelimited,
764                                                        encode, merge, encoded_len)?;
765            }
766
767            #[test]
768            fn check_repeated_vec(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
769                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
770                                                   encode_repeated, merge_repeated,
771                                                   encoded_len_repeated)?;
772            }
773
774            #[test]
775            fn check_repeated_bytes(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
776                let value = value.into_iter().map(Bytes::from).collect();
777                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
778                                                   encode_repeated, merge_repeated,
779                                                   encoded_len_repeated)?;
780            }
781        }
782    }
783}
784
785pub mod message {
786    use super::*;
787
788    pub fn encode<M>(tag: u32, msg: &M, buf: &mut impl BufMut)
789    where
790        M: Message,
791    {
792        encode_key(tag, WireType::LengthDelimited, buf);
793        encode_varint(msg.encoded_len() as u64, buf);
794        msg.encode_raw(buf);
795    }
796
797    pub fn merge<M, B>(
798        wire_type: WireType,
799        msg: &mut M,
800        buf: &mut B,
801        ctx: DecodeContext,
802    ) -> Result<(), DecodeError>
803    where
804        M: Message,
805        B: Buf,
806    {
807        check_wire_type(WireType::LengthDelimited, wire_type)?;
808        ctx.limit_reached()?;
809        merge_loop(
810            msg,
811            buf,
812            ctx.enter_recursion(),
813            |msg: &mut M, buf: &mut B, ctx| {
814                let (tag, wire_type) = decode_key(buf)?;
815                msg.merge_field(tag, wire_type, buf, ctx)
816            },
817        )
818    }
819
820    pub fn encode_repeated<M>(tag: u32, messages: &[M], buf: &mut impl BufMut)
821    where
822        M: Message,
823    {
824        for msg in messages {
825            encode(tag, msg, buf);
826        }
827    }
828
829    pub fn merge_repeated<M>(
830        wire_type: WireType,
831        messages: &mut Vec<M>,
832        buf: &mut impl Buf,
833        ctx: DecodeContext,
834    ) -> Result<(), DecodeError>
835    where
836        M: Message + Default,
837    {
838        check_wire_type(WireType::LengthDelimited, wire_type)?;
839        let mut msg = M::default();
840        merge(WireType::LengthDelimited, &mut msg, buf, ctx)?;
841        messages.push(msg);
842        Ok(())
843    }
844
845    #[inline]
846    pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
847    where
848        M: Message,
849    {
850        let len = msg.encoded_len();
851        key_len(tag) + encoded_len_varint(len as u64) + len
852    }
853
854    #[inline]
855    pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
856    where
857        M: Message,
858    {
859        key_len(tag) * messages.len()
860            + messages
861                .iter()
862                .map(Message::encoded_len)
863                .map(|len| len + encoded_len_varint(len as u64))
864                .sum::<usize>()
865    }
866}
867
868pub mod group {
869    use crate::error::DecodeErrorKind;
870
871    use super::*;
872
873    pub fn encode<M>(tag: u32, msg: &M, buf: &mut impl BufMut)
874    where
875        M: Message,
876    {
877        encode_key(tag, WireType::StartGroup, buf);
878        msg.encode_raw(buf);
879        encode_key(tag, WireType::EndGroup, buf);
880    }
881
882    pub fn merge<M>(
883        tag: u32,
884        wire_type: WireType,
885        msg: &mut M,
886        buf: &mut impl Buf,
887        ctx: DecodeContext,
888    ) -> Result<(), DecodeError>
889    where
890        M: Message,
891    {
892        check_wire_type(WireType::StartGroup, wire_type)?;
893
894        ctx.limit_reached()?;
895        loop {
896            let (field_tag, field_wire_type) = decode_key(buf)?;
897            if field_wire_type == WireType::EndGroup {
898                if field_tag != tag {
899                    return Err(DecodeErrorKind::UnexpectedEndGroupTag.into());
900                }
901                return Ok(());
902            }
903
904            M::merge_field(msg, field_tag, field_wire_type, buf, ctx.enter_recursion())?;
905        }
906    }
907
908    pub fn encode_repeated<M>(tag: u32, messages: &[M], buf: &mut impl BufMut)
909    where
910        M: Message,
911    {
912        for msg in messages {
913            encode(tag, msg, buf);
914        }
915    }
916
917    pub fn merge_repeated<M>(
918        tag: u32,
919        wire_type: WireType,
920        messages: &mut Vec<M>,
921        buf: &mut impl Buf,
922        ctx: DecodeContext,
923    ) -> Result<(), DecodeError>
924    where
925        M: Message + Default,
926    {
927        check_wire_type(WireType::StartGroup, wire_type)?;
928        let mut msg = M::default();
929        merge(tag, WireType::StartGroup, &mut msg, buf, ctx)?;
930        messages.push(msg);
931        Ok(())
932    }
933
934    #[inline]
935    pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
936    where
937        M: Message,
938    {
939        2 * key_len(tag) + msg.encoded_len()
940    }
941
942    #[inline]
943    pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
944    where
945        M: Message,
946    {
947        2 * key_len(tag) * messages.len() + messages.iter().map(Message::encoded_len).sum::<usize>()
948    }
949}
950
951/// Rust doesn't have a `Map` trait, so macros are currently the best way to be
952/// generic over `HashMap` and `BTreeMap`.
953macro_rules! map {
954    ($map_ty:ident) => {
955        use crate::encoding::*;
956        use core::hash::Hash;
957
958        /// Generic protobuf map encode function.
959        pub fn encode<K, V, B, KE, KL, VE, VL>(
960            key_encode: KE,
961            key_encoded_len: KL,
962            val_encode: VE,
963            val_encoded_len: VL,
964            tag: u32,
965            values: &$map_ty<K, V>,
966            buf: &mut B,
967        ) where
968            K: Default + Eq + Hash + Ord,
969            V: Default + PartialEq,
970            B: BufMut,
971            KE: Fn(u32, &K, &mut B),
972            KL: Fn(u32, &K) -> usize,
973            VE: Fn(u32, &V, &mut B),
974            VL: Fn(u32, &V) -> usize,
975        {
976            encode_with_default(
977                key_encode,
978                key_encoded_len,
979                val_encode,
980                val_encoded_len,
981                &V::default(),
982                tag,
983                values,
984                buf,
985            )
986        }
987
988        /// Generic protobuf map merge function.
989        pub fn merge<K, V, B, KM, VM>(
990            key_merge: KM,
991            val_merge: VM,
992            values: &mut $map_ty<K, V>,
993            buf: &mut B,
994            ctx: DecodeContext,
995        ) -> Result<(), DecodeError>
996        where
997            K: Default + Eq + Hash + Ord,
998            V: Default,
999            B: Buf,
1000            KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1001            VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1002        {
1003            merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx)
1004        }
1005
1006        /// Generic protobuf map encode function.
1007        pub fn encoded_len<K, V, KL, VL>(
1008            key_encoded_len: KL,
1009            val_encoded_len: VL,
1010            tag: u32,
1011            values: &$map_ty<K, V>,
1012        ) -> usize
1013        where
1014            K: Default + Eq + Hash + Ord,
1015            V: Default + PartialEq,
1016            KL: Fn(u32, &K) -> usize,
1017            VL: Fn(u32, &V) -> usize,
1018        {
1019            encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values)
1020        }
1021
1022        /// Generic protobuf map encode function with an overridden value default.
1023        ///
1024        /// This is necessary because enumeration values can have a default value other
1025        /// than 0 in proto2.
1026        pub fn encode_with_default<K, V, B, KE, KL, VE, VL>(
1027            key_encode: KE,
1028            key_encoded_len: KL,
1029            val_encode: VE,
1030            val_encoded_len: VL,
1031            val_default: &V,
1032            tag: u32,
1033            values: &$map_ty<K, V>,
1034            buf: &mut B,
1035        ) where
1036            K: Default + Eq + Hash + Ord,
1037            V: PartialEq,
1038            B: BufMut,
1039            KE: Fn(u32, &K, &mut B),
1040            KL: Fn(u32, &K) -> usize,
1041            VE: Fn(u32, &V, &mut B),
1042            VL: Fn(u32, &V) -> usize,
1043        {
1044            for (key, val) in values.iter() {
1045                let skip_key = key == &K::default();
1046                let skip_val = val == val_default;
1047
1048                let len = (if skip_key { 0 } else { key_encoded_len(1, key) })
1049                    + (if skip_val { 0 } else { val_encoded_len(2, val) });
1050
1051                encode_key(tag, WireType::LengthDelimited, buf);
1052                encode_varint(len as u64, buf);
1053                if !skip_key {
1054                    key_encode(1, key, buf);
1055                }
1056                if !skip_val {
1057                    val_encode(2, val, buf);
1058                }
1059            }
1060        }
1061
1062        /// Generic protobuf map merge function with an overridden value default.
1063        ///
1064        /// This is necessary because enumeration values can have a default value other
1065        /// than 0 in proto2.
1066        pub fn merge_with_default<K, V, B, KM, VM>(
1067            key_merge: KM,
1068            val_merge: VM,
1069            val_default: V,
1070            values: &mut $map_ty<K, V>,
1071            buf: &mut B,
1072            ctx: DecodeContext,
1073        ) -> Result<(), DecodeError>
1074        where
1075            K: Default + Eq + Hash + Ord,
1076            B: Buf,
1077            KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1078            VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1079        {
1080            let mut key = Default::default();
1081            let mut val = val_default;
1082            ctx.limit_reached()?;
1083            merge_loop(
1084                &mut (&mut key, &mut val),
1085                buf,
1086                ctx.enter_recursion(),
1087                |&mut (ref mut key, ref mut val), buf, ctx| {
1088                    let (tag, wire_type) = decode_key(buf)?;
1089                    match tag {
1090                        1 => key_merge(wire_type, key, buf, ctx),
1091                        2 => val_merge(wire_type, val, buf, ctx),
1092                        _ => skip_field(wire_type, tag, buf, ctx),
1093                    }
1094                },
1095            )?;
1096            values.insert(key, val);
1097
1098            Ok(())
1099        }
1100
1101        /// Generic protobuf map encode function with an overridden value default.
1102        ///
1103        /// This is necessary because enumeration values can have a default value other
1104        /// than 0 in proto2.
1105        pub fn encoded_len_with_default<K, V, KL, VL>(
1106            key_encoded_len: KL,
1107            val_encoded_len: VL,
1108            val_default: &V,
1109            tag: u32,
1110            values: &$map_ty<K, V>,
1111        ) -> usize
1112        where
1113            K: Default + Eq + Hash + Ord,
1114            V: PartialEq,
1115            KL: Fn(u32, &K) -> usize,
1116            VL: Fn(u32, &V) -> usize,
1117        {
1118            key_len(tag) * values.len()
1119                + values
1120                    .iter()
1121                    .map(|(key, val)| {
1122                        let len = (if key == &K::default() {
1123                            0
1124                        } else {
1125                            key_encoded_len(1, key)
1126                        }) + (if val == val_default {
1127                            0
1128                        } else {
1129                            val_encoded_len(2, val)
1130                        });
1131                        encoded_len_varint(len as u64) + len
1132                    })
1133                    .sum::<usize>()
1134        }
1135    };
1136}
1137
1138#[cfg(feature = "std")]
1139pub mod hash_map {
1140    use std::collections::HashMap;
1141    map!(HashMap);
1142}
1143
1144pub mod btree_map {
1145    map!(BTreeMap);
1146}
1147
1148#[cfg(test)]
1149mod test {
1150    #[cfg(not(feature = "std"))]
1151    use alloc::string::ToString;
1152    use core::borrow::Borrow;
1153    use core::fmt::Debug;
1154
1155    use ::bytes::BytesMut;
1156    use proptest::{prelude::*, test_runner::TestCaseResult};
1157
1158    use super::*;
1159
1160    pub fn check_type<T, B>(
1161        value: T,
1162        tag: u32,
1163        wire_type: WireType,
1164        encode: fn(u32, &B, &mut BytesMut),
1165        merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1166        encoded_len: fn(u32, &B) -> usize,
1167    ) -> TestCaseResult
1168    where
1169        T: Debug + Default + PartialEq + Borrow<B>,
1170        B: ?Sized,
1171    {
1172        prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1173
1174        let expected_len = encoded_len(tag, value.borrow());
1175
1176        let mut buf = BytesMut::with_capacity(expected_len);
1177        encode(tag, value.borrow(), &mut buf);
1178
1179        let mut buf = buf.freeze();
1180
1181        prop_assert_eq!(
1182            buf.remaining(),
1183            expected_len,
1184            "encoded_len wrong; expected: {}, actual: {}",
1185            expected_len,
1186            buf.remaining()
1187        );
1188
1189        if !buf.has_remaining() {
1190            // Short circuit for empty packed values.
1191            return Ok(());
1192        }
1193
1194        let (decoded_tag, decoded_wire_type) =
1195            decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1196        prop_assert_eq!(
1197            tag,
1198            decoded_tag,
1199            "decoded tag does not match; expected: {}, actual: {}",
1200            tag,
1201            decoded_tag
1202        );
1203
1204        prop_assert_eq!(
1205            wire_type,
1206            decoded_wire_type,
1207            "decoded wire type does not match; expected: {:?}, actual: {:?}",
1208            wire_type,
1209            decoded_wire_type,
1210        );
1211
1212        match wire_type {
1213            WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!(
1214                "64bit wire type illegal remaining: {}, tag: {}",
1215                buf.remaining(),
1216                tag
1217            ))),
1218            WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!(
1219                "32bit wire type illegal remaining: {}, tag: {}",
1220                buf.remaining(),
1221                tag
1222            ))),
1223            _ => Ok(()),
1224        }?;
1225
1226        let mut roundtrip_value = T::default();
1227        merge(
1228            wire_type,
1229            &mut roundtrip_value,
1230            &mut buf,
1231            DecodeContext::default(),
1232        )
1233        .map_err(|error| TestCaseError::fail(error.to_string()))?;
1234
1235        prop_assert!(
1236            !buf.has_remaining(),
1237            "expected buffer to be empty, remaining: {}",
1238            buf.remaining()
1239        );
1240
1241        prop_assert_eq!(value, roundtrip_value);
1242
1243        Ok(())
1244    }
1245
1246    pub fn check_collection_type<T, B, E, M, L>(
1247        value: T,
1248        tag: u32,
1249        wire_type: WireType,
1250        encode: E,
1251        mut merge: M,
1252        encoded_len: L,
1253    ) -> TestCaseResult
1254    where
1255        T: Debug + Default + PartialEq + Borrow<B>,
1256        B: ?Sized,
1257        E: FnOnce(u32, &B, &mut BytesMut),
1258        M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1259        L: FnOnce(u32, &B) -> usize,
1260    {
1261        prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1262
1263        let expected_len = encoded_len(tag, value.borrow());
1264
1265        let mut buf = BytesMut::with_capacity(expected_len);
1266        encode(tag, value.borrow(), &mut buf);
1267
1268        let mut buf = buf.freeze();
1269
1270        prop_assert_eq!(
1271            buf.remaining(),
1272            expected_len,
1273            "encoded_len wrong; expected: {}, actual: {}",
1274            expected_len,
1275            buf.remaining()
1276        );
1277
1278        let mut roundtrip_value = Default::default();
1279        while buf.has_remaining() {
1280            let (decoded_tag, decoded_wire_type) =
1281                decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1282
1283            prop_assert_eq!(
1284                tag,
1285                decoded_tag,
1286                "decoded tag does not match; expected: {}, actual: {}",
1287                tag,
1288                decoded_tag
1289            );
1290
1291            prop_assert_eq!(
1292                wire_type,
1293                decoded_wire_type,
1294                "decoded wire type does not match; expected: {:?}, actual: {:?}",
1295                wire_type,
1296                decoded_wire_type
1297            );
1298
1299            merge(
1300                wire_type,
1301                &mut roundtrip_value,
1302                &mut buf,
1303                DecodeContext::default(),
1304            )
1305            .map_err(|error| TestCaseError::fail(error.to_string()))?;
1306        }
1307
1308        prop_assert_eq!(value, roundtrip_value);
1309
1310        Ok(())
1311    }
1312
1313    #[test]
1314    fn string_merge_invalid_utf8() {
1315        let mut s = String::new();
1316        let buf = b"\x02\x80\x80";
1317
1318        let r = string::merge(
1319            WireType::LengthDelimited,
1320            &mut s,
1321            &mut &buf[..],
1322            DecodeContext::default(),
1323        );
1324        r.expect_err("must be an error");
1325        assert!(s.is_empty());
1326    }
1327
1328    /// This big bowl o' macro soup generates an encoding property test for each combination of map
1329    /// type, scalar map key, and value type.
1330    /// TODO: these tests take a long time to compile, can this be improved?
1331    #[cfg(feature = "std")]
1332    macro_rules! map_tests {
1333        (keys: $keys:tt,
1334         vals: $vals:tt) => {
1335            mod hash_map {
1336                map_tests!(@private HashMap, hash_map, $keys, $vals);
1337            }
1338            mod btree_map {
1339                map_tests!(@private BTreeMap, btree_map, $keys, $vals);
1340            }
1341        };
1342
1343        (@private $map_type:ident,
1344                  $mod_name:ident,
1345                  [$(($key_ty:ty, $key_proto:ident)),*],
1346                  $vals:tt) => {
1347            $(
1348                mod $key_proto {
1349                    use std::collections::$map_type;
1350
1351                    use proptest::prelude::*;
1352
1353                    use crate::encoding::*;
1354                    use crate::encoding::test::check_collection_type;
1355
1356                    map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals);
1357                }
1358            )*
1359        };
1360
1361        (@private $map_type:ident,
1362                  $mod_name:ident,
1363                  ($key_ty:ty, $key_proto:ident),
1364                  [$(($val_ty:ty, $val_proto:ident)),*]) => {
1365            $(
1366                proptest! {
1367                    #[test]
1368                    fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) {
1369                        check_collection_type(values, tag, WireType::LengthDelimited,
1370                                              |tag, values, buf| {
1371                                                  $mod_name::encode($key_proto::encode,
1372                                                                    $key_proto::encoded_len,
1373                                                                    $val_proto::encode,
1374                                                                    $val_proto::encoded_len,
1375                                                                    tag,
1376                                                                    values,
1377                                                                    buf)
1378                                              },
1379                                              |wire_type, values, buf, ctx| {
1380                                                  check_wire_type(WireType::LengthDelimited, wire_type)?;
1381                                                  $mod_name::merge($key_proto::merge,
1382                                                                   $val_proto::merge,
1383                                                                   values,
1384                                                                   buf,
1385                                                                   ctx)
1386                                              },
1387                                              |tag, values| {
1388                                                  $mod_name::encoded_len($key_proto::encoded_len,
1389                                                                         $val_proto::encoded_len,
1390                                                                         tag,
1391                                                                         values)
1392                                              })?;
1393                    }
1394                }
1395             )*
1396        };
1397    }
1398
1399    #[cfg(feature = "std")]
1400    map_tests!(keys: [
1401        (i32, int32),
1402        (i64, int64),
1403        (u32, uint32),
1404        (u64, uint64),
1405        (i32, sint32),
1406        (i64, sint64),
1407        (u32, fixed32),
1408        (u64, fixed64),
1409        (i32, sfixed32),
1410        (i64, sfixed64),
1411        (bool, bool),
1412        (String, string)
1413    ],
1414    vals: [
1415        (f32, float),
1416        (f64, double),
1417        (i32, int32),
1418        (i64, int64),
1419        (u32, uint32),
1420        (u64, uint64),
1421        (i32, sint32),
1422        (i64, sint64),
1423        (u32, fixed32),
1424        (u64, fixed64),
1425        (i32, sfixed32),
1426        (i64, sfixed64),
1427        (bool, bool),
1428        (String, string),
1429        (Vec<u8>, bytes)
1430    ]);
1431
1432    #[test]
1433    /// `decode_varint` accepts a `Buf`, which can be multiple concatenated buffers.
1434    /// This test ensures that future optimizations don't break the
1435    /// `decode_varint` for non-continuous memory.
1436    fn split_varint_decoding() {
1437        let mut test_values = Vec::<u64>::with_capacity(10 * 2);
1438        test_values.push(128);
1439        for i in 2..9 {
1440            test_values.push((1 << (7 * i)) - 1);
1441            test_values.push(1 << (7 * i));
1442        }
1443
1444        for v in test_values {
1445            let mut buf = BytesMut::with_capacity(10);
1446            encode_varint(v, &mut buf);
1447            let half_len = buf.len() / 2;
1448            let len = buf.len();
1449            // this weird sequence here splits the buffer into two instances of Bytes
1450            // which we then stitch together with `bytes::buf::Buf::chain`
1451            // which ensures the varint bytes are not in a single chunk
1452            let b2 = buf.split_off(half_len);
1453            let mut c = buf.chain(b2);
1454
1455            // make sure all the bytes are inside
1456            assert_eq!(c.remaining(), len);
1457            // make sure the first chunk is split as we expected
1458            assert_eq!(c.chunk().len(), half_len);
1459            assert_eq!(v, decode_varint(&mut c).unwrap());
1460        }
1461    }
1462}