1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
macro_rules! downcast_get_type_id {
    () => {
        /// A helper method to get the type ID of the type
        /// this trait is implemented on.
        /// This method is unsafe to *implement*, since `downcast_ref` relies
        /// on the returned `TypeId` to perform a cast.
        ///
        /// Unfortunately, Rust has no notion of a trait method that is
        /// unsafe to implement (marking it as `unsafe` makes it unsafe
        /// to *call*). As a workaround, we require this method
        /// to return a private type along with the `TypeId`. This
        /// private type (`PrivateHelper`) has a private constructor,
        /// making it impossible for safe code to construct outside of
        /// this module. This ensures that safe code cannot violate
        /// type-safety by implementing this method.
        ///
        /// We also take `PrivateHelper` as a parameter, to ensure that
        /// safe code cannot obtain a `PrivateHelper` instance by
        /// delegating to an existing implementation of `__private_get_type_id__`
        #[doc(hidden)]
        #[allow(dead_code)]
        fn __private_get_type_id__(&self, _: PrivateHelper) -> (std::any::TypeId, PrivateHelper)
        where
            Self: 'static,
        {
            (std::any::TypeId::of::<Self>(), PrivateHelper(()))
        }
    };
}

// Generate implementation for dyn $name
macro_rules! downcast_dyn {
    ($name:ident) => {
        /// A struct with a private constructor, for use with
        /// `__private_get_type_id__`. Its single field is private,
        /// ensuring that it can only be constructed from this module
        #[doc(hidden)]
        #[allow(dead_code)]
        pub struct PrivateHelper(());

        impl dyn $name + 'static {
            /// Downcasts generic body to a specific type.
            #[allow(dead_code)]
            pub fn downcast_ref<T: $name + 'static>(&self) -> Option<&T> {
                if self.__private_get_type_id__(PrivateHelper(())).0
                    == std::any::TypeId::of::<T>()
                {
                    // SAFETY: external crates cannot override the default
                    // implementation of `__private_get_type_id__`, since
                    // it requires returning a private type. We can therefore
                    // rely on the returned `TypeId`, which ensures that this
                    // case is correct.
                    unsafe { Some(&*(self as *const dyn $name as *const T)) }
                } else {
                    None
                }
            }

            /// Downcasts a generic body to a mutable specific type.
            #[allow(dead_code)]
            pub fn downcast_mut<T: $name + 'static>(&mut self) -> Option<&mut T> {
                if self.__private_get_type_id__(PrivateHelper(())).0
                    == std::any::TypeId::of::<T>()
                {
                    // SAFETY: external crates cannot override the default
                    // implementation of `__private_get_type_id__`, since
                    // it requires returning a private type. We can therefore
                    // rely on the returned `TypeId`, which ensures that this
                    // case is correct.
                    unsafe { Some(&mut *(self as *const dyn $name as *const T as *mut T)) }
                } else {
                    None
                }
            }
        }
    };
}

pub(crate) use {downcast_dyn, downcast_get_type_id};

#[cfg(test)]
mod tests {
    #![allow(clippy::upper_case_acronyms)]

    trait MB {
        downcast_get_type_id!();
    }

    downcast_dyn!(MB);

    impl MB for String {}
    impl MB for () {}

    #[actix_rt::test]
    async fn test_any_casting() {
        let mut body = String::from("hello cast");
        let resp_body: &mut dyn MB = &mut body;
        let body = resp_body.downcast_ref::<String>().unwrap();
        assert_eq!(body, "hello cast");
        let body = resp_body.downcast_mut::<String>().unwrap();
        body.push('!');
        let body = resp_body.downcast_ref::<String>().unwrap();
        assert_eq!(body, "hello cast!");
        let not_body = resp_body.downcast_ref::<()>();
        assert!(not_body.is_none());
    }
}