actix_web_lab/
panic_reporter.rs

1//! Panic reporter middleware.
2//!
3//! See [`PanicReporter`] for docs.
4
5use std::{
6    any::Any,
7    future::{ready, Ready},
8    panic::{self, AssertUnwindSafe},
9    rc::Rc,
10};
11
12use actix_web::dev::{forward_ready, Service, Transform};
13use futures_core::future::LocalBoxFuture;
14use futures_util::FutureExt as _;
15
16type PanicCallback = Rc<dyn Fn(&(dyn Any + Send))>;
17
18/// A middleware that triggers a callback when the worker is panicking.
19///
20/// Mostly useful for logging or metrics publishing. The callback received the object with which
21/// panic was originally invoked to allow down-casting.
22///
23/// # Examples
24///
25/// ```no_run
26/// # use actix_web::App;
27/// use actix_web_lab::middleware::PanicReporter;
28/// # mod metrics {
29/// #   macro_rules! increment_counter {
30/// #       ($tt:tt) => {{}};
31/// #   }
32/// #   pub(crate) use increment_counter;
33/// # }
34///
35/// App::new().wrap(PanicReporter::new(|_| metrics::increment_counter!("panic")))
36///     # ;
37/// ```
38#[derive(Clone)]
39pub struct PanicReporter {
40    cb: PanicCallback,
41}
42
43impl PanicReporter {
44    /// Constructs new panic reporter middleware with `callback`.
45    pub fn new(callback: impl Fn(&(dyn Any + Send)) + 'static) -> Self {
46        Self {
47            cb: Rc::new(callback),
48        }
49    }
50}
51
52impl std::fmt::Debug for PanicReporter {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("PanicReporter")
55            .field("cb", &"<callback>")
56            .finish()
57    }
58}
59
60impl<S, Req> Transform<S, Req> for PanicReporter
61where
62    S: Service<Req>,
63    S::Future: 'static,
64{
65    type Response = S::Response;
66    type Error = S::Error;
67    type Transform = PanicReporterMiddleware<S>;
68    type InitError = ();
69    type Future = Ready<Result<Self::Transform, Self::InitError>>;
70
71    fn new_transform(&self, service: S) -> Self::Future {
72        ready(Ok(PanicReporterMiddleware {
73            service: Rc::new(service),
74            cb: Rc::clone(&self.cb),
75        }))
76    }
77}
78
79pub struct PanicReporterMiddleware<S> {
80    service: Rc<S>,
81    cb: PanicCallback,
82}
83
84impl<S, Req> Service<Req> for PanicReporterMiddleware<S>
85where
86    S: Service<Req>,
87    S::Future: 'static,
88{
89    type Response = S::Response;
90    type Error = S::Error;
91    type Future = LocalBoxFuture<'static, Result<S::Response, S::Error>>;
92
93    forward_ready!(service);
94
95    fn call(&self, req: Req) -> Self::Future {
96        let cb = Rc::clone(&self.cb);
97
98        // catch panics in service call
99        AssertUnwindSafe(self.service.call(req))
100            .catch_unwind()
101            .map(move |maybe_res| match maybe_res {
102                Ok(res) => res,
103                Err(panic_err) => {
104                    // invoke callback with panic arg
105                    (cb)(&panic_err);
106
107                    // continue unwinding
108                    panic::resume_unwind(panic_err)
109                }
110            })
111            .boxed_local()
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use std::sync::{
118        atomic::{AtomicBool, Ordering},
119        Arc,
120    };
121
122    use actix_web::{
123        dev::Service as _,
124        test,
125        web::{self, ServiceConfig},
126        App,
127    };
128
129    use super::*;
130
131    fn configure_test_app(cfg: &mut ServiceConfig) {
132        cfg.route("/", web::get().to(|| async { "content" })).route(
133            "/disco",
134            #[allow(unreachable_code)]
135            web::get().to(|| async {
136                panic!("the disco");
137                ""
138            }),
139        );
140    }
141
142    #[actix_web::test]
143    async fn report_when_panics_occur() {
144        let triggered = Arc::new(AtomicBool::new(false));
145
146        let app = App::new()
147            .wrap(PanicReporter::new({
148                let triggered = Arc::clone(&triggered);
149                move |_| {
150                    triggered.store(true, Ordering::SeqCst);
151                }
152            }))
153            .configure(configure_test_app);
154
155        let app = test::init_service(app).await;
156
157        let req = test::TestRequest::with_uri("/").to_request();
158        assert!(app.call(req).await.is_ok());
159        assert!(!triggered.load(Ordering::SeqCst));
160
161        let req = test::TestRequest::with_uri("/disco").to_request();
162        assert!(AssertUnwindSafe(app.call(req))
163            .catch_unwind()
164            .await
165            .is_err());
166        assert!(triggered.load(Ordering::SeqCst));
167    }
168}