1use std::{
2 collections::HashSet,
3 convert::{TryFrom, TryInto},
4 fmt,
5 rc::Rc,
6};
7
8use actix_web::{
9 dev::RequestHead,
10 error::Result,
11 http::{
12 header::{self, HeaderMap, HeaderName, HeaderValue},
13 Method,
14 },
15};
16use once_cell::sync::Lazy;
17use smallvec::SmallVec;
18
19use crate::{AllOrSome, CorsError};
20
21#[derive(Clone)]
22pub(crate) struct OriginFn {
23 #[allow(clippy::type_complexity)]
24 pub(crate) boxed_fn: Rc<dyn Fn(&HeaderValue, &RequestHead) -> bool>,
25}
26
27impl Default for OriginFn {
28 fn default() -> Self {
30 let boxed_fn: Rc<dyn Fn(&_, &_) -> _> = Rc::new(|_origin, _req_head| false);
31 Self { boxed_fn }
32 }
33}
34
35impl fmt::Debug for OriginFn {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 f.write_str("origin_fn")
38 }
39}
40
41pub(crate) fn header_value_try_into_method(hdr: &HeaderValue) -> Option<Method> {
43 hdr.to_str()
44 .ok()
45 .and_then(|meth| Method::try_from(meth).ok())
46}
47
48#[derive(Debug, Clone)]
49pub(crate) struct Inner {
50 pub(crate) allowed_origins: AllOrSome<HashSet<HeaderValue>>,
51 pub(crate) allowed_origins_fns: SmallVec<[OriginFn; 4]>,
52
53 pub(crate) allowed_methods: HashSet<Method>,
54 pub(crate) allowed_methods_baked: Option<HeaderValue>,
55
56 pub(crate) allowed_headers: AllOrSome<HashSet<HeaderName>>,
57 pub(crate) allowed_headers_baked: Option<HeaderValue>,
58
59 pub(crate) expose_headers: AllOrSome<HashSet<HeaderName>>,
61 pub(crate) expose_headers_baked: Option<HeaderValue>,
62
63 pub(crate) max_age: Option<usize>,
64 pub(crate) preflight: bool,
65 pub(crate) send_wildcard: bool,
66 pub(crate) supports_credentials: bool,
67 #[cfg(feature = "draft-private-network-access")]
68 pub(crate) allow_private_network_access: bool,
69 pub(crate) vary_header: bool,
70 pub(crate) block_on_origin_mismatch: bool,
71}
72
73static EMPTY_ORIGIN_SET: Lazy<HashSet<HeaderValue>> = Lazy::new(HashSet::new);
74
75impl Inner {
76 pub(crate) fn validate_origin(&self, req: &RequestHead) -> Result<bool, CorsError> {
79 #[allow(clippy::mutable_key_type)]
81 let allowed_origins = match &self.allowed_origins {
82 AllOrSome::All if self.allowed_origins_fns.is_empty() => return Ok(true),
83 AllOrSome::Some(allowed_origins) => allowed_origins,
84 _ => &EMPTY_ORIGIN_SET,
86 };
87
88 match req.headers().get(header::ORIGIN) {
90 Some(origin) => {
92 if allowed_origins.contains(origin) || self.validate_origin_fns(origin, req) {
93 Ok(true)
94 } else if self.block_on_origin_mismatch {
95 Err(CorsError::OriginNotAllowed)
96 } else {
97 Ok(false)
98 }
99 }
100
101 None => Err(CorsError::MissingOrigin),
105 }
106 }
107
108 fn validate_origin_fns(&self, origin: &HeaderValue, req: &RequestHead) -> bool {
110 self.allowed_origins_fns
111 .iter()
112 .any(|origin_fn| (origin_fn.boxed_fn)(origin, req))
113 }
114
115 pub(crate) fn access_control_allow_origin(&self, req: &RequestHead) -> Option<HeaderValue> {
117 let origin = req.headers().get(header::ORIGIN);
118
119 match self.allowed_origins {
120 AllOrSome::All => {
121 if self.send_wildcard {
122 Some(HeaderValue::from_static("*"))
123 } else {
124 origin.cloned()
126 }
127 }
128
129 AllOrSome::Some(_) => {
130 origin.cloned()
134 }
135 }
136 }
137
138 pub(crate) fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> {
141 let request_method = req
143 .headers()
144 .get(header::ACCESS_CONTROL_REQUEST_METHOD)
145 .map(header_value_try_into_method);
146
147 match request_method {
148 Some(Some(method)) if self.allowed_methods.contains(&method) => Ok(()),
150
151 Some(Some(_)) => Err(CorsError::MethodNotAllowed),
153
154 Some(_) => Err(CorsError::BadRequestMethod),
156
157 None => Err(CorsError::MissingRequestMethod),
159 }
160 }
161
162 pub(crate) fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> {
163 #[allow(clippy::mutable_key_type)]
165 let allowed_headers = match &self.allowed_headers {
166 AllOrSome::All => return Ok(()),
167 AllOrSome::Some(allowed_headers) => allowed_headers,
168 };
169
170 let request_headers = req
173 .headers()
174 .get(header::ACCESS_CONTROL_REQUEST_HEADERS)
175 .map(|hdr| hdr.to_str());
176
177 match request_headers {
178 Some(Ok(headers)) => {
180 #[allow(clippy::mutable_key_type)]
183 let mut request_headers = HashSet::with_capacity(8);
184
185 for hdr in headers.split(',') {
187 match hdr.trim().try_into() {
188 Ok(hdr) => request_headers.insert(hdr),
189 Err(_) => return Err(CorsError::BadRequestHeaders),
190 };
191 }
192
193 if request_headers.is_empty() {
195 return Err(CorsError::BadRequestHeaders);
196 }
197
198 if !request_headers.is_subset(allowed_headers) {
200 return Err(CorsError::HeadersNotAllowed);
201 }
202
203 Ok(())
204 }
205
206 Some(Err(_)) => Err(CorsError::BadRequestHeaders),
208
209 None => Ok(()),
211 }
212 }
213}
214
215pub(crate) fn add_vary_header(headers: &mut HeaderMap) {
219 let value = match headers.get(header::VARY) {
220 Some(hdr) => {
221 let mut val: Vec<u8> = Vec::with_capacity(hdr.len() + 71);
222 val.extend(hdr.as_bytes());
223 val.extend(b", Origin, Access-Control-Request-Method, Access-Control-Request-Headers");
224
225 #[cfg(feature = "draft-private-network-access")]
226 val.extend(b", Access-Control-Request-Private-Network");
227
228 val.try_into().unwrap()
229 }
230
231 #[cfg(feature = "draft-private-network-access")]
232 None => HeaderValue::from_static(
233 "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, \
234 Access-Control-Request-Private-Network",
235 ),
236
237 #[cfg(not(feature = "draft-private-network-access"))]
238 None => HeaderValue::from_static(
239 "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
240 ),
241 };
242
243 headers.insert(header::VARY, value);
244}
245
246#[cfg(test)]
247mod test {
248 use std::rc::Rc;
249
250 use actix_web::{
251 dev::Transform,
252 http::{
253 header::{self, HeaderValue},
254 Method, StatusCode,
255 },
256 test::{self, TestRequest},
257 };
258
259 use crate::Cors;
260
261 fn val_as_str(val: &HeaderValue) -> &str {
262 val.to_str().unwrap()
263 }
264
265 #[actix_web::test]
266 async fn test_validate_not_allowed_origin() {
267 let cors = Cors::default()
268 .allowed_origin("https://www.example.com")
269 .new_transform(test::ok_service())
270 .await
271 .unwrap();
272
273 let req = TestRequest::get()
274 .insert_header((header::ORIGIN, "https://www.unknown.com"))
275 .insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "DNT"))
276 .to_srv_request();
277
278 assert!(cors.inner.validate_origin(req.head()).is_err());
279 assert!(cors.inner.validate_allowed_method(req.head()).is_err());
280 assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
281 }
282
283 #[actix_web::test]
284 async fn test_preflight() {
285 let mut cors = Cors::default()
286 .allow_any_origin()
287 .send_wildcard()
288 .max_age(3600)
289 .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
290 .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
291 .allowed_header(header::CONTENT_TYPE)
292 .new_transform(test::ok_service())
293 .await
294 .unwrap();
295
296 let req = TestRequest::default()
297 .method(Method::OPTIONS)
298 .insert_header(("Origin", "https://www.example.com"))
299 .insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed"))
300 .to_srv_request();
301
302 assert!(cors.inner.validate_allowed_method(req.head()).is_err());
303 assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
304 let resp = test::call_service(&cors, req).await;
305 assert_eq!(resp.status(), StatusCode::OK);
306
307 let req = TestRequest::default()
308 .method(Method::OPTIONS)
309 .insert_header(("Origin", "https://www.example.com"))
310 .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "put"))
311 .to_srv_request();
312
313 assert!(cors.inner.validate_allowed_method(req.head()).is_err());
314 assert!(cors.inner.validate_allowed_headers(req.head()).is_ok());
315
316 let req = TestRequest::default()
317 .method(Method::OPTIONS)
318 .insert_header(("Origin", "https://www.example.com"))
319 .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
320 .insert_header((
321 header::ACCESS_CONTROL_REQUEST_HEADERS,
322 "AUTHORIZATION,ACCEPT",
323 ))
324 .to_srv_request();
325
326 let resp = test::call_service(&cors, req).await;
327 assert_eq!(
328 Some(&b"*"[..]),
329 resp.headers()
330 .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
331 .map(HeaderValue::as_bytes)
332 );
333 assert_eq!(
334 Some(&b"3600"[..]),
335 resp.headers()
336 .get(header::ACCESS_CONTROL_MAX_AGE)
337 .map(HeaderValue::as_bytes)
338 );
339
340 let hdr = resp
341 .headers()
342 .get(header::ACCESS_CONTROL_ALLOW_HEADERS)
343 .map(val_as_str)
344 .unwrap();
345 assert!(hdr.contains("authorization"));
346 assert!(hdr.contains("accept"));
347 assert!(hdr.contains("content-type"));
348
349 let methods = resp
350 .headers()
351 .get(header::ACCESS_CONTROL_ALLOW_METHODS)
352 .unwrap()
353 .to_str()
354 .unwrap();
355 assert!(methods.contains("POST"));
356 assert!(methods.contains("GET"));
357 assert!(methods.contains("OPTIONS"));
358
359 Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
360
361 let req = TestRequest::default()
362 .method(Method::OPTIONS)
363 .insert_header(("Origin", "https://www.example.com"))
364 .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
365 .insert_header((
366 header::ACCESS_CONTROL_REQUEST_HEADERS,
367 "AUTHORIZATION,ACCEPT",
368 ))
369 .to_srv_request();
370
371 let resp = test::call_service(&cors, req).await;
372 assert_eq!(resp.status(), StatusCode::OK);
373 }
374
375 #[actix_web::test]
376 async fn allow_fn_origin_equals_head_origin() {
377 let cors = Cors::default()
378 .allowed_origin_fn(|origin, head| {
379 let head_origin = head
380 .headers()
381 .get(header::ORIGIN)
382 .expect("unwrapping origin header should never fail in allowed_origin_fn");
383 assert!(origin == head_origin);
384 true
385 })
386 .allow_any_method()
387 .allow_any_header()
388 .new_transform(test::status_service(StatusCode::NO_CONTENT))
389 .await
390 .unwrap();
391
392 let req = TestRequest::default()
393 .method(Method::OPTIONS)
394 .insert_header(("Origin", "https://www.example.com"))
395 .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
396 .to_srv_request();
397 let resp = test::call_service(&cors, req).await;
398 assert_eq!(resp.status(), StatusCode::OK);
399
400 let req = TestRequest::default()
401 .method(Method::GET)
402 .insert_header(("Origin", "https://www.example.com"))
403 .to_srv_request();
404 let resp = test::call_service(&cors, req).await;
405 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
406 }
407}