1use std::{
4 future::Future,
5 pin::Pin,
6 rc::Rc,
7 task::{ready, Context, Poll},
8};
9
10use actix_service::{Service, Transform};
11use actix_web::{
12 body::EitherBody,
13 dev::{ServiceRequest, ServiceResponse},
14 http::StatusCode,
15 Error, Result,
16};
17use ahash::AHashMap;
18use futures_core::future::LocalBoxFuture;
19use pin_project_lite::pin_project;
20
21type ErrorHandlerRes<B> = Result<ServiceResponse<EitherBody<B>>>;
22type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> LocalBoxFuture<'static, ErrorHandlerRes<B>>;
23type Handlers<B> = Rc<AHashMap<StatusCode, Box<ErrorHandler<B>>>>;
24
25pub struct ErrorHandlers<B> {
55 handlers: Handlers<B>,
56}
57
58impl<B> Default for ErrorHandlers<B> {
59 fn default() -> Self {
60 ErrorHandlers {
61 handlers: Default::default(),
62 }
63 }
64}
65
66impl<B> ErrorHandlers<B> {
67 pub fn new() -> Self {
69 ErrorHandlers::default()
70 }
71
72 pub fn handler<F, Fut>(mut self, status: StatusCode, handler: F) -> Self
74 where
75 F: Fn(ServiceResponse<B>) -> Fut + 'static,
76 Fut: Future<Output = ErrorHandlerRes<B>> + 'static,
77 {
78 Rc::get_mut(&mut self.handlers)
79 .unwrap()
80 .insert(status, Box::new(move |res| Box::pin((handler)(res))));
81 self
82 }
83}
84
85impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B>
86where
87 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
88 S::Future: 'static,
89 B: 'static,
90{
91 type Response = ServiceResponse<EitherBody<B>>;
92 type Error = Error;
93 type Transform = ErrorHandlersMiddleware<S, B>;
94 type InitError = ();
95 type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
96
97 fn new_transform(&self, service: S) -> Self::Future {
98 let handlers = self.handlers.clone();
99 Box::pin(async move { Ok(ErrorHandlersMiddleware { service, handlers }) })
100 }
101}
102
103#[doc(hidden)]
104pub struct ErrorHandlersMiddleware<S, B> {
105 service: S,
106 handlers: Handlers<B>,
107}
108
109impl<S, B> Service<ServiceRequest> for ErrorHandlersMiddleware<S, B>
110where
111 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
112 S::Future: 'static,
113 B: 'static,
114{
115 type Response = ServiceResponse<EitherBody<B>>;
116 type Error = Error;
117 type Future = ErrorHandlersFuture<S::Future, B>;
118
119 actix_service::forward_ready!(service);
120
121 fn call(&self, req: ServiceRequest) -> Self::Future {
122 let handlers = self.handlers.clone();
123 let fut = self.service.call(req);
124 ErrorHandlersFuture::ServiceFuture { fut, handlers }
125 }
126}
127
128pin_project! {
129 #[project = ErrorHandlersProj]
130 pub enum ErrorHandlersFuture<Fut, B>
131 where
132 Fut: Future,
133 {
134 ServiceFuture {
135 #[pin]
136 fut: Fut,
137 handlers: Handlers<B>,
138 },
139 ErrorHandlerFuture {
140 fut: LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>,
141 },
142 }
143}
144
145impl<Fut, B> Future for ErrorHandlersFuture<Fut, B>
146where
147 Fut: Future<Output = Result<ServiceResponse<B>, Error>>,
148{
149 type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
150
151 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
152 match self.as_mut().project() {
153 ErrorHandlersProj::ServiceFuture { fut, handlers } => {
154 let res = ready!(fut.poll(cx))?;
155
156 match handlers.get(&res.status()) {
157 Some(handler) => {
158 let fut = handler(res);
159
160 self.as_mut()
161 .set(ErrorHandlersFuture::ErrorHandlerFuture { fut });
162
163 self.poll(cx)
164 }
165
166 None => Poll::Ready(Ok(res.map_into_left_body())),
167 }
168 }
169
170 ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx),
171 }
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use actix_service::IntoService;
178 use actix_web::{
179 body,
180 http::{
181 header::{HeaderValue, CONTENT_TYPE},
182 StatusCode,
183 },
184 test::{self, TestRequest},
185 };
186 use bytes::Bytes;
187
188 use super::*;
189
190 #[actix_web::test]
191 async fn add_header_error_handler() {
192 #[allow(clippy::unnecessary_wraps)]
193 async fn error_handler<B>(
194 mut res: ServiceResponse<B>,
195 ) -> Result<ServiceResponse<EitherBody<B>>> {
196 res.response_mut()
197 .headers_mut()
198 .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
199
200 Ok(res.map_into_left_body())
201 }
202
203 let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
204
205 let mw = ErrorHandlers::new()
206 .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
207 .new_transform(srv.into_service())
208 .await
209 .unwrap();
210
211 let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
212 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
213 }
214
215 #[actix_web::test]
216 async fn add_header_error_handler_async() {
217 #[allow(clippy::unnecessary_wraps)]
218 async fn error_handler<B: 'static>(
219 mut res: ServiceResponse<B>,
220 ) -> Result<ServiceResponse<EitherBody<B>>> {
221 res.response_mut()
222 .headers_mut()
223 .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
224
225 Ok(res.map_into_left_body())
226 }
227
228 let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
229
230 let mw = ErrorHandlers::new()
231 .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
232 .new_transform(srv.into_service())
233 .await
234 .unwrap();
235
236 let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
237 assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
238 }
239
240 #[actix_web::test]
241 async fn changes_body_type() {
242 #[allow(clippy::unnecessary_wraps)]
243 async fn error_handler<B>(
244 res: ServiceResponse<B>,
245 ) -> Result<ServiceResponse<EitherBody<B>>> {
246 let (req, res) = res.into_parts();
247 let res = res.set_body(Bytes::from("sorry, that's no bueno"));
248
249 let res = ServiceResponse::new(req, res)
250 .map_into_boxed_body()
251 .map_into_right_body();
252
253 Ok(res)
254 }
255
256 let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
257
258 let mw = ErrorHandlers::new()
259 .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
260 .new_transform(srv.into_service())
261 .await
262 .unwrap();
263
264 let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
265 assert_eq!(test::read_body(res).await, "sorry, that's no bueno");
266 }
267
268 #[actix_web::test]
269 async fn error_thrown() {
270 #[allow(clippy::unnecessary_wraps)]
271 async fn error_handler<B>(
272 _res: ServiceResponse<B>,
273 ) -> Result<ServiceResponse<EitherBody<B>>> {
274 Err(actix_web::error::ErrorInternalServerError(
275 "error in error handler",
276 ))
277 }
278
279 let srv = test::status_service(StatusCode::BAD_REQUEST);
280
281 let mw = ErrorHandlers::new()
282 .handler(StatusCode::BAD_REQUEST, error_handler)
283 .new_transform(srv.into_service())
284 .await
285 .unwrap();
286
287 let err = mw
288 .call(TestRequest::default().to_srv_request())
289 .await
290 .unwrap_err();
291 let res = err.error_response();
292
293 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
294 assert_eq!(
295 body::to_bytes(res.into_body()).await.unwrap(),
296 "error in error handler"
297 );
298 }
299}