actix_web_lab/
middleware_map_response.rs

1use std::{
2    future::{ready, Future, Ready},
3    marker::PhantomData,
4    pin::Pin,
5    rc::Rc,
6    task::{Context, Poll},
7};
8
9use actix_service::{forward_ready, Service, Transform};
10use actix_web::{
11    body::MessageBody,
12    dev::{ServiceRequest, ServiceResponse},
13    Error,
14};
15use futures_core::ready;
16use pin_project_lite::pin_project;
17
18/// Creates a middleware from an async function that is used as a mapping function for a
19/// [`ServiceResponse`].
20///
21/// # Examples
22/// Adds header:
23/// ```
24/// # use actix_web_lab::middleware::map_response;
25/// use actix_web::{body::MessageBody, dev::ServiceResponse, http::header};
26///
27/// async fn add_header(
28///     mut res: ServiceResponse<impl MessageBody>,
29/// ) -> actix_web::Result<ServiceResponse<impl MessageBody>> {
30///     res.headers_mut()
31///         .insert(header::WARNING, header::HeaderValue::from_static("42"));
32///
33///     Ok(res)
34/// }
35/// # actix_web::App::new().wrap(map_response(add_header));
36/// ```
37///
38/// Maps body:
39/// ```
40/// # use actix_web_lab::middleware::map_response;
41/// use actix_web::{body::MessageBody, dev::ServiceResponse};
42///
43/// async fn mutate_body_type(
44///     res: ServiceResponse<impl MessageBody + 'static>,
45/// ) -> actix_web::Result<ServiceResponse<impl MessageBody>> {
46///     Ok(res.map_into_left_body::<()>())
47/// }
48/// # actix_web::App::new().wrap(map_response(mutate_body_type));
49/// ```
50pub fn map_response<F>(mapper_fn: F) -> MapResMiddleware<F> {
51    MapResMiddleware {
52        mw_fn: Rc::new(mapper_fn),
53    }
54}
55
56/// Middleware transform for [`map_response`].
57pub struct MapResMiddleware<F> {
58    mw_fn: Rc<F>,
59}
60
61impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MapResMiddleware<F>
62where
63    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
64    F: Fn(ServiceResponse<B>) -> Fut,
65    Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
66    B2: MessageBody,
67{
68    type Response = ServiceResponse<B2>;
69    type Error = Error;
70    type Transform = MapResService<S, F, B>;
71    type InitError = ();
72    type Future = Ready<Result<Self::Transform, Self::InitError>>;
73
74    fn new_transform(&self, service: S) -> Self::Future {
75        ready(Ok(MapResService {
76            service,
77            mw_fn: Rc::clone(&self.mw_fn),
78            _phantom: PhantomData,
79        }))
80    }
81}
82
83/// Middleware service for [`from_fn`].
84pub struct MapResService<S, F, B> {
85    service: S,
86    mw_fn: Rc<F>,
87    _phantom: PhantomData<(B,)>,
88}
89
90impl<S, F, Fut, B, B2> Service<ServiceRequest> for MapResService<S, F, B>
91where
92    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
93    F: Fn(ServiceResponse<B>) -> Fut,
94    Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
95    B2: MessageBody,
96{
97    type Response = ServiceResponse<B2>;
98    type Error = Error;
99    type Future = MapResFut<S::Future, F, Fut>;
100
101    forward_ready!(service);
102
103    fn call(&self, req: ServiceRequest) -> Self::Future {
104        let mw_fn = Rc::clone(&self.mw_fn);
105        let fut = self.service.call(req);
106
107        MapResFut {
108            mw_fn,
109            state: MapResFutState::Svc { fut },
110        }
111    }
112}
113
114pin_project! {
115    pub struct MapResFut<SvcFut, F, FnFut> {
116        mw_fn: Rc<F>,
117        #[pin]
118        state: MapResFutState<SvcFut, FnFut>,
119    }
120}
121
122pin_project! {
123    #[project = MapResFutStateProj]
124    enum MapResFutState<SvcFut, FnFut> {
125        Svc { #[pin] fut: SvcFut },
126        Fn { #[pin] fut: FnFut },
127    }
128}
129
130impl<SvcFut, B, F, FnFut, B2> Future for MapResFut<SvcFut, F, FnFut>
131where
132    SvcFut: Future<Output = Result<ServiceResponse<B>, Error>>,
133    F: Fn(ServiceResponse<B>) -> FnFut,
134    FnFut: Future<Output = Result<ServiceResponse<B2>, Error>>,
135{
136    type Output = Result<ServiceResponse<B2>, Error>;
137
138    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139        let mut this = self.as_mut().project();
140
141        match this.state.as_mut().project() {
142            MapResFutStateProj::Svc { fut } => {
143                let res = ready!(fut.poll(cx))?;
144
145                let fut = (this.mw_fn)(res);
146                this.state.set(MapResFutState::Fn { fut });
147                self.poll(cx)
148            }
149
150            MapResFutStateProj::Fn { fut } => fut.poll(cx),
151        }
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use actix_web::{
158        http::header::{self, HeaderValue},
159        middleware::{Compat, Logger},
160        test, web, App, HttpResponse,
161    };
162
163    use super::*;
164
165    async fn noop(
166        res: ServiceResponse<impl MessageBody>,
167    ) -> Result<ServiceResponse<impl MessageBody>, Error> {
168        Ok(res)
169    }
170
171    async fn add_header(
172        mut res: ServiceResponse<impl MessageBody>,
173    ) -> Result<ServiceResponse<impl MessageBody>, Error> {
174        res.headers_mut()
175            .insert(header::WARNING, HeaderValue::from_static("42"));
176
177        Ok(res)
178    }
179
180    async fn mutate_body_type(
181        res: ServiceResponse<impl MessageBody + 'static>,
182    ) -> Result<ServiceResponse<impl MessageBody>, Error> {
183        Ok(res.map_into_left_body::<()>())
184    }
185
186    #[actix_web::test]
187    async fn compat_compat() {
188        let _ = App::new().wrap(Compat::new(map_response(noop)));
189        let _ = App::new().wrap(Compat::new(map_response(mutate_body_type)));
190    }
191
192    #[actix_web::test]
193    async fn feels_good() {
194        let app = test::init_service(
195            App::new()
196                .default_service(web::to(HttpResponse::Ok))
197                .wrap(map_response(|res| async move { Ok(res) }))
198                .wrap(map_response(noop))
199                .wrap(map_response(add_header))
200                .wrap(Logger::default())
201                .wrap(map_response(mutate_body_type)),
202        )
203        .await;
204
205        let req = test::TestRequest::default().to_request();
206        let res = test::call_service(&app, req).await;
207        assert!(res.headers().contains_key(header::WARNING));
208    }
209}