actix_web_lab/
strict_transport_security.rs1use std::{convert::Infallible, str, time::Duration};
6
7use actix_web::{
8 error::ParseError,
9 http::header::{
10 from_one_raw_str, Header, HeaderName, HeaderValue, TryIntoHeaderValue,
11 STRICT_TRANSPORT_SECURITY,
12 },
13 HttpMessage,
14};
15
16const SECS_IN_YEAR: u64 = 3600 * 24 * 365;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[doc(alias = "hsts", alias = "sts")]
39pub struct StrictTransportSecurity {
40 duration: Duration,
41
42 pub include_subdomains: bool,
44
45 pub preload: bool,
47}
48
49impl StrictTransportSecurity {
50 pub fn new(duration: Duration) -> Self {
54 Self {
55 duration,
56 ..Self::default()
57 }
58 }
59
60 pub fn recommended() -> Self {
64 Self {
65 duration: Duration::from_secs(2 * SECS_IN_YEAR),
66 include_subdomains: true,
67 ..Self::default()
68 }
69 }
70
71 pub fn include_subdomains(mut self) -> Self {
73 self.include_subdomains = true;
74 self
75 }
76
77 pub fn preload(mut self) -> Self {
81 self.preload = true;
82 self
83 }
84}
85
86impl Default for StrictTransportSecurity {
87 fn default() -> Self {
88 Self {
89 duration: Duration::from_secs(300),
90 include_subdomains: false,
91 preload: false,
92 }
93 }
94}
95
96impl str::FromStr for StrictTransportSecurity {
97 type Err = ParseError;
98
99 fn from_str(val: &str) -> Result<Self, Self::Err> {
100 let mut parts = val.split(';').map(str::trim);
101
102 let duration = parts
104 .next()
105 .ok_or(ParseError::Header)?
106 .split_once('=')
107 .and_then(|(key, max_age)| {
108 if key.trim() != "max-age" {
109 return None;
110 }
111
112 max_age.trim().parse().ok()
113 })
114 .map(Duration::from_secs)
115 .ok_or(ParseError::Header)?;
116
117 let mut include_subdomains = false;
118 let mut preload = false;
119
120 for part in parts {
122 if part == "includeSubdomains" {
123 include_subdomains = true;
124 }
125
126 if part == "preload" {
127 preload = true;
128 }
129 }
130
131 Ok(Self {
132 duration,
133 include_subdomains,
134 preload,
135 })
136 }
137}
138
139impl TryIntoHeaderValue for StrictTransportSecurity {
140 type Error = Infallible;
141
142 fn try_into_value(self) -> Result<HeaderValue, Self::Error> {
143 let secs = self.duration.as_secs();
144 let subdomains = if self.include_subdomains {
145 "; includeSubDomains"
146 } else {
147 ""
148 };
149 let preload = if self.preload { "; preload" } else { "" };
150
151 let sts = format!("max-age={secs}{subdomains}{preload}")
153 .parse()
154 .unwrap();
155
156 Ok(sts)
157 }
158}
159
160impl Header for StrictTransportSecurity {
161 fn name() -> HeaderName {
162 STRICT_TRANSPORT_SECURITY
163 }
164
165 fn parse<M: HttpMessage>(msg: &M) -> Result<Self, ParseError> {
166 from_one_raw_str(msg.headers().get(Self::name()))
167 }
168}
169
170#[cfg(test)]
171mod test {
172 use actix_web::HttpResponse;
173
174 use super::*;
175
176 #[test]
177 fn hsts_as_header() {
178 let res = HttpResponse::Ok()
179 .insert_header(StrictTransportSecurity::default())
180 .finish();
181 assert_eq!(
182 res.headers()
183 .get(StrictTransportSecurity::name())
184 .unwrap()
185 .to_str()
186 .unwrap(),
187 "max-age=300"
188 );
189
190 let res = HttpResponse::Ok()
191 .insert_header(StrictTransportSecurity::default().include_subdomains())
192 .finish();
193 assert_eq!(
194 res.headers()
195 .get(StrictTransportSecurity::name())
196 .unwrap()
197 .to_str()
198 .unwrap(),
199 "max-age=300; includeSubDomains"
200 );
201
202 let res = HttpResponse::Ok()
203 .insert_header(StrictTransportSecurity::default().preload())
204 .finish();
205 assert_eq!(
206 res.headers()
207 .get(StrictTransportSecurity::name())
208 .unwrap()
209 .to_str()
210 .unwrap(),
211 "max-age=300; preload"
212 );
213
214 let res = HttpResponse::Ok()
215 .insert_header(
216 StrictTransportSecurity::default()
217 .include_subdomains()
218 .preload(),
219 )
220 .finish();
221 assert_eq!(
222 res.headers()
223 .get(StrictTransportSecurity::name())
224 .unwrap()
225 .to_str()
226 .unwrap(),
227 "max-age=300; includeSubDomains; preload"
228 );
229 }
230
231 #[test]
232 fn recommended_config() {
233 let res = HttpResponse::Ok()
234 .insert_header(StrictTransportSecurity::recommended())
235 .finish();
236 assert_eq!(
237 res.headers().get("strict-transport-security").unwrap(),
238 "max-age=63072000; includeSubDomains"
239 );
240 }
241
242 #[test]
243 fn parsing() {
244 assert!("".parse::<StrictTransportSecurity>().is_err());
245 assert!("duration=1".parse::<StrictTransportSecurity>().is_err());
246
247 assert_eq!(
248 "max-age=1".parse::<StrictTransportSecurity>().unwrap(),
249 StrictTransportSecurity {
250 duration: Duration::from_secs(1),
251 include_subdomains: false,
252 preload: false,
253 }
254 );
255
256 assert_eq!(
257 "max-age=1; includeSubdomains"
258 .parse::<StrictTransportSecurity>()
259 .unwrap(),
260 StrictTransportSecurity {
261 duration: Duration::from_secs(1),
262 include_subdomains: true,
263 preload: false,
264 }
265 );
266
267 assert_eq!(
268 "max-age=1; preload"
269 .parse::<StrictTransportSecurity>()
270 .unwrap(),
271 StrictTransportSecurity {
272 duration: Duration::from_secs(1),
273 include_subdomains: false,
274 preload: true,
275 }
276 );
277
278 assert_eq!(
279 "max-age=1; includeSubdomains; preload"
280 .parse::<StrictTransportSecurity>()
281 .unwrap(),
282 StrictTransportSecurity {
283 duration: Duration::from_secs(1),
284 include_subdomains: true,
285 preload: true,
286 }
287 );
288 }
289}