actix_web_lab/
middleware_from_fn.rs

1use std::{
2    future::{ready, Ready},
3    marker::PhantomData,
4    rc::Rc,
5};
6
7use actix_service::{
8    boxed::{self, BoxFuture, RcService},
9    forward_ready, Service, Transform,
10};
11use actix_web::{
12    body::MessageBody,
13    dev::{ServiceRequest, ServiceResponse},
14    Error, FromRequest,
15};
16use futures_core::{future::LocalBoxFuture, Future};
17
18/// Wraps an async function to be used as a middleware.
19///
20/// # Examples
21/// The wrapped function should have the following form:
22/// ```
23/// # use actix_web::{
24/// #     App, Error,
25/// #     body::MessageBody,
26/// #     dev::{ServiceRequest, ServiceResponse, Service as _},
27/// # };
28/// use actix_web_lab::middleware::Next;
29///
30/// async fn my_mw(
31///     req: ServiceRequest,
32///     next: Next<impl MessageBody>,
33/// ) -> Result<ServiceResponse<impl MessageBody>, Error> {
34///     // pre-processing
35///     next.call(req).await
36///     // post-processing
37/// }
38/// # actix_web::App::new().wrap(actix_web_lab::middleware::from_fn(my_mw));
39/// ```
40///
41/// Then use in an app builder like this:
42/// ```
43/// use actix_web::{
44///     App, Error,
45///     dev::{ServiceRequest, ServiceResponse, Service as _},
46/// };
47/// use actix_web_lab::middleware::from_fn;
48/// # use actix_web_lab::middleware::Next;
49/// # async fn my_mw<B>(req: ServiceRequest, next: Next<B>) -> Result<ServiceResponse<B>, Error> {
50/// #     next.call(req).await
51/// # }
52///
53/// App::new()
54///     .wrap(from_fn(my_mw))
55/// # ;
56/// ```
57///
58/// It is also possible to write a middleware that automatically uses extractors, similar to request
59/// handlers, by declaring them as the first parameters:
60/// ```
61/// # use std::collections::HashMap;
62/// # use actix_web::{
63/// #     App, Error,
64/// #     body::MessageBody,
65/// #     dev::{ServiceRequest, ServiceResponse, Service as _},
66/// #     web,
67/// # };
68/// use actix_web_lab::middleware::Next;
69///
70/// async fn my_extracting_mw(
71///     string_body: String,
72///     query: web::Query<HashMap<String, String>>,
73///     req: ServiceRequest,
74///     next: Next<impl MessageBody>,
75/// ) -> Result<ServiceResponse<impl MessageBody>, Error> {
76///     // pre-processing
77///     next.call(req).await
78///     // post-processing
79/// }
80/// # actix_web::App::new().wrap(actix_web_lab::middleware::from_fn(my_extracting_mw));
81pub fn from_fn<F, Es>(mw_fn: F) -> MiddlewareFn<F, Es> {
82    MiddlewareFn {
83        mw_fn: Rc::new(mw_fn),
84        _phantom: PhantomData,
85    }
86}
87
88/// Middleware transform for [`from_fn`].
89pub struct MiddlewareFn<F, Es> {
90    mw_fn: Rc<F>,
91    _phantom: PhantomData<Es>,
92}
93
94impl<S, F, Fut, B, B2> Transform<S, ServiceRequest> for MiddlewareFn<F, ()>
95where
96    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
97    F: Fn(ServiceRequest, Next<B>) -> Fut + 'static,
98    Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
99    B2: MessageBody,
100{
101    type Response = ServiceResponse<B2>;
102    type Error = Error;
103    type Transform = MiddlewareFnService<F, B, ()>;
104    type InitError = ();
105    type Future = Ready<Result<Self::Transform, Self::InitError>>;
106
107    fn new_transform(&self, service: S) -> Self::Future {
108        ready(Ok(MiddlewareFnService {
109            service: boxed::rc_service(service),
110            mw_fn: Rc::clone(&self.mw_fn),
111            _phantom: PhantomData,
112        }))
113    }
114}
115
116/// Middleware service for [`from_fn`].
117pub struct MiddlewareFnService<F, B, Es> {
118    service: RcService<ServiceRequest, ServiceResponse<B>, Error>,
119    mw_fn: Rc<F>,
120    _phantom: PhantomData<(B, Es)>,
121}
122
123impl<F, Fut, B, B2> Service<ServiceRequest> for MiddlewareFnService<F, B, ()>
124where
125    F: Fn(ServiceRequest, Next<B>) -> Fut,
126    Fut: Future<Output = Result<ServiceResponse<B2>, Error>>,
127    B2: MessageBody,
128{
129    type Response = ServiceResponse<B2>;
130    type Error = Error;
131    type Future = Fut;
132
133    forward_ready!(service);
134
135    fn call(&self, req: ServiceRequest) -> Self::Future {
136        (self.mw_fn)(
137            req,
138            Next::<B> {
139                service: Rc::clone(&self.service),
140            },
141        )
142    }
143}
144
145macro_rules! impl_middleware_fn_service {
146    ($($ext_type:ident),*) => {
147        impl<S, F, Fut, B, B2, $($ext_type),*> Transform<S, ServiceRequest> for MiddlewareFn<F, ($($ext_type),*,)>
148        where
149            S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
150            F: Fn($($ext_type),*, ServiceRequest, Next<B>) -> Fut + 'static,
151            $($ext_type: FromRequest + 'static,)*
152            Fut: Future<Output = Result<ServiceResponse<B2>, Error>> + 'static,
153            B: MessageBody + 'static,
154            B2: MessageBody + 'static,
155        {
156            type Response = ServiceResponse<B2>;
157            type Error = Error;
158            type Transform = MiddlewareFnService<F, B, ($($ext_type,)*)>;
159            type InitError = ();
160            type Future = Ready<Result<Self::Transform, Self::InitError>>;
161
162            fn new_transform(&self, service: S) -> Self::Future {
163                ready(Ok(MiddlewareFnService {
164                    service: boxed::rc_service(service),
165                    mw_fn: Rc::clone(&self.mw_fn),
166                    _phantom: PhantomData,
167                }))
168            }
169        }
170
171        impl<F, $($ext_type),*, Fut, B: 'static, B2> Service<ServiceRequest>
172            for MiddlewareFnService<F, B, ($($ext_type),*,)>
173        where
174            F: Fn(
175                $($ext_type),*,
176                ServiceRequest,
177                Next<B>
178            ) -> Fut + 'static,
179            $($ext_type: FromRequest + 'static,)*
180            Fut: Future<Output = Result<ServiceResponse<B2>, Error>> + 'static,
181            B2: MessageBody + 'static,
182        {
183            type Response = ServiceResponse<B2>;
184            type Error = Error;
185            type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
186
187            forward_ready!(service);
188
189            #[allow(nonstandard_style)]
190            fn call(&self, mut req: ServiceRequest) -> Self::Future {
191                let mw_fn = Rc::clone(&self.mw_fn);
192                let service = Rc::clone(&self.service);
193
194                Box::pin(async move {
195                    let ($($ext_type,)*) = req.extract::<($($ext_type,)*)>().await?;
196
197                    (mw_fn)($($ext_type),*, req, Next::<B> { service }).await
198                })
199            }
200        }
201    };
202}
203
204impl_middleware_fn_service!(E1);
205impl_middleware_fn_service!(E1, E2);
206impl_middleware_fn_service!(E1, E2, E3);
207impl_middleware_fn_service!(E1, E2, E3, E4);
208impl_middleware_fn_service!(E1, E2, E3, E4, E5);
209impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6);
210impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7);
211impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7, E8);
212impl_middleware_fn_service!(E1, E2, E3, E4, E5, E6, E7, E8, E9);
213
214/// Wraps the "next" service in the middleware chain.
215pub struct Next<B> {
216    service: RcService<ServiceRequest, ServiceResponse<B>, Error>,
217}
218
219impl<B> Next<B> {
220    /// Equivalent to `Service::call(self, req)`.
221    pub fn call(&self, req: ServiceRequest) -> <Self as Service<ServiceRequest>>::Future {
222        Service::call(self, req)
223    }
224}
225
226impl<B> Service<ServiceRequest> for Next<B> {
227    type Response = ServiceResponse<B>;
228    type Error = Error;
229    type Future = BoxFuture<Result<Self::Response, Self::Error>>;
230
231    forward_ready!(service);
232
233    fn call(&self, req: ServiceRequest) -> Self::Future {
234        self.service.call(req)
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use actix_web::{
241        http::header::{self, HeaderValue},
242        middleware::{Compat, Logger},
243        test, web, App, HttpResponse,
244    };
245
246    use super::*;
247
248    async fn noop<B>(req: ServiceRequest, next: Next<B>) -> Result<ServiceResponse<B>, Error> {
249        next.call(req).await
250    }
251
252    async fn add_res_header<B>(
253        req: ServiceRequest,
254        next: Next<B>,
255    ) -> Result<ServiceResponse<B>, Error> {
256        let mut res = next.call(req).await?;
257        res.headers_mut()
258            .insert(header::WARNING, HeaderValue::from_static("42"));
259        Ok(res)
260    }
261
262    async fn mutate_body_type(
263        req: ServiceRequest,
264        next: Next<impl MessageBody + 'static>,
265    ) -> Result<ServiceResponse<impl MessageBody>, Error> {
266        let res = next.call(req).await?;
267        Ok(res.map_into_left_body::<()>())
268    }
269
270    struct MyMw(bool);
271
272    impl MyMw {
273        async fn mw_cb(
274            &self,
275            req: ServiceRequest,
276            next: Next<impl MessageBody + 'static>,
277        ) -> Result<ServiceResponse<impl MessageBody>, Error> {
278            let mut res = match self.0 {
279                true => req.into_response("short-circuited").map_into_right_body(),
280                false => next.call(req).await?.map_into_left_body(),
281            };
282            res.headers_mut()
283                .insert(header::WARNING, HeaderValue::from_static("42"));
284            Ok(res)
285        }
286
287        pub fn into_middleware<S, B>(
288            self,
289        ) -> impl Transform<
290            S,
291            ServiceRequest,
292            Response = ServiceResponse<impl MessageBody>,
293            Error = Error,
294            InitError = (),
295        >
296        where
297            S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
298            B: MessageBody + 'static,
299        {
300            let this = Rc::new(self);
301            from_fn(move |req, next| {
302                let this = Rc::clone(&this);
303                async move { Self::mw_cb(&this, req, next).await }
304            })
305        }
306    }
307
308    #[actix_web::test]
309    async fn compat_compat() {
310        let _ = App::new().wrap(Compat::new(from_fn(noop)));
311        let _ = App::new().wrap(Compat::new(from_fn(mutate_body_type)));
312    }
313
314    #[actix_web::test]
315    async fn feels_good() {
316        let app = test::init_service(
317            App::new()
318                .wrap(from_fn(mutate_body_type))
319                .wrap(from_fn(add_res_header))
320                .wrap(Logger::default())
321                .wrap(from_fn(noop))
322                .default_service(web::to(HttpResponse::NotFound)),
323        )
324        .await;
325
326        let req = test::TestRequest::default().to_request();
327        let res = test::call_service(&app, req).await;
328        assert!(res.headers().contains_key(header::WARNING));
329    }
330
331    #[actix_web::test]
332    async fn closure_capture_and_return_from_fn() {
333        let app = test::init_service(
334            App::new()
335                .wrap(Logger::default())
336                .wrap(MyMw(true).into_middleware())
337                .wrap(Logger::default()),
338        )
339        .await;
340
341        let req = test::TestRequest::default().to_request();
342        let res = test::call_service(&app, req).await;
343        assert!(res.headers().contains_key(header::WARNING));
344    }
345}