actix_web_lab/
middleware_map_response_body.rs

1use std::{
2    future::{ready, Future, Ready},
3    marker::PhantomData,
4    pin::Pin,
5    rc::Rc,
6    task::{Context, Poll},
7};
8
9use actix_service::{forward_ready, Service, Transform};
10use actix_web::{
11    body::MessageBody,
12    dev::{ServiceRequest, ServiceResponse},
13    Error, HttpRequest, HttpResponse,
14};
15use futures_core::ready;
16use pin_project_lite::pin_project;
17
18/// Creates a middleware from an async function that is used as a mapping function for an
19/// [`impl MessageBody`][MessageBody].
20///
21/// # Examples
22/// Completely replaces the body:
23/// ```
24/// # use actix_web_lab::middleware::map_response_body;
25/// use actix_web::{body::MessageBody, HttpRequest};
26///
27/// async fn replace_body(
28///     _req: HttpRequest,
29///     _: impl MessageBody,
30/// ) -> actix_web::Result<impl MessageBody> {
31///     Ok("foo".to_owned())
32/// }
33/// # actix_web::App::new().wrap(map_response_body(replace_body));
34/// ```
35///
36/// Appends some bytes to the body:
37/// ```
38/// # use actix_web_lab::middleware::map_response_body;
39/// use actix_web::{
40///     body::{self, MessageBody},
41///     web::{BufMut as _, BytesMut},
42///     HttpRequest,
43/// };
44///
45/// async fn append_bytes(
46///     _req: HttpRequest,
47///     body: impl MessageBody,
48/// ) -> actix_web::Result<impl MessageBody> {
49///     let buf = body::to_bytes(body).await.ok().unwrap();
50///
51///     let mut body = BytesMut::from(&buf[..]);
52///     body.put_slice(b" - hope you like things ruining your payload format");
53///
54///     Ok(body)
55/// }
56/// # actix_web::App::new().wrap(map_response_body(append_bytes));
57/// ```
58pub fn map_response_body<F>(mapper_fn: F) -> MapResBodyMiddleware<F> {
59    MapResBodyMiddleware {
60        mw_fn: Rc::new(mapper_fn),
61    }
62}
63
64/// Middleware transform for [`map_response_body`].
65pub struct MapResBodyMiddleware<F> {
66    mw_fn: Rc<F>,
67}
68
69impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MapResBodyMiddleware<F>
70where
71    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
72    F: Fn(HttpRequest, B) -> Fut,
73    Fut: Future<Output = Result<B2, Error>>,
74    B2: MessageBody,
75{
76    type Response = ServiceResponse<B2>;
77    type Error = Error;
78    type Transform = MapResBodyService<S, F, B>;
79    type InitError = ();
80    type Future = Ready<Result<Self::Transform, Self::InitError>>;
81
82    fn new_transform(&self, service: S) -> Self::Future {
83        ready(Ok(MapResBodyService {
84            service,
85            mw_fn: Rc::clone(&self.mw_fn),
86            _phantom: PhantomData,
87        }))
88    }
89}
90
91/// Middleware service for [`from_fn`].
92pub struct MapResBodyService<S, F, B> {
93    service: S,
94    mw_fn: Rc<F>,
95    _phantom: PhantomData<(B,)>,
96}
97
98impl<S, F, Fut, B, B2> Service<ServiceRequest> for MapResBodyService<S, F, B>
99where
100    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
101    F: Fn(HttpRequest, B) -> Fut,
102    Fut: Future<Output = Result<B2, Error>>,
103    B2: MessageBody,
104{
105    type Response = ServiceResponse<B2>;
106    type Error = Error;
107    type Future = MapResBodyFut<S::Future, F, Fut>;
108
109    forward_ready!(service);
110
111    fn call(&self, req: ServiceRequest) -> Self::Future {
112        let mw_fn = Rc::clone(&self.mw_fn);
113        let fut = self.service.call(req);
114
115        MapResBodyFut {
116            mw_fn,
117            state: MapResBodyFutState::Svc { fut },
118        }
119    }
120}
121
122pin_project! {
123    pub struct MapResBodyFut<SvcFut, F, FnFut> {
124        mw_fn: Rc<F>,
125        #[pin]
126        state: MapResBodyFutState<SvcFut, FnFut>,
127    }
128}
129
130pin_project! {
131    #[project = MapResBodyFutStateProj]
132    enum MapResBodyFutState<SvcFut, FnFut> {
133        Svc { #[pin] fut: SvcFut },
134
135        Fn {
136            #[pin]
137            fut: FnFut,
138
139            req: Option<HttpRequest>,
140            res: Option<HttpResponse<()>>
141        },
142    }
143}
144
145impl<SvcFut, B, F, FnFut, B2> Future for MapResBodyFut<SvcFut, F, FnFut>
146where
147    SvcFut: Future<Output = Result<ServiceResponse<B>, Error>>,
148    F: Fn(HttpRequest, B) -> FnFut,
149    FnFut: Future<Output = Result<B2, Error>>,
150{
151    type Output = Result<ServiceResponse<B2>, Error>;
152
153    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
154        let mut this = self.as_mut().project();
155
156        match this.state.as_mut().project() {
157            MapResBodyFutStateProj::Svc { fut } => {
158                let res = ready!(fut.poll(cx))?;
159
160                let (req, res) = res.into_parts();
161                let (res, body) = res.into_parts();
162
163                let fut = (this.mw_fn)(req.clone(), body);
164                this.state.set(MapResBodyFutState::Fn {
165                    fut,
166                    req: Some(req),
167                    res: Some(res),
168                });
169
170                self.poll(cx)
171            }
172
173            MapResBodyFutStateProj::Fn { fut, req, res } => {
174                let body = ready!(fut.poll(cx))?;
175
176                let req = req.take().unwrap();
177                let res = res.take().unwrap();
178
179                let res = res.set_body(body);
180                let res = ServiceResponse::new(req, res);
181
182                Poll::Ready(Ok(res))
183            }
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use actix_web::{
191        middleware::{Compat, Logger},
192        test, web, App, HttpResponse,
193    };
194
195    use super::*;
196
197    async fn noop(_req: HttpRequest, body: impl MessageBody) -> Result<impl MessageBody, Error> {
198        Ok(body)
199    }
200
201    async fn mutate_body_type(
202        _req: HttpRequest,
203        _body: impl MessageBody + 'static,
204    ) -> Result<impl MessageBody, Error> {
205        Ok("foo".to_owned())
206    }
207
208    #[actix_web::test]
209    async fn compat_compat() {
210        let _ = App::new().wrap(Compat::new(map_response_body(noop)));
211        let _ = App::new().wrap(Compat::new(map_response_body(mutate_body_type)));
212    }
213
214    #[actix_web::test]
215    async fn feels_good() {
216        let app = test::init_service(
217            App::new()
218                .default_service(web::to(HttpResponse::Ok))
219                .wrap(map_response_body(|_req, body| async move { Ok(body) }))
220                .wrap(map_response_body(noop))
221                .wrap(Logger::default())
222                .wrap(map_response_body(mutate_body_type)),
223        )
224        .await;
225
226        let req = test::TestRequest::default().to_request();
227        let body = test::call_and_read_body(&app, req).await;
228        assert_eq!(body, "foo");
229    }
230}