1use std::{hint::unreachable_unchecked, iter, mem, vec};
2
3use serde::de::{
4 self,
5 value::{Error, SeqDeserializer},
6 Deserializer, IntoDeserializer,
7};
8
9#[derive(Debug)]
10pub(crate) enum ValOrVec<T> {
11 Val(T),
12 Vec(Vec<T>),
13}
14
15impl<T> ValOrVec<T> {
16 pub fn push(&mut self, new_val: T) {
17 match self {
18 Self::Val(_) => {
19 let old_self = mem::replace(self, ValOrVec::Vec(Vec::with_capacity(2)));
21
22 let old_val = match old_self {
23 Self::Val(v) => v,
24 _ => unsafe { unreachable_unchecked() },
26 };
27
28 let vec = match self {
29 ValOrVec::Vec(v) => v,
30 _ => unsafe { unreachable_unchecked() },
32 };
33
34 vec.push(old_val);
35 vec.push(new_val);
36 }
37 Self::Vec(vec) => vec.push(new_val),
38 }
39 }
40
41 fn deserialize_val<U, E, F>(self, f: F) -> Result<U, E>
42 where
43 F: FnOnce(T) -> Result<U, E>,
44 E: de::Error,
45 {
46 match self {
47 ValOrVec::Val(val) => f(val),
48 ValOrVec::Vec(_) => Err(de::Error::custom("unsupported value")),
49 }
50 }
51}
52
53impl<T> IntoIterator for ValOrVec<T> {
54 type Item = T;
55 type IntoIter = IntoIter<T>;
56
57 fn into_iter(self) -> Self::IntoIter {
58 IntoIter::new(self)
59 }
60}
61
62pub enum IntoIter<T> {
63 Val(iter::Once<T>),
64 Vec(vec::IntoIter<T>),
65}
66
67impl<T> IntoIter<T> {
68 fn new(vv: ValOrVec<T>) -> Self {
69 match vv {
70 ValOrVec::Val(val) => IntoIter::Val(iter::once(val)),
71 ValOrVec::Vec(vec) => IntoIter::Vec(vec.into_iter()),
72 }
73 }
74}
75
76impl<T> Iterator for IntoIter<T> {
77 type Item = T;
78
79 fn next(&mut self) -> Option<Self::Item> {
80 match self {
81 IntoIter::Val(iter) => iter.next(),
82 IntoIter::Vec(iter) => iter.next(),
83 }
84 }
85}
86
87impl<'de, T> IntoDeserializer<'de> for ValOrVec<T>
88where
89 T: IntoDeserializer<'de> + Deserializer<'de, Error = Error>,
90{
91 type Deserializer = Self;
92
93 fn into_deserializer(self) -> Self::Deserializer {
94 self
95 }
96}
97
98macro_rules! forward_to_part {
99 ($($method:ident,)*) => {
100 $(
101 fn $method<V>(self, visitor: V) -> Result<V::Value, Self::Error>
102 where V: de::Visitor<'de>
103 {
104 self.deserialize_val(move |val| val.$method(visitor))
105 }
106 )*
107 }
108}
109
110impl<'de, T> Deserializer<'de> for ValOrVec<T>
111where
112 T: IntoDeserializer<'de> + Deserializer<'de, Error = Error>,
113{
114 type Error = Error;
115
116 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
117 where
118 V: de::Visitor<'de>,
119 {
120 match self {
121 Self::Val(val) => val.deserialize_any(visitor),
122 Self::Vec(_) => self.deserialize_seq(visitor),
123 }
124 }
125
126 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
127 where
128 V: de::Visitor<'de>,
129 {
130 visitor.visit_seq(SeqDeserializer::new(self.into_iter()))
131 }
132
133 fn deserialize_enum<V>(
134 self,
135 name: &'static str,
136 variants: &'static [&'static str],
137 visitor: V,
138 ) -> Result<V::Value, Self::Error>
139 where
140 V: de::Visitor<'de>,
141 {
142 self.deserialize_val(move |val| val.deserialize_enum(name, variants, visitor))
143 }
144
145 fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
146 where
147 V: de::Visitor<'de>,
148 {
149 self.deserialize_val(move |val| val.deserialize_tuple(len, visitor))
150 }
151
152 fn deserialize_struct<V>(
153 self,
154 name: &'static str,
155 fields: &'static [&'static str],
156 visitor: V,
157 ) -> Result<V::Value, Self::Error>
158 where
159 V: de::Visitor<'de>,
160 {
161 self.deserialize_val(move |val| val.deserialize_struct(name, fields, visitor))
162 }
163
164 fn deserialize_unit_struct<V>(
165 self,
166 name: &'static str,
167 visitor: V,
168 ) -> Result<V::Value, Self::Error>
169 where
170 V: de::Visitor<'de>,
171 {
172 self.deserialize_val(move |val| val.deserialize_unit_struct(name, visitor))
173 }
174
175 fn deserialize_tuple_struct<V>(
176 self,
177 name: &'static str,
178 len: usize,
179 visitor: V,
180 ) -> Result<V::Value, Self::Error>
181 where
182 V: de::Visitor<'de>,
183 {
184 self.deserialize_val(move |val| val.deserialize_tuple_struct(name, len, visitor))
185 }
186
187 fn deserialize_newtype_struct<V>(
188 self,
189 name: &'static str,
190 visitor: V,
191 ) -> Result<V::Value, Self::Error>
192 where
193 V: de::Visitor<'de>,
194 {
195 self.deserialize_val(move |val| val.deserialize_newtype_struct(name, visitor))
196 }
197
198 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
199 where
200 V: de::Visitor<'de>,
201 {
202 visitor.visit_unit()
203 }
204
205 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
206 where
207 V: de::Visitor<'de>,
208 {
209 match self {
210 ValOrVec::Val(val) => val.deserialize_option(visitor),
211 ValOrVec::Vec(_) => visitor.visit_some(self),
212 }
213 }
214
215 forward_to_part! {
216 deserialize_bool,
217 deserialize_char,
218 deserialize_str,
219 deserialize_string,
220 deserialize_bytes,
221 deserialize_byte_buf,
222 deserialize_unit,
223 deserialize_u8,
224 deserialize_u16,
225 deserialize_u32,
226 deserialize_u64,
227 deserialize_i8,
228 deserialize_i16,
229 deserialize_i32,
230 deserialize_i64,
231 deserialize_f32,
232 deserialize_f64,
233 deserialize_identifier,
234 deserialize_map,
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use std::borrow::Cow;
241
242 use assert_matches2::assert_matches;
243
244 use super::ValOrVec;
245
246 #[test]
247 fn cow_borrowed() {
248 let mut x = ValOrVec::Val(Cow::Borrowed("a"));
249 x.push(Cow::Borrowed("b"));
250 x.push(Cow::Borrowed("c"));
251 assert_matches!(x, ValOrVec::Vec(v));
252 assert_eq!(v, vec!["a", "b", "c"]);
253 }
254
255 #[test]
256 fn cow_owned() {
257 let mut x = ValOrVec::Val(Cow::from("a".to_owned()));
258 x.push(Cow::from("b".to_owned()));
259 x.push(Cow::from("c".to_owned()));
260 assert_matches!(x, ValOrVec::Vec(v));
261 assert_eq!(v, vec!["a".to_owned(), "b".to_owned(), "c".to_owned()]);
262 }
263}