1use alloc::collections::BTreeMap;
9use alloc::string::String;
10use alloc::vec::Vec;
11use core::mem;
12use core::str;
13
14use ::bytes::{Buf, BufMut, Bytes};
15
16use crate::error::DecodeErrorKind;
17use crate::DecodeError;
18use crate::Message;
19
20pub mod varint;
21pub use varint::{decode_varint, encode_varint, encoded_len_varint};
22
23pub mod length_delimiter;
24pub use length_delimiter::{
25 decode_length_delimiter, encode_length_delimiter, length_delimiter_len,
26};
27
28pub mod wire_type;
29pub use wire_type::{check_wire_type, WireType};
30
31#[derive(Clone, Debug)]
36#[cfg_attr(feature = "no-recursion-limit", derive(Default))]
37pub struct DecodeContext {
38 #[cfg(not(feature = "no-recursion-limit"))]
45 recurse_count: u32,
46}
47
48#[cfg(not(feature = "no-recursion-limit"))]
49impl Default for DecodeContext {
50 #[inline]
51 fn default() -> DecodeContext {
52 DecodeContext {
53 recurse_count: crate::RECURSION_LIMIT,
54 }
55 }
56}
57
58impl DecodeContext {
59 #[cfg(not(feature = "no-recursion-limit"))]
65 #[inline]
66 pub(crate) fn enter_recursion(&self) -> DecodeContext {
67 DecodeContext {
68 recurse_count: self.recurse_count - 1,
69 }
70 }
71
72 #[cfg(feature = "no-recursion-limit")]
73 #[inline]
74 pub(crate) fn enter_recursion(&self) -> DecodeContext {
75 DecodeContext {}
76 }
77
78 #[cfg(not(feature = "no-recursion-limit"))]
84 #[inline]
85 pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
86 if self.recurse_count == 0 {
87 Err(DecodeErrorKind::RecursionLimitReached.into())
88 } else {
89 Ok(())
90 }
91 }
92
93 #[cfg(feature = "no-recursion-limit")]
94 #[inline]
95 pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
96 Ok(())
97 }
98}
99
100pub const MIN_TAG: u32 = 1;
101pub const MAX_TAG: u32 = (1 << 29) - 1;
102
103#[inline]
106pub fn encode_key(tag: u32, wire_type: WireType, buf: &mut impl BufMut) {
107 debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
108 let key = (tag << 3) | wire_type as u32;
109 encode_varint(u64::from(key), buf);
110}
111
112#[inline(always)]
115pub fn decode_key(buf: &mut impl Buf) -> Result<(u32, WireType), DecodeError> {
116 let key = decode_varint(buf)?;
117 if key > u64::from(u32::MAX) {
118 return Err(DecodeErrorKind::InvalidKey { key }.into());
119 }
120 let wire_type = WireType::try_from(key & 0x07)?;
121 let tag = key as u32 >> 3;
122
123 if tag < MIN_TAG {
124 return Err(DecodeErrorKind::InvalidTag.into());
125 }
126
127 Ok((tag, wire_type))
128}
129
130#[inline]
133pub const fn key_len(tag: u32) -> usize {
134 encoded_len_varint((tag << 3) as u64)
135}
136
137pub fn merge_loop<T, M, B>(
140 value: &mut T,
141 buf: &mut B,
142 ctx: DecodeContext,
143 mut merge: M,
144) -> Result<(), DecodeError>
145where
146 M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>,
147 B: Buf,
148{
149 let len = decode_varint(buf)?;
150 let remaining = buf.remaining();
151 if len > remaining as u64 {
152 return Err(DecodeErrorKind::BufferUnderflow.into());
153 }
154
155 let limit = remaining - len as usize;
156 while buf.remaining() > limit {
157 merge(value, buf, ctx.clone())?;
158 }
159
160 if buf.remaining() != limit {
161 return Err(DecodeErrorKind::DelimitedLengthExceeded.into());
162 }
163 Ok(())
164}
165
166pub fn skip_field(
167 wire_type: WireType,
168 tag: u32,
169 buf: &mut impl Buf,
170 ctx: DecodeContext,
171) -> Result<(), DecodeError> {
172 ctx.limit_reached()?;
173 let len = match wire_type {
174 WireType::Varint => decode_varint(buf).map(|_| 0)?,
175 WireType::ThirtyTwoBit => 4,
176 WireType::SixtyFourBit => 8,
177 WireType::LengthDelimited => decode_varint(buf)?,
178 WireType::StartGroup => loop {
179 let (inner_tag, inner_wire_type) = decode_key(buf)?;
180 match inner_wire_type {
181 WireType::EndGroup => {
182 if inner_tag != tag {
183 return Err(DecodeErrorKind::UnexpectedEndGroupTag.into());
184 }
185 break 0;
186 }
187 _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
188 }
189 },
190 WireType::EndGroup => return Err(DecodeErrorKind::UnexpectedEndGroupTag.into()),
191 };
192
193 if len > buf.remaining() as u64 {
194 return Err(DecodeErrorKind::BufferUnderflow.into());
195 }
196
197 buf.advance(len as usize);
198 Ok(())
199}
200
201macro_rules! encode_repeated {
203 ($ty:ty) => {
204 pub fn encode_repeated(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
205 for value in values {
206 encode(tag, value, buf);
207 }
208 }
209 };
210}
211
212macro_rules! merge_repeated_numeric {
214 ($ty:ty,
215 $wire_type:expr,
216 $merge:ident,
217 $merge_repeated:ident) => {
218 pub fn $merge_repeated(
219 wire_type: WireType,
220 values: &mut Vec<$ty>,
221 buf: &mut impl Buf,
222 ctx: DecodeContext,
223 ) -> Result<(), DecodeError> {
224 if wire_type == WireType::LengthDelimited {
225 merge_loop(values, buf, ctx, |values, buf, ctx| {
227 let mut value = Default::default();
228 $merge($wire_type, &mut value, buf, ctx)?;
229 values.push(value);
230 Ok(())
231 })
232 } else {
233 check_wire_type($wire_type, wire_type)?;
235 let mut value = Default::default();
236 $merge(wire_type, &mut value, buf, ctx)?;
237 values.push(value);
238 Ok(())
239 }
240 }
241 };
242}
243
244macro_rules! varint {
247 ($ty:ty,
248 $proto_ty:ident) => (
249 varint!($ty,
250 $proto_ty,
251 to_uint64(value) { *value as u64 },
252 from_uint64(value) { value as $ty });
253 );
254
255 ($ty:ty,
256 $proto_ty:ident,
257 to_uint64($to_uint64_value:ident) $to_uint64:expr,
258 from_uint64($from_uint64_value:ident) $from_uint64:expr) => (
259
260 pub mod $proto_ty {
261 use crate::encoding::*;
262
263 pub fn encode(tag: u32, $to_uint64_value: &$ty, buf: &mut impl BufMut) {
264 encode_key(tag, WireType::Varint, buf);
265 encode_varint($to_uint64, buf);
266 }
267
268 pub fn merge(wire_type: WireType, value: &mut $ty, buf: &mut impl Buf, _ctx: DecodeContext) -> Result<(), DecodeError> {
269 check_wire_type(WireType::Varint, wire_type)?;
270 let $from_uint64_value = decode_varint(buf)?;
271 *value = $from_uint64;
272 Ok(())
273 }
274
275 encode_repeated!($ty);
276
277 pub fn encode_packed(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
278 if values.is_empty() { return; }
279
280 encode_key(tag, WireType::LengthDelimited, buf);
281 let len: usize = values.iter().map(|$to_uint64_value| {
282 encoded_len_varint($to_uint64)
283 }).sum();
284 encode_varint(len as u64, buf);
285
286 for $to_uint64_value in values {
287 encode_varint($to_uint64, buf);
288 }
289 }
290
291 merge_repeated_numeric!($ty, WireType::Varint, merge, merge_repeated);
292
293 #[inline]
294 pub fn encoded_len(tag: u32, $to_uint64_value: &$ty) -> usize {
295 key_len(tag) + encoded_len_varint($to_uint64)
296 }
297
298 #[inline]
299 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
300 key_len(tag) * values.len() + values.iter().map(|$to_uint64_value| {
301 encoded_len_varint($to_uint64)
302 }).sum::<usize>()
303 }
304
305 #[inline]
306 pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
307 if values.is_empty() {
308 0
309 } else {
310 let len = values.iter()
311 .map(|$to_uint64_value| encoded_len_varint($to_uint64))
312 .sum::<usize>();
313 key_len(tag) + encoded_len_varint(len as u64) + len
314 }
315 }
316
317 #[cfg(test)]
318 mod test {
319 use proptest::prelude::*;
320
321 use crate::encoding::$proto_ty::*;
322 use crate::encoding::test::{
323 check_collection_type,
324 check_type,
325 };
326
327 proptest! {
328 #[test]
329 fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
330 check_type(value, tag, WireType::Varint,
331 encode, merge, encoded_len)?;
332 }
333 #[test]
334 fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
335 check_collection_type(value, tag, WireType::Varint,
336 encode_repeated, merge_repeated,
337 encoded_len_repeated)?;
338 }
339 #[test]
340 fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
341 check_type(value, tag, WireType::LengthDelimited,
342 encode_packed, merge_repeated,
343 encoded_len_packed)?;
344 }
345 }
346 }
347 }
348
349 );
350}
351varint!(bool, bool,
352 to_uint64(value) u64::from(*value),
353 from_uint64(value) value != 0);
354varint!(i32, int32);
355varint!(i64, int64);
356varint!(u32, uint32);
357varint!(u64, uint64);
358varint!(i32, sint32,
359to_uint64(value) {
360 ((value << 1) ^ (value >> 31)) as u32 as u64
361},
362from_uint64(value) {
363 let value = value as u32;
364 ((value >> 1) as i32) ^ (-((value & 1) as i32))
365});
366varint!(i64, sint64,
367to_uint64(value) {
368 ((value << 1) ^ (value >> 63)) as u64
369},
370from_uint64(value) {
371 ((value >> 1) as i64) ^ (-((value & 1) as i64))
372});
373
374macro_rules! fixed_width {
377 ($ty:ty,
378 $width:expr,
379 $wire_type:expr,
380 $proto_ty:ident,
381 $put:ident,
382 $get:ident) => {
383 pub mod $proto_ty {
384 use crate::encoding::*;
385
386 pub fn encode(tag: u32, value: &$ty, buf: &mut impl BufMut) {
387 encode_key(tag, $wire_type, buf);
388 buf.$put(*value);
389 }
390
391 pub fn merge(
392 wire_type: WireType,
393 value: &mut $ty,
394 buf: &mut impl Buf,
395 _ctx: DecodeContext,
396 ) -> Result<(), DecodeError> {
397 check_wire_type($wire_type, wire_type)?;
398 if buf.remaining() < $width {
399 return Err(DecodeErrorKind::BufferUnderflow.into());
400 }
401 *value = buf.$get();
402 Ok(())
403 }
404
405 encode_repeated!($ty);
406
407 pub fn encode_packed(tag: u32, values: &[$ty], buf: &mut impl BufMut) {
408 if values.is_empty() {
409 return;
410 }
411
412 encode_key(tag, WireType::LengthDelimited, buf);
413 let len = values.len() as u64 * $width;
414 encode_varint(len as u64, buf);
415
416 for value in values {
417 buf.$put(*value);
418 }
419 }
420
421 merge_repeated_numeric!($ty, $wire_type, merge, merge_repeated);
422
423 #[inline]
424 pub fn encoded_len(tag: u32, _: &$ty) -> usize {
425 key_len(tag) + $width
426 }
427
428 #[inline]
429 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
430 (key_len(tag) + $width) * values.len()
431 }
432
433 #[inline]
434 pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
435 if values.is_empty() {
436 0
437 } else {
438 let len = $width * values.len();
439 key_len(tag) + encoded_len_varint(len as u64) + len
440 }
441 }
442
443 #[cfg(test)]
444 mod test {
445 use proptest::prelude::*;
446
447 use super::super::test::{check_collection_type, check_type};
448 use super::*;
449
450 proptest! {
451 #[test]
452 fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
453 check_type(value, tag, $wire_type,
454 encode, merge, encoded_len)?;
455 }
456 #[test]
457 fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
458 check_collection_type(value, tag, $wire_type,
459 encode_repeated, merge_repeated,
460 encoded_len_repeated)?;
461 }
462 #[test]
463 fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
464 check_type(value, tag, WireType::LengthDelimited,
465 encode_packed, merge_repeated,
466 encoded_len_packed)?;
467 }
468 }
469 }
470 }
471 };
472}
473fixed_width!(
474 f32,
475 4,
476 WireType::ThirtyTwoBit,
477 float,
478 put_f32_le,
479 get_f32_le
480);
481fixed_width!(
482 f64,
483 8,
484 WireType::SixtyFourBit,
485 double,
486 put_f64_le,
487 get_f64_le
488);
489fixed_width!(
490 u32,
491 4,
492 WireType::ThirtyTwoBit,
493 fixed32,
494 put_u32_le,
495 get_u32_le
496);
497fixed_width!(
498 u64,
499 8,
500 WireType::SixtyFourBit,
501 fixed64,
502 put_u64_le,
503 get_u64_le
504);
505fixed_width!(
506 i32,
507 4,
508 WireType::ThirtyTwoBit,
509 sfixed32,
510 put_i32_le,
511 get_i32_le
512);
513fixed_width!(
514 i64,
515 8,
516 WireType::SixtyFourBit,
517 sfixed64,
518 put_i64_le,
519 get_i64_le
520);
521
522macro_rules! length_delimited {
524 ($ty:ty) => {
525 encode_repeated!($ty);
526
527 pub fn merge_repeated(
528 wire_type: WireType,
529 values: &mut Vec<$ty>,
530 buf: &mut impl Buf,
531 ctx: DecodeContext,
532 ) -> Result<(), DecodeError> {
533 check_wire_type(WireType::LengthDelimited, wire_type)?;
534 let mut value = Default::default();
535 merge(wire_type, &mut value, buf, ctx)?;
536 values.push(value);
537 Ok(())
538 }
539
540 #[inline]
541 #[allow(clippy::ptr_arg)]
542 pub fn encoded_len(tag: u32, value: &$ty) -> usize {
543 key_len(tag) + encoded_len_varint(value.len() as u64) + value.len()
544 }
545
546 #[inline]
547 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
548 key_len(tag) * values.len()
549 + values
550 .iter()
551 .map(|value| encoded_len_varint(value.len() as u64) + value.len())
552 .sum::<usize>()
553 }
554 };
555}
556
557pub mod string {
558 use super::*;
559
560 pub fn encode(tag: u32, value: &String, buf: &mut impl BufMut) {
561 encode_key(tag, WireType::LengthDelimited, buf);
562 encode_varint(value.len() as u64, buf);
563 buf.put_slice(value.as_bytes());
564 }
565
566 pub fn merge(
567 wire_type: WireType,
568 value: &mut String,
569 buf: &mut impl Buf,
570 ctx: DecodeContext,
571 ) -> Result<(), DecodeError> {
572 unsafe {
586 struct DropGuard<'a>(&'a mut Vec<u8>);
587 impl Drop for DropGuard<'_> {
588 #[inline]
589 fn drop(&mut self) {
590 self.0.clear();
591 }
592 }
593
594 let drop_guard = DropGuard(value.as_mut_vec());
595 bytes::merge_one_copy(wire_type, drop_guard.0, buf, ctx)?;
596 match str::from_utf8(drop_guard.0) {
597 Ok(_) => {
598 mem::forget(drop_guard);
600 Ok(())
601 }
602 Err(_) => Err(DecodeErrorKind::InvalidString.into()),
603 }
604 }
605 }
606
607 length_delimited!(String);
608
609 #[cfg(test)]
610 mod test {
611 use proptest::prelude::*;
612
613 use super::super::test::{check_collection_type, check_type};
614 use super::*;
615
616 proptest! {
617 #[test]
618 fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
619 super::test::check_type(value, tag, WireType::LengthDelimited,
620 encode, merge, encoded_len)?;
621 }
622 #[test]
623 fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
624 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
625 encode_repeated, merge_repeated,
626 encoded_len_repeated)?;
627 }
628 }
629 }
630}
631
632pub trait BytesAdapter: sealed::BytesAdapter {}
633
634mod sealed {
635 use super::{Buf, BufMut};
636
637 pub trait BytesAdapter: Default + Sized + 'static {
638 fn len(&self) -> usize;
639
640 fn replace_with(&mut self, buf: impl Buf);
642
643 fn append_to(&self, buf: &mut impl BufMut);
645
646 fn is_empty(&self) -> bool {
647 self.len() == 0
648 }
649 }
650}
651
652impl BytesAdapter for Bytes {}
653
654impl sealed::BytesAdapter for Bytes {
655 fn len(&self) -> usize {
656 Buf::remaining(self)
657 }
658
659 fn replace_with(&mut self, mut buf: impl Buf) {
660 *self = buf.copy_to_bytes(buf.remaining());
661 }
662
663 fn append_to(&self, buf: &mut impl BufMut) {
664 buf.put(self.clone())
665 }
666}
667
668impl BytesAdapter for Vec<u8> {}
669
670impl sealed::BytesAdapter for Vec<u8> {
671 fn len(&self) -> usize {
672 Vec::len(self)
673 }
674
675 fn replace_with(&mut self, buf: impl Buf) {
676 self.clear();
677 self.reserve(buf.remaining());
678 self.put(buf);
679 }
680
681 fn append_to(&self, buf: &mut impl BufMut) {
682 buf.put(self.as_slice())
683 }
684}
685
686pub mod bytes {
687 use crate::error::DecodeErrorKind;
688
689 use super::*;
690
691 pub fn encode(tag: u32, value: &impl BytesAdapter, buf: &mut impl BufMut) {
692 encode_key(tag, WireType::LengthDelimited, buf);
693 encode_varint(value.len() as u64, buf);
694 value.append_to(buf);
695 }
696
697 pub fn merge(
698 wire_type: WireType,
699 value: &mut impl BytesAdapter,
700 buf: &mut impl Buf,
701 _ctx: DecodeContext,
702 ) -> Result<(), DecodeError> {
703 check_wire_type(WireType::LengthDelimited, wire_type)?;
704 let len = decode_varint(buf)?;
705 if len > buf.remaining() as u64 {
706 return Err(DecodeErrorKind::BufferUnderflow.into());
707 }
708 let len = len as usize;
709
710 value.replace_with(buf.copy_to_bytes(len));
723 Ok(())
724 }
725
726 pub(super) fn merge_one_copy(
727 wire_type: WireType,
728 value: &mut impl BytesAdapter,
729 buf: &mut impl Buf,
730 _ctx: DecodeContext,
731 ) -> Result<(), DecodeError> {
732 check_wire_type(WireType::LengthDelimited, wire_type)?;
733 let len = decode_varint(buf)?;
734 if len > buf.remaining() as u64 {
735 return Err(DecodeErrorKind::BufferUnderflow.into());
736 }
737 let len = len as usize;
738
739 value.replace_with(buf.take(len));
741 Ok(())
742 }
743
744 length_delimited!(impl BytesAdapter);
745
746 #[cfg(test)]
747 mod test {
748 use proptest::prelude::*;
749
750 use super::super::test::{check_collection_type, check_type};
751 use super::*;
752
753 proptest! {
754 #[test]
755 fn check_vec(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
756 super::test::check_type::<Vec<u8>, Vec<u8>>(value, tag, WireType::LengthDelimited,
757 encode, merge, encoded_len)?;
758 }
759
760 #[test]
761 fn check_bytes(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
762 let value = Bytes::from(value);
763 super::test::check_type::<Bytes, Bytes>(value, tag, WireType::LengthDelimited,
764 encode, merge, encoded_len)?;
765 }
766
767 #[test]
768 fn check_repeated_vec(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
769 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
770 encode_repeated, merge_repeated,
771 encoded_len_repeated)?;
772 }
773
774 #[test]
775 fn check_repeated_bytes(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
776 let value = value.into_iter().map(Bytes::from).collect();
777 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
778 encode_repeated, merge_repeated,
779 encoded_len_repeated)?;
780 }
781 }
782 }
783}
784
785pub mod message {
786 use super::*;
787
788 pub fn encode<M>(tag: u32, msg: &M, buf: &mut impl BufMut)
789 where
790 M: Message,
791 {
792 encode_key(tag, WireType::LengthDelimited, buf);
793 encode_varint(msg.encoded_len() as u64, buf);
794 msg.encode_raw(buf);
795 }
796
797 pub fn merge<M, B>(
798 wire_type: WireType,
799 msg: &mut M,
800 buf: &mut B,
801 ctx: DecodeContext,
802 ) -> Result<(), DecodeError>
803 where
804 M: Message,
805 B: Buf,
806 {
807 check_wire_type(WireType::LengthDelimited, wire_type)?;
808 ctx.limit_reached()?;
809 merge_loop(
810 msg,
811 buf,
812 ctx.enter_recursion(),
813 |msg: &mut M, buf: &mut B, ctx| {
814 let (tag, wire_type) = decode_key(buf)?;
815 msg.merge_field(tag, wire_type, buf, ctx)
816 },
817 )
818 }
819
820 pub fn encode_repeated<M>(tag: u32, messages: &[M], buf: &mut impl BufMut)
821 where
822 M: Message,
823 {
824 for msg in messages {
825 encode(tag, msg, buf);
826 }
827 }
828
829 pub fn merge_repeated<M>(
830 wire_type: WireType,
831 messages: &mut Vec<M>,
832 buf: &mut impl Buf,
833 ctx: DecodeContext,
834 ) -> Result<(), DecodeError>
835 where
836 M: Message + Default,
837 {
838 check_wire_type(WireType::LengthDelimited, wire_type)?;
839 let mut msg = M::default();
840 merge(WireType::LengthDelimited, &mut msg, buf, ctx)?;
841 messages.push(msg);
842 Ok(())
843 }
844
845 #[inline]
846 pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
847 where
848 M: Message,
849 {
850 let len = msg.encoded_len();
851 key_len(tag) + encoded_len_varint(len as u64) + len
852 }
853
854 #[inline]
855 pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
856 where
857 M: Message,
858 {
859 key_len(tag) * messages.len()
860 + messages
861 .iter()
862 .map(Message::encoded_len)
863 .map(|len| len + encoded_len_varint(len as u64))
864 .sum::<usize>()
865 }
866}
867
868pub mod group {
869 use crate::error::DecodeErrorKind;
870
871 use super::*;
872
873 pub fn encode<M>(tag: u32, msg: &M, buf: &mut impl BufMut)
874 where
875 M: Message,
876 {
877 encode_key(tag, WireType::StartGroup, buf);
878 msg.encode_raw(buf);
879 encode_key(tag, WireType::EndGroup, buf);
880 }
881
882 pub fn merge<M>(
883 tag: u32,
884 wire_type: WireType,
885 msg: &mut M,
886 buf: &mut impl Buf,
887 ctx: DecodeContext,
888 ) -> Result<(), DecodeError>
889 where
890 M: Message,
891 {
892 check_wire_type(WireType::StartGroup, wire_type)?;
893
894 ctx.limit_reached()?;
895 loop {
896 let (field_tag, field_wire_type) = decode_key(buf)?;
897 if field_wire_type == WireType::EndGroup {
898 if field_tag != tag {
899 return Err(DecodeErrorKind::UnexpectedEndGroupTag.into());
900 }
901 return Ok(());
902 }
903
904 M::merge_field(msg, field_tag, field_wire_type, buf, ctx.enter_recursion())?;
905 }
906 }
907
908 pub fn encode_repeated<M>(tag: u32, messages: &[M], buf: &mut impl BufMut)
909 where
910 M: Message,
911 {
912 for msg in messages {
913 encode(tag, msg, buf);
914 }
915 }
916
917 pub fn merge_repeated<M>(
918 tag: u32,
919 wire_type: WireType,
920 messages: &mut Vec<M>,
921 buf: &mut impl Buf,
922 ctx: DecodeContext,
923 ) -> Result<(), DecodeError>
924 where
925 M: Message + Default,
926 {
927 check_wire_type(WireType::StartGroup, wire_type)?;
928 let mut msg = M::default();
929 merge(tag, WireType::StartGroup, &mut msg, buf, ctx)?;
930 messages.push(msg);
931 Ok(())
932 }
933
934 #[inline]
935 pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
936 where
937 M: Message,
938 {
939 2 * key_len(tag) + msg.encoded_len()
940 }
941
942 #[inline]
943 pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
944 where
945 M: Message,
946 {
947 2 * key_len(tag) * messages.len() + messages.iter().map(Message::encoded_len).sum::<usize>()
948 }
949}
950
951macro_rules! map {
954 ($map_ty:ident) => {
955 use crate::encoding::*;
956 use core::hash::Hash;
957
958 pub fn encode<K, V, B, KE, KL, VE, VL>(
960 key_encode: KE,
961 key_encoded_len: KL,
962 val_encode: VE,
963 val_encoded_len: VL,
964 tag: u32,
965 values: &$map_ty<K, V>,
966 buf: &mut B,
967 ) where
968 K: Default + Eq + Hash + Ord,
969 V: Default + PartialEq,
970 B: BufMut,
971 KE: Fn(u32, &K, &mut B),
972 KL: Fn(u32, &K) -> usize,
973 VE: Fn(u32, &V, &mut B),
974 VL: Fn(u32, &V) -> usize,
975 {
976 encode_with_default(
977 key_encode,
978 key_encoded_len,
979 val_encode,
980 val_encoded_len,
981 &V::default(),
982 tag,
983 values,
984 buf,
985 )
986 }
987
988 pub fn merge<K, V, B, KM, VM>(
990 key_merge: KM,
991 val_merge: VM,
992 values: &mut $map_ty<K, V>,
993 buf: &mut B,
994 ctx: DecodeContext,
995 ) -> Result<(), DecodeError>
996 where
997 K: Default + Eq + Hash + Ord,
998 V: Default,
999 B: Buf,
1000 KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1001 VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1002 {
1003 merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx)
1004 }
1005
1006 pub fn encoded_len<K, V, KL, VL>(
1008 key_encoded_len: KL,
1009 val_encoded_len: VL,
1010 tag: u32,
1011 values: &$map_ty<K, V>,
1012 ) -> usize
1013 where
1014 K: Default + Eq + Hash + Ord,
1015 V: Default + PartialEq,
1016 KL: Fn(u32, &K) -> usize,
1017 VL: Fn(u32, &V) -> usize,
1018 {
1019 encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values)
1020 }
1021
1022 pub fn encode_with_default<K, V, B, KE, KL, VE, VL>(
1027 key_encode: KE,
1028 key_encoded_len: KL,
1029 val_encode: VE,
1030 val_encoded_len: VL,
1031 val_default: &V,
1032 tag: u32,
1033 values: &$map_ty<K, V>,
1034 buf: &mut B,
1035 ) where
1036 K: Default + Eq + Hash + Ord,
1037 V: PartialEq,
1038 B: BufMut,
1039 KE: Fn(u32, &K, &mut B),
1040 KL: Fn(u32, &K) -> usize,
1041 VE: Fn(u32, &V, &mut B),
1042 VL: Fn(u32, &V) -> usize,
1043 {
1044 for (key, val) in values.iter() {
1045 let skip_key = key == &K::default();
1046 let skip_val = val == val_default;
1047
1048 let len = (if skip_key { 0 } else { key_encoded_len(1, key) })
1049 + (if skip_val { 0 } else { val_encoded_len(2, val) });
1050
1051 encode_key(tag, WireType::LengthDelimited, buf);
1052 encode_varint(len as u64, buf);
1053 if !skip_key {
1054 key_encode(1, key, buf);
1055 }
1056 if !skip_val {
1057 val_encode(2, val, buf);
1058 }
1059 }
1060 }
1061
1062 pub fn merge_with_default<K, V, B, KM, VM>(
1067 key_merge: KM,
1068 val_merge: VM,
1069 val_default: V,
1070 values: &mut $map_ty<K, V>,
1071 buf: &mut B,
1072 ctx: DecodeContext,
1073 ) -> Result<(), DecodeError>
1074 where
1075 K: Default + Eq + Hash + Ord,
1076 B: Buf,
1077 KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1078 VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1079 {
1080 let mut key = Default::default();
1081 let mut val = val_default;
1082 ctx.limit_reached()?;
1083 merge_loop(
1084 &mut (&mut key, &mut val),
1085 buf,
1086 ctx.enter_recursion(),
1087 |&mut (ref mut key, ref mut val), buf, ctx| {
1088 let (tag, wire_type) = decode_key(buf)?;
1089 match tag {
1090 1 => key_merge(wire_type, key, buf, ctx),
1091 2 => val_merge(wire_type, val, buf, ctx),
1092 _ => skip_field(wire_type, tag, buf, ctx),
1093 }
1094 },
1095 )?;
1096 values.insert(key, val);
1097
1098 Ok(())
1099 }
1100
1101 pub fn encoded_len_with_default<K, V, KL, VL>(
1106 key_encoded_len: KL,
1107 val_encoded_len: VL,
1108 val_default: &V,
1109 tag: u32,
1110 values: &$map_ty<K, V>,
1111 ) -> usize
1112 where
1113 K: Default + Eq + Hash + Ord,
1114 V: PartialEq,
1115 KL: Fn(u32, &K) -> usize,
1116 VL: Fn(u32, &V) -> usize,
1117 {
1118 key_len(tag) * values.len()
1119 + values
1120 .iter()
1121 .map(|(key, val)| {
1122 let len = (if key == &K::default() {
1123 0
1124 } else {
1125 key_encoded_len(1, key)
1126 }) + (if val == val_default {
1127 0
1128 } else {
1129 val_encoded_len(2, val)
1130 });
1131 encoded_len_varint(len as u64) + len
1132 })
1133 .sum::<usize>()
1134 }
1135 };
1136}
1137
1138#[cfg(feature = "std")]
1139pub mod hash_map {
1140 use std::collections::HashMap;
1141 map!(HashMap);
1142}
1143
1144pub mod btree_map {
1145 map!(BTreeMap);
1146}
1147
1148#[cfg(test)]
1149mod test {
1150 #[cfg(not(feature = "std"))]
1151 use alloc::string::ToString;
1152 use core::borrow::Borrow;
1153 use core::fmt::Debug;
1154
1155 use ::bytes::BytesMut;
1156 use proptest::{prelude::*, test_runner::TestCaseResult};
1157
1158 use super::*;
1159
1160 pub fn check_type<T, B>(
1161 value: T,
1162 tag: u32,
1163 wire_type: WireType,
1164 encode: fn(u32, &B, &mut BytesMut),
1165 merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1166 encoded_len: fn(u32, &B) -> usize,
1167 ) -> TestCaseResult
1168 where
1169 T: Debug + Default + PartialEq + Borrow<B>,
1170 B: ?Sized,
1171 {
1172 prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1173
1174 let expected_len = encoded_len(tag, value.borrow());
1175
1176 let mut buf = BytesMut::with_capacity(expected_len);
1177 encode(tag, value.borrow(), &mut buf);
1178
1179 let mut buf = buf.freeze();
1180
1181 prop_assert_eq!(
1182 buf.remaining(),
1183 expected_len,
1184 "encoded_len wrong; expected: {}, actual: {}",
1185 expected_len,
1186 buf.remaining()
1187 );
1188
1189 if !buf.has_remaining() {
1190 return Ok(());
1192 }
1193
1194 let (decoded_tag, decoded_wire_type) =
1195 decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1196 prop_assert_eq!(
1197 tag,
1198 decoded_tag,
1199 "decoded tag does not match; expected: {}, actual: {}",
1200 tag,
1201 decoded_tag
1202 );
1203
1204 prop_assert_eq!(
1205 wire_type,
1206 decoded_wire_type,
1207 "decoded wire type does not match; expected: {:?}, actual: {:?}",
1208 wire_type,
1209 decoded_wire_type,
1210 );
1211
1212 match wire_type {
1213 WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!(
1214 "64bit wire type illegal remaining: {}, tag: {}",
1215 buf.remaining(),
1216 tag
1217 ))),
1218 WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!(
1219 "32bit wire type illegal remaining: {}, tag: {}",
1220 buf.remaining(),
1221 tag
1222 ))),
1223 _ => Ok(()),
1224 }?;
1225
1226 let mut roundtrip_value = T::default();
1227 merge(
1228 wire_type,
1229 &mut roundtrip_value,
1230 &mut buf,
1231 DecodeContext::default(),
1232 )
1233 .map_err(|error| TestCaseError::fail(error.to_string()))?;
1234
1235 prop_assert!(
1236 !buf.has_remaining(),
1237 "expected buffer to be empty, remaining: {}",
1238 buf.remaining()
1239 );
1240
1241 prop_assert_eq!(value, roundtrip_value);
1242
1243 Ok(())
1244 }
1245
1246 pub fn check_collection_type<T, B, E, M, L>(
1247 value: T,
1248 tag: u32,
1249 wire_type: WireType,
1250 encode: E,
1251 mut merge: M,
1252 encoded_len: L,
1253 ) -> TestCaseResult
1254 where
1255 T: Debug + Default + PartialEq + Borrow<B>,
1256 B: ?Sized,
1257 E: FnOnce(u32, &B, &mut BytesMut),
1258 M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1259 L: FnOnce(u32, &B) -> usize,
1260 {
1261 prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1262
1263 let expected_len = encoded_len(tag, value.borrow());
1264
1265 let mut buf = BytesMut::with_capacity(expected_len);
1266 encode(tag, value.borrow(), &mut buf);
1267
1268 let mut buf = buf.freeze();
1269
1270 prop_assert_eq!(
1271 buf.remaining(),
1272 expected_len,
1273 "encoded_len wrong; expected: {}, actual: {}",
1274 expected_len,
1275 buf.remaining()
1276 );
1277
1278 let mut roundtrip_value = Default::default();
1279 while buf.has_remaining() {
1280 let (decoded_tag, decoded_wire_type) =
1281 decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1282
1283 prop_assert_eq!(
1284 tag,
1285 decoded_tag,
1286 "decoded tag does not match; expected: {}, actual: {}",
1287 tag,
1288 decoded_tag
1289 );
1290
1291 prop_assert_eq!(
1292 wire_type,
1293 decoded_wire_type,
1294 "decoded wire type does not match; expected: {:?}, actual: {:?}",
1295 wire_type,
1296 decoded_wire_type
1297 );
1298
1299 merge(
1300 wire_type,
1301 &mut roundtrip_value,
1302 &mut buf,
1303 DecodeContext::default(),
1304 )
1305 .map_err(|error| TestCaseError::fail(error.to_string()))?;
1306 }
1307
1308 prop_assert_eq!(value, roundtrip_value);
1309
1310 Ok(())
1311 }
1312
1313 #[test]
1314 fn string_merge_invalid_utf8() {
1315 let mut s = String::new();
1316 let buf = b"\x02\x80\x80";
1317
1318 let r = string::merge(
1319 WireType::LengthDelimited,
1320 &mut s,
1321 &mut &buf[..],
1322 DecodeContext::default(),
1323 );
1324 r.expect_err("must be an error");
1325 assert!(s.is_empty());
1326 }
1327
1328 #[cfg(feature = "std")]
1332 macro_rules! map_tests {
1333 (keys: $keys:tt,
1334 vals: $vals:tt) => {
1335 mod hash_map {
1336 map_tests!(@private HashMap, hash_map, $keys, $vals);
1337 }
1338 mod btree_map {
1339 map_tests!(@private BTreeMap, btree_map, $keys, $vals);
1340 }
1341 };
1342
1343 (@private $map_type:ident,
1344 $mod_name:ident,
1345 [$(($key_ty:ty, $key_proto:ident)),*],
1346 $vals:tt) => {
1347 $(
1348 mod $key_proto {
1349 use std::collections::$map_type;
1350
1351 use proptest::prelude::*;
1352
1353 use crate::encoding::*;
1354 use crate::encoding::test::check_collection_type;
1355
1356 map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals);
1357 }
1358 )*
1359 };
1360
1361 (@private $map_type:ident,
1362 $mod_name:ident,
1363 ($key_ty:ty, $key_proto:ident),
1364 [$(($val_ty:ty, $val_proto:ident)),*]) => {
1365 $(
1366 proptest! {
1367 #[test]
1368 fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) {
1369 check_collection_type(values, tag, WireType::LengthDelimited,
1370 |tag, values, buf| {
1371 $mod_name::encode($key_proto::encode,
1372 $key_proto::encoded_len,
1373 $val_proto::encode,
1374 $val_proto::encoded_len,
1375 tag,
1376 values,
1377 buf)
1378 },
1379 |wire_type, values, buf, ctx| {
1380 check_wire_type(WireType::LengthDelimited, wire_type)?;
1381 $mod_name::merge($key_proto::merge,
1382 $val_proto::merge,
1383 values,
1384 buf,
1385 ctx)
1386 },
1387 |tag, values| {
1388 $mod_name::encoded_len($key_proto::encoded_len,
1389 $val_proto::encoded_len,
1390 tag,
1391 values)
1392 })?;
1393 }
1394 }
1395 )*
1396 };
1397 }
1398
1399 #[cfg(feature = "std")]
1400 map_tests!(keys: [
1401 (i32, int32),
1402 (i64, int64),
1403 (u32, uint32),
1404 (u64, uint64),
1405 (i32, sint32),
1406 (i64, sint64),
1407 (u32, fixed32),
1408 (u64, fixed64),
1409 (i32, sfixed32),
1410 (i64, sfixed64),
1411 (bool, bool),
1412 (String, string)
1413 ],
1414 vals: [
1415 (f32, float),
1416 (f64, double),
1417 (i32, int32),
1418 (i64, int64),
1419 (u32, uint32),
1420 (u64, uint64),
1421 (i32, sint32),
1422 (i64, sint64),
1423 (u32, fixed32),
1424 (u64, fixed64),
1425 (i32, sfixed32),
1426 (i64, sfixed64),
1427 (bool, bool),
1428 (String, string),
1429 (Vec<u8>, bytes)
1430 ]);
1431
1432 #[test]
1433 fn split_varint_decoding() {
1437 let mut test_values = Vec::<u64>::with_capacity(10 * 2);
1438 test_values.push(128);
1439 for i in 2..9 {
1440 test_values.push((1 << (7 * i)) - 1);
1441 test_values.push(1 << (7 * i));
1442 }
1443
1444 for v in test_values {
1445 let mut buf = BytesMut::with_capacity(10);
1446 encode_varint(v, &mut buf);
1447 let half_len = buf.len() / 2;
1448 let len = buf.len();
1449 let b2 = buf.split_off(half_len);
1453 let mut c = buf.chain(b2);
1454
1455 assert_eq!(c.remaining(), len);
1457 assert_eq!(c.chunk().len(), half_len);
1459 assert_eq!(v, decode_varint(&mut c).unwrap());
1460 }
1461 }
1462}