1use std::{
2 future::{ready, Ready},
3 marker::PhantomData,
4 rc::Rc,
5};
6
7use actix_service::{
8 boxed::{self, BoxFuture, RcService},
9 forward_ready, Service, Transform,
10};
11use actix_web::{
12 body::MessageBody,
13 dev::{ServiceRequest, ServiceResponse},
14 Error, FromRequest,
15};
16use futures_core::{future::LocalBoxFuture, Future};
17
18pub fn from_fn<F, Es>(mw_fn: F) -> MiddlewareFn<F, Es> {
82 MiddlewareFn {
83 mw_fn: Rc::new(mw_fn),
84 _phantom: PhantomData,
85 }
86}
87
88pub struct MiddlewareFn<F, Es> {
90 mw_fn: Rc<F>,
91 _phantom: PhantomData<Es>,
92}
93
94impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MiddlewareFn<F, ()>
95where
96 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
97 F: Fn(ServiceRequest, Next<B>) -> Fut + 'static,
98 Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
99 B2: MessageBody,
100{
101 type Response = ServiceResponse<B2>;
102 type Error = Error;
103 type Transform = MiddlewareFnService<F, B, ()>;
104 type InitError = ();
105 type Future = Ready<Result<Self::Transform, Self::InitError>>;
106
107 fn new_transform(&self, service: S) -> Self::Future {
108 ready(Ok(MiddlewareFnService {
109 service: boxed::rc_service(service),
110 mw_fn: Rc::clone(&self.mw_fn),
111 _phantom: PhantomData,
112 }))
113 }
114}
115
116pub struct MiddlewareFnService<F, B, Es> {
118 service: RcService<ServiceRequest, ServiceResponse<B>, Error>,
119 mw_fn: Rc<F>,
120 _phantom: PhantomData<(B, Es)>,
121}
122
123impl<F, Fut, B, B2> Service<ServiceRequest> for MiddlewareFnService<F, B, ()>
124where
125 F: Fn(ServiceRequest, Next<B>) -> Fut,
126 Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
127 B2: MessageBody,
128{
129 type Response = ServiceResponse<B2>;
130 type Error = Error;
131 type Future = Fut;
132
133 forward_ready!(service);
134
135 fn call(&self, req: ServiceRequest) -> Self::Future {
136 (self.mw_fn)(
137 req,
138 Next::<B> {
139 service: Rc::clone(&self.service),
140 },
141 )
142 }
143}
144
145macro_rules! impl_middleware_fn_service {
146 ($($ext_type:ident),*) => {
147 impl<S, F, Fut, B, B2, $($ext_type),*> Transform<S, ServiceRequest> for MiddlewareFn<F, ($($ext_type),*,)>
148 where
149 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
150 F: Fn($($ext_type),*, ServiceRequest, Next<B>) -> Fut + 'static,
151 $($ext_type: FromRequest + 'static,)*
152 Fut: Future<Output = Result<ServiceResponse<B2>, Error>> + 'static,
153 B: MessageBody + 'static,
154 B2: MessageBody + 'static,
155 {
156 type Response = ServiceResponse<B2>;
157 type Error = Error;
158 type Transform = MiddlewareFnService<F, B, ($($ext_type,)*)>;
159 type InitError = ();
160 type Future = Ready<Result<Self::Transform, Self::InitError>>;
161
162 fn new_transform(&self, service: S) -> Self::Future {
163 ready(Ok(MiddlewareFnService {
164 service: boxed::rc_service(service),
165 mw_fn: Rc::clone(&self.mw_fn),
166 _phantom: PhantomData,
167 }))
168 }
169 }
170
171 impl<F, $($ext_type),*, Fut, B: 'static, B2> Service<ServiceRequest>
172 for MiddlewareFnService<F, B, ($($ext_type),*,)>
173 where
174 F: Fn(
175 $($ext_type),*,
176 ServiceRequest,
177 Next<B>
178 ) -> Fut + 'static,
179 $($ext_type: FromRequest + 'static,)*
180 Fut: Future<Output = Result<ServiceResponse<B2>, Error>> + 'static,
181 B2: MessageBody + 'static,
182 {
183 type Response = ServiceResponse<B2>;
184 type Error = Error;
185 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
186
187 forward_ready!(service);
188
189 #[allow(nonstandard_style)]
190 fn call(&self, mut req: ServiceRequest) -> Self::Future {
191 let mw_fn = Rc::clone(&self.mw_fn);
192 let service = Rc::clone(&self.service);
193
194 Box::pin(async move {
195 let ($($ext_type,)*) = req.extract::<($($ext_type,)*)>().await?;
196
197 (mw_fn)($($ext_type),*, req, Next::<B> { service }).await
198 })
199 }
200 }
201 };
202}
203
204impl_middleware_fn_service!(E1);
205impl_middleware_fn_service!(E1, E2);
206impl_middleware_fn_service!(E1, E2, E3);
207impl_middleware_fn_service!(E1, E2, E3, E4);
208impl_middleware_fn_service!(E1, E2, E3, E4, E5);
209impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6);
210impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7);
211impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7, E8);
212impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7, E8, E9);
213
214pub struct Next<B> {
216 service: RcService<ServiceRequest, ServiceResponse<B>, Error>,
217}
218
219impl<B> Next<B> {
220 pub fn call(&self, req: ServiceRequest) -> <Self as Service<ServiceRequest>>::Future {
222 Service::call(self, req)
223 }
224}
225
226impl<B> Service<ServiceRequest> for Next<B> {
227 type Response = ServiceResponse<B>;
228 type Error = Error;
229 type Future = BoxFuture<Result<Self::Response, Self::Error>>;
230
231 forward_ready!(service);
232
233 fn call(&self, req: ServiceRequest) -> Self::Future {
234 self.service.call(req)
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use actix_web::{
241 http::header::{self, HeaderValue},
242 middleware::{Compat, Logger},
243 test, web, App, HttpResponse,
244 };
245
246 use super::*;
247
248 async fn noop<B>(req: ServiceRequest, next: Next<B>) -> Result<ServiceResponse<B>, Error> {
249 next.call(req).await
250 }
251
252 async fn add_res_header<B>(
253 req: ServiceRequest,
254 next: Next<B>,
255 ) -> Result<ServiceResponse<B>, Error> {
256 let mut res = next.call(req).await?;
257 res.headers_mut()
258 .insert(header::WARNING, HeaderValue::from_static("42"));
259 Ok(res)
260 }
261
262 async fn mutate_body_type(
263 req: ServiceRequest,
264 next: Next<impl MessageBody + 'static>,
265 ) -> Result<ServiceResponse<impl MessageBody>, Error> {
266 let res = next.call(req).await?;
267 Ok(res.map_into_left_body::<()>())
268 }
269
270 struct MyMw(bool);
271
272 impl MyMw {
273 async fn mw_cb(
274 &self,
275 req: ServiceRequest,
276 next: Next<impl MessageBody + 'static>,
277 ) -> Result<ServiceResponse<impl MessageBody>, Error> {
278 let mut res = match self.0 {
279 true => req.into_response("short-circuited").map_into_right_body(),
280 false => next.call(req).await?.map_into_left_body(),
281 };
282 res.headers_mut()
283 .insert(header::WARNING, HeaderValue::from_static("42"));
284 Ok(res)
285 }
286
287 pub fn into_middleware<S, B>(
288 self,
289 ) -> impl Transform<
290 S,
291 ServiceRequest,
292 Response = ServiceResponse<impl MessageBody>,
293 Error = Error,
294 InitError = (),
295 >
296 where
297 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
298 B: MessageBody + 'static,
299 {
300 let this = Rc::new(self);
301 from_fn(move |req, next| {
302 let this = Rc::clone(&this);
303 async move { Self::mw_cb(&this, req, next).await }
304 })
305 }
306 }
307
308 #[actix_web::test]
309 async fn compat_compat() {
310 let _ = App::new().wrap(Compat::new(from_fn(noop)));
311 let _ = App::new().wrap(Compat::new(from_fn(mutate_body_type)));
312 }
313
314 #[actix_web::test]
315 async fn feels_good() {
316 let app = test::init_service(
317 App::new()
318 .wrap(from_fn(mutate_body_type))
319 .wrap(from_fn(add_res_header))
320 .wrap(Logger::default())
321 .wrap(from_fn(noop))
322 .default_service(web::to(HttpResponse::NotFound)),
323 )
324 .await;
325
326 let req = test::TestRequest::default().to_request();
327 let res = test::call_service(&app, req).await;
328 assert!(res.headers().contains_key(header::WARNING));
329 }
330
331 #[actix_web::test]
332 async fn closure_capture_and_return_from_fn() {
333 let app = test::init_service(
334 App::new()
335 .wrap(Logger::default())
336 .wrap(MyMw(true).into_middleware())
337 .wrap(Logger::default()),
338 )
339 .await;
340
341 let req = test::TestRequest::default().to_request();
342 let res = test::call_service(&app, req).await;
343 assert!(res.headers().contains_key(header::WARNING));
344 }
345}