use std::{
future::Future,
pin::Pin,
rc::Rc,
task::{ready, Context, Poll},
};
use actix_service::{Service, Transform};
use actix_web::{
body::EitherBody,
dev::{ServiceRequest, ServiceResponse},
http::StatusCode,
Error, Result,
};
use ahash::AHashMap;
use futures_core::future::LocalBoxFuture;
use pin_project_lite::pin_project;
type ErrorHandlerRes<B> = Result<ServiceResponse<EitherBody<B>>>;
type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> LocalBoxFuture<'static, ErrorHandlerRes<B>>;
type Handlers<B> = Rc<AHashMap<StatusCode, Box<ErrorHandler<B>>>>;
pub struct ErrorHandlers<B> {
handlers: Handlers<B>,
}
impl<B> Default for ErrorHandlers<B> {
fn default() -> Self {
ErrorHandlers {
handlers: Default::default(),
}
}
}
impl<B> ErrorHandlers<B> {
pub fn new() -> Self {
ErrorHandlers::default()
}
pub fn handler<F, Fut>(mut self, status: StatusCode, handler: F) -> Self
where
F: Fn(ServiceResponse<B>) -> Fut + 'static,
Fut: Future<Output = ErrorHandlerRes<B>> + 'static,
{
Rc::get_mut(&mut self.handlers)
.unwrap()
.insert(status, Box::new(move |res| Box::pin((handler)(res))));
self
}
}
impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Transform = ErrorHandlersMiddleware<S, B>;
type InitError = ();
type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
let handlers = self.handlers.clone();
Box::pin(async move { Ok(ErrorHandlersMiddleware { service, handlers }) })
}
}
#[doc(hidden)]
pub struct ErrorHandlersMiddleware<S, B> {
service: S,
handlers: Handlers<B>,
}
impl<S, B> Service<ServiceRequest> for ErrorHandlersMiddleware<S, B>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<EitherBody<B>>;
type Error = Error;
type Future = ErrorHandlersFuture<S::Future, B>;
actix_service::forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let handlers = self.handlers.clone();
let fut = self.service.call(req);
ErrorHandlersFuture::ServiceFuture { fut, handlers }
}
}
pin_project! {
#[project = ErrorHandlersProj]
pub enum ErrorHandlersFuture<Fut, B>
where
Fut: Future,
{
ServiceFuture {
#[pin]
fut: Fut,
handlers: Handlers<B>,
},
ErrorHandlerFuture {
fut: LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>,
},
}
}
impl<Fut, B> Future for ErrorHandlersFuture<Fut, B>
where
Fut: Future<Output = Result<ServiceResponse<B>, Error>>,
{
type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.as_mut().project() {
ErrorHandlersProj::ServiceFuture { fut, handlers } => {
let res = ready!(fut.poll(cx))?;
match handlers.get(&res.status()) {
Some(handler) => {
let fut = handler(res);
self.as_mut()
.set(ErrorHandlersFuture::ErrorHandlerFuture { fut });
self.poll(cx)
}
None => Poll::Ready(Ok(res.map_into_left_body())),
}
}
ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx),
}
}
}
#[cfg(test)]
mod tests {
use actix_service::IntoService;
use actix_web::{
body,
http::{
header::{HeaderValue, CONTENT_TYPE},
StatusCode,
},
test::{self, TestRequest},
};
use bytes::Bytes;
use super::*;
#[actix_web::test]
async fn add_header_error_handler() {
#[allow(clippy::unnecessary_wraps)]
async fn error_handler<B>(
mut res: ServiceResponse<B>,
) -> Result<ServiceResponse<EitherBody<B>>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(res.map_into_left_body())
}
let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
let mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
#[actix_web::test]
async fn add_header_error_handler_async() {
#[allow(clippy::unnecessary_wraps)]
async fn error_handler<B: 'static>(
mut res: ServiceResponse<B>,
) -> Result<ServiceResponse<EitherBody<B>>> {
res.response_mut()
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
Ok(res.map_into_left_body())
}
let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
let mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
.new_transform(srv.into_service())
.await
.unwrap();
let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
}
#[actix_web::test]
async fn changes_body_type() {
#[allow(clippy::unnecessary_wraps)]
async fn error_handler<B>(
res: ServiceResponse<B>,
) -> Result<ServiceResponse<EitherBody<B>>> {
let (req, res) = res.into_parts();
let res = res.set_body(Bytes::from("sorry, that's no bueno"));
let res = ServiceResponse::new(req, res)
.map_into_boxed_body()
.map_into_right_body();
Ok(res)
}
let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
let mw = ErrorHandlers::new()
.handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
.new_transform(srv.into_service())
.await
.unwrap();
let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
assert_eq!(test::read_body(res).await, "sorry, that's no bueno");
}
#[actix_web::test]
async fn error_thrown() {
#[allow(clippy::unnecessary_wraps)]
async fn error_handler<B>(
_res: ServiceResponse<B>,
) -> Result<ServiceResponse<EitherBody<B>>> {
Err(actix_web::error::ErrorInternalServerError(
"error in error handler",
))
}
let srv = test::status_service(StatusCode::BAD_REQUEST);
let mw = ErrorHandlers::new()
.handler(StatusCode::BAD_REQUEST, error_handler)
.new_transform(srv.into_service())
.await
.unwrap();
let err = mw
.call(TestRequest::default().to_srv_request())
.await
.unwrap_err();
let res = err.error_response();
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
body::to_bytes(res.into_body()).await.unwrap(),
"error in error handler"
);
}
}