actix_web_lab/
err_handler.rs

1//! For middleware documentation, see [`ErrorHandlers`].
2
3use std::{
4    future::Future,
5    pin::Pin,
6    rc::Rc,
7    task::{ready, Context, Poll},
8};
9
10use actix_service::{Service, Transform};
11use actix_web::{
12    body::EitherBody,
13    dev::{ServiceRequest, ServiceResponse},
14    http::StatusCode,
15    Error, Result,
16};
17use ahash::AHashMap;
18use futures_core::future::LocalBoxFuture;
19use pin_project_lite::pin_project;
20
21type ErrorHandlerRes<B> = Result<ServiceResponse<EitherBody<B>>>;
22type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> LocalBoxFuture<'static, ErrorHandlerRes<B>>;
23type Handlers<B> = Rc<AHashMap<StatusCode, Box<ErrorHandler<B>>>>;
24
25/// Middleware for registering custom status code based error handlers.
26///
27/// Register handlers with the `ErrorHandlers::handler()` method to register a custom error handler
28/// for a given status code. Handlers can modify existing responses or create completely new ones.
29///
30/// # Examples
31/// ```
32/// use actix_web::{
33///     body::EitherBody,
34///     dev::ServiceResponse,
35///     http::{header, StatusCode},
36///     web, App, HttpResponse, Result,
37/// };
38/// use actix_web_lab::middleware::ErrorHandlers;
39///
40/// async fn add_error_header<B>(
41///     mut res: ServiceResponse<B>,
42/// ) -> Result<ServiceResponse<EitherBody<B>>> {
43///     res.response_mut().headers_mut().insert(
44///         header::CONTENT_TYPE,
45///         header::HeaderValue::from_static("Error"),
46///     );
47///     Ok(res.map_into_left_body())
48/// }
49///
50/// let app = App::new()
51///     .wrap(ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, add_error_header))
52///     .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError)));
53/// ```
54pub struct ErrorHandlers<B> {
55    handlers: Handlers<B>,
56}
57
58impl<B> Default for ErrorHandlers<B> {
59    fn default() -> Self {
60        ErrorHandlers {
61            handlers: Default::default(),
62        }
63    }
64}
65
66impl<B> ErrorHandlers<B> {
67    /// Construct new `ErrorHandlers` instance.
68    pub fn new() -> Self {
69        ErrorHandlers::default()
70    }
71
72    /// Register error handler for specified status code.
73    pub fn handler<F, Fut>(mut self, status: StatusCode, handler: F) -> Self
74    where
75        F: Fn(ServiceResponse<B>) -> Fut + 'static,
76        Fut: Future<Output = ErrorHandlerRes<B>> + 'static,
77    {
78        Rc::get_mut(&mut self.handlers)
79            .unwrap()
80            .insert(status, Box::new(move |res| Box::pin((handler)(res))));
81        self
82    }
83}
84
85impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B>
86where
87    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
88    S::Future: 'static,
89    B: 'static,
90{
91    type Response = ServiceResponse<EitherBody<B>>;
92    type Error = Error;
93    type Transform = ErrorHandlersMiddleware<S, B>;
94    type InitError = ();
95    type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
96
97    fn new_transform(&self, service: S) -> Self::Future {
98        let handlers = self.handlers.clone();
99        Box::pin(async move { Ok(ErrorHandlersMiddleware { service, handlers }) })
100    }
101}
102
103#[doc(hidden)]
104pub struct ErrorHandlersMiddleware<S, B> {
105    service: S,
106    handlers: Handlers<B>,
107}
108
109impl<S, B> Service<ServiceRequest> for ErrorHandlersMiddleware<S, B>
110where
111    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
112    S::Future: 'static,
113    B: 'static,
114{
115    type Response = ServiceResponse<EitherBody<B>>;
116    type Error = Error;
117    type Future = ErrorHandlersFuture<S::Future, B>;
118
119    actix_service::forward_ready!(service);
120
121    fn call(&self, req: ServiceRequest) -> Self::Future {
122        let handlers = self.handlers.clone();
123        let fut = self.service.call(req);
124        ErrorHandlersFuture::ServiceFuture { fut, handlers }
125    }
126}
127
128pin_project! {
129    #[project = ErrorHandlersProj]
130    pub enum ErrorHandlersFuture<Fut, B>
131    where
132        Fut: Future,
133    {
134        ServiceFuture {
135            #[pin]
136            fut: Fut,
137            handlers: Handlers<B>,
138        },
139        ErrorHandlerFuture {
140            fut: LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>,
141        },
142    }
143}
144
145impl<Fut, B> Future for ErrorHandlersFuture<Fut, B>
146where
147    Fut: Future<Output = Result<ServiceResponse<B>, Error>>,
148{
149    type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
150
151    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152        match self.as_mut().project() {
153            ErrorHandlersProj::ServiceFuture { fut, handlers } => {
154                let res = ready!(fut.poll(cx))?;
155
156                match handlers.get(&res.status()) {
157                    Some(handler) => {
158                        let fut = handler(res);
159
160                        self.as_mut()
161                            .set(ErrorHandlersFuture::ErrorHandlerFuture { fut });
162
163                        self.poll(cx)
164                    }
165
166                    None => Poll::Ready(Ok(res.map_into_left_body())),
167                }
168            }
169
170            ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx),
171        }
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use actix_service::IntoService;
178    use actix_web::{
179        body,
180        http::{
181            header::{HeaderValue, CONTENT_TYPE},
182            StatusCode,
183        },
184        test::{self, TestRequest},
185    };
186    use bytes::Bytes;
187
188    use super::*;
189
190    #[actix_web::test]
191    async fn add_header_error_handler() {
192        #[allow(clippy::unnecessary_wraps)]
193        async fn error_handler<B>(
194            mut res: ServiceResponse<B>,
195        ) -> Result<ServiceResponse<EitherBody<B>>> {
196            res.response_mut()
197                .headers_mut()
198                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
199
200            Ok(res.map_into_left_body())
201        }
202
203        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
204
205        let mw = ErrorHandlers::new()
206            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
207            .new_transform(srv.into_service())
208            .await
209            .unwrap();
210
211        let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
212        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
213    }
214
215    #[actix_web::test]
216    async fn add_header_error_handler_async() {
217        #[allow(clippy::unnecessary_wraps)]
218        async fn error_handler<B: 'static>(
219            mut res: ServiceResponse<B>,
220        ) -> Result<ServiceResponse<EitherBody<B>>> {
221            res.response_mut()
222                .headers_mut()
223                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
224
225            Ok(res.map_into_left_body())
226        }
227
228        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
229
230        let mw = ErrorHandlers::new()
231            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
232            .new_transform(srv.into_service())
233            .await
234            .unwrap();
235
236        let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
237        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
238    }
239
240    #[actix_web::test]
241    async fn changes_body_type() {
242        #[allow(clippy::unnecessary_wraps)]
243        async fn error_handler<B>(
244            res: ServiceResponse<B>,
245        ) -> Result<ServiceResponse<EitherBody<B>>> {
246            let (req, res) = res.into_parts();
247            let res = res.set_body(Bytes::from("sorry, that's no bueno"));
248
249            let res = ServiceResponse::new(req, res)
250                .map_into_boxed_body()
251                .map_into_right_body();
252
253            Ok(res)
254        }
255
256        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
257
258        let mw = ErrorHandlers::new()
259            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
260            .new_transform(srv.into_service())
261            .await
262            .unwrap();
263
264        let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
265        assert_eq!(test::read_body(res).await, "sorry, that's no bueno");
266    }
267
268    #[actix_web::test]
269    async fn error_thrown() {
270        #[allow(clippy::unnecessary_wraps)]
271        async fn error_handler<B>(
272            _res: ServiceResponse<B>,
273        ) -> Result<ServiceResponse<EitherBody<B>>> {
274            Err(actix_web::error::ErrorInternalServerError(
275                "error in error handler",
276            ))
277        }
278
279        let srv = test::status_service(StatusCode::BAD_REQUEST);
280
281        let mw = ErrorHandlers::new()
282            .handler(StatusCode::BAD_REQUEST, error_handler)
283            .new_transform(srv.into_service())
284            .await
285            .unwrap();
286
287        let err = mw
288            .call(TestRequest::default().to_srv_request())
289            .await
290            .unwrap_err();
291        let res = err.error_response();
292
293        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
294        assert_eq!(
295            body::to_bytes(res.into_body()).await.unwrap(),
296            "error in error handler"
297        );
298    }
299}