1use std::{
2 future::{ready, Future, Ready},
3 marker::PhantomData,
4 pin::Pin,
5 rc::Rc,
6 task::{Context, Poll},
7};
8
9use actix_service::{forward_ready, Service, Transform};
10use actix_web::{
11 body::MessageBody,
12 dev::{ServiceRequest, ServiceResponse},
13 Error,
14};
15use futures_core::ready;
16use pin_project_lite::pin_project;
17
18pub fn map_response<F>(mapper_fn: F) -> MapResMiddleware<F> {
51 MapResMiddleware {
52 mw_fn: Rc::new(mapper_fn),
53 }
54}
55
56pub struct MapResMiddleware<F> {
58 mw_fn: Rc<F>,
59}
60
61impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MapResMiddleware<F>
62where
63 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
64 F: Fn(ServiceResponse<B>) -> Fut,
65 Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
66 B2: MessageBody,
67{
68 type Response = ServiceResponse<B2>;
69 type Error = Error;
70 type Transform = MapResService<S, F, B>;
71 type InitError = ();
72 type Future = Ready<Result<Self::Transform, Self::InitError>>;
73
74 fn new_transform(&self, service: S) -> Self::Future {
75 ready(Ok(MapResService {
76 service,
77 mw_fn: Rc::clone(&self.mw_fn),
78 _phantom: PhantomData,
79 }))
80 }
81}
82
83pub struct MapResService<S, F, B> {
85 service: S,
86 mw_fn: Rc<F>,
87 _phantom: PhantomData<(B,)>,
88}
89
90impl<S, F, Fut, B, B2> Service<ServiceRequest> for MapResService<S, F, B>
91where
92 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
93 F: Fn(ServiceResponse<B>) -> Fut,
94 Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
95 B2: MessageBody,
96{
97 type Response = ServiceResponse<B2>;
98 type Error = Error;
99 type Future = MapResFut<S::Future, F, Fut>;
100
101 forward_ready!(service);
102
103 fn call(&self, req: ServiceRequest) -> Self::Future {
104 let mw_fn = Rc::clone(&self.mw_fn);
105 let fut = self.service.call(req);
106
107 MapResFut {
108 mw_fn,
109 state: MapResFutState::Svc { fut },
110 }
111 }
112}
113
114pin_project! {
115 pub struct MapResFut<SvcFut, F, FnFut> {
116 mw_fn: Rc<F>,
117 #[pin]
118 state: MapResFutState<SvcFut, FnFut>,
119 }
120}
121
122pin_project! {
123 #[project = MapResFutStateProj]
124 enum MapResFutState<SvcFut, FnFut> {
125 Svc { #[pin] fut: SvcFut },
126 Fn { #[pin] fut: FnFut },
127 }
128}
129
130impl<SvcFut, B, F, FnFut, B2> Future for MapResFut<SvcFut, F, FnFut>
131where
132 SvcFut: Future<Output = Result<ServiceResponse<B>, Error>>,
133 F: Fn(ServiceResponse<B>) -> FnFut,
134 FnFut: Future<Output = Result<ServiceResponse<B2>, Error>>,
135{
136 type Output = Result<ServiceResponse<B2>, Error>;
137
138 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139 let mut this = self.as_mut().project();
140
141 match this.state.as_mut().project() {
142 MapResFutStateProj::Svc { fut } => {
143 let res = ready!(fut.poll(cx))?;
144
145 let fut = (this.mw_fn)(res);
146 this.state.set(MapResFutState::Fn { fut });
147 self.poll(cx)
148 }
149
150 MapResFutStateProj::Fn { fut } => fut.poll(cx),
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use actix_web::{
158 http::header::{self, HeaderValue},
159 middleware::{Compat, Logger},
160 test, web, App, HttpResponse,
161 };
162
163 use super::*;
164
165 async fn noop(
166 res: ServiceResponse<impl MessageBody>,
167 ) -> Result<ServiceResponse<impl MessageBody>, Error> {
168 Ok(res)
169 }
170
171 async fn add_header(
172 mut res: ServiceResponse<impl MessageBody>,
173 ) -> Result<ServiceResponse<impl MessageBody>, Error> {
174 res.headers_mut()
175 .insert(header::WARNING, HeaderValue::from_static("42"));
176
177 Ok(res)
178 }
179
180 async fn mutate_body_type(
181 res: ServiceResponse<impl MessageBody + 'static>,
182 ) -> Result<ServiceResponse<impl MessageBody>, Error> {
183 Ok(res.map_into_left_body::<()>())
184 }
185
186 #[actix_web::test]
187 async fn compat_compat() {
188 let _ = App::new().wrap(Compat::new(map_response(noop)));
189 let _ = App::new().wrap(Compat::new(map_response(mutate_body_type)));
190 }
191
192 #[actix_web::test]
193 async fn feels_good() {
194 let app = test::init_service(
195 App::new()
196 .default_service(web::to(HttpResponse::Ok))
197 .wrap(map_response(|res| async move { Ok(res) }))
198 .wrap(map_response(noop))
199 .wrap(map_response(add_header))
200 .wrap(Logger::default())
201 .wrap(map_response(mutate_body_type)),
202 )
203 .await;
204
205 let req = test::TestRequest::default().to_request();
206 let res = test::call_service(&app, req).await;
207 assert!(res.headers().contains_key(header::WARNING));
208 }
209}