actix_http/
extensions.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    fmt,
5    hash::{BuildHasherDefault, Hasher},
6};
7
8/// A hasher for `TypeId`s that takes advantage of its known characteristics.
9///
10/// Author of `anymap` crate has done research on the topic:
11/// https://github.com/chris-morgan/anymap/blob/2e9a5704/src/lib.rs#L599
12#[derive(Debug, Default)]
13struct NoOpHasher(u64);
14
15impl Hasher for NoOpHasher {
16    fn write(&mut self, _bytes: &[u8]) {
17        unimplemented!("This NoOpHasher can only handle u64s")
18    }
19
20    fn write_u64(&mut self, i: u64) {
21        self.0 = i;
22    }
23
24    fn finish(&self) -> u64 {
25        self.0
26    }
27}
28
29/// A type map for request extensions.
30///
31/// All entries into this map must be owned types (or static references).
32#[derive(Default)]
33pub struct Extensions {
34    /// Use AHasher with a std HashMap with for faster lookups on the small `TypeId` keys.
35    map: HashMap<TypeId, Box<dyn Any>, BuildHasherDefault<NoOpHasher>>,
36}
37
38impl Extensions {
39    /// Creates an empty `Extensions`.
40    #[inline]
41    pub fn new() -> Extensions {
42        Extensions {
43            map: HashMap::default(),
44        }
45    }
46
47    /// Insert an item into the map.
48    ///
49    /// If an item of this type was already stored, it will be replaced and returned.
50    ///
51    /// ```
52    /// # use actix_http::Extensions;
53    /// let mut map = Extensions::new();
54    /// assert_eq!(map.insert(""), None);
55    /// assert_eq!(map.insert(1u32), None);
56    /// assert_eq!(map.insert(2u32), Some(1u32));
57    /// assert_eq!(*map.get::<u32>().unwrap(), 2u32);
58    /// ```
59    pub fn insert<T: 'static>(&mut self, val: T) -> Option<T> {
60        self.map
61            .insert(TypeId::of::<T>(), Box::new(val))
62            .and_then(downcast_owned)
63    }
64
65    /// Check if map contains an item of a given type.
66    ///
67    /// ```
68    /// # use actix_http::Extensions;
69    /// let mut map = Extensions::new();
70    /// assert!(!map.contains::<u32>());
71    ///
72    /// assert_eq!(map.insert(1u32), None);
73    /// assert!(map.contains::<u32>());
74    /// ```
75    pub fn contains<T: 'static>(&self) -> bool {
76        self.map.contains_key(&TypeId::of::<T>())
77    }
78
79    /// Get a reference to an item of a given type.
80    ///
81    /// ```
82    /// # use actix_http::Extensions;
83    /// let mut map = Extensions::new();
84    /// map.insert(1u32);
85    /// assert_eq!(map.get::<u32>(), Some(&1u32));
86    /// ```
87    pub fn get<T: 'static>(&self) -> Option<&T> {
88        self.map
89            .get(&TypeId::of::<T>())
90            .and_then(|boxed| boxed.downcast_ref())
91    }
92
93    /// Get a mutable reference to an item of a given type.
94    ///
95    /// ```
96    /// # use actix_http::Extensions;
97    /// let mut map = Extensions::new();
98    /// map.insert(1u32);
99    /// assert_eq!(map.get_mut::<u32>(), Some(&mut 1u32));
100    /// ```
101    pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
102        self.map
103            .get_mut(&TypeId::of::<T>())
104            .and_then(|boxed| boxed.downcast_mut())
105    }
106
107    /// Remove an item from the map of a given type.
108    ///
109    /// If an item of this type was already stored, it will be returned.
110    ///
111    /// ```
112    /// # use actix_http::Extensions;
113    /// let mut map = Extensions::new();
114    ///
115    /// map.insert(1u32);
116    /// assert_eq!(map.get::<u32>(), Some(&1u32));
117    ///
118    /// assert_eq!(map.remove::<u32>(), Some(1u32));
119    /// assert!(!map.contains::<u32>());
120    /// ```
121    pub fn remove<T: 'static>(&mut self) -> Option<T> {
122        self.map.remove(&TypeId::of::<T>()).and_then(downcast_owned)
123    }
124
125    /// Clear the `Extensions` of all inserted extensions.
126    ///
127    /// ```
128    /// # use actix_http::Extensions;
129    /// let mut map = Extensions::new();
130    ///
131    /// map.insert(1u32);
132    /// assert!(map.contains::<u32>());
133    ///
134    /// map.clear();
135    /// assert!(!map.contains::<u32>());
136    /// ```
137    #[inline]
138    pub fn clear(&mut self) {
139        self.map.clear();
140    }
141
142    /// Extends self with the items from another `Extensions`.
143    pub fn extend(&mut self, other: Extensions) {
144        self.map.extend(other.map);
145    }
146}
147
148impl fmt::Debug for Extensions {
149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150        f.debug_struct("Extensions").finish()
151    }
152}
153
154fn downcast_owned<T: 'static>(boxed: Box<dyn Any>) -> Option<T> {
155    boxed.downcast().ok().map(|boxed| *boxed)
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn test_remove() {
164        let mut map = Extensions::new();
165
166        map.insert::<i8>(123);
167        assert!(map.get::<i8>().is_some());
168
169        map.remove::<i8>();
170        assert!(map.get::<i8>().is_none());
171    }
172
173    #[test]
174    fn test_clear() {
175        let mut map = Extensions::new();
176
177        map.insert::<i8>(8);
178        map.insert::<i16>(16);
179        map.insert::<i32>(32);
180
181        assert!(map.contains::<i8>());
182        assert!(map.contains::<i16>());
183        assert!(map.contains::<i32>());
184
185        map.clear();
186
187        assert!(!map.contains::<i8>());
188        assert!(!map.contains::<i16>());
189        assert!(!map.contains::<i32>());
190
191        map.insert::<i8>(10);
192        assert_eq!(*map.get::<i8>().unwrap(), 10);
193    }
194
195    #[test]
196    fn test_integers() {
197        static A: u32 = 8;
198
199        let mut map = Extensions::new();
200
201        map.insert::<i8>(8);
202        map.insert::<i16>(16);
203        map.insert::<i32>(32);
204        map.insert::<i64>(64);
205        map.insert::<i128>(128);
206        map.insert::<u8>(8);
207        map.insert::<u16>(16);
208        map.insert::<u32>(32);
209        map.insert::<u64>(64);
210        map.insert::<u128>(128);
211        map.insert::<&'static u32>(&A);
212        assert!(map.get::<i8>().is_some());
213        assert!(map.get::<i16>().is_some());
214        assert!(map.get::<i32>().is_some());
215        assert!(map.get::<i64>().is_some());
216        assert!(map.get::<i128>().is_some());
217        assert!(map.get::<u8>().is_some());
218        assert!(map.get::<u16>().is_some());
219        assert!(map.get::<u32>().is_some());
220        assert!(map.get::<u64>().is_some());
221        assert!(map.get::<u128>().is_some());
222        assert!(map.get::<&'static u32>().is_some());
223    }
224
225    #[test]
226    fn test_composition() {
227        struct Magi<T>(pub T);
228
229        struct Madoka {
230            pub god: bool,
231        }
232
233        struct Homura {
234            pub attempts: usize,
235        }
236
237        struct Mami {
238            pub guns: usize,
239        }
240
241        let mut map = Extensions::new();
242
243        map.insert(Magi(Madoka { god: false }));
244        map.insert(Magi(Homura { attempts: 0 }));
245        map.insert(Magi(Mami { guns: 999 }));
246
247        assert!(!map.get::<Magi<Madoka>>().unwrap().0.god);
248        assert_eq!(0, map.get::<Magi<Homura>>().unwrap().0.attempts);
249        assert_eq!(999, map.get::<Magi<Mami>>().unwrap().0.guns);
250    }
251
252    #[test]
253    fn test_extensions() {
254        #[derive(Debug, PartialEq)]
255        struct MyType(i32);
256
257        let mut extensions = Extensions::new();
258
259        extensions.insert(5i32);
260        extensions.insert(MyType(10));
261
262        assert_eq!(extensions.get(), Some(&5i32));
263        assert_eq!(extensions.get_mut(), Some(&mut 5i32));
264
265        assert_eq!(extensions.remove::<i32>(), Some(5i32));
266        assert!(extensions.get::<i32>().is_none());
267
268        assert_eq!(extensions.get::<bool>(), None);
269        assert_eq!(extensions.get(), Some(&MyType(10)));
270    }
271
272    #[test]
273    fn test_extend() {
274        #[derive(Debug, PartialEq)]
275        struct MyType(i32);
276
277        let mut extensions = Extensions::new();
278
279        extensions.insert(5i32);
280        extensions.insert(MyType(10));
281
282        let mut other = Extensions::new();
283
284        other.insert(15i32);
285        other.insert(20u8);
286
287        extensions.extend(other);
288
289        assert_eq!(extensions.get(), Some(&15i32));
290        assert_eq!(extensions.get_mut(), Some(&mut 15i32));
291
292        assert_eq!(extensions.remove::<i32>(), Some(15i32));
293        assert!(extensions.get::<i32>().is_none());
294
295        assert_eq!(extensions.get::<bool>(), None);
296        assert_eq!(extensions.get(), Some(&MyType(10)));
297
298        assert_eq!(extensions.get(), Some(&20u8));
299        assert_eq!(extensions.get_mut(), Some(&mut 20u8));
300    }
301}