serde_html_form/de/
val_or_vec.rs

1use std::{hint::unreachable_unchecked, iter, mem, vec};
2
3use serde::de::{
4    self,
5    value::{Error, SeqDeserializer},
6    Deserializer, IntoDeserializer,
7};
8
9#[derive(Debug)]
10pub(crate) enum ValOrVec<T> {
11    Val(T),
12    Vec(Vec<T>),
13}
14
15impl<T> ValOrVec<T> {
16    pub fn push(&mut self, new_val: T) {
17        match self {
18            Self::Val(_) => {
19                // Change self to a Vec variant and take ownership of the previous value
20                let old_self = mem::replace(self, ValOrVec::Vec(Vec::with_capacity(2)));
21
22                let old_val = match old_self {
23                    Self::Val(v) => v,
24                    // Safety: We would not be in the outer branch otherwise
25                    _ => unsafe { unreachable_unchecked() },
26                };
27
28                let vec = match self {
29                    ValOrVec::Vec(v) => v,
30                    // Safety: We set self to Vec with the mem::replace above
31                    _ => unsafe { unreachable_unchecked() },
32                };
33
34                vec.push(old_val);
35                vec.push(new_val);
36            }
37            Self::Vec(vec) => vec.push(new_val),
38        }
39    }
40
41    fn deserialize_val<U, E, F>(self, f: F) -> Result<U, E>
42    where
43        F: FnOnce(T) -> Result<U, E>,
44        E: de::Error,
45    {
46        match self {
47            ValOrVec::Val(val) => f(val),
48            ValOrVec::Vec(_) => Err(de::Error::custom("unsupported value")),
49        }
50    }
51}
52
53impl<T> IntoIterator for ValOrVec<T> {
54    type Item = T;
55    type IntoIter = IntoIter<T>;
56
57    fn into_iter(self) -> Self::IntoIter {
58        IntoIter::new(self)
59    }
60}
61
62pub enum IntoIter<T> {
63    Val(iter::Once<T>),
64    Vec(vec::IntoIter<T>),
65}
66
67impl<T> IntoIter<T> {
68    fn new(vv: ValOrVec<T>) -> Self {
69        match vv {
70            ValOrVec::Val(val) => IntoIter::Val(iter::once(val)),
71            ValOrVec::Vec(vec) => IntoIter::Vec(vec.into_iter()),
72        }
73    }
74}
75
76impl<T> Iterator for IntoIter<T> {
77    type Item = T;
78
79    fn next(&mut self) -> Option<Self::Item> {
80        match self {
81            IntoIter::Val(iter) => iter.next(),
82            IntoIter::Vec(iter) => iter.next(),
83        }
84    }
85}
86
87impl<'de, T> IntoDeserializer<'de> for ValOrVec<T>
88where
89    T: IntoDeserializer<'de> + Deserializer<'de, Error = Error>,
90{
91    type Deserializer = Self;
92
93    fn into_deserializer(self) -> Self::Deserializer {
94        self
95    }
96}
97
98macro_rules! forward_to_part {
99    ($($method:ident,)*) => {
100        $(
101            fn $method<V>(self, visitor: V) -> Result<V::Value, Self::Error>
102                where V: de::Visitor<'de>
103            {
104                self.deserialize_val(move |val| val.$method(visitor))
105            }
106        )*
107    }
108}
109
110impl<'de, T> Deserializer<'de> for ValOrVec<T>
111where
112    T: IntoDeserializer<'de> + Deserializer<'de, Error = Error>,
113{
114    type Error = Error;
115
116    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
117    where
118        V: de::Visitor<'de>,
119    {
120        match self {
121            Self::Val(val) => val.deserialize_any(visitor),
122            Self::Vec(_) => self.deserialize_seq(visitor),
123        }
124    }
125
126    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
127    where
128        V: de::Visitor<'de>,
129    {
130        visitor.visit_seq(SeqDeserializer::new(self.into_iter()))
131    }
132
133    fn deserialize_enum<V>(
134        self,
135        name: &'static str,
136        variants: &'static [&'static str],
137        visitor: V,
138    ) -> Result<V::Value, Self::Error>
139    where
140        V: de::Visitor<'de>,
141    {
142        self.deserialize_val(move |val| val.deserialize_enum(name, variants, visitor))
143    }
144
145    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
146    where
147        V: de::Visitor<'de>,
148    {
149        self.deserialize_val(move |val| val.deserialize_tuple(len, visitor))
150    }
151
152    fn deserialize_struct<V>(
153        self,
154        name: &'static str,
155        fields: &'static [&'static str],
156        visitor: V,
157    ) -> Result<V::Value, Self::Error>
158    where
159        V: de::Visitor<'de>,
160    {
161        self.deserialize_val(move |val| val.deserialize_struct(name, fields, visitor))
162    }
163
164    fn deserialize_unit_struct<V>(
165        self,
166        name: &'static str,
167        visitor: V,
168    ) -> Result<V::Value, Self::Error>
169    where
170        V: de::Visitor<'de>,
171    {
172        self.deserialize_val(move |val| val.deserialize_unit_struct(name, visitor))
173    }
174
175    fn deserialize_tuple_struct<V>(
176        self,
177        name: &'static str,
178        len: usize,
179        visitor: V,
180    ) -> Result<V::Value, Self::Error>
181    where
182        V: de::Visitor<'de>,
183    {
184        self.deserialize_val(move |val| val.deserialize_tuple_struct(name, len, visitor))
185    }
186
187    fn deserialize_newtype_struct<V>(
188        self,
189        name: &'static str,
190        visitor: V,
191    ) -> Result<V::Value, Self::Error>
192    where
193        V: de::Visitor<'de>,
194    {
195        self.deserialize_val(move |val| val.deserialize_newtype_struct(name, visitor))
196    }
197
198    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
199    where
200        V: de::Visitor<'de>,
201    {
202        visitor.visit_unit()
203    }
204
205    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
206    where
207        V: de::Visitor<'de>,
208    {
209        match self {
210            ValOrVec::Val(val) => val.deserialize_option(visitor),
211            ValOrVec::Vec(_) => visitor.visit_some(self),
212        }
213    }
214
215    forward_to_part! {
216        deserialize_bool,
217        deserialize_char,
218        deserialize_str,
219        deserialize_string,
220        deserialize_bytes,
221        deserialize_byte_buf,
222        deserialize_unit,
223        deserialize_u8,
224        deserialize_u16,
225        deserialize_u32,
226        deserialize_u64,
227        deserialize_i8,
228        deserialize_i16,
229        deserialize_i32,
230        deserialize_i64,
231        deserialize_f32,
232        deserialize_f64,
233        deserialize_identifier,
234        deserialize_map,
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use std::borrow::Cow;
241
242    use assert_matches2::assert_matches;
243
244    use super::ValOrVec;
245
246    #[test]
247    fn cow_borrowed() {
248        let mut x = ValOrVec::Val(Cow::Borrowed("a"));
249        x.push(Cow::Borrowed("b"));
250        x.push(Cow::Borrowed("c"));
251        assert_matches!(x, ValOrVec::Vec(v));
252        assert_eq!(v, vec!["a", "b", "c"]);
253    }
254
255    #[test]
256    fn cow_owned() {
257        let mut x = ValOrVec::Val(Cow::from("a".to_owned()));
258        x.push(Cow::from("b".to_owned()));
259        x.push(Cow::from("c".to_owned()));
260        assert_matches!(x, ValOrVec::Vec(v));
261        assert_eq!(v, vec!["a".to_owned(), "b".to_owned(), "c".to_owned()]);
262    }
263}