actix_web_httpauth/
middleware.rs

1//! HTTP Authentication middleware.
2
3use std::{
4    future::Future,
5    marker::PhantomData,
6    pin::Pin,
7    rc::Rc,
8    sync::Arc,
9    task::{Context, Poll},
10};
11
12use actix_web::{
13    body::{EitherBody, MessageBody},
14    dev::{Service, ServiceRequest, ServiceResponse, Transform},
15    Error, FromRequest,
16};
17use futures_core::ready;
18use futures_util::future::{self, LocalBoxFuture, TryFutureExt as _};
19
20use crate::extractors::{basic, bearer};
21
22/// Middleware for checking HTTP authentication.
23///
24/// If there is no `Authorization` header in the request, this middleware returns an error
25/// immediately, without calling the `F` callback.
26///
27/// Otherwise, it will pass both the request and the parsed credentials into it. In case of
28/// successful validation `F` callback is required to return the `ServiceRequest` back.
29#[derive(Debug, Clone)]
30pub struct HttpAuthentication<T, F>
31where
32    T: FromRequest,
33{
34    process_fn: Arc<F>,
35    _extractor: PhantomData<T>,
36}
37
38impl<T, F, O> HttpAuthentication<T, F>
39where
40    T: FromRequest,
41    F: Fn(ServiceRequest, T) -> O,
42    O: Future<Output = Result<ServiceRequest, (Error, ServiceRequest)>>,
43{
44    /// Construct `HttpAuthentication` middleware with the provided auth extractor `T` and
45    /// validation callback `F`.
46    ///
47    /// This function can be used to implement optional authentication and/or custom responses to
48    /// missing authentication.
49    ///
50    /// # Examples
51    ///
52    /// ## Required Basic Auth
53    ///
54    /// ```no_run
55    /// # use actix_web_httpauth::extractors::basic::BasicAuth;
56    /// # use actix_web::dev::ServiceRequest;
57    /// async fn validator(
58    ///     req: ServiceRequest,
59    ///     credentials: BasicAuth,
60    /// ) -> Result<ServiceRequest, (actix_web::Error, ServiceRequest)> {
61    ///     eprintln!("{credentials:?}");
62    ///
63    ///     if credentials.user_id().contains('x') {
64    ///         return Err((actix_web::error::ErrorBadRequest("user ID contains x"), req));
65    ///     }
66    ///
67    ///     Ok(req)
68    /// }
69    /// # actix_web_httpauth::middleware::HttpAuthentication::with_fn(validator);
70    /// ```
71    ///
72    /// ## Optional Bearer Auth
73    ///
74    /// ```no_run
75    /// # use actix_web_httpauth::extractors::bearer::BearerAuth;
76    /// # use actix_web::dev::ServiceRequest;
77    /// async fn validator(
78    ///     req: ServiceRequest,
79    ///     credentials: Option<BearerAuth>,
80    /// ) -> Result<ServiceRequest, (actix_web::Error, ServiceRequest)> {
81    ///     let Some(credentials) = credentials else {
82    ///         return Err((actix_web::error::ErrorBadRequest("no bearer header"), req));
83    ///     };
84    ///
85    ///     eprintln!("{credentials:?}");
86    ///
87    ///     if credentials.token().contains('x') {
88    ///         return Err((actix_web::error::ErrorBadRequest("token contains x"), req));
89    ///     }
90    ///
91    ///     Ok(req)
92    /// }
93    /// # actix_web_httpauth::middleware::HttpAuthentication::with_fn(validator);
94    /// ```
95    pub fn with_fn(process_fn: F) -> HttpAuthentication<T, F> {
96        HttpAuthentication {
97            process_fn: Arc::new(process_fn),
98            _extractor: PhantomData,
99        }
100    }
101}
102
103impl<F, O> HttpAuthentication<basic::BasicAuth, F>
104where
105    F: Fn(ServiceRequest, basic::BasicAuth) -> O,
106    O: Future<Output = Result<ServiceRequest, (Error, ServiceRequest)>>,
107{
108    /// Construct `HttpAuthentication` middleware for the HTTP "Basic" authentication scheme.
109    ///
110    /// # Examples
111    /// ```
112    /// # use actix_web::{Error, dev::ServiceRequest};
113    /// # use actix_web_httpauth::{extractors::basic::BasicAuth, middleware::HttpAuthentication};
114    /// // In this example validator returns immediately, but since it is required to return
115    /// // anything that implements `IntoFuture` trait, it can be extended to query database or to
116    /// // do something else in a async manner.
117    /// async fn validator(
118    ///     req: ServiceRequest,
119    ///     credentials: BasicAuth,
120    /// ) -> Result<ServiceRequest, (Error, ServiceRequest)> {
121    ///     // All users are great and more than welcome!
122    ///     Ok(req)
123    /// }
124    ///
125    /// let middleware = HttpAuthentication::basic(validator);
126    /// ```
127    pub fn basic(process_fn: F) -> Self {
128        Self::with_fn(process_fn)
129    }
130}
131
132impl<F, O> HttpAuthentication<bearer::BearerAuth, F>
133where
134    F: Fn(ServiceRequest, bearer::BearerAuth) -> O,
135    O: Future<Output = Result<ServiceRequest, (Error, ServiceRequest)>>,
136{
137    /// Construct `HttpAuthentication` middleware for the HTTP "Bearer" authentication scheme.
138    ///
139    /// # Examples
140    /// ```
141    /// # use actix_web::{Error, dev::ServiceRequest};
142    /// # use actix_web_httpauth::{
143    /// #     extractors::{AuthenticationError, AuthExtractorConfig, bearer::{self, BearerAuth}},
144    /// #     middleware::HttpAuthentication,
145    /// # };
146    /// async fn validator(
147    ///     req: ServiceRequest,
148    ///     credentials: BearerAuth
149    /// ) -> Result<ServiceRequest, (Error, ServiceRequest)> {
150    ///     if credentials.token() == "mF_9.B5f-4.1JqM" {
151    ///         Ok(req)
152    ///     } else {
153    ///         let config = req.app_data::<bearer::Config>()
154    ///             .cloned()
155    ///             .unwrap_or_default()
156    ///             .scope("urn:example:channel=HBO&urn:example:rating=G,PG-13");
157    ///
158    ///         Err((AuthenticationError::from(config).into(), req))
159    ///     }
160    /// }
161    ///
162    /// let middleware = HttpAuthentication::bearer(validator);
163    /// ```
164    pub fn bearer(process_fn: F) -> Self {
165        Self::with_fn(process_fn)
166    }
167}
168
169impl<S, B, T, F, O> Transform<S, ServiceRequest> for HttpAuthentication<T, F>
170where
171    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
172    S::Future: 'static,
173    F: Fn(ServiceRequest, T) -> O + 'static,
174    O: Future<Output = Result<ServiceRequest, (Error, ServiceRequest)>> + 'static,
175    T: FromRequest + 'static,
176    B: MessageBody + 'static,
177{
178    type Response = ServiceResponse<EitherBody<B>>;
179    type Error = Error;
180    type Transform = AuthenticationMiddleware<S, F, T>;
181    type InitError = ();
182    type Future = future::Ready<Result<Self::Transform, Self::InitError>>;
183
184    fn new_transform(&self, service: S) -> Self::Future {
185        future::ok(AuthenticationMiddleware {
186            service: Rc::new(service),
187            process_fn: self.process_fn.clone(),
188            _extractor: PhantomData,
189        })
190    }
191}
192
193#[doc(hidden)]
194pub struct AuthenticationMiddleware<S, F, T>
195where
196    T: FromRequest,
197{
198    service: Rc<S>,
199    process_fn: Arc<F>,
200    _extractor: PhantomData<T>,
201}
202
203impl<S, B, F, T, O> Service<ServiceRequest> for AuthenticationMiddleware<S, F, T>
204where
205    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
206    S::Future: 'static,
207    F: Fn(ServiceRequest, T) -> O + 'static,
208    O: Future<Output = Result<ServiceRequest, (Error, ServiceRequest)>> + 'static,
209    T: FromRequest + 'static,
210    B: MessageBody + 'static,
211{
212    type Response = ServiceResponse<EitherBody<B>>;
213    type Error = S::Error;
214    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
215
216    actix_web::dev::forward_ready!(service);
217
218    fn call(&self, req: ServiceRequest) -> Self::Future {
219        let process_fn = Arc::clone(&self.process_fn);
220        let service = Rc::clone(&self.service);
221
222        Box::pin(async move {
223            let (req, credentials) = match Extract::<T>::new(req).await {
224                Ok(req) => req,
225                Err((err, req)) => {
226                    return Ok(req.error_response(err).map_into_right_body());
227                }
228            };
229
230            let req = match process_fn(req, credentials).await {
231                Ok(req) => req,
232                Err((err, req)) => {
233                    return Ok(req.error_response(err).map_into_right_body());
234                }
235            };
236
237            service.call(req).await.map(|res| res.map_into_left_body())
238        })
239    }
240}
241
242struct Extract<T> {
243    req: Option<ServiceRequest>,
244    fut: Option<LocalBoxFuture<'static, Result<T, Error>>>,
245    _extractor: PhantomData<fn() -> T>,
246}
247
248impl<T> Extract<T> {
249    pub fn new(req: ServiceRequest) -> Self {
250        Extract {
251            req: Some(req),
252            fut: None,
253            _extractor: PhantomData,
254        }
255    }
256}
257
258impl<T> Future for Extract<T>
259where
260    T: FromRequest,
261    T::Future: 'static,
262    T::Error: 'static,
263{
264    type Output = Result<(ServiceRequest, T), (Error, ServiceRequest)>;
265
266    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
267        if self.fut.is_none() {
268            let req = self.req.as_mut().expect("Extract future was polled twice!");
269            let fut = req.extract::<T>().map_err(Into::into);
270            self.fut = Some(Box::pin(fut));
271        }
272
273        let fut = self
274            .fut
275            .as_mut()
276            .expect("Extraction future should be initialized at this point");
277
278        let credentials = ready!(fut.as_mut().poll(ctx)).map_err(|err| {
279            (
280                err,
281                // returning request allows a proper error response to be created
282                self.req.take().expect("Extract future was polled twice!"),
283            )
284        })?;
285
286        let req = self.req.take().expect("Extract future was polled twice!");
287        Poll::Ready(Ok((req, credentials)))
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use actix_service::into_service;
294    use actix_web::{
295        error::{self, ErrorForbidden},
296        http::StatusCode,
297        test::TestRequest,
298        web, App, HttpResponse,
299    };
300
301    use super::*;
302    use crate::extractors::{basic::BasicAuth, bearer::BearerAuth};
303
304    /// This is a test for https://github.com/actix/actix-extras/issues/10
305    #[actix_web::test]
306    async fn test_middleware_panic() {
307        let middleware = AuthenticationMiddleware {
308            service: Rc::new(into_service(|_: ServiceRequest| async move {
309                actix_web::rt::time::sleep(std::time::Duration::from_secs(1)).await;
310                Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
311            })),
312            process_fn: Arc::new(|req, _: BearerAuth| async { Ok(req) }),
313            _extractor: PhantomData,
314        };
315
316        let req = TestRequest::get()
317            .append_header(("Authorization", "Bearer 1"))
318            .to_srv_request();
319
320        let f = middleware.call(req).await;
321
322        let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await;
323
324        assert!(f.is_err());
325    }
326
327    /// This is a test for https://github.com/actix/actix-extras/issues/10
328    #[actix_web::test]
329    async fn test_middleware_panic_several_orders() {
330        let middleware = AuthenticationMiddleware {
331            service: Rc::new(into_service(|_: ServiceRequest| async move {
332                actix_web::rt::time::sleep(std::time::Duration::from_secs(1)).await;
333                Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
334            })),
335            process_fn: Arc::new(|req, _: BearerAuth| async { Ok(req) }),
336            _extractor: PhantomData,
337        };
338
339        let req = TestRequest::get()
340            .append_header(("Authorization", "Bearer 1"))
341            .to_srv_request();
342
343        let f1 = middleware.call(req).await;
344
345        let req = TestRequest::get()
346            .append_header(("Authorization", "Bearer 1"))
347            .to_srv_request();
348
349        let f2 = middleware.call(req).await;
350
351        let req = TestRequest::get()
352            .append_header(("Authorization", "Bearer 1"))
353            .to_srv_request();
354
355        let f3 = middleware.call(req).await;
356
357        let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await;
358
359        assert!(f1.is_err());
360        assert!(f2.is_err());
361        assert!(f3.is_err());
362    }
363
364    #[actix_web::test]
365    async fn test_middleware_opt_extractor() {
366        let middleware = AuthenticationMiddleware {
367            service: Rc::new(into_service(|req: ServiceRequest| async move {
368                Ok::<ServiceResponse, _>(req.into_response(HttpResponse::Ok().finish()))
369            })),
370            process_fn: Arc::new(|req, auth: Option<BearerAuth>| {
371                assert!(auth.is_none());
372                async { Ok(req) }
373            }),
374            _extractor: PhantomData,
375        };
376
377        let req = TestRequest::get()
378            .append_header(("Authorization996", "Bearer 1"))
379            .to_srv_request();
380
381        let f = middleware.call(req).await;
382
383        let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await;
384
385        assert!(f.is_ok());
386    }
387
388    #[actix_web::test]
389    async fn test_middleware_res_extractor() {
390        let middleware = AuthenticationMiddleware {
391            service: Rc::new(into_service(|req: ServiceRequest| async move {
392                Ok::<ServiceResponse, _>(req.into_response(HttpResponse::Ok().finish()))
393            })),
394            process_fn: Arc::new(
395                |req, auth: Result<BearerAuth, <BearerAuth as FromRequest>::Error>| {
396                    assert!(auth.is_err());
397                    async { Ok(req) }
398                },
399            ),
400            _extractor: PhantomData,
401        };
402
403        let req = TestRequest::get()
404            .append_header(("Authorization", "BearerLOL"))
405            .to_srv_request();
406
407        let f = middleware.call(req).await;
408
409        let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await;
410
411        assert!(f.is_ok());
412    }
413
414    #[actix_web::test]
415    async fn test_middleware_works_with_app() {
416        async fn validator(
417            req: ServiceRequest,
418            _credentials: BasicAuth,
419        ) -> Result<ServiceRequest, (actix_web::Error, ServiceRequest)> {
420            Err((ErrorForbidden("You are not welcome!"), req))
421        }
422        let middleware = HttpAuthentication::basic(validator);
423
424        let srv = actix_web::test::init_service(
425            App::new()
426                .wrap(middleware)
427                .route("/", web::get().to(HttpResponse::Ok)),
428        )
429        .await;
430
431        let req = actix_web::test::TestRequest::with_uri("/")
432            .append_header(("Authorization", "Basic DontCare"))
433            .to_request();
434
435        let resp = srv.call(req).await.unwrap();
436        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
437    }
438
439    #[actix_web::test]
440    async fn test_middleware_works_with_scope() {
441        async fn validator(
442            req: ServiceRequest,
443            _credentials: BasicAuth,
444        ) -> Result<ServiceRequest, (actix_web::Error, ServiceRequest)> {
445            Err((ErrorForbidden("You are not welcome!"), req))
446        }
447        let middleware = actix_web::middleware::Compat::new(HttpAuthentication::basic(validator));
448
449        let srv = actix_web::test::init_service(
450            App::new().service(
451                web::scope("/")
452                    .wrap(middleware)
453                    .route("/", web::get().to(HttpResponse::Ok)),
454            ),
455        )
456        .await;
457
458        let req = actix_web::test::TestRequest::with_uri("/")
459            .append_header(("Authorization", "Basic DontCare"))
460            .to_request();
461
462        let resp = srv.call(req).await.unwrap();
463        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
464    }
465}