1use std::{
2 future::{ready, Future, Ready},
3 marker::PhantomData,
4 pin::Pin,
5 rc::Rc,
6 task::{Context, Poll},
7};
8
9use actix_service::{forward_ready, Service, Transform};
10use actix_web::{
11 body::MessageBody,
12 dev::{ServiceRequest, ServiceResponse},
13 Error, HttpRequest, HttpResponse,
14};
15use futures_core::ready;
16use pin_project_lite::pin_project;
17
18pub fn map_response_body<F>(mapper_fn: F) -> MapResBodyMiddleware<F> {
59 MapResBodyMiddleware {
60 mw_fn: Rc::new(mapper_fn),
61 }
62}
63
64pub struct MapResBodyMiddleware<F> {
66 mw_fn: Rc<F>,
67}
68
69impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MapResBodyMiddleware<F>
70where
71 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
72 F: Fn(HttpRequest, B) -> Fut,
73 Fut: Future<Output = Result<B2, Error>>,
74 B2: MessageBody,
75{
76 type Response = ServiceResponse<B2>;
77 type Error = Error;
78 type Transform = MapResBodyService<S, F, B>;
79 type InitError = ();
80 type Future = Ready<Result<Self::Transform, Self::InitError>>;
81
82 fn new_transform(&self, service: S) -> Self::Future {
83 ready(Ok(MapResBodyService {
84 service,
85 mw_fn: Rc::clone(&self.mw_fn),
86 _phantom: PhantomData,
87 }))
88 }
89}
90
91pub struct MapResBodyService<S, F, B> {
93 service: S,
94 mw_fn: Rc<F>,
95 _phantom: PhantomData<(B,)>,
96}
97
98impl<S, F, Fut, B, B2> Service<ServiceRequest> for MapResBodyService<S, F, B>
99where
100 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
101 F: Fn(HttpRequest, B) -> Fut,
102 Fut: Future<Output = Result<B2, Error>>,
103 B2: MessageBody,
104{
105 type Response = ServiceResponse<B2>;
106 type Error = Error;
107 type Future = MapResBodyFut<S::Future, F, Fut>;
108
109 forward_ready!(service);
110
111 fn call(&self, req: ServiceRequest) -> Self::Future {
112 let mw_fn = Rc::clone(&self.mw_fn);
113 let fut = self.service.call(req);
114
115 MapResBodyFut {
116 mw_fn,
117 state: MapResBodyFutState::Svc { fut },
118 }
119 }
120}
121
122pin_project! {
123 pub struct MapResBodyFut<SvcFut, F, FnFut> {
124 mw_fn: Rc<F>,
125 #[pin]
126 state: MapResBodyFutState<SvcFut, FnFut>,
127 }
128}
129
130pin_project! {
131 #[project = MapResBodyFutStateProj]
132 enum MapResBodyFutState<SvcFut, FnFut> {
133 Svc { #[pin] fut: SvcFut },
134
135 Fn {
136 #[pin]
137 fut: FnFut,
138
139 req: Option<HttpRequest>,
140 res: Option<HttpResponse<()>>
141 },
142 }
143}
144
145impl<SvcFut, B, F, FnFut, B2> Future for MapResBodyFut<SvcFut, F, FnFut>
146where
147 SvcFut: Future<Output = Result<ServiceResponse<B>, Error>>,
148 F: Fn(HttpRequest, B) -> FnFut,
149 FnFut: Future<Output = Result<B2, Error>>,
150{
151 type Output = Result<ServiceResponse<B2>, Error>;
152
153 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
154 let mut this = self.as_mut().project();
155
156 match this.state.as_mut().project() {
157 MapResBodyFutStateProj::Svc { fut } => {
158 let res = ready!(fut.poll(cx))?;
159
160 let (req, res) = res.into_parts();
161 let (res, body) = res.into_parts();
162
163 let fut = (this.mw_fn)(req.clone(), body);
164 this.state.set(MapResBodyFutState::Fn {
165 fut,
166 req: Some(req),
167 res: Some(res),
168 });
169
170 self.poll(cx)
171 }
172
173 MapResBodyFutStateProj::Fn { fut, req, res } => {
174 let body = ready!(fut.poll(cx))?;
175
176 let req = req.take().unwrap();
177 let res = res.take().unwrap();
178
179 let res = res.set_body(body);
180 let res = ServiceResponse::new(req, res);
181
182 Poll::Ready(Ok(res))
183 }
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use actix_web::{
191 middleware::{Compat, Logger},
192 test, web, App, HttpResponse,
193 };
194
195 use super::*;
196
197 async fn noop(_req: HttpRequest, body: impl MessageBody) -> Result<impl MessageBody, Error> {
198 Ok(body)
199 }
200
201 async fn mutate_body_type(
202 _req: HttpRequest,
203 _body: impl MessageBody + 'static,
204 ) -> Result<impl MessageBody, Error> {
205 Ok("foo".to_owned())
206 }
207
208 #[actix_web::test]
209 async fn compat_compat() {
210 let _ = App::new().wrap(Compat::new(map_response_body(noop)));
211 let _ = App::new().wrap(Compat::new(map_response_body(mutate_body_type)));
212 }
213
214 #[actix_web::test]
215 async fn feels_good() {
216 let app = test::init_service(
217 App::new()
218 .default_service(web::to(HttpResponse::Ok))
219 .wrap(map_response_body(|_req, body| async move { Ok(body) }))
220 .wrap(map_response_body(noop))
221 .wrap(Logger::default())
222 .wrap(map_response_body(mutate_body_type)),
223 )
224 .await;
225
226 let req = test::TestRequest::default().to_request();
227 let body = test::call_and_read_body(&app, req).await;
228 assert_eq!(body, "foo");
229 }
230}