1use std::{
4 future::Future,
5 marker::PhantomData,
6 pin::Pin,
7 task::{ready, Context, Poll},
8};
9
10use actix_service::{Service, Transform};
11use actix_utils::future::{ready, Ready};
12use actix_web::{
13 body::EitherBody,
14 dev::{ServiceRequest, ServiceResponse},
15 http::{
16 header,
17 uri::{PathAndQuery, Uri},
18 StatusCode,
19 },
20 middleware::TrailingSlash,
21 Error, HttpResponse,
22};
23use bytes::Bytes;
24use pin_project_lite::pin_project;
25use regex::Regex;
26
27#[derive(Debug, Clone, Copy)]
76pub struct NormalizePath {
77 trailing_slash_behavior: TrailingSlash,
79
80 use_redirects: Option<StatusCode>,
82}
83
84impl Default for NormalizePath {
85 fn default() -> Self {
86 Self {
87 trailing_slash_behavior: TrailingSlash::Trim,
88 use_redirects: None,
89 }
90 }
91}
92
93impl NormalizePath {
94 pub fn new(behavior: TrailingSlash) -> Self {
96 Self {
97 trailing_slash_behavior: behavior,
98 use_redirects: None,
99 }
100 }
101
102 pub fn trim() -> Self {
106 Self::new(TrailingSlash::Trim)
107 }
108
109 pub fn use_redirects(mut self) -> Self {
118 self.use_redirects = Some(StatusCode::TEMPORARY_REDIRECT);
119 self
120 }
121
122 pub fn use_redirects_with(mut self, status_code: StatusCode) -> Self {
130 assert!(status_code.is_redirection());
131 self.use_redirects = Some(status_code);
132 self
133 }
134}
135
136impl<S, B> Transform<S, ServiceRequest> for NormalizePath
137where
138 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
139 S::Future: 'static,
140{
141 type Response = ServiceResponse<EitherBody<B, ()>>;
142 type Error = Error;
143 type Transform = NormalizePathService<S>;
144 type InitError = ();
145 type Future = Ready<Result<Self::Transform, Self::InitError>>;
146
147 fn new_transform(&self, service: S) -> Self::Future {
148 ready(Ok(NormalizePathService {
149 service,
150 merge_slash: Regex::new("//+").unwrap(),
151 trailing_slash_behavior: self.trailing_slash_behavior,
152 use_redirects: self.use_redirects,
153 }))
154 }
155}
156
157pub struct NormalizePathService<S> {
158 service: S,
159 merge_slash: Regex,
160 trailing_slash_behavior: TrailingSlash,
161 use_redirects: Option<StatusCode>,
162}
163
164impl<S, B> Service<ServiceRequest> for NormalizePathService<S>
165where
166 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
167 S::Future: 'static,
168{
169 type Response = ServiceResponse<EitherBody<B, ()>>;
170 type Error = Error;
171 type Future = NormalizePathFuture<S, B>;
172
173 actix_service::forward_ready!(service);
174
175 fn call(&self, mut req: ServiceRequest) -> Self::Future {
176 let head = req.head_mut();
177
178 let mut path_altered = false;
179 let original_path = head.uri.path();
180
181 if !original_path.is_empty() {
184 let path = match self.trailing_slash_behavior {
187 TrailingSlash::Always => format!("{original_path}/"),
188 TrailingSlash::MergeOnly => original_path.to_string(),
189 TrailingSlash::Trim => original_path.trim_end_matches('/').to_string(),
190 ts_behavior => panic!("unknown trailing slash behavior: {ts_behavior:?}"),
191 };
192
193 let path = self.merge_slash.replace_all(&path, "/");
195
196 let path = if path.is_empty() { "/" } else { path.as_ref() };
199
200 if path != original_path {
212 let mut parts = head.uri.clone().into_parts();
213 let query = parts.path_and_query.as_ref().and_then(|pq| pq.query());
214
215 let path = match query {
216 Some(query) => Bytes::from(format!("{path}?{query}")),
217 None => Bytes::copy_from_slice(path.as_bytes()),
218 };
219 parts.path_and_query = Some(PathAndQuery::from_maybe_shared(path).unwrap());
220
221 let uri = Uri::from_parts(parts).unwrap();
222 req.match_info_mut().get_mut().update(&uri);
223 req.head_mut().uri = uri;
224
225 path_altered = true;
226 }
227 }
228
229 match self.use_redirects {
230 Some(code) if path_altered => {
231 let mut res = HttpResponse::with_body(code, ());
232 res.headers_mut().insert(
233 header::LOCATION,
234 req.head_mut().uri.to_string().parse().unwrap(),
235 );
236 NormalizePathFuture::redirect(req.into_response(res))
237 }
238
239 _ => NormalizePathFuture::service(self.service.call(req)),
240 }
241 }
242}
243
244pin_project! {
245 pub struct NormalizePathFuture<S: Service<ServiceRequest>, B> {
246 #[pin] inner: Inner<S, B>,
247 }
248}
249
250impl<S: Service<ServiceRequest>, B> NormalizePathFuture<S, B> {
251 fn service(fut: S::Future) -> Self {
252 Self {
253 inner: Inner::Service {
254 fut,
255 _body: PhantomData,
256 },
257 }
258 }
259
260 fn redirect(res: ServiceResponse<()>) -> Self {
261 Self {
262 inner: Inner::Redirect { res: Some(res) },
263 }
264 }
265}
266
267pin_project! {
268 #[project = InnerProj]
269 enum Inner<S: Service<ServiceRequest>, B> {
270 Redirect { res: Option<ServiceResponse<()>>, },
271 Service {
272 #[pin] fut: S::Future,
273 _body: PhantomData<B>,
274 },
275 }
276}
277
278impl<S, B> Future for NormalizePathFuture<S, B>
279where
280 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
281{
282 type Output = Result<ServiceResponse<EitherBody<B, ()>>, Error>;
283
284 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
285 let this = self.project();
286
287 match this.inner.project() {
288 InnerProj::Redirect { res } => {
289 Poll::Ready(Ok(res.take().unwrap().map_into_right_body()))
290 }
291
292 InnerProj::Service { fut, .. } => {
293 let res = ready!(fut.poll(cx))?;
294 Poll::Ready(Ok(res.map_into_left_body()))
295 }
296 }
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use actix_service::IntoService;
303 use actix_web::{
304 dev::ServiceRequest,
305 guard::fn_guard,
306 test::{self, call_service, init_service, TestRequest},
307 web, App, HttpRequest, HttpResponse,
308 };
309
310 use super::*;
311
312 #[actix_web::test]
313 async fn default_is_trim_no_redirect() {
314 let app = init_service(App::new().wrap(NormalizePath::default()).service(
315 web::resource("/test").to(|req: HttpRequest| async move { req.path().to_owned() }),
316 ))
317 .await;
318
319 let req = TestRequest::with_uri("/test/").to_request();
320 let res = call_service(&app, req).await;
321 assert!(res.status().is_success());
322 assert_eq!(test::read_body(res).await, "/test");
323 }
324
325 #[actix_web::test]
326 async fn trim_trailing_slashes() {
327 let app = init_service(
328 App::new()
329 .wrap(NormalizePath::trim())
330 .service(web::resource("/").to(HttpResponse::Ok))
331 .service(web::resource("/v1/something").to(HttpResponse::Ok))
332 .service(
333 web::resource("/v2/something")
334 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
335 .to(HttpResponse::Ok),
336 ),
337 )
338 .await;
339
340 let test_uris = vec![
341 "/",
342 "/?query=test",
343 "///",
344 "/v1//something",
345 "/v1//something////",
346 "//v1/something",
347 "//v1//////something",
348 "/v2//something?query=test",
349 "/v2//something////?query=test",
350 "//v2/something?query=test",
351 "//v2//////something?query=test",
352 ];
353
354 for uri in test_uris {
355 let req = TestRequest::with_uri(uri).to_request();
356 let res = call_service(&app, req).await;
357 assert!(res.status().is_success(), "Failed uri: {uri}");
358 }
359 }
360
361 #[actix_web::test]
362 async fn always_trailing_slashes() {
363 let app = init_service(
364 App::new()
365 .wrap(NormalizePath::new(TrailingSlash::Always))
366 .service(web::resource("/").to(HttpResponse::Ok))
367 .service(web::resource("/v1/something/").to(HttpResponse::Ok))
368 .service(
369 web::resource("/v2/something/")
370 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
371 .to(HttpResponse::Ok),
372 ),
373 )
374 .await;
375
376 let test_uris = vec![
377 "/",
378 "///",
379 "/v1/something",
380 "/v1/something/",
381 "/v1/something////",
382 "//v1//something",
383 "//v1//something//",
384 "/v2/something?query=test",
385 "/v2/something/?query=test",
386 "/v2/something////?query=test",
387 "//v2//something?query=test",
388 "//v2//something//?query=test",
389 ];
390
391 for uri in test_uris {
392 let req = TestRequest::with_uri(uri).to_request();
393 let res = call_service(&app, req).await;
394 assert!(res.status().is_success(), "Failed uri: {uri}");
395 }
396 }
397
398 #[actix_web::test]
399 async fn trim_root_trailing_slashes_with_query() {
400 let app = init_service(
401 App::new()
402 .wrap(NormalizePath::new(TrailingSlash::Trim))
403 .service(
404 web::resource("/")
405 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
406 .to(HttpResponse::Ok),
407 ),
408 )
409 .await;
410
411 let test_uris = vec!["/?query=test", "//?query=test", "///?query=test"];
412
413 for uri in test_uris {
414 let req = TestRequest::with_uri(uri).to_request();
415 let res = call_service(&app, req).await;
416 assert!(res.status().is_success(), "Failed uri: {uri}");
417 }
418 }
419
420 #[actix_web::test]
421 async fn ensure_trailing_slash() {
422 let app = init_service(
423 App::new()
424 .wrap(NormalizePath::new(TrailingSlash::Always))
425 .service(web::resource("/").to(HttpResponse::Ok))
426 .service(web::resource("/v1/something/").to(HttpResponse::Ok))
427 .service(
428 web::resource("/v2/something/")
429 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
430 .to(HttpResponse::Ok),
431 ),
432 )
433 .await;
434
435 let test_uris = vec![
436 "/",
437 "///",
438 "/v1/something",
439 "/v1/something/",
440 "/v1/something////",
441 "//v1//something",
442 "//v1//something//",
443 "/v2/something?query=test",
444 "/v2/something/?query=test",
445 "/v2/something////?query=test",
446 "//v2//something?query=test",
447 "//v2//something//?query=test",
448 ];
449
450 for uri in test_uris {
451 let req = TestRequest::with_uri(uri).to_request();
452 let res = call_service(&app, req).await;
453 assert!(res.status().is_success(), "Failed uri: {uri}");
454 }
455 }
456
457 #[actix_web::test]
458 async fn ensure_root_trailing_slash_with_query() {
459 let app = init_service(
460 App::new()
461 .wrap(NormalizePath::new(TrailingSlash::Always))
462 .service(
463 web::resource("/")
464 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
465 .to(HttpResponse::Ok),
466 ),
467 )
468 .await;
469
470 let test_uris = vec!["/?query=test", "//?query=test", "///?query=test"];
471
472 for uri in test_uris {
473 let req = TestRequest::with_uri(uri).to_request();
474 let res = call_service(&app, req).await;
475 assert!(res.status().is_success(), "Failed uri: {uri}");
476 }
477 }
478
479 #[actix_web::test]
480 async fn keep_trailing_slash_unchanged() {
481 let app = init_service(
482 App::new()
483 .wrap(NormalizePath::new(TrailingSlash::MergeOnly))
484 .service(web::resource("/").to(HttpResponse::Ok))
485 .service(web::resource("/v1/something").to(HttpResponse::Ok))
486 .service(web::resource("/v1/").to(HttpResponse::Ok))
487 .service(
488 web::resource("/v2/something")
489 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
490 .to(HttpResponse::Ok),
491 ),
492 )
493 .await;
494
495 let tests = vec![
496 ("/", true), ("/?query=test", true),
498 ("///", true),
499 ("/v1/something////", false),
500 ("/v1/something/", false),
501 ("//v1//something", true),
502 ("/v1/", true),
503 ("/v1", false),
504 ("/v1////", true),
505 ("//v1//", true),
506 ("///v1", false),
507 ("/v2/something?query=test", true),
508 ("/v2/something/?query=test", false),
509 ("/v2/something//?query=test", false),
510 ("//v2//something?query=test", true),
511 ];
512
513 for (uri, success) in tests {
514 let req = TestRequest::with_uri(uri).to_request();
515 let res = call_service(&app, req).await;
516 assert_eq!(res.status().is_success(), success, "Failed uri: {uri}");
517 }
518 }
519
520 #[actix_web::test]
521 async fn no_path() {
522 let app = init_service(
523 App::new()
524 .wrap(NormalizePath::default())
525 .service(web::resource("/").to(HttpResponse::Ok)),
526 )
527 .await;
528
529 let req = TestRequest::with_uri("eh").to_request();
532 let res = call_service(&app, req).await;
533 assert_eq!(res.status(), StatusCode::NOT_FOUND);
534 }
535
536 #[actix_web::test]
537 async fn test_in_place_normalization() {
538 let srv = |req: ServiceRequest| {
539 assert_eq!("/v1/something", req.path());
540 ready(Ok(req.into_response(HttpResponse::Ok().finish())))
541 };
542
543 let normalize = NormalizePath::default()
544 .new_transform(srv.into_service())
545 .await
546 .unwrap();
547
548 let test_uris = vec![
549 "/v1//something////",
550 "///v1/something",
551 "//v1///something",
552 "/v1//something",
553 ];
554
555 for uri in test_uris {
556 let req = TestRequest::with_uri(uri).to_srv_request();
557 let res = normalize.call(req).await.unwrap();
558 assert!(res.status().is_success(), "Failed uri: {uri}");
559 }
560 }
561
562 #[actix_web::test]
563 async fn should_normalize_nothing() {
564 const URI: &str = "/v1/something";
565
566 let srv = |req: ServiceRequest| {
567 assert_eq!(URI, req.path());
568 ready(Ok(req.into_response(HttpResponse::Ok().finish())))
569 };
570
571 let normalize = NormalizePath::default()
572 .new_transform(srv.into_service())
573 .await
574 .unwrap();
575
576 let req = TestRequest::with_uri(URI).to_srv_request();
577 let res = normalize.call(req).await.unwrap();
578 assert!(res.status().is_success());
579 }
580
581 #[actix_web::test]
582 async fn should_normalize_no_trail() {
583 let srv = |req: ServiceRequest| {
584 assert_eq!("/v1/something", req.path());
585 ready(Ok(req.into_response(HttpResponse::Ok().finish())))
586 };
587
588 let normalize = NormalizePath::default()
589 .new_transform(srv.into_service())
590 .await
591 .unwrap();
592
593 let req = TestRequest::with_uri("/v1/something/").to_srv_request();
594 let res = normalize.call(req).await.unwrap();
595 assert!(res.status().is_success());
596 }
597
598 #[actix_web::test]
599 async fn should_return_redirects_when_configured() {
600 let normalize = NormalizePath::trim()
601 .use_redirects()
602 .new_transform(test::ok_service())
603 .await
604 .unwrap();
605
606 let req = TestRequest::with_uri("/v1/something/").to_srv_request();
607 let res = normalize.call(req).await.unwrap();
608 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
609
610 let normalize = NormalizePath::trim()
611 .use_redirects_with(StatusCode::PERMANENT_REDIRECT)
612 .new_transform(test::ok_service())
613 .await
614 .unwrap();
615
616 let req = TestRequest::with_uri("/v1/something/").to_srv_request();
617 let res = normalize.call(req).await.unwrap();
618 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
619 }
620
621 #[actix_web::test]
622 async fn trim_with_redirect() {
623 let app = init_service(
624 App::new()
625 .wrap(NormalizePath::trim().use_redirects())
626 .service(web::resource("/").to(HttpResponse::Ok))
627 .service(web::resource("/v1/something").to(HttpResponse::Ok))
628 .service(
629 web::resource("/v2/something")
630 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
631 .to(HttpResponse::Ok),
632 ),
633 )
634 .await;
635
636 let test_uris = vec![
638 ("/", false),
639 ("///", true),
640 ("/v1/something", false),
641 ("/v1/something/", true),
642 ("/v1/something////", true),
643 ("//v1//something", true),
644 ("//v1//something//", true),
645 ("/v2/something?query=test", false),
646 ("/v2/something/?query=test", true),
647 ("/v2/something////?query=test", true),
648 ("//v2//something?query=test", true),
649 ("//v2//something//?query=test", true),
650 ];
651
652 for (uri, should_redirect) in test_uris {
653 let req = TestRequest::with_uri(uri).to_request();
654 let res = call_service(&app, req).await;
655
656 if should_redirect {
657 assert!(res.status().is_redirection(), "URI did not redirect: {uri}");
658 } else {
659 assert!(res.status().is_success(), "Failed URI: {uri}");
660 }
661 }
662 }
663}