use std::{
collections::HashSet,
convert::{TryFrom, TryInto},
fmt,
rc::Rc,
};
use actix_web::{
dev::RequestHead,
error::Result,
http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
Method,
},
};
use once_cell::sync::Lazy;
use smallvec::SmallVec;
use crate::{AllOrSome, CorsError};
#[derive(Clone)]
pub(crate) struct OriginFn {
#[allow(clippy::type_complexity)]
pub(crate) boxed_fn: Rc<dyn Fn(&HeaderValue, &RequestHead) -> bool>,
}
impl Default for OriginFn {
fn default() -> Self {
let boxed_fn: Rc<dyn Fn(&_, &_) -> _> = Rc::new(|_origin, _req_head| false);
Self { boxed_fn }
}
}
impl fmt::Debug for OriginFn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("origin_fn")
}
}
pub(crate) fn header_value_try_into_method(hdr: &HeaderValue) -> Option<Method> {
hdr.to_str()
.ok()
.and_then(|meth| Method::try_from(meth).ok())
}
#[derive(Debug, Clone)]
pub(crate) struct Inner {
pub(crate) allowed_origins: AllOrSome<HashSet<HeaderValue>>,
pub(crate) allowed_origins_fns: SmallVec<[OriginFn; 4]>,
pub(crate) allowed_methods: HashSet<Method>,
pub(crate) allowed_methods_baked: Option<HeaderValue>,
pub(crate) allowed_headers: AllOrSome<HashSet<HeaderName>>,
pub(crate) allowed_headers_baked: Option<HeaderValue>,
pub(crate) expose_headers: AllOrSome<HashSet<HeaderName>>,
pub(crate) expose_headers_baked: Option<HeaderValue>,
pub(crate) max_age: Option<usize>,
pub(crate) preflight: bool,
pub(crate) send_wildcard: bool,
pub(crate) supports_credentials: bool,
#[cfg(feature = "draft-private-network-access")]
pub(crate) allow_private_network_access: bool,
pub(crate) vary_header: bool,
pub(crate) block_on_origin_mismatch: bool,
}
static EMPTY_ORIGIN_SET: Lazy<HashSet<HeaderValue>> = Lazy::new(HashSet::new);
impl Inner {
pub(crate) fn validate_origin(&self, req: &RequestHead) -> Result<bool, CorsError> {
#[allow(clippy::mutable_key_type)]
let allowed_origins = match &self.allowed_origins {
AllOrSome::All if self.allowed_origins_fns.is_empty() => return Ok(true),
AllOrSome::Some(allowed_origins) => allowed_origins,
_ => &EMPTY_ORIGIN_SET,
};
match req.headers().get(header::ORIGIN) {
Some(origin) => {
if allowed_origins.contains(origin) || self.validate_origin_fns(origin, req) {
Ok(true)
} else if self.block_on_origin_mismatch {
Err(CorsError::OriginNotAllowed)
} else {
Ok(false)
}
}
None => Err(CorsError::MissingOrigin),
}
}
fn validate_origin_fns(&self, origin: &HeaderValue, req: &RequestHead) -> bool {
self.allowed_origins_fns
.iter()
.any(|origin_fn| (origin_fn.boxed_fn)(origin, req))
}
pub(crate) fn access_control_allow_origin(&self, req: &RequestHead) -> Option<HeaderValue> {
let origin = req.headers().get(header::ORIGIN);
match self.allowed_origins {
AllOrSome::All => {
if self.send_wildcard {
Some(HeaderValue::from_static("*"))
} else {
origin.cloned()
}
}
AllOrSome::Some(_) => {
origin.cloned()
}
}
}
pub(crate) fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> {
let request_method = req
.headers()
.get(header::ACCESS_CONTROL_REQUEST_METHOD)
.map(header_value_try_into_method);
match request_method {
Some(Some(method)) if self.allowed_methods.contains(&method) => Ok(()),
Some(Some(_)) => Err(CorsError::MethodNotAllowed),
Some(_) => Err(CorsError::BadRequestMethod),
None => Err(CorsError::MissingRequestMethod),
}
}
pub(crate) fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> {
#[allow(clippy::mutable_key_type)]
let allowed_headers = match &self.allowed_headers {
AllOrSome::All => return Ok(()),
AllOrSome::Some(allowed_headers) => allowed_headers,
};
let request_headers = req
.headers()
.get(header::ACCESS_CONTROL_REQUEST_HEADERS)
.map(|hdr| hdr.to_str());
match request_headers {
Some(Ok(headers)) => {
#[allow(clippy::mutable_key_type)]
let mut request_headers = HashSet::with_capacity(8);
for hdr in headers.split(',') {
match hdr.trim().try_into() {
Ok(hdr) => request_headers.insert(hdr),
Err(_) => return Err(CorsError::BadRequestHeaders),
};
}
if request_headers.is_empty() {
return Err(CorsError::BadRequestHeaders);
}
if !request_headers.is_subset(allowed_headers) {
return Err(CorsError::HeadersNotAllowed);
}
Ok(())
}
Some(Err(_)) => Err(CorsError::BadRequestHeaders),
None => Ok(()),
}
}
}
pub(crate) fn add_vary_header(headers: &mut HeaderMap) {
let value = match headers.get(header::VARY) {
Some(hdr) => {
let mut val: Vec<u8> = Vec::with_capacity(hdr.len() + 71);
val.extend(hdr.as_bytes());
val.extend(b", Origin, Access-Control-Request-Method, Access-Control-Request-Headers");
#[cfg(feature = "draft-private-network-access")]
val.extend(b", Access-Control-Allow-Private-Network");
val.try_into().unwrap()
}
#[cfg(feature = "draft-private-network-access")]
None => HeaderValue::from_static(
"Origin, Access-Control-Request-Method, Access-Control-Request-Headers, \
Access-Control-Allow-Private-Network",
),
#[cfg(not(feature = "draft-private-network-access"))]
None => HeaderValue::from_static(
"Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
),
};
headers.insert(header::VARY, value);
}
#[cfg(test)]
mod test {
use std::rc::Rc;
use actix_web::{
dev::Transform,
http::{
header::{self, HeaderValue},
Method, StatusCode,
},
test::{self, TestRequest},
};
use crate::Cors;
fn val_as_str(val: &HeaderValue) -> &str {
val.to_str().unwrap()
}
#[actix_web::test]
async fn test_validate_not_allowed_origin() {
let cors = Cors::default()
.allowed_origin("https://www.example.com")
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::get()
.insert_header((header::ORIGIN, "https://www.unknown.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "DNT"))
.to_srv_request();
assert!(cors.inner.validate_origin(req.head()).is_err());
assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
}
#[actix_web::test]
async fn test_preflight() {
let mut cors = Cors::default()
.allow_any_origin()
.send_wildcard()
.max_age(3600)
.allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
.allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
.allowed_header(header::CONTENT_TYPE)
.new_transform(test::ok_service())
.await
.unwrap();
let req = TestRequest::default()
.method(Method::OPTIONS)
.insert_header(("Origin", "https://www.example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed"))
.to_srv_request();
assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK);
let req = TestRequest::default()
.method(Method::OPTIONS)
.insert_header(("Origin", "https://www.example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "put"))
.to_srv_request();
assert!(cors.inner.validate_allowed_method(req.head()).is_err());
assert!(cors.inner.validate_allowed_headers(req.head()).is_ok());
let req = TestRequest::default()
.method(Method::OPTIONS)
.insert_header(("Origin", "https://www.example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
.insert_header((
header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT",
))
.to_srv_request();
let resp = test::call_service(&cors, req).await;
assert_eq!(
Some(&b"*"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.map(HeaderValue::as_bytes)
);
assert_eq!(
Some(&b"3600"[..]),
resp.headers()
.get(header::ACCESS_CONTROL_MAX_AGE)
.map(HeaderValue::as_bytes)
);
let hdr = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_HEADERS)
.map(val_as_str)
.unwrap();
assert!(hdr.contains("authorization"));
assert!(hdr.contains("accept"));
assert!(hdr.contains("content-type"));
let methods = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.unwrap()
.to_str()
.unwrap();
assert!(methods.contains("POST"));
assert!(methods.contains("GET"));
assert!(methods.contains("OPTIONS"));
Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
let req = TestRequest::default()
.method(Method::OPTIONS)
.insert_header(("Origin", "https://www.example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
.insert_header((
header::ACCESS_CONTROL_REQUEST_HEADERS,
"AUTHORIZATION,ACCEPT",
))
.to_srv_request();
let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK);
}
#[actix_web::test]
async fn allow_fn_origin_equals_head_origin() {
let cors = Cors::default()
.allowed_origin_fn(|origin, head| {
let head_origin = head
.headers()
.get(header::ORIGIN)
.expect("unwrapping origin header should never fail in allowed_origin_fn");
assert!(origin == head_origin);
true
})
.allow_any_method()
.allow_any_header()
.new_transform(test::status_service(StatusCode::NO_CONTENT))
.await
.unwrap();
let req = TestRequest::default()
.method(Method::OPTIONS)
.insert_header(("Origin", "https://www.example.com"))
.insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
.to_srv_request();
let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::OK);
let req = TestRequest::default()
.method(Method::GET)
.insert_header(("Origin", "https://www.example.com"))
.to_srv_request();
let resp = test::call_service(&cors, req).await;
assert_eq!(resp.status(), StatusCode::NO_CONTENT);
}
}