actix_web_lab/
redirect_to_https.rs

1use std::{
2    future::{ready, Ready},
3    rc::Rc,
4};
5
6use actix_web::{
7    body::EitherBody,
8    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
9    http::header::TryIntoHeaderPair,
10    web::Redirect,
11    HttpResponse, Responder as _,
12};
13use futures_core::future::LocalBoxFuture;
14
15use crate::header::StrictTransportSecurity;
16
17/// A middleware to redirect traffic to HTTPS if connection is insecure.
18///
19/// # HSTS
20///
21/// [HTTP Strict Transport Security (HSTS)] is configurable. Care should be taken when setting up
22/// HSTS for your site; misconfiguration can potentially leave parts of your site in an unusable
23/// state. By default it is disabled.
24///
25/// See [`StrictTransportSecurity`] docs for more info.
26///
27/// # Examples
28///
29/// ```
30/// # use std::time::Duration;
31/// # use actix_web::App;
32/// use actix_web_lab::{header::StrictTransportSecurity, middleware::RedirectHttps};
33///
34/// let mw = RedirectHttps::default();
35/// let mw = RedirectHttps::default().to_port(8443);
36/// let mw = RedirectHttps::with_hsts(StrictTransportSecurity::default());
37/// let mw = RedirectHttps::with_hsts(StrictTransportSecurity::new(Duration::from_secs(60 * 60)));
38/// let mw = RedirectHttps::with_hsts(StrictTransportSecurity::recommended());
39///
40/// App::new().wrap(mw)
41/// # ;
42/// ```
43///
44/// [HTTP Strict Transport Security (HSTS)]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security
45#[derive(Debug, Clone, Default)]
46pub struct RedirectHttps {
47    hsts: Option<StrictTransportSecurity>,
48    port: Option<u16>,
49}
50
51impl RedirectHttps {
52    /// Construct new HTTP redirect middleware with strict transport security configuration.
53    pub fn with_hsts(hsts: StrictTransportSecurity) -> Self {
54        Self {
55            hsts: Some(hsts),
56            ..Self::default()
57        }
58    }
59
60    /// Sets custom secure redirect port.
61    ///
62    /// By default, no port is set explicitly so the standard HTTPS port (443) is used.
63    pub fn to_port(mut self, port: u16) -> Self {
64        self.port = Some(port);
65        self
66    }
67}
68
69impl<S, B> Transform<S, ServiceRequest> for RedirectHttps
70where
71    S: Service<ServiceRequest, Response = ServiceResponse<B>> + 'static,
72{
73    type Response = ServiceResponse<EitherBody<B, ()>>;
74    type Error = S::Error;
75    type Transform = RedirectHttpsMiddleware<S>;
76    type InitError = ();
77    type Future = Ready<Result<Self::Transform, Self::InitError>>;
78
79    fn new_transform(&self, service: S) -> Self::Future {
80        ready(Ok(RedirectHttpsMiddleware {
81            service: Rc::new(service),
82            hsts: self.hsts,
83            port: self.port,
84        }))
85    }
86}
87
88pub struct RedirectHttpsMiddleware<S> {
89    service: Rc<S>,
90    hsts: Option<StrictTransportSecurity>,
91    port: Option<u16>,
92}
93
94impl<S, B> Service<ServiceRequest> for RedirectHttpsMiddleware<S>
95where
96    S: Service<ServiceRequest, Response = ServiceResponse<B>> + 'static,
97{
98    type Response = ServiceResponse<EitherBody<B, ()>>;
99    type Error = S::Error;
100    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
101
102    forward_ready!(service);
103
104    fn call(&self, req: ServiceRequest) -> Self::Future {
105        #![allow(clippy::await_holding_refcell_ref)] // RefCell is dropped before await
106
107        let service = Rc::clone(&self.service);
108        let hsts = self.hsts;
109        let port = self.port;
110
111        Box::pin(async move {
112            let (req, pl) = req.into_parts();
113            let conn_info = req.connection_info();
114
115            if conn_info.scheme() != "https" {
116                let host = conn_info.host();
117
118                // construct equivalent https path
119                let (hostname, _port) = host.split_once(':').unwrap_or((host, ""));
120
121                let path = req.uri().path();
122                let uri = match port {
123                    Some(port) => format!("https://{hostname}:{port}{path}"),
124                    None => format!("https://{hostname}{path}"),
125                };
126
127                // all connection info is acquired
128                drop(conn_info);
129
130                // create redirection response
131                let redirect = Redirect::to(uri);
132
133                let mut res = redirect.respond_to(&req).map_into_right_body();
134                apply_hsts(&mut res, hsts);
135
136                return Ok(ServiceResponse::new(req, res));
137            }
138
139            drop(conn_info);
140
141            let req = ServiceRequest::from_parts(req, pl);
142
143            // TODO: apply HSTS header to error case
144
145            service.call(req).await.map(|mut res| {
146                apply_hsts(res.response_mut(), hsts);
147                res.map_into_left_body()
148            })
149        })
150    }
151}
152
153/// Apply HSTS config to an `HttpResponse`.
154fn apply_hsts<B>(res: &mut HttpResponse<B>, hsts: Option<StrictTransportSecurity>) {
155    if let Some(hsts) = hsts {
156        let (name, val) = hsts.try_into_pair().unwrap();
157        res.headers_mut().insert(name, val);
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use actix_web::{
164        body::MessageBody,
165        dev::ServiceFactory,
166        http::{
167            header::{self, Header as _},
168            StatusCode,
169        },
170        test, web, App, Error, HttpResponse,
171    };
172
173    use super::*;
174    use crate::{assert_response_matches, test_request};
175
176    fn test_app() -> App<
177        impl ServiceFactory<
178            ServiceRequest,
179            Response = ServiceResponse<impl MessageBody>,
180            Config = (),
181            InitError = (),
182            Error = Error,
183        >,
184    > {
185        App::new().wrap(RedirectHttps::default()).route(
186            "/",
187            web::get().to(|| async { HttpResponse::Ok().body("content") }),
188        )
189    }
190
191    #[actix_web::test]
192    async fn redirect_non_https() {
193        let app = test::init_service(test_app()).await;
194
195        let req = test::TestRequest::default().to_request();
196        let res = test::call_service(&app, req).await;
197
198        assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
199        let loc = res.headers().get(header::LOCATION);
200        assert!(loc.unwrap().as_bytes().starts_with(b"https://"));
201
202        let body = test::read_body(res).await;
203        assert!(body.is_empty());
204    }
205
206    #[actix_web::test]
207    async fn do_not_redirect_already_https() {
208        let app = test::init_service(test_app()).await;
209
210        let req = test::TestRequest::default()
211            .uri("https://localhost:443/")
212            .to_request();
213
214        let res = test::call_service(&app, req).await;
215        assert_eq!(res.status(), StatusCode::OK);
216        assert!(res.headers().get(header::LOCATION).is_none());
217
218        let body = test::read_body(res).await;
219        assert_eq!(body, "content");
220    }
221
222    #[actix_web::test]
223    async fn with_hsts() {
224        // no HSTS
225        let app = RedirectHttps::default()
226            .new_transform(test::ok_service())
227            .await
228            .unwrap();
229
230        let req = test_request!(GET "http://localhost/").to_srv_request();
231        let res = test::call_service(&app, req).await;
232        assert!(!res.headers().contains_key(StrictTransportSecurity::name()));
233
234        let req = test_request!(GET "https://localhost:443/").to_srv_request();
235        let res = test::call_service(&app, req).await;
236        assert!(!res.headers().contains_key(StrictTransportSecurity::name()));
237
238        // with HSTS
239        let app = RedirectHttps::with_hsts(StrictTransportSecurity::recommended())
240            .new_transform(test::ok_service())
241            .await
242            .unwrap();
243
244        let req = test_request!(GET "http://localhost/").to_srv_request();
245        let res = test::call_service(&app, req).await;
246        assert!(res.headers().contains_key(StrictTransportSecurity::name()));
247
248        let req = test_request!(GET "https://localhost:443/").to_srv_request();
249        let res = test::call_service(&app, req).await;
250        assert!(res.headers().contains_key(StrictTransportSecurity::name()));
251    }
252
253    #[actix_web::test]
254    async fn to_custom_port() {
255        let app = RedirectHttps::default()
256            .to_port(8443)
257            .new_transform(test::ok_service())
258            .await
259            .unwrap();
260
261        let req = test_request!(GET "http://localhost/").to_srv_request();
262        let res = test::call_service(&app, req).await;
263        assert_response_matches!(res, TEMPORARY_REDIRECT; "location" => "https://localhost:8443/");
264    }
265
266    #[actix_web::test]
267    async fn to_custom_port_when_port_in_host() {
268        let app = RedirectHttps::default()
269            .to_port(8443)
270            .new_transform(test::ok_service())
271            .await
272            .unwrap();
273
274        let req = test_request!(GET "http://localhost:8080/").to_srv_request();
275        let res = test::call_service(&app, req).await;
276        assert_response_matches!(res, TEMPORARY_REDIRECT; "location" => "https://localhost:8443/");
277    }
278}