1use std::{
4 future::Future,
5 marker::PhantomData,
6 pin::Pin,
7 rc::Rc,
8 sync::Arc,
9 task::{Context, Poll},
10};
11
12use actix_web::{
13 body::{EitherBody, MessageBody},
14 dev::{Service, ServiceRequest, ServiceResponse, Transform},
15 Error, FromRequest,
16};
17use futures_core::ready;
18use futures_util::future::{self, LocalBoxFuture, TryFutureExt as _};
19
20use crate::extractors::{basic, bearer};
21
22#[derive(Debug, Clone)]
30pub struct HttpAuthentication<T, F>
31where
32 T: FromRequest,
33{
34 process_fn: Arc<F>,
35 _extractor: PhantomData<T>,
36}
37
38impl<T, F, O> HttpAuthentication<T, F>
39where
40 T: FromRequest,
41 F: Fn(ServiceRequest, T) -> O,
42 O: Future<Output = Result<ServiceRequest, (Error, ServiceRequest)>>,
43{
44 pub fn with_fn(process_fn: F) -> HttpAuthentication<T, F> {
96 HttpAuthentication {
97 process_fn: Arc::new(process_fn),
98 _extractor: PhantomData,
99 }
100 }
101}
102
103impl<F, O> HttpAuthentication<basic::BasicAuth, F>
104where
105 F: Fn(ServiceRequest, basic::BasicAuth) -> O,
106 O: Future<Output = Result<ServiceRequest, (Error, ServiceRequest)>>,
107{
108 pub fn basic(process_fn: F) -> Self {
128 Self::with_fn(process_fn)
129 }
130}
131
132impl<F, O> HttpAuthentication<bearer::BearerAuth, F>
133where
134 F: Fn(ServiceRequest, bearer::BearerAuth) -> O,
135 O: Future<Output = Result<ServiceRequest, (Error, ServiceRequest)>>,
136{
137 pub fn bearer(process_fn: F) -> Self {
165 Self::with_fn(process_fn)
166 }
167}
168
169impl<S, B, T, F, O> Transform<S, ServiceRequest> for HttpAuthentication<T, F>
170where
171 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
172 S::Future: 'static,
173 F: Fn(ServiceRequest, T) -> O + 'static,
174 O: Future<Output = Result<ServiceRequest, (Error, ServiceRequest)>> + 'static,
175 T: FromRequest + 'static,
176 B: MessageBody + 'static,
177{
178 type Response = ServiceResponse<EitherBody<B>>;
179 type Error = Error;
180 type Transform = AuthenticationMiddleware<S, F, T>;
181 type InitError = ();
182 type Future = future::Ready<Result<Self::Transform, Self::InitError>>;
183
184 fn new_transform(&self, service: S) -> Self::Future {
185 future::ok(AuthenticationMiddleware {
186 service: Rc::new(service),
187 process_fn: self.process_fn.clone(),
188 _extractor: PhantomData,
189 })
190 }
191}
192
193#[doc(hidden)]
194pub struct AuthenticationMiddleware<S, F, T>
195where
196 T: FromRequest,
197{
198 service: Rc<S>,
199 process_fn: Arc<F>,
200 _extractor: PhantomData<T>,
201}
202
203impl<S, B, F, T, O> Service<ServiceRequest> for AuthenticationMiddleware<S, F, T>
204where
205 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
206 S::Future: 'static,
207 F: Fn(ServiceRequest, T) -> O + 'static,
208 O: Future<Output = Result<ServiceRequest, (Error, ServiceRequest)>> + 'static,
209 T: FromRequest + 'static,
210 B: MessageBody + 'static,
211{
212 type Response = ServiceResponse<EitherBody<B>>;
213 type Error = S::Error;
214 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
215
216 actix_web::dev::forward_ready!(service);
217
218 fn call(&self, req: ServiceRequest) -> Self::Future {
219 let process_fn = Arc::clone(&self.process_fn);
220 let service = Rc::clone(&self.service);
221
222 Box::pin(async move {
223 let (req, credentials) = match Extract::<T>::new(req).await {
224 Ok(req) => req,
225 Err((err, req)) => {
226 return Ok(req.error_response(err).map_into_right_body());
227 }
228 };
229
230 let req = match process_fn(req, credentials).await {
231 Ok(req) => req,
232 Err((err, req)) => {
233 return Ok(req.error_response(err).map_into_right_body());
234 }
235 };
236
237 service.call(req).await.map(|res| res.map_into_left_body())
238 })
239 }
240}
241
242struct Extract<T> {
243 req: Option<ServiceRequest>,
244 fut: Option<LocalBoxFuture<'static, Result<T, Error>>>,
245 _extractor: PhantomData<fn() -> T>,
246}
247
248impl<T> Extract<T> {
249 pub fn new(req: ServiceRequest) -> Self {
250 Extract {
251 req: Some(req),
252 fut: None,
253 _extractor: PhantomData,
254 }
255 }
256}
257
258impl<T> Future for Extract<T>
259where
260 T: FromRequest,
261 T::Future: 'static,
262 T::Error: 'static,
263{
264 type Output = Result<(ServiceRequest, T), (Error, ServiceRequest)>;
265
266 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
267 if self.fut.is_none() {
268 let req = self.req.as_mut().expect("Extract future was polled twice!");
269 let fut = req.extract::<T>().map_err(Into::into);
270 self.fut = Some(Box::pin(fut));
271 }
272
273 let fut = self
274 .fut
275 .as_mut()
276 .expect("Extraction future should be initialized at this point");
277
278 let credentials = ready!(fut.as_mut().poll(ctx)).map_err(|err| {
279 (
280 err,
281 self.req.take().expect("Extract future was polled twice!"),
283 )
284 })?;
285
286 let req = self.req.take().expect("Extract future was polled twice!");
287 Poll::Ready(Ok((req, credentials)))
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use actix_service::into_service;
294 use actix_web::{
295 error::{self, ErrorForbidden},
296 http::StatusCode,
297 test::TestRequest,
298 web, App, HttpResponse,
299 };
300
301 use super::*;
302 use crate::extractors::{basic::BasicAuth, bearer::BearerAuth};
303
304 #[actix_web::test]
306 async fn test_middleware_panic() {
307 let middleware = AuthenticationMiddleware {
308 service: Rc::new(into_service(|_: ServiceRequest| async move {
309 actix_web::rt::time::sleep(std::time::Duration::from_secs(1)).await;
310 Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
311 })),
312 process_fn: Arc::new(|req, _: BearerAuth| async { Ok(req) }),
313 _extractor: PhantomData,
314 };
315
316 let req = TestRequest::get()
317 .append_header(("Authorization", "Bearer 1"))
318 .to_srv_request();
319
320 let f = middleware.call(req).await;
321
322 let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await;
323
324 assert!(f.is_err());
325 }
326
327 #[actix_web::test]
329 async fn test_middleware_panic_several_orders() {
330 let middleware = AuthenticationMiddleware {
331 service: Rc::new(into_service(|_: ServiceRequest| async move {
332 actix_web::rt::time::sleep(std::time::Duration::from_secs(1)).await;
333 Err::<ServiceResponse, _>(error::ErrorBadRequest("error"))
334 })),
335 process_fn: Arc::new(|req, _: BearerAuth| async { Ok(req) }),
336 _extractor: PhantomData,
337 };
338
339 let req = TestRequest::get()
340 .append_header(("Authorization", "Bearer 1"))
341 .to_srv_request();
342
343 let f1 = middleware.call(req).await;
344
345 let req = TestRequest::get()
346 .append_header(("Authorization", "Bearer 1"))
347 .to_srv_request();
348
349 let f2 = middleware.call(req).await;
350
351 let req = TestRequest::get()
352 .append_header(("Authorization", "Bearer 1"))
353 .to_srv_request();
354
355 let f3 = middleware.call(req).await;
356
357 let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await;
358
359 assert!(f1.is_err());
360 assert!(f2.is_err());
361 assert!(f3.is_err());
362 }
363
364 #[actix_web::test]
365 async fn test_middleware_opt_extractor() {
366 let middleware = AuthenticationMiddleware {
367 service: Rc::new(into_service(|req: ServiceRequest| async move {
368 Ok::<ServiceResponse, _>(req.into_response(HttpResponse::Ok().finish()))
369 })),
370 process_fn: Arc::new(|req, auth: Option<BearerAuth>| {
371 assert!(auth.is_none());
372 async { Ok(req) }
373 }),
374 _extractor: PhantomData,
375 };
376
377 let req = TestRequest::get()
378 .append_header(("Authorization996", "Bearer 1"))
379 .to_srv_request();
380
381 let f = middleware.call(req).await;
382
383 let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await;
384
385 assert!(f.is_ok());
386 }
387
388 #[actix_web::test]
389 async fn test_middleware_res_extractor() {
390 let middleware = AuthenticationMiddleware {
391 service: Rc::new(into_service(|req: ServiceRequest| async move {
392 Ok::<ServiceResponse, _>(req.into_response(HttpResponse::Ok().finish()))
393 })),
394 process_fn: Arc::new(
395 |req, auth: Result<BearerAuth, <BearerAuth as FromRequest>::Error>| {
396 assert!(auth.is_err());
397 async { Ok(req) }
398 },
399 ),
400 _extractor: PhantomData,
401 };
402
403 let req = TestRequest::get()
404 .append_header(("Authorization", "BearerLOL"))
405 .to_srv_request();
406
407 let f = middleware.call(req).await;
408
409 let _res = futures_util::future::lazy(|cx| middleware.poll_ready(cx)).await;
410
411 assert!(f.is_ok());
412 }
413
414 #[actix_web::test]
415 async fn test_middleware_works_with_app() {
416 async fn validator(
417 req: ServiceRequest,
418 _credentials: BasicAuth,
419 ) -> Result<ServiceRequest, (actix_web::Error, ServiceRequest)> {
420 Err((ErrorForbidden("You are not welcome!"), req))
421 }
422 let middleware = HttpAuthentication::basic(validator);
423
424 let srv = actix_web::test::init_service(
425 App::new()
426 .wrap(middleware)
427 .route("/", web::get().to(HttpResponse::Ok)),
428 )
429 .await;
430
431 let req = actix_web::test::TestRequest::with_uri("/")
432 .append_header(("Authorization", "Basic DontCare"))
433 .to_request();
434
435 let resp = srv.call(req).await.unwrap();
436 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
437 }
438
439 #[actix_web::test]
440 async fn test_middleware_works_with_scope() {
441 async fn validator(
442 req: ServiceRequest,
443 _credentials: BasicAuth,
444 ) -> Result<ServiceRequest, (actix_web::Error, ServiceRequest)> {
445 Err((ErrorForbidden("You are not welcome!"), req))
446 }
447 let middleware = actix_web::middleware::Compat::new(HttpAuthentication::basic(validator));
448
449 let srv = actix_web::test::init_service(
450 App::new().service(
451 web::scope("/")
452 .wrap(middleware)
453 .route("/", web::get().to(HttpResponse::Ok)),
454 ),
455 )
456 .await;
457
458 let req = actix_web::test::TestRequest::with_uri("/")
459 .append_header(("Authorization", "Basic DontCare"))
460 .to_request();
461
462 let resp = srv.call(req).await.unwrap();
463 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
464 }
465}