actix_cors/
builder.rs

1use std::{collections::HashSet, convert::TryInto, iter::FromIterator, rc::Rc};
2
3use actix_utils::future::{self, Ready};
4use actix_web::{
5    body::{EitherBody, MessageBody},
6    dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform},
7    error::HttpError,
8    http::{
9        header::{HeaderName, HeaderValue},
10        Method, Uri,
11    },
12    Either, Error, Result,
13};
14use log::error;
15use once_cell::sync::Lazy;
16use smallvec::smallvec;
17
18use crate::{AllOrSome, CorsError, CorsMiddleware, Inner, OriginFn};
19
20/// Convenience for getting mut refs to inner. Cleaner than `Rc::get_mut`.
21/// Additionally, always causes first error (if any) to be reported during initialization.
22fn cors<'a>(
23    inner: &'a mut Rc<Inner>,
24    err: &Option<Either<HttpError, CorsError>>,
25) -> Option<&'a mut Inner> {
26    if err.is_some() {
27        return None;
28    }
29
30    Rc::get_mut(inner)
31}
32
33static ALL_METHODS_SET: Lazy<HashSet<Method>> = Lazy::new(|| {
34    HashSet::from_iter(vec![
35        Method::GET,
36        Method::POST,
37        Method::PUT,
38        Method::DELETE,
39        Method::HEAD,
40        Method::OPTIONS,
41        Method::CONNECT,
42        Method::PATCH,
43        Method::TRACE,
44    ])
45});
46
47/// Builder for CORS middleware.
48///
49/// To construct a CORS middleware, call [`Cors::default()`] to create a blank, restrictive builder.
50/// Then use any of the builder methods to customize CORS behavior.
51///
52/// The alternative [`Cors::permissive()`] constructor is available for local development, allowing
53/// all origins and headers, etc. **The permissive constructor should not be used in production.**
54///
55/// # Behavior
56///
57/// In all cases, behavior for this crate follows the [Fetch Standard CORS protocol]. See that
58/// document for information on exact semantics for configuration options and combinations.
59///
60/// # Errors
61///
62/// Errors surface in the middleware initialization phase. This means that, if you have logs enabled
63/// in Actix Web (using `env_logger` or other crate that exposes logs from the `log` crate), error
64/// messages will outline what is wrong with the CORS configuration in the server logs and the
65/// server will fail to start up or serve requests.
66///
67/// # Example
68///
69/// ```
70/// use actix_cors::Cors;
71/// use actix_web::http::header;
72///
73/// let cors = Cors::default()
74///     .allowed_origin("https://www.rust-lang.org")
75///     .allowed_methods(vec!["GET", "POST"])
76///     .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
77///     .allowed_header(header::CONTENT_TYPE)
78///     .max_age(3600);
79///
80/// // `cors` can now be used in `App::wrap`.
81/// ```
82///
83/// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
84#[derive(Debug)]
85pub struct Cors {
86    inner: Rc<Inner>,
87    error: Option<Either<HttpError, CorsError>>,
88}
89
90impl Cors {
91    /// Constructs a very permissive set of defaults for quick development. (Not recommended for
92    /// production use.)
93    ///
94    /// *All* origins, methods, request headers and exposed headers allowed. Credentials supported.
95    /// Max age 1 hour. Does not send wildcard.
96    pub fn permissive() -> Self {
97        let inner = Inner {
98            allowed_origins: AllOrSome::All,
99            allowed_origins_fns: smallvec![],
100
101            allowed_methods: ALL_METHODS_SET.clone(),
102            allowed_methods_baked: None,
103
104            allowed_headers: AllOrSome::All,
105            allowed_headers_baked: None,
106
107            expose_headers: AllOrSome::All,
108            expose_headers_baked: None,
109
110            max_age: Some(3600),
111            preflight: true,
112            send_wildcard: false,
113            supports_credentials: true,
114            #[cfg(feature = "draft-private-network-access")]
115            allow_private_network_access: false,
116            vary_header: true,
117            block_on_origin_mismatch: true,
118        };
119
120        Cors {
121            inner: Rc::new(inner),
122            error: None,
123        }
124    }
125
126    /// Resets allowed origin list to a state where any origin is accepted.
127    ///
128    /// See [`Cors::allowed_origin`] for more info on allowed origins.
129    pub fn allow_any_origin(mut self) -> Cors {
130        if let Some(cors) = cors(&mut self.inner, &self.error) {
131            cors.allowed_origins = AllOrSome::All;
132        }
133
134        self
135    }
136
137    /// Adds an origin that is allowed to make requests.
138    ///
139    /// This method allows specifying a finite set of origins to verify the value of the `Origin`
140    /// request header. These are `origin-or-null` types in the [Fetch Standard].
141    ///
142    /// By default, no origins are accepted.
143    ///
144    /// When this list is set, the client's `Origin` request header will be checked in a
145    /// case-sensitive manner.
146    ///
147    /// When all origins are allowed and `send_wildcard` is set, `*` will be sent in the
148    /// `Access-Control-Allow-Origin` response header. If `send_wildcard` is not set, the client's
149    /// `Origin` request header will be echoed back in the `Access-Control-Allow-Origin`
150    /// response header.
151    ///
152    /// If the origin of the request doesn't match any allowed origins and at least one
153    /// `allowed_origin_fn` function is set, these functions will be used to determinate
154    /// allowed origins.
155    ///
156    /// # Initialization Errors
157    /// - If supplied origin is not valid uri
158    /// - If supplied origin is a wildcard (`*`). [`Cors::send_wildcard`] should be used instead.
159    ///
160    /// [Fetch Standard]: https://fetch.spec.whatwg.org/#origin-header
161    pub fn allowed_origin(mut self, origin: &str) -> Cors {
162        if let Some(cors) = cors(&mut self.inner, &self.error) {
163            match TryInto::<Uri>::try_into(origin) {
164                Ok(_) if origin == "*" => {
165                    error!("Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`.");
166                    self.error = Some(Either::Right(CorsError::WildcardOrigin));
167                }
168
169                Ok(_) => {
170                    if cors.allowed_origins.is_all() {
171                        cors.allowed_origins = AllOrSome::Some(HashSet::with_capacity(8));
172                    }
173
174                    if let Some(origins) = cors.allowed_origins.as_mut() {
175                        // any uri is a valid header value
176                        let hv = origin.try_into().unwrap();
177                        origins.insert(hv);
178                    }
179                }
180
181                Err(err) => {
182                    self.error = Some(Either::Left(err.into()));
183                }
184            }
185        }
186
187        self
188    }
189
190    /// Determinates allowed origins by processing requests which didn't match any origins specified
191    /// in the `allowed_origin`.
192    ///
193    /// The function will receive two parameters, the Origin header value, and the `RequestHead` of
194    /// each request, which can be used to determine whether to allow the request or not.
195    ///
196    /// If the function returns `true`, the client's `Origin` request header will be echoed back
197    /// into the `Access-Control-Allow-Origin` response header.
198    pub fn allowed_origin_fn<F>(mut self, f: F) -> Cors
199    where
200        F: (Fn(&HeaderValue, &RequestHead) -> bool) + 'static,
201    {
202        if let Some(cors) = cors(&mut self.inner, &self.error) {
203            cors.allowed_origins_fns.push(OriginFn {
204                boxed_fn: Rc::new(f),
205            });
206        }
207
208        self
209    }
210
211    /// Resets allowed methods list to all methods.
212    ///
213    /// See [`Cors::allowed_methods`] for more info on allowed methods.
214    pub fn allow_any_method(mut self) -> Cors {
215        if let Some(cors) = cors(&mut self.inner, &self.error) {
216            cors.allowed_methods = ALL_METHODS_SET.clone();
217        }
218
219        self
220    }
221
222    /// Sets a list of methods which allowed origins can perform.
223    ///
224    /// These will be sent in the `Access-Control-Allow-Methods` response header.
225    ///
226    /// This defaults to an empty set.
227    pub fn allowed_methods<U, M>(mut self, methods: U) -> Cors
228    where
229        U: IntoIterator<Item = M>,
230        M: TryInto<Method>,
231        <M as TryInto<Method>>::Error: Into<HttpError>,
232    {
233        if let Some(cors) = cors(&mut self.inner, &self.error) {
234            for m in methods {
235                match m.try_into() {
236                    Ok(method) => {
237                        cors.allowed_methods.insert(method);
238                    }
239
240                    Err(err) => {
241                        self.error = Some(Either::Left(err.into()));
242                        break;
243                    }
244                }
245            }
246        }
247
248        self
249    }
250
251    /// Resets allowed request header list to a state where any header is accepted.
252    ///
253    /// See [`Cors::allowed_headers`] for more info on allowed request headers.
254    pub fn allow_any_header(mut self) -> Cors {
255        if let Some(cors) = cors(&mut self.inner, &self.error) {
256            cors.allowed_headers = AllOrSome::All;
257        }
258
259        self
260    }
261
262    /// Add an allowed request header.
263    ///
264    /// See [`Cors::allowed_headers`] for more info on allowed request headers.
265    pub fn allowed_header<H>(mut self, header: H) -> Cors
266    where
267        H: TryInto<HeaderName>,
268        <H as TryInto<HeaderName>>::Error: Into<HttpError>,
269    {
270        if let Some(cors) = cors(&mut self.inner, &self.error) {
271            match header.try_into() {
272                Ok(method) => {
273                    if cors.allowed_headers.is_all() {
274                        cors.allowed_headers = AllOrSome::Some(HashSet::with_capacity(8));
275                    }
276
277                    if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
278                        headers.insert(method);
279                    }
280                }
281
282                Err(err) => self.error = Some(Either::Left(err.into())),
283            }
284        }
285
286        self
287    }
288
289    /// Sets a list of request header field names which can be used when this resource is accessed
290    /// by allowed origins.
291    ///
292    /// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers`
293    /// will be echoed back in the `Access-Control-Allow-Headers` header.
294    ///
295    /// This defaults to an empty set.
296    pub fn allowed_headers<U, H>(mut self, headers: U) -> Cors
297    where
298        U: IntoIterator<Item = H>,
299        H: TryInto<HeaderName>,
300        <H as TryInto<HeaderName>>::Error: Into<HttpError>,
301    {
302        if let Some(cors) = cors(&mut self.inner, &self.error) {
303            for h in headers {
304                match h.try_into() {
305                    Ok(method) => {
306                        if cors.allowed_headers.is_all() {
307                            cors.allowed_headers = AllOrSome::Some(HashSet::with_capacity(8));
308                        }
309
310                        if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
311                            headers.insert(method);
312                        }
313                    }
314                    Err(err) => {
315                        self.error = Some(Either::Left(err.into()));
316                        break;
317                    }
318                }
319            }
320        }
321
322        self
323    }
324
325    /// Resets exposed response header list to a state where all headers are exposed.
326    ///
327    /// See [`Cors::expose_headers`] for more info on exposed response headers.
328    pub fn expose_any_header(mut self) -> Cors {
329        if let Some(cors) = cors(&mut self.inner, &self.error) {
330            cors.expose_headers = AllOrSome::All;
331        }
332
333        self
334    }
335
336    /// Sets a list of headers which are safe to expose to the API of a CORS API specification.
337    ///
338    /// This corresponds to the `Access-Control-Expose-Headers` response header.
339    ///
340    /// This defaults to an empty set.
341    pub fn expose_headers<U, H>(mut self, headers: U) -> Cors
342    where
343        U: IntoIterator<Item = H>,
344        H: TryInto<HeaderName>,
345        <H as TryInto<HeaderName>>::Error: Into<HttpError>,
346    {
347        for h in headers {
348            match h.try_into() {
349                Ok(header) => {
350                    if let Some(cors) = cors(&mut self.inner, &self.error) {
351                        if cors.expose_headers.is_all() {
352                            cors.expose_headers = AllOrSome::Some(HashSet::with_capacity(8));
353                        }
354                        if let AllOrSome::Some(ref mut headers) = cors.expose_headers {
355                            headers.insert(header);
356                        }
357                    }
358                }
359                Err(err) => {
360                    self.error = Some(Either::Left(err.into()));
361                    break;
362                }
363            }
364        }
365
366        self
367    }
368
369    /// Sets a maximum time (in seconds) for which this CORS request may be cached.
370    ///
371    /// This value is set as the `Access-Control-Max-Age` header.
372    ///
373    /// Pass a number (of seconds) or use None to disable sending max age header.
374    pub fn max_age(mut self, max_age: impl Into<Option<usize>>) -> Cors {
375        if let Some(cors) = cors(&mut self.inner, &self.error) {
376            cors.max_age = max_age.into();
377        }
378
379        self
380    }
381
382    /// Configures use of wildcard (`*`) origin in responses when appropriate.
383    ///
384    /// If send wildcard is set and the `allowed_origins` parameter is `All`, a wildcard
385    /// `Access-Control-Allow-Origin` response header is sent, rather than the request’s
386    /// `Origin` header.
387    ///
388    /// This option **CANNOT** be used in conjunction with a [credential
389    /// supported](Self::supports_credentials()) configuration. Doing so will result in an error
390    /// during server startup.
391    ///
392    /// Defaults to disabled.
393    pub fn send_wildcard(mut self) -> Cors {
394        if let Some(cors) = cors(&mut self.inner, &self.error) {
395            cors.send_wildcard = true;
396        }
397
398        self
399    }
400
401    /// Allows users to make authenticated requests.
402    ///
403    /// If true, injects the `Access-Control-Allow-Credentials` header in responses. This allows
404    /// cookies and credentials to be submitted across domains.
405    ///
406    /// This option **CANNOT** be used in conjunction with option cannot be used in conjunction
407    /// with [wildcard origins](Self::send_wildcard()) configured. Doing so will result in an error
408    /// during server startup.
409    ///
410    /// Defaults to disabled.
411    pub fn supports_credentials(mut self) -> Cors {
412        if let Some(cors) = cors(&mut self.inner, &self.error) {
413            cors.supports_credentials = true;
414        }
415
416        self
417    }
418
419    /// Allow private network access.
420    ///
421    /// If true, injects the `Access-Control-Allow-Private-Network: true` header in responses if the
422    /// request contained the `Access-Control-Request-Private-Network: true` header.
423    ///
424    /// For more information on this behavior, see the draft [Private Network Access] spec.
425    ///
426    /// Defaults to `false`.
427    ///
428    /// [Private Network Access]: https://wicg.github.io/private-network-access
429    #[cfg(feature = "draft-private-network-access")]
430    pub fn allow_private_network_access(mut self) -> Cors {
431        if let Some(cors) = cors(&mut self.inner, &self.error) {
432            cors.allow_private_network_access = true;
433        }
434
435        self
436    }
437
438    /// Disables `Vary` header support.
439    ///
440    /// When enabled the header `Vary: Origin` will be returned as per the Fetch Standard
441    /// implementation guidelines.
442    ///
443    /// Setting this header when the `Access-Control-Allow-Origin` is dynamically generated
444    /// (eg. when there is more than one allowed origin, and an Origin other than '*' is returned)
445    /// informs CDNs and other caches that the CORS headers are dynamic, and cannot be cached.
446    ///
447    /// By default, `Vary` header support is enabled.
448    pub fn disable_vary_header(mut self) -> Cors {
449        if let Some(cors) = cors(&mut self.inner, &self.error) {
450            cors.vary_header = false;
451        }
452
453        self
454    }
455
456    /// Disables preflight request handling.
457    ///
458    /// When enabled CORS middleware automatically handles `OPTIONS` requests. This is useful for
459    /// application level middleware.
460    ///
461    /// By default, preflight support is enabled.
462    pub fn disable_preflight(mut self) -> Cors {
463        if let Some(cors) = cors(&mut self.inner, &self.error) {
464            cors.preflight = false;
465        }
466
467        self
468    }
469
470    /// Configures whether requests should be pre-emptively blocked on mismatched origin.
471    ///
472    /// If `true`, a 400 Bad Request is returned immediately when a request fails origin validation.
473    ///
474    /// If `false`, the request will be processed as normal but relevant CORS headers will not be
475    /// appended to the response. In this case, the browser is trusted to validate CORS headers and
476    /// and block requests based on pre-flight requests. Use this setting to allow cURL and other
477    /// non-browser HTTP clients to function as normal, no matter what `Origin` the request has.
478    ///
479    /// Defaults to true.
480    pub fn block_on_origin_mismatch(mut self, block: bool) -> Cors {
481        if let Some(cors) = cors(&mut self.inner, &self.error) {
482            cors.block_on_origin_mismatch = block;
483        }
484
485        self
486    }
487}
488
489impl Default for Cors {
490    /// A restrictive (security paranoid) set of defaults.
491    ///
492    /// *No* allowed origins, methods, request headers or exposed headers. Credentials
493    /// not supported. No max age (will use browser's default).
494    fn default() -> Cors {
495        let inner = Inner {
496            allowed_origins: AllOrSome::Some(HashSet::with_capacity(8)),
497            allowed_origins_fns: smallvec![],
498
499            allowed_methods: HashSet::with_capacity(8),
500            allowed_methods_baked: None,
501
502            allowed_headers: AllOrSome::Some(HashSet::with_capacity(8)),
503            allowed_headers_baked: None,
504
505            expose_headers: AllOrSome::Some(HashSet::with_capacity(8)),
506            expose_headers_baked: None,
507
508            max_age: None,
509            preflight: true,
510            send_wildcard: false,
511            supports_credentials: false,
512            #[cfg(feature = "draft-private-network-access")]
513            allow_private_network_access: false,
514            vary_header: true,
515            block_on_origin_mismatch: true,
516        };
517
518        Cors {
519            inner: Rc::new(inner),
520            error: None,
521        }
522    }
523}
524
525impl<S, B> Transform<S, ServiceRequest> for Cors
526where
527    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
528    S::Future: 'static,
529
530    B: MessageBody + 'static,
531{
532    type Response = ServiceResponse<EitherBody<B>>;
533    type Error = Error;
534    type InitError = ();
535    type Transform = CorsMiddleware<S>;
536    type Future = Ready<Result<Self::Transform, Self::InitError>>;
537
538    fn new_transform(&self, service: S) -> Self::Future {
539        if let Some(ref err) = self.error {
540            match err {
541                Either::Left(err) => error!("{}", err),
542                Either::Right(err) => error!("{}", err),
543            }
544
545            return future::err(());
546        }
547
548        let mut inner = Rc::clone(&self.inner);
549
550        if inner.supports_credentials && inner.send_wildcard && inner.allowed_origins.is_all() {
551            error!(
552                "Illegal combination of CORS options: credentials can not be supported when all \
553                    origins are allowed and `send_wildcard` is enabled."
554            );
555            return future::err(());
556        }
557
558        // bake allowed headers value if Some and not empty
559        match inner.allowed_headers.as_ref() {
560            Some(header_set) if !header_set.is_empty() => {
561                let allowed_headers_str = intersperse_header_values(header_set);
562                Rc::make_mut(&mut inner).allowed_headers_baked = Some(allowed_headers_str);
563            }
564            _ => {}
565        }
566
567        // bake allowed methods value if not empty
568        if !inner.allowed_methods.is_empty() {
569            let allowed_methods_str = intersperse_header_values(&inner.allowed_methods);
570            Rc::make_mut(&mut inner).allowed_methods_baked = Some(allowed_methods_str);
571        }
572
573        // bake exposed headers value if Some and not empty
574        match inner.expose_headers.as_ref() {
575            Some(header_set) if !header_set.is_empty() => {
576                let expose_headers_str = intersperse_header_values(header_set);
577                Rc::make_mut(&mut inner).expose_headers_baked = Some(expose_headers_str);
578            }
579            _ => {}
580        }
581
582        future::ok(CorsMiddleware { service, inner })
583    }
584}
585
586/// Only call when values are guaranteed to be valid header values and set is not empty.
587pub(crate) fn intersperse_header_values<T>(val_set: &HashSet<T>) -> HeaderValue
588where
589    T: AsRef<str>,
590{
591    debug_assert!(
592        !val_set.is_empty(),
593        "only call `intersperse_header_values` when set is not empty"
594    );
595
596    val_set
597        .iter()
598        .fold(String::with_capacity(64), |mut acc, val| {
599            acc.push_str(", ");
600            acc.push_str(val.as_ref());
601            acc
602        })
603        // set is not empty so string will always have leading ", " to trim
604        [2..]
605        .try_into()
606        // all method names are valid header values
607        .unwrap()
608}
609
610#[cfg(test)]
611mod test {
612    use std::convert::{Infallible, TryInto};
613
614    use actix_web::{
615        body,
616        dev::{fn_service, Transform},
617        http::{header::HeaderName, StatusCode},
618        test::{self, TestRequest},
619        HttpResponse,
620    };
621
622    use super::*;
623
624    #[test]
625    fn illegal_allow_credentials() {
626        // using the permissive defaults (all origins allowed) and adding send_wildcard
627        // and supports_credentials should error on construction
628
629        assert!(Cors::permissive()
630            .supports_credentials()
631            .send_wildcard()
632            .new_transform(test::ok_service())
633            .into_inner()
634            .is_err());
635    }
636
637    #[actix_web::test]
638    async fn restrictive_defaults() {
639        let cors = Cors::default()
640            .new_transform(test::ok_service())
641            .await
642            .unwrap();
643
644        let req = TestRequest::default()
645            .insert_header(("Origin", "https://www.example.com"))
646            .to_srv_request();
647
648        let resp = test::call_service(&cors, req).await;
649        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
650    }
651
652    #[actix_web::test]
653    async fn allowed_header_try_from() {
654        let _cors = Cors::default().allowed_header("Content-Type");
655    }
656
657    #[actix_web::test]
658    async fn allowed_header_try_into() {
659        struct ContentType;
660
661        impl TryInto<HeaderName> for ContentType {
662            type Error = Infallible;
663
664            fn try_into(self) -> Result<HeaderName, Self::Error> {
665                Ok(HeaderName::from_static("content-type"))
666            }
667        }
668
669        let _cors = Cors::default().allowed_header(ContentType);
670    }
671
672    #[actix_web::test]
673    async fn middleware_generic_over_body_type() {
674        let srv = fn_service(|req: ServiceRequest| async move {
675            Ok(req.into_response(HttpResponse::with_body(StatusCode::OK, body::None::new())))
676        });
677
678        Cors::default().new_transform(srv).await.unwrap();
679    }
680}