actix_web_lab/
redirect_to_https.rs1use std::{
2 future::{ready, Ready},
3 rc::Rc,
4};
5
6use actix_web::{
7 body::EitherBody,
8 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
9 http::header::TryIntoHeaderPair,
10 web::Redirect,
11 HttpResponse, Responder as _,
12};
13use futures_core::future::LocalBoxFuture;
14
15use crate::header::StrictTransportSecurity;
16
17#[derive(Debug, Clone, Default)]
46pub struct RedirectHttps {
47 hsts: Option<StrictTransportSecurity>,
48 port: Option<u16>,
49}
50
51impl RedirectHttps {
52 pub fn with_hsts(hsts: StrictTransportSecurity) -> Self {
54 Self {
55 hsts: Some(hsts),
56 ..Self::default()
57 }
58 }
59
60 pub fn to_port(mut self, port: u16) -> Self {
64 self.port = Some(port);
65 self
66 }
67}
68
69impl<S, B> Transform<S, ServiceRequest> for RedirectHttps
70where
71 S: Service<ServiceRequest, Response = ServiceResponse<B>> + 'static,
72{
73 type Response = ServiceResponse<EitherBody<B, ()>>;
74 type Error = S::Error;
75 type Transform = RedirectHttpsMiddleware<S>;
76 type InitError = ();
77 type Future = Ready<Result<Self::Transform, Self::InitError>>;
78
79 fn new_transform(&self, service: S) -> Self::Future {
80 ready(Ok(RedirectHttpsMiddleware {
81 service: Rc::new(service),
82 hsts: self.hsts,
83 port: self.port,
84 }))
85 }
86}
87
88pub struct RedirectHttpsMiddleware<S> {
89 service: Rc<S>,
90 hsts: Option<StrictTransportSecurity>,
91 port: Option<u16>,
92}
93
94impl<S, B> Service<ServiceRequest> for RedirectHttpsMiddleware<S>
95where
96 S: Service<ServiceRequest, Response = ServiceResponse<B>> + 'static,
97{
98 type Response = ServiceResponse<EitherBody<B, ()>>;
99 type Error = S::Error;
100 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
101
102 forward_ready!(service);
103
104 fn call(&self, req: ServiceRequest) -> Self::Future {
105 #![allow(clippy::await_holding_refcell_ref)] let service = Rc::clone(&self.service);
108 let hsts = self.hsts;
109 let port = self.port;
110
111 Box::pin(async move {
112 let (req, pl) = req.into_parts();
113 let conn_info = req.connection_info();
114
115 if conn_info.scheme() != "https" {
116 let host = conn_info.host();
117
118 let (hostname, _port) = host.split_once(':').unwrap_or((host, ""));
120
121 let path = req.uri().path();
122 let uri = match port {
123 Some(port) => format!("https://{hostname}:{port}{path}"),
124 None => format!("https://{hostname}{path}"),
125 };
126
127 drop(conn_info);
129
130 let redirect = Redirect::to(uri);
132
133 let mut res = redirect.respond_to(&req).map_into_right_body();
134 apply_hsts(&mut res, hsts);
135
136 return Ok(ServiceResponse::new(req, res));
137 }
138
139 drop(conn_info);
140
141 let req = ServiceRequest::from_parts(req, pl);
142
143 service.call(req).await.map(|mut res| {
146 apply_hsts(res.response_mut(), hsts);
147 res.map_into_left_body()
148 })
149 })
150 }
151}
152
153fn apply_hsts<B>(res: &mut HttpResponse<B>, hsts: Option<StrictTransportSecurity>) {
155 if let Some(hsts) = hsts {
156 let (name, val) = hsts.try_into_pair().unwrap();
157 res.headers_mut().insert(name, val);
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use actix_web::{
164 body::MessageBody,
165 dev::ServiceFactory,
166 http::{
167 header::{self, Header as _},
168 StatusCode,
169 },
170 test, web, App, Error, HttpResponse,
171 };
172
173 use super::*;
174 use crate::{assert_response_matches, test_request};
175
176 fn test_app() -> App<
177 impl ServiceFactory<
178 ServiceRequest,
179 Response = ServiceResponse<impl MessageBody>,
180 Config = (),
181 InitError = (),
182 Error = Error,
183 >,
184 > {
185 App::new().wrap(RedirectHttps::default()).route(
186 "/",
187 web::get().to(|| async { HttpResponse::Ok().body("content") }),
188 )
189 }
190
191 #[actix_web::test]
192 async fn redirect_non_https() {
193 let app = test::init_service(test_app()).await;
194
195 let req = test::TestRequest::default().to_request();
196 let res = test::call_service(&app, req).await;
197
198 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
199 let loc = res.headers().get(header::LOCATION);
200 assert!(loc.unwrap().as_bytes().starts_with(b"https://"));
201
202 let body = test::read_body(res).await;
203 assert!(body.is_empty());
204 }
205
206 #[actix_web::test]
207 async fn do_not_redirect_already_https() {
208 let app = test::init_service(test_app()).await;
209
210 let req = test::TestRequest::default()
211 .uri("https://localhost:443/")
212 .to_request();
213
214 let res = test::call_service(&app, req).await;
215 assert_eq!(res.status(), StatusCode::OK);
216 assert!(res.headers().get(header::LOCATION).is_none());
217
218 let body = test::read_body(res).await;
219 assert_eq!(body, "content");
220 }
221
222 #[actix_web::test]
223 async fn with_hsts() {
224 let app = RedirectHttps::default()
226 .new_transform(test::ok_service())
227 .await
228 .unwrap();
229
230 let req = test_request!(GET "http://localhost/").to_srv_request();
231 let res = test::call_service(&app, req).await;
232 assert!(!res.headers().contains_key(StrictTransportSecurity::name()));
233
234 let req = test_request!(GET "https://localhost:443/").to_srv_request();
235 let res = test::call_service(&app, req).await;
236 assert!(!res.headers().contains_key(StrictTransportSecurity::name()));
237
238 let app = RedirectHttps::with_hsts(StrictTransportSecurity::recommended())
240 .new_transform(test::ok_service())
241 .await
242 .unwrap();
243
244 let req = test_request!(GET "http://localhost/").to_srv_request();
245 let res = test::call_service(&app, req).await;
246 assert!(res.headers().contains_key(StrictTransportSecurity::name()));
247
248 let req = test_request!(GET "https://localhost:443/").to_srv_request();
249 let res = test::call_service(&app, req).await;
250 assert!(res.headers().contains_key(StrictTransportSecurity::name()));
251 }
252
253 #[actix_web::test]
254 async fn to_custom_port() {
255 let app = RedirectHttps::default()
256 .to_port(8443)
257 .new_transform(test::ok_service())
258 .await
259 .unwrap();
260
261 let req = test_request!(GET "http://localhost/").to_srv_request();
262 let res = test::call_service(&app, req).await;
263 assert_response_matches!(res, TEMPORARY_REDIRECT; "location" => "https://localhost:8443/");
264 }
265
266 #[actix_web::test]
267 async fn to_custom_port_when_port_in_host() {
268 let app = RedirectHttps::default()
269 .to_port(8443)
270 .new_transform(test::ok_service())
271 .await
272 .unwrap();
273
274 let req = test_request!(GET "http://localhost:8080/").to_srv_request();
275 let res = test::call_service(&app, req).await;
276 assert_response_matches!(res, TEMPORARY_REDIRECT; "location" => "https://localhost:8443/");
277 }
278}