actix_cors/
inner.rs

1use std::{
2    collections::HashSet,
3    convert::{TryFrom, TryInto},
4    fmt,
5    rc::Rc,
6};
7
8use actix_web::{
9    dev::RequestHead,
10    error::Result,
11    http::{
12        header::{self, HeaderMap, HeaderName, HeaderValue},
13        Method,
14    },
15};
16use once_cell::sync::Lazy;
17use smallvec::SmallVec;
18
19use crate::{AllOrSome, CorsError};
20
21#[derive(Clone)]
22pub(crate) struct OriginFn {
23    #[allow(clippy::type_complexity)]
24    pub(crate) boxed_fn: Rc<dyn Fn(&HeaderValue, &RequestHead) -> bool>,
25}
26
27impl Default for OriginFn {
28    /// Dummy default for use in tiny_vec. Do not use.
29    fn default() -> Self {
30        let boxed_fn: Rc<dyn Fn(&_, &_) -> _> = Rc::new(|_origin, _req_head| false);
31        Self { boxed_fn }
32    }
33}
34
35impl fmt::Debug for OriginFn {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        f.write_str("origin_fn")
38    }
39}
40
41/// Try to parse header value as HTTP method.
42pub(crate) fn header_value_try_into_method(hdr: &HeaderValue) -> Option<Method> {
43    hdr.to_str()
44        .ok()
45        .and_then(|meth| Method::try_from(meth).ok())
46}
47
48#[derive(Debug, Clone)]
49pub(crate) struct Inner {
50    pub(crate) allowed_origins: AllOrSome<HashSet<HeaderValue>>,
51    pub(crate) allowed_origins_fns: SmallVec<[OriginFn; 4]>,
52
53    pub(crate) allowed_methods: HashSet<Method>,
54    pub(crate) allowed_methods_baked: Option<HeaderValue>,
55
56    pub(crate) allowed_headers: AllOrSome<HashSet<HeaderName>>,
57    pub(crate) allowed_headers_baked: Option<HeaderValue>,
58
59    /// `All` will echo back `Access-Control-Request-Header` list.
60    pub(crate) expose_headers: AllOrSome<HashSet<HeaderName>>,
61    pub(crate) expose_headers_baked: Option<HeaderValue>,
62
63    pub(crate) max_age: Option<usize>,
64    pub(crate) preflight: bool,
65    pub(crate) send_wildcard: bool,
66    pub(crate) supports_credentials: bool,
67    #[cfg(feature = "draft-private-network-access")]
68    pub(crate) allow_private_network_access: bool,
69    pub(crate) vary_header: bool,
70    pub(crate) block_on_origin_mismatch: bool,
71}
72
73static EMPTY_ORIGIN_SET: Lazy<HashSet<HeaderValue>> = Lazy::new(HashSet::new);
74
75impl Inner {
76    /// The bool returned in Ok(_) position indicates whether the `Access-Control-Allow-Origin`
77    /// header should be added to the response or not.
78    pub(crate) fn validate_origin(&self, req: &RequestHead) -> Result<bool, CorsError> {
79        // return early if all origins are allowed or get ref to allowed origins set
80        #[allow(clippy::mutable_key_type)]
81        let allowed_origins = match &self.allowed_origins {
82            AllOrSome::All if self.allowed_origins_fns.is_empty() => return Ok(true),
83            AllOrSome::Some(allowed_origins) => allowed_origins,
84            // only function origin validators are defined
85            _ => &EMPTY_ORIGIN_SET,
86        };
87
88        // get origin header and try to parse as string
89        match req.headers().get(header::ORIGIN) {
90            // origin header exists and is a string
91            Some(origin) => {
92                if allowed_origins.contains(origin) || self.validate_origin_fns(origin, req) {
93                    Ok(true)
94                } else if self.block_on_origin_mismatch {
95                    Err(CorsError::OriginNotAllowed)
96                } else {
97                    Ok(false)
98                }
99            }
100
101            // origin header is missing
102            // note: with our implementation, the origin header is required for OPTIONS request or
103            // else this would be unreachable
104            None => Err(CorsError::MissingOrigin),
105        }
106    }
107
108    /// Accepts origin if _ANY_ functions return true. Only called when Origin exists.
109    fn validate_origin_fns(&self, origin: &HeaderValue, req: &RequestHead) -> bool {
110        self.allowed_origins_fns
111            .iter()
112            .any(|origin_fn| (origin_fn.boxed_fn)(origin, req))
113    }
114
115    /// Only called if origin exists and always after it's validated.
116    pub(crate) fn access_control_allow_origin(&self, req: &RequestHead) -> Option<HeaderValue> {
117        let origin = req.headers().get(header::ORIGIN);
118
119        match self.allowed_origins {
120            AllOrSome::All => {
121                if self.send_wildcard {
122                    Some(HeaderValue::from_static("*"))
123                } else {
124                    // see note below about why `.cloned()` is correct
125                    origin.cloned()
126                }
127            }
128
129            AllOrSome::Some(_) => {
130                // since origin (if it exists) is known to be allowed if this method is called
131                // then cloning the option is all that is required to be used as an echoed back
132                // header value (or omitted if None)
133                origin.cloned()
134            }
135        }
136    }
137
138    /// Use in preflight checks and therefore operates on header list in
139    /// `Access-Control-Request-Headers` not the actual header set.
140    pub(crate) fn validate_allowed_method(&self, req: &RequestHead) -> Result<(), CorsError> {
141        // extract access control header and try to parse as method
142        let request_method = req
143            .headers()
144            .get(header::ACCESS_CONTROL_REQUEST_METHOD)
145            .map(header_value_try_into_method);
146
147        match request_method {
148            // method valid and allowed
149            Some(Some(method)) if self.allowed_methods.contains(&method) => Ok(()),
150
151            // method valid but not allowed
152            Some(Some(_)) => Err(CorsError::MethodNotAllowed),
153
154            // method invalid
155            Some(_) => Err(CorsError::BadRequestMethod),
156
157            // method missing so this is not a preflight request
158            None => Err(CorsError::MissingRequestMethod),
159        }
160    }
161
162    pub(crate) fn validate_allowed_headers(&self, req: &RequestHead) -> Result<(), CorsError> {
163        // return early if all headers are allowed or get ref to allowed origins set
164        #[allow(clippy::mutable_key_type)]
165        let allowed_headers = match &self.allowed_headers {
166            AllOrSome::All => return Ok(()),
167            AllOrSome::Some(allowed_headers) => allowed_headers,
168        };
169
170        // extract access control header as string
171        // header format should be comma separated header names
172        let request_headers = req
173            .headers()
174            .get(header::ACCESS_CONTROL_REQUEST_HEADERS)
175            .map(|hdr| hdr.to_str());
176
177        match request_headers {
178            // header list is valid string
179            Some(Ok(headers)) => {
180                // the set is ephemeral we take care not to mutate the
181                // inserted keys so this lint exception is acceptable
182                #[allow(clippy::mutable_key_type)]
183                let mut request_headers = HashSet::with_capacity(8);
184
185                // try to convert each header name in the comma-separated list
186                for hdr in headers.split(',') {
187                    match hdr.trim().try_into() {
188                        Ok(hdr) => request_headers.insert(hdr),
189                        Err(_) => return Err(CorsError::BadRequestHeaders),
190                    };
191                }
192
193                // header list must contain 1 or more header name
194                if request_headers.is_empty() {
195                    return Err(CorsError::BadRequestHeaders);
196                }
197
198                // request header list must be a subset of allowed headers
199                if !request_headers.is_subset(allowed_headers) {
200                    return Err(CorsError::HeadersNotAllowed);
201                }
202
203                Ok(())
204            }
205
206            // header list is not a string
207            Some(Err(_)) => Err(CorsError::BadRequestHeaders),
208
209            // header list missing
210            None => Ok(()),
211        }
212    }
213}
214
215/// Add CORS related request headers to response's Vary header.
216///
217/// See <https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches>.
218pub(crate) fn add_vary_header(headers: &mut HeaderMap) {
219    let value = match headers.get(header::VARY) {
220        Some(hdr) => {
221            let mut val: Vec<u8> = Vec::with_capacity(hdr.len() + 71);
222            val.extend(hdr.as_bytes());
223            val.extend(b", Origin, Access-Control-Request-Method, Access-Control-Request-Headers");
224
225            #[cfg(feature = "draft-private-network-access")]
226            val.extend(b", Access-Control-Request-Private-Network");
227
228            val.try_into().unwrap()
229        }
230
231        #[cfg(feature = "draft-private-network-access")]
232        None => HeaderValue::from_static(
233            "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, \
234            Access-Control-Request-Private-Network",
235        ),
236
237        #[cfg(not(feature = "draft-private-network-access"))]
238        None => HeaderValue::from_static(
239            "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
240        ),
241    };
242
243    headers.insert(header::VARY, value);
244}
245
246#[cfg(test)]
247mod test {
248    use std::rc::Rc;
249
250    use actix_web::{
251        dev::Transform,
252        http::{
253            header::{self, HeaderValue},
254            Method, StatusCode,
255        },
256        test::{self, TestRequest},
257    };
258
259    use crate::Cors;
260
261    fn val_as_str(val: &HeaderValue) -> &str {
262        val.to_str().unwrap()
263    }
264
265    #[actix_web::test]
266    async fn test_validate_not_allowed_origin() {
267        let cors = Cors::default()
268            .allowed_origin("https://www.example.com")
269            .new_transform(test::ok_service())
270            .await
271            .unwrap();
272
273        let req = TestRequest::get()
274            .insert_header((header::ORIGIN, "https://www.unknown.com"))
275            .insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "DNT"))
276            .to_srv_request();
277
278        assert!(cors.inner.validate_origin(req.head()).is_err());
279        assert!(cors.inner.validate_allowed_method(req.head()).is_err());
280        assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
281    }
282
283    #[actix_web::test]
284    async fn test_preflight() {
285        let mut cors = Cors::default()
286            .allow_any_origin()
287            .send_wildcard()
288            .max_age(3600)
289            .allowed_methods(vec![Method::GET, Method::OPTIONS, Method::POST])
290            .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
291            .allowed_header(header::CONTENT_TYPE)
292            .new_transform(test::ok_service())
293            .await
294            .unwrap();
295
296        let req = TestRequest::default()
297            .method(Method::OPTIONS)
298            .insert_header(("Origin", "https://www.example.com"))
299            .insert_header((header::ACCESS_CONTROL_REQUEST_HEADERS, "X-Not-Allowed"))
300            .to_srv_request();
301
302        assert!(cors.inner.validate_allowed_method(req.head()).is_err());
303        assert!(cors.inner.validate_allowed_headers(req.head()).is_err());
304        let resp = test::call_service(&cors, req).await;
305        assert_eq!(resp.status(), StatusCode::OK);
306
307        let req = TestRequest::default()
308            .method(Method::OPTIONS)
309            .insert_header(("Origin", "https://www.example.com"))
310            .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "put"))
311            .to_srv_request();
312
313        assert!(cors.inner.validate_allowed_method(req.head()).is_err());
314        assert!(cors.inner.validate_allowed_headers(req.head()).is_ok());
315
316        let req = TestRequest::default()
317            .method(Method::OPTIONS)
318            .insert_header(("Origin", "https://www.example.com"))
319            .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
320            .insert_header((
321                header::ACCESS_CONTROL_REQUEST_HEADERS,
322                "AUTHORIZATION,ACCEPT",
323            ))
324            .to_srv_request();
325
326        let resp = test::call_service(&cors, req).await;
327        assert_eq!(
328            Some(&b"*"[..]),
329            resp.headers()
330                .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
331                .map(HeaderValue::as_bytes)
332        );
333        assert_eq!(
334            Some(&b"3600"[..]),
335            resp.headers()
336                .get(header::ACCESS_CONTROL_MAX_AGE)
337                .map(HeaderValue::as_bytes)
338        );
339
340        let hdr = resp
341            .headers()
342            .get(header::ACCESS_CONTROL_ALLOW_HEADERS)
343            .map(val_as_str)
344            .unwrap();
345        assert!(hdr.contains("authorization"));
346        assert!(hdr.contains("accept"));
347        assert!(hdr.contains("content-type"));
348
349        let methods = resp
350            .headers()
351            .get(header::ACCESS_CONTROL_ALLOW_METHODS)
352            .unwrap()
353            .to_str()
354            .unwrap();
355        assert!(methods.contains("POST"));
356        assert!(methods.contains("GET"));
357        assert!(methods.contains("OPTIONS"));
358
359        Rc::get_mut(&mut cors.inner).unwrap().preflight = false;
360
361        let req = TestRequest::default()
362            .method(Method::OPTIONS)
363            .insert_header(("Origin", "https://www.example.com"))
364            .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
365            .insert_header((
366                header::ACCESS_CONTROL_REQUEST_HEADERS,
367                "AUTHORIZATION,ACCEPT",
368            ))
369            .to_srv_request();
370
371        let resp = test::call_service(&cors, req).await;
372        assert_eq!(resp.status(), StatusCode::OK);
373    }
374
375    #[actix_web::test]
376    async fn allow_fn_origin_equals_head_origin() {
377        let cors = Cors::default()
378            .allowed_origin_fn(|origin, head| {
379                let head_origin = head
380                    .headers()
381                    .get(header::ORIGIN)
382                    .expect("unwrapping origin header should never fail in allowed_origin_fn");
383                assert!(origin == head_origin);
384                true
385            })
386            .allow_any_method()
387            .allow_any_header()
388            .new_transform(test::status_service(StatusCode::NO_CONTENT))
389            .await
390            .unwrap();
391
392        let req = TestRequest::default()
393            .method(Method::OPTIONS)
394            .insert_header(("Origin", "https://www.example.com"))
395            .insert_header((header::ACCESS_CONTROL_REQUEST_METHOD, "POST"))
396            .to_srv_request();
397        let resp = test::call_service(&cors, req).await;
398        assert_eq!(resp.status(), StatusCode::OK);
399
400        let req = TestRequest::default()
401            .method(Method::GET)
402            .insert_header(("Origin", "https://www.example.com"))
403            .to_srv_request();
404        let resp = test::call_service(&cors, req).await;
405        assert_eq!(resp.status(), StatusCode::NO_CONTENT);
406    }
407}