actix_web_lab/
normalize_path.rs

1//! For middleware documentation, see [`NormalizePath`].
2
3use std::{
4    future::Future,
5    marker::PhantomData,
6    pin::Pin,
7    task::{ready, Context, Poll},
8};
9
10use actix_service::{Service, Transform};
11use actix_utils::future::{ready, Ready};
12use actix_web::{
13    body::EitherBody,
14    dev::{ServiceRequest, ServiceResponse},
15    http::{
16        header,
17        uri::{PathAndQuery, Uri},
18        StatusCode,
19    },
20    middleware::TrailingSlash,
21    Error, HttpResponse,
22};
23use bytes::Bytes;
24use pin_project_lite::pin_project;
25use regex::Regex;
26
27/// Middleware for normalizing a request's path so that routes can be matched more flexibly.
28///
29/// # Normalization Steps
30/// - Merges consecutive slashes into one. (For example, `/path//one` always becomes `/path/one`.)
31/// - Appends a trailing slash if one is not present, removes one if present, or keeps trailing
32///   slashes as-is, depending on which [`TrailingSlash`] variant is supplied
33///   to [`new`](NormalizePath::new()).
34///
35/// # Default Behavior
36/// The default constructor chooses to strip trailing slashes from the end of paths with them
37/// ([`TrailingSlash::Trim`]). The implication is that route definitions should be defined without
38/// trailing slashes or else they will be inaccessible (or vice versa when using the
39/// `TrailingSlash::Always` behavior), as shown in the example tests below.
40///
41/// # Examples
42/// ```
43/// use actix_web::{middleware, web, App};
44///
45/// # actix_web::rt::System::new().block_on(async {
46/// let app = App::new()
47///     .wrap(middleware::NormalizePath::trim())
48///     .route("/test", web::get().to(|| async { "test" }))
49///     .route("/unmatchable/", web::get().to(|| async { "unmatchable" }));
50///
51/// use actix_web::{
52///     http::StatusCode,
53///     test::{call_service, init_service, TestRequest},
54/// };
55///
56/// let app = init_service(app).await;
57///
58/// let req = TestRequest::with_uri("/test").to_request();
59/// let res = call_service(&app, req).await;
60/// assert_eq!(res.status(), StatusCode::OK);
61///
62/// let req = TestRequest::with_uri("/test/").to_request();
63/// let res = call_service(&app, req).await;
64/// assert_eq!(res.status(), StatusCode::OK);
65///
66/// let req = TestRequest::with_uri("/unmatchable").to_request();
67/// let res = call_service(&app, req).await;
68/// assert_eq!(res.status(), StatusCode::NOT_FOUND);
69///
70/// let req = TestRequest::with_uri("/unmatchable/").to_request();
71/// let res = call_service(&app, req).await;
72/// assert_eq!(res.status(), StatusCode::NOT_FOUND);
73/// # })
74/// ```
75#[derive(Debug, Clone, Copy)]
76pub struct NormalizePath {
77    /// Controls path normalization behavior.
78    trailing_slash_behavior: TrailingSlash,
79
80    /// Returns redirects for non-normalized paths if `Some`.
81    use_redirects: Option<StatusCode>,
82}
83
84impl Default for NormalizePath {
85    fn default() -> Self {
86        Self {
87            trailing_slash_behavior: TrailingSlash::Trim,
88            use_redirects: None,
89        }
90    }
91}
92
93impl NormalizePath {
94    /// Create new `NormalizePath` middleware with the specified trailing slash style.
95    pub fn new(behavior: TrailingSlash) -> Self {
96        Self {
97            trailing_slash_behavior: behavior,
98            use_redirects: None,
99        }
100    }
101
102    /// Constructs a new `NormalizePath` middleware with [trim](TrailingSlash::Trim) semantics.
103    ///
104    /// Use this instead of `NormalizePath::default()` to avoid deprecation warning.
105    pub fn trim() -> Self {
106        Self::new(TrailingSlash::Trim)
107    }
108
109    /// Configures middleware to respond to requests with non-normalized paths with a 307 redirect.
110    ///
111    /// If configured
112    ///
113    /// For example, a request with the path `/api//v1/foo/` would receive a response with a
114    /// `Location: /api/v1/foo` header (assuming `Trim` trailing slash behavior.)
115    ///
116    /// To customize the status code, use [`use_redirects_with`](Self::use_redirects_with).
117    pub fn use_redirects(mut self) -> Self {
118        self.use_redirects = Some(StatusCode::TEMPORARY_REDIRECT);
119        self
120    }
121
122    /// Configures middleware to respond to requests with non-normalized paths with a redirect.
123    ///
124    /// For example, a request with the path `/api//v1/foo/` would receive a 307 response with a
125    /// `Location: /api/v1/foo` header (assuming `Trim` trailing slash behavior.)
126    ///
127    /// # Panics
128    /// Panics if `status_code` is not a redirect (300-399).
129    pub fn use_redirects_with(mut self, status_code: StatusCode) -> Self {
130        assert!(status_code.is_redirection());
131        self.use_redirects = Some(status_code);
132        self
133    }
134}
135
136impl<S, B> Transform<S, ServiceRequest> for NormalizePath
137where
138    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
139    S::Future: 'static,
140{
141    type Response = ServiceResponse<EitherBody<B, ()>>;
142    type Error = Error;
143    type Transform = NormalizePathService<S>;
144    type InitError = ();
145    type Future = Ready<Result<Self::Transform, Self::InitError>>;
146
147    fn new_transform(&self, service: S) -> Self::Future {
148        ready(Ok(NormalizePathService {
149            service,
150            merge_slash: Regex::new("//+").unwrap(),
151            trailing_slash_behavior: self.trailing_slash_behavior,
152            use_redirects: self.use_redirects,
153        }))
154    }
155}
156
157pub struct NormalizePathService<S> {
158    service: S,
159    merge_slash: Regex,
160    trailing_slash_behavior: TrailingSlash,
161    use_redirects: Option<StatusCode>,
162}
163
164impl<S, B> Service<ServiceRequest> for NormalizePathService<S>
165where
166    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
167    S::Future: 'static,
168{
169    type Response = ServiceResponse<EitherBody<B, ()>>;
170    type Error = Error;
171    type Future = NormalizePathFuture<S, B>;
172
173    actix_service::forward_ready!(service);
174
175    fn call(&self, mut req: ServiceRequest) -> Self::Future {
176        let head = req.head_mut();
177
178        let mut path_altered = false;
179        let original_path = head.uri.path();
180
181        // An empty path here means that the URI has no valid path. We skip normalization in this
182        // case, because adding a path can make the URI invalid
183        if !original_path.is_empty() {
184            // Either adds a string to the end (duplicates will be removed anyways) or trims all
185            // slashes from the end
186            let path = match self.trailing_slash_behavior {
187                TrailingSlash::Always => format!("{original_path}/"),
188                TrailingSlash::MergeOnly => original_path.to_string(),
189                TrailingSlash::Trim => original_path.trim_end_matches('/').to_string(),
190                ts_behavior => panic!("unknown trailing slash behavior: {ts_behavior:?}"),
191            };
192
193            // normalize multiple /'s to one /
194            let path = self.merge_slash.replace_all(&path, "/");
195
196            // Ensure root paths are still resolvable. If resulting path is blank after previous
197            // step it means the path was one or more slashes. Reduce to single slash.
198            let path = if path.is_empty() { "/" } else { path.as_ref() };
199
200            // Check whether the path has been changed
201            //
202            // This check was previously implemented as string length comparison
203            //
204            // That approach fails when a trailing slash is added,
205            // and a duplicate slash is removed,
206            // since the length of the strings remains the same
207            //
208            // For example, the path "/v1//s" will be normalized to "/v1/s/"
209            // Both of the paths have the same length,
210            // so the change can not be deduced from the length comparison
211            if path != original_path {
212                let mut parts = head.uri.clone().into_parts();
213                let query = parts.path_and_query.as_ref().and_then(|pq| pq.query());
214
215                let path = match query {
216                    Some(query) => Bytes::from(format!("{path}?{query}")),
217                    None => Bytes::copy_from_slice(path.as_bytes()),
218                };
219                parts.path_and_query = Some(PathAndQuery::from_maybe_shared(path).unwrap());
220
221                let uri = Uri::from_parts(parts).unwrap();
222                req.match_info_mut().get_mut().update(&uri);
223                req.head_mut().uri = uri;
224
225                path_altered = true;
226            }
227        }
228
229        match self.use_redirects {
230            Some(code) if path_altered => {
231                let mut res = HttpResponse::with_body(code, ());
232                res.headers_mut().insert(
233                    header::LOCATION,
234                    req.head_mut().uri.to_string().parse().unwrap(),
235                );
236                NormalizePathFuture::redirect(req.into_response(res))
237            }
238
239            _ => NormalizePathFuture::service(self.service.call(req)),
240        }
241    }
242}
243
244pin_project! {
245    pub struct NormalizePathFuture<S: Service<ServiceRequest>, B> {
246        #[pin] inner: Inner<S, B>,
247    }
248}
249
250impl<S: Service<ServiceRequest>, B> NormalizePathFuture<S, B> {
251    fn service(fut: S::Future) -> Self {
252        Self {
253            inner: Inner::Service {
254                fut,
255                _body: PhantomData,
256            },
257        }
258    }
259
260    fn redirect(res: ServiceResponse<()>) -> Self {
261        Self {
262            inner: Inner::Redirect { res: Some(res) },
263        }
264    }
265}
266
267pin_project! {
268    #[project = InnerProj]
269    enum Inner<S: Service<ServiceRequest>, B> {
270        Redirect { res: Option<ServiceResponse<()>>, },
271        Service {
272            #[pin] fut: S::Future,
273            _body: PhantomData<B>,
274        },
275    }
276}
277
278impl<S, B> Future for NormalizePathFuture<S, B>
279where
280    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
281{
282    type Output = Result<ServiceResponse<EitherBody<B, ()>>, Error>;
283
284    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
285        let this = self.project();
286
287        match this.inner.project() {
288            InnerProj::Redirect { res } => {
289                Poll::Ready(Ok(res.take().unwrap().map_into_right_body()))
290            }
291
292            InnerProj::Service { fut, .. } => {
293                let res = ready!(fut.poll(cx))?;
294                Poll::Ready(Ok(res.map_into_left_body()))
295            }
296        }
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use actix_service::IntoService;
303    use actix_web::{
304        dev::ServiceRequest,
305        guard::fn_guard,
306        test::{self, call_service, init_service, TestRequest},
307        web, App, HttpRequest, HttpResponse,
308    };
309
310    use super::*;
311
312    #[actix_web::test]
313    async fn default_is_trim_no_redirect() {
314        let app = init_service(App::new().wrap(NormalizePath::default()).service(
315            web::resource("/test").to(|req: HttpRequest| async move { req.path().to_owned() }),
316        ))
317        .await;
318
319        let req = TestRequest::with_uri("/test/").to_request();
320        let res = call_service(&app, req).await;
321        assert!(res.status().is_success());
322        assert_eq!(test::read_body(res).await, "/test");
323    }
324
325    #[actix_web::test]
326    async fn trim_trailing_slashes() {
327        let app = init_service(
328            App::new()
329                .wrap(NormalizePath::trim())
330                .service(web::resource("/").to(HttpResponse::Ok))
331                .service(web::resource("/v1/something").to(HttpResponse::Ok))
332                .service(
333                    web::resource("/v2/something")
334                        .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
335                        .to(HttpResponse::Ok),
336                ),
337        )
338        .await;
339
340        let test_uris = vec![
341            "/",
342            "/?query=test",
343            "///",
344            "/v1//something",
345            "/v1//something////",
346            "//v1/something",
347            "//v1//////something",
348            "/v2//something?query=test",
349            "/v2//something////?query=test",
350            "//v2/something?query=test",
351            "//v2//////something?query=test",
352        ];
353
354        for uri in test_uris {
355            let req = TestRequest::with_uri(uri).to_request();
356            let res = call_service(&app, req).await;
357            assert!(res.status().is_success(), "Failed uri: {uri}");
358        }
359    }
360
361    #[actix_web::test]
362    async fn always_trailing_slashes() {
363        let app = init_service(
364            App::new()
365                .wrap(NormalizePath::new(TrailingSlash::Always))
366                .service(web::resource("/").to(HttpResponse::Ok))
367                .service(web::resource("/v1/something/").to(HttpResponse::Ok))
368                .service(
369                    web::resource("/v2/something/")
370                        .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
371                        .to(HttpResponse::Ok),
372                ),
373        )
374        .await;
375
376        let test_uris = vec![
377            "/",
378            "///",
379            "/v1/something",
380            "/v1/something/",
381            "/v1/something////",
382            "//v1//something",
383            "//v1//something//",
384            "/v2/something?query=test",
385            "/v2/something/?query=test",
386            "/v2/something////?query=test",
387            "//v2//something?query=test",
388            "//v2//something//?query=test",
389        ];
390
391        for uri in test_uris {
392            let req = TestRequest::with_uri(uri).to_request();
393            let res = call_service(&app, req).await;
394            assert!(res.status().is_success(), "Failed uri: {uri}");
395        }
396    }
397
398    #[actix_web::test]
399    async fn trim_root_trailing_slashes_with_query() {
400        let app = init_service(
401            App::new()
402                .wrap(NormalizePath::new(TrailingSlash::Trim))
403                .service(
404                    web::resource("/")
405                        .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
406                        .to(HttpResponse::Ok),
407                ),
408        )
409        .await;
410
411        let test_uris = vec!["/?query=test", "//?query=test", "///?query=test"];
412
413        for uri in test_uris {
414            let req = TestRequest::with_uri(uri).to_request();
415            let res = call_service(&app, req).await;
416            assert!(res.status().is_success(), "Failed uri: {uri}");
417        }
418    }
419
420    #[actix_web::test]
421    async fn ensure_trailing_slash() {
422        let app = init_service(
423            App::new()
424                .wrap(NormalizePath::new(TrailingSlash::Always))
425                .service(web::resource("/").to(HttpResponse::Ok))
426                .service(web::resource("/v1/something/").to(HttpResponse::Ok))
427                .service(
428                    web::resource("/v2/something/")
429                        .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
430                        .to(HttpResponse::Ok),
431                ),
432        )
433        .await;
434
435        let test_uris = vec![
436            "/",
437            "///",
438            "/v1/something",
439            "/v1/something/",
440            "/v1/something////",
441            "//v1//something",
442            "//v1//something//",
443            "/v2/something?query=test",
444            "/v2/something/?query=test",
445            "/v2/something////?query=test",
446            "//v2//something?query=test",
447            "//v2//something//?query=test",
448        ];
449
450        for uri in test_uris {
451            let req = TestRequest::with_uri(uri).to_request();
452            let res = call_service(&app, req).await;
453            assert!(res.status().is_success(), "Failed uri: {uri}");
454        }
455    }
456
457    #[actix_web::test]
458    async fn ensure_root_trailing_slash_with_query() {
459        let app = init_service(
460            App::new()
461                .wrap(NormalizePath::new(TrailingSlash::Always))
462                .service(
463                    web::resource("/")
464                        .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
465                        .to(HttpResponse::Ok),
466                ),
467        )
468        .await;
469
470        let test_uris = vec!["/?query=test", "//?query=test", "///?query=test"];
471
472        for uri in test_uris {
473            let req = TestRequest::with_uri(uri).to_request();
474            let res = call_service(&app, req).await;
475            assert!(res.status().is_success(), "Failed uri: {uri}");
476        }
477    }
478
479    #[actix_web::test]
480    async fn keep_trailing_slash_unchanged() {
481        let app = init_service(
482            App::new()
483                .wrap(NormalizePath::new(TrailingSlash::MergeOnly))
484                .service(web::resource("/").to(HttpResponse::Ok))
485                .service(web::resource("/v1/something").to(HttpResponse::Ok))
486                .service(web::resource("/v1/").to(HttpResponse::Ok))
487                .service(
488                    web::resource("/v2/something")
489                        .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
490                        .to(HttpResponse::Ok),
491                ),
492        )
493        .await;
494
495        let tests = vec![
496            ("/", true), // root paths should still work
497            ("/?query=test", true),
498            ("///", true),
499            ("/v1/something////", false),
500            ("/v1/something/", false),
501            ("//v1//something", true),
502            ("/v1/", true),
503            ("/v1", false),
504            ("/v1////", true),
505            ("//v1//", true),
506            ("///v1", false),
507            ("/v2/something?query=test", true),
508            ("/v2/something/?query=test", false),
509            ("/v2/something//?query=test", false),
510            ("//v2//something?query=test", true),
511        ];
512
513        for (uri, success) in tests {
514            let req = TestRequest::with_uri(uri).to_request();
515            let res = call_service(&app, req).await;
516            assert_eq!(res.status().is_success(), success, "Failed uri: {uri}");
517        }
518    }
519
520    #[actix_web::test]
521    async fn no_path() {
522        let app = init_service(
523            App::new()
524                .wrap(NormalizePath::default())
525                .service(web::resource("/").to(HttpResponse::Ok)),
526        )
527        .await;
528
529        // This URI will be interpreted as an authority form, i.e. there is no path nor scheme
530        // (https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3)
531        let req = TestRequest::with_uri("eh").to_request();
532        let res = call_service(&app, req).await;
533        assert_eq!(res.status(), StatusCode::NOT_FOUND);
534    }
535
536    #[actix_web::test]
537    async fn test_in_place_normalization() {
538        let srv = |req: ServiceRequest| {
539            assert_eq!("/v1/something", req.path());
540            ready(Ok(req.into_response(HttpResponse::Ok().finish())))
541        };
542
543        let normalize = NormalizePath::default()
544            .new_transform(srv.into_service())
545            .await
546            .unwrap();
547
548        let test_uris = vec![
549            "/v1//something////",
550            "///v1/something",
551            "//v1///something",
552            "/v1//something",
553        ];
554
555        for uri in test_uris {
556            let req = TestRequest::with_uri(uri).to_srv_request();
557            let res = normalize.call(req).await.unwrap();
558            assert!(res.status().is_success(), "Failed uri: {uri}");
559        }
560    }
561
562    #[actix_web::test]
563    async fn should_normalize_nothing() {
564        const URI: &str = "/v1/something";
565
566        let srv = |req: ServiceRequest| {
567            assert_eq!(URI, req.path());
568            ready(Ok(req.into_response(HttpResponse::Ok().finish())))
569        };
570
571        let normalize = NormalizePath::default()
572            .new_transform(srv.into_service())
573            .await
574            .unwrap();
575
576        let req = TestRequest::with_uri(URI).to_srv_request();
577        let res = normalize.call(req).await.unwrap();
578        assert!(res.status().is_success());
579    }
580
581    #[actix_web::test]
582    async fn should_normalize_no_trail() {
583        let srv = |req: ServiceRequest| {
584            assert_eq!("/v1/something", req.path());
585            ready(Ok(req.into_response(HttpResponse::Ok().finish())))
586        };
587
588        let normalize = NormalizePath::default()
589            .new_transform(srv.into_service())
590            .await
591            .unwrap();
592
593        let req = TestRequest::with_uri("/v1/something/").to_srv_request();
594        let res = normalize.call(req).await.unwrap();
595        assert!(res.status().is_success());
596    }
597
598    #[actix_web::test]
599    async fn should_return_redirects_when_configured() {
600        let normalize = NormalizePath::trim()
601            .use_redirects()
602            .new_transform(test::ok_service())
603            .await
604            .unwrap();
605
606        let req = TestRequest::with_uri("/v1/something/").to_srv_request();
607        let res = normalize.call(req).await.unwrap();
608        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
609
610        let normalize = NormalizePath::trim()
611            .use_redirects_with(StatusCode::PERMANENT_REDIRECT)
612            .new_transform(test::ok_service())
613            .await
614            .unwrap();
615
616        let req = TestRequest::with_uri("/v1/something/").to_srv_request();
617        let res = normalize.call(req).await.unwrap();
618        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
619    }
620
621    #[actix_web::test]
622    async fn trim_with_redirect() {
623        let app = init_service(
624            App::new()
625                .wrap(NormalizePath::trim().use_redirects())
626                .service(web::resource("/").to(HttpResponse::Ok))
627                .service(web::resource("/v1/something").to(HttpResponse::Ok))
628                .service(
629                    web::resource("/v2/something")
630                        .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
631                        .to(HttpResponse::Ok),
632                ),
633        )
634        .await;
635
636        // list of uri and if it should result in a redirect
637        let test_uris = vec![
638            ("/", false),
639            ("///", true),
640            ("/v1/something", false),
641            ("/v1/something/", true),
642            ("/v1/something////", true),
643            ("//v1//something", true),
644            ("//v1//something//", true),
645            ("/v2/something?query=test", false),
646            ("/v2/something/?query=test", true),
647            ("/v2/something////?query=test", true),
648            ("//v2//something?query=test", true),
649            ("//v2//something//?query=test", true),
650        ];
651
652        for (uri, should_redirect) in test_uris {
653            let req = TestRequest::with_uri(uri).to_request();
654            let res = call_service(&app, req).await;
655
656            if should_redirect {
657                assert!(res.status().is_redirection(), "URI did not redirect: {uri}");
658            } else {
659                assert!(res.status().is_success(), "Failed URI: {uri}");
660            }
661        }
662    }
663}