actix_cors/builder.rs
1use std::{collections::HashSet, convert::TryInto, iter::FromIterator, rc::Rc};
2
3use actix_utils::future::{self, Ready};
4use actix_web::{
5 body::{EitherBody, MessageBody},
6 dev::{RequestHead, Service, ServiceRequest, ServiceResponse, Transform},
7 error::HttpError,
8 http::{
9 header::{HeaderName, HeaderValue},
10 Method, Uri,
11 },
12 Either, Error, Result,
13};
14use log::error;
15use once_cell::sync::Lazy;
16use smallvec::smallvec;
17
18use crate::{AllOrSome, CorsError, CorsMiddleware, Inner, OriginFn};
19
20/// Convenience for getting mut refs to inner. Cleaner than `Rc::get_mut`.
21/// Additionally, always causes first error (if any) to be reported during initialization.
22fn cors<'a>(
23 inner: &'a mut Rc<Inner>,
24 err: &Option<Either<HttpError, CorsError>>,
25) -> Option<&'a mut Inner> {
26 if err.is_some() {
27 return None;
28 }
29
30 Rc::get_mut(inner)
31}
32
33static ALL_METHODS_SET: Lazy<HashSet<Method>> = Lazy::new(|| {
34 HashSet::from_iter(vec![
35 Method::GET,
36 Method::POST,
37 Method::PUT,
38 Method::DELETE,
39 Method::HEAD,
40 Method::OPTIONS,
41 Method::CONNECT,
42 Method::PATCH,
43 Method::TRACE,
44 ])
45});
46
47/// Builder for CORS middleware.
48///
49/// To construct a CORS middleware, call [`Cors::default()`] to create a blank, restrictive builder.
50/// Then use any of the builder methods to customize CORS behavior.
51///
52/// The alternative [`Cors::permissive()`] constructor is available for local development, allowing
53/// all origins and headers, etc. **The permissive constructor should not be used in production.**
54///
55/// # Behavior
56///
57/// In all cases, behavior for this crate follows the [Fetch Standard CORS protocol]. See that
58/// document for information on exact semantics for configuration options and combinations.
59///
60/// # Errors
61///
62/// Errors surface in the middleware initialization phase. This means that, if you have logs enabled
63/// in Actix Web (using `env_logger` or other crate that exposes logs from the `log` crate), error
64/// messages will outline what is wrong with the CORS configuration in the server logs and the
65/// server will fail to start up or serve requests.
66///
67/// # Example
68///
69/// ```
70/// use actix_cors::Cors;
71/// use actix_web::http::header;
72///
73/// let cors = Cors::default()
74/// .allowed_origin("https://www.rust-lang.org")
75/// .allowed_methods(vec!["GET", "POST"])
76/// .allowed_headers(vec![header::AUTHORIZATION, header::ACCEPT])
77/// .allowed_header(header::CONTENT_TYPE)
78/// .max_age(3600);
79///
80/// // `cors` can now be used in `App::wrap`.
81/// ```
82///
83/// [Fetch Standard CORS protocol]: https://fetch.spec.whatwg.org/#http-cors-protocol
84#[derive(Debug)]
85pub struct Cors {
86 inner: Rc<Inner>,
87 error: Option<Either<HttpError, CorsError>>,
88}
89
90impl Cors {
91 /// Constructs a very permissive set of defaults for quick development. (Not recommended for
92 /// production use.)
93 ///
94 /// *All* origins, methods, request headers and exposed headers allowed. Credentials supported.
95 /// Max age 1 hour. Does not send wildcard.
96 pub fn permissive() -> Self {
97 let inner = Inner {
98 allowed_origins: AllOrSome::All,
99 allowed_origins_fns: smallvec![],
100
101 allowed_methods: ALL_METHODS_SET.clone(),
102 allowed_methods_baked: None,
103
104 allowed_headers: AllOrSome::All,
105 allowed_headers_baked: None,
106
107 expose_headers: AllOrSome::All,
108 expose_headers_baked: None,
109
110 max_age: Some(3600),
111 preflight: true,
112 send_wildcard: false,
113 supports_credentials: true,
114 #[cfg(feature = "draft-private-network-access")]
115 allow_private_network_access: false,
116 vary_header: true,
117 block_on_origin_mismatch: true,
118 };
119
120 Cors {
121 inner: Rc::new(inner),
122 error: None,
123 }
124 }
125
126 /// Resets allowed origin list to a state where any origin is accepted.
127 ///
128 /// See [`Cors::allowed_origin`] for more info on allowed origins.
129 pub fn allow_any_origin(mut self) -> Cors {
130 if let Some(cors) = cors(&mut self.inner, &self.error) {
131 cors.allowed_origins = AllOrSome::All;
132 }
133
134 self
135 }
136
137 /// Adds an origin that is allowed to make requests.
138 ///
139 /// This method allows specifying a finite set of origins to verify the value of the `Origin`
140 /// request header. These are `origin-or-null` types in the [Fetch Standard].
141 ///
142 /// By default, no origins are accepted.
143 ///
144 /// When this list is set, the client's `Origin` request header will be checked in a
145 /// case-sensitive manner.
146 ///
147 /// When all origins are allowed and `send_wildcard` is set, `*` will be sent in the
148 /// `Access-Control-Allow-Origin` response header. If `send_wildcard` is not set, the client's
149 /// `Origin` request header will be echoed back in the `Access-Control-Allow-Origin`
150 /// response header.
151 ///
152 /// If the origin of the request doesn't match any allowed origins and at least one
153 /// `allowed_origin_fn` function is set, these functions will be used to determinate
154 /// allowed origins.
155 ///
156 /// # Initialization Errors
157 /// - If supplied origin is not valid uri
158 /// - If supplied origin is a wildcard (`*`). [`Cors::send_wildcard`] should be used instead.
159 ///
160 /// [Fetch Standard]: https://fetch.spec.whatwg.org/#origin-header
161 pub fn allowed_origin(mut self, origin: &str) -> Cors {
162 if let Some(cors) = cors(&mut self.inner, &self.error) {
163 match TryInto::<Uri>::try_into(origin) {
164 Ok(_) if origin == "*" => {
165 error!("Wildcard in `allowed_origin` is not allowed. Use `send_wildcard`.");
166 self.error = Some(Either::Right(CorsError::WildcardOrigin));
167 }
168
169 Ok(_) => {
170 if cors.allowed_origins.is_all() {
171 cors.allowed_origins = AllOrSome::Some(HashSet::with_capacity(8));
172 }
173
174 if let Some(origins) = cors.allowed_origins.as_mut() {
175 // any uri is a valid header value
176 let hv = origin.try_into().unwrap();
177 origins.insert(hv);
178 }
179 }
180
181 Err(err) => {
182 self.error = Some(Either::Left(err.into()));
183 }
184 }
185 }
186
187 self
188 }
189
190 /// Determinates allowed origins by processing requests which didn't match any origins specified
191 /// in the `allowed_origin`.
192 ///
193 /// The function will receive two parameters, the Origin header value, and the `RequestHead` of
194 /// each request, which can be used to determine whether to allow the request or not.
195 ///
196 /// If the function returns `true`, the client's `Origin` request header will be echoed back
197 /// into the `Access-Control-Allow-Origin` response header.
198 pub fn allowed_origin_fn<F>(mut self, f: F) -> Cors
199 where
200 F: (Fn(&HeaderValue, &RequestHead) -> bool) + 'static,
201 {
202 if let Some(cors) = cors(&mut self.inner, &self.error) {
203 cors.allowed_origins_fns.push(OriginFn {
204 boxed_fn: Rc::new(f),
205 });
206 }
207
208 self
209 }
210
211 /// Resets allowed methods list to all methods.
212 ///
213 /// See [`Cors::allowed_methods`] for more info on allowed methods.
214 pub fn allow_any_method(mut self) -> Cors {
215 if let Some(cors) = cors(&mut self.inner, &self.error) {
216 cors.allowed_methods = ALL_METHODS_SET.clone();
217 }
218
219 self
220 }
221
222 /// Sets a list of methods which allowed origins can perform.
223 ///
224 /// These will be sent in the `Access-Control-Allow-Methods` response header.
225 ///
226 /// This defaults to an empty set.
227 pub fn allowed_methods<U, M>(mut self, methods: U) -> Cors
228 where
229 U: IntoIterator<Item = M>,
230 M: TryInto<Method>,
231 <M as TryInto<Method>>::Error: Into<HttpError>,
232 {
233 if let Some(cors) = cors(&mut self.inner, &self.error) {
234 for m in methods {
235 match m.try_into() {
236 Ok(method) => {
237 cors.allowed_methods.insert(method);
238 }
239
240 Err(err) => {
241 self.error = Some(Either::Left(err.into()));
242 break;
243 }
244 }
245 }
246 }
247
248 self
249 }
250
251 /// Resets allowed request header list to a state where any header is accepted.
252 ///
253 /// See [`Cors::allowed_headers`] for more info on allowed request headers.
254 pub fn allow_any_header(mut self) -> Cors {
255 if let Some(cors) = cors(&mut self.inner, &self.error) {
256 cors.allowed_headers = AllOrSome::All;
257 }
258
259 self
260 }
261
262 /// Add an allowed request header.
263 ///
264 /// See [`Cors::allowed_headers`] for more info on allowed request headers.
265 pub fn allowed_header<H>(mut self, header: H) -> Cors
266 where
267 H: TryInto<HeaderName>,
268 <H as TryInto<HeaderName>>::Error: Into<HttpError>,
269 {
270 if let Some(cors) = cors(&mut self.inner, &self.error) {
271 match header.try_into() {
272 Ok(method) => {
273 if cors.allowed_headers.is_all() {
274 cors.allowed_headers = AllOrSome::Some(HashSet::with_capacity(8));
275 }
276
277 if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
278 headers.insert(method);
279 }
280 }
281
282 Err(err) => self.error = Some(Either::Left(err.into())),
283 }
284 }
285
286 self
287 }
288
289 /// Sets a list of request header field names which can be used when this resource is accessed
290 /// by allowed origins.
291 ///
292 /// If `All` is set, whatever is requested by the client in `Access-Control-Request-Headers`
293 /// will be echoed back in the `Access-Control-Allow-Headers` header.
294 ///
295 /// This defaults to an empty set.
296 pub fn allowed_headers<U, H>(mut self, headers: U) -> Cors
297 where
298 U: IntoIterator<Item = H>,
299 H: TryInto<HeaderName>,
300 <H as TryInto<HeaderName>>::Error: Into<HttpError>,
301 {
302 if let Some(cors) = cors(&mut self.inner, &self.error) {
303 for h in headers {
304 match h.try_into() {
305 Ok(method) => {
306 if cors.allowed_headers.is_all() {
307 cors.allowed_headers = AllOrSome::Some(HashSet::with_capacity(8));
308 }
309
310 if let AllOrSome::Some(ref mut headers) = cors.allowed_headers {
311 headers.insert(method);
312 }
313 }
314 Err(err) => {
315 self.error = Some(Either::Left(err.into()));
316 break;
317 }
318 }
319 }
320 }
321
322 self
323 }
324
325 /// Resets exposed response header list to a state where all headers are exposed.
326 ///
327 /// See [`Cors::expose_headers`] for more info on exposed response headers.
328 pub fn expose_any_header(mut self) -> Cors {
329 if let Some(cors) = cors(&mut self.inner, &self.error) {
330 cors.expose_headers = AllOrSome::All;
331 }
332
333 self
334 }
335
336 /// Sets a list of headers which are safe to expose to the API of a CORS API specification.
337 ///
338 /// This corresponds to the `Access-Control-Expose-Headers` response header.
339 ///
340 /// This defaults to an empty set.
341 pub fn expose_headers<U, H>(mut self, headers: U) -> Cors
342 where
343 U: IntoIterator<Item = H>,
344 H: TryInto<HeaderName>,
345 <H as TryInto<HeaderName>>::Error: Into<HttpError>,
346 {
347 for h in headers {
348 match h.try_into() {
349 Ok(header) => {
350 if let Some(cors) = cors(&mut self.inner, &self.error) {
351 if cors.expose_headers.is_all() {
352 cors.expose_headers = AllOrSome::Some(HashSet::with_capacity(8));
353 }
354 if let AllOrSome::Some(ref mut headers) = cors.expose_headers {
355 headers.insert(header);
356 }
357 }
358 }
359 Err(err) => {
360 self.error = Some(Either::Left(err.into()));
361 break;
362 }
363 }
364 }
365
366 self
367 }
368
369 /// Sets a maximum time (in seconds) for which this CORS request may be cached.
370 ///
371 /// This value is set as the `Access-Control-Max-Age` header.
372 ///
373 /// Pass a number (of seconds) or use None to disable sending max age header.
374 pub fn max_age(mut self, max_age: impl Into<Option<usize>>) -> Cors {
375 if let Some(cors) = cors(&mut self.inner, &self.error) {
376 cors.max_age = max_age.into();
377 }
378
379 self
380 }
381
382 /// Configures use of wildcard (`*`) origin in responses when appropriate.
383 ///
384 /// If send wildcard is set and the `allowed_origins` parameter is `All`, a wildcard
385 /// `Access-Control-Allow-Origin` response header is sent, rather than the request’s
386 /// `Origin` header.
387 ///
388 /// This option **CANNOT** be used in conjunction with a [credential
389 /// supported](Self::supports_credentials()) configuration. Doing so will result in an error
390 /// during server startup.
391 ///
392 /// Defaults to disabled.
393 pub fn send_wildcard(mut self) -> Cors {
394 if let Some(cors) = cors(&mut self.inner, &self.error) {
395 cors.send_wildcard = true;
396 }
397
398 self
399 }
400
401 /// Allows users to make authenticated requests.
402 ///
403 /// If true, injects the `Access-Control-Allow-Credentials` header in responses. This allows
404 /// cookies and credentials to be submitted across domains.
405 ///
406 /// This option **CANNOT** be used in conjunction with option cannot be used in conjunction
407 /// with [wildcard origins](Self::send_wildcard()) configured. Doing so will result in an error
408 /// during server startup.
409 ///
410 /// Defaults to disabled.
411 pub fn supports_credentials(mut self) -> Cors {
412 if let Some(cors) = cors(&mut self.inner, &self.error) {
413 cors.supports_credentials = true;
414 }
415
416 self
417 }
418
419 /// Allow private network access.
420 ///
421 /// If true, injects the `Access-Control-Allow-Private-Network: true` header in responses if the
422 /// request contained the `Access-Control-Request-Private-Network: true` header.
423 ///
424 /// For more information on this behavior, see the draft [Private Network Access] spec.
425 ///
426 /// Defaults to `false`.
427 ///
428 /// [Private Network Access]: https://wicg.github.io/private-network-access
429 #[cfg(feature = "draft-private-network-access")]
430 pub fn allow_private_network_access(mut self) -> Cors {
431 if let Some(cors) = cors(&mut self.inner, &self.error) {
432 cors.allow_private_network_access = true;
433 }
434
435 self
436 }
437
438 /// Disables `Vary` header support.
439 ///
440 /// When enabled the header `Vary: Origin` will be returned as per the Fetch Standard
441 /// implementation guidelines.
442 ///
443 /// Setting this header when the `Access-Control-Allow-Origin` is dynamically generated
444 /// (eg. when there is more than one allowed origin, and an Origin other than '*' is returned)
445 /// informs CDNs and other caches that the CORS headers are dynamic, and cannot be cached.
446 ///
447 /// By default, `Vary` header support is enabled.
448 pub fn disable_vary_header(mut self) -> Cors {
449 if let Some(cors) = cors(&mut self.inner, &self.error) {
450 cors.vary_header = false;
451 }
452
453 self
454 }
455
456 /// Disables preflight request handling.
457 ///
458 /// When enabled CORS middleware automatically handles `OPTIONS` requests. This is useful for
459 /// application level middleware.
460 ///
461 /// By default, preflight support is enabled.
462 pub fn disable_preflight(mut self) -> Cors {
463 if let Some(cors) = cors(&mut self.inner, &self.error) {
464 cors.preflight = false;
465 }
466
467 self
468 }
469
470 /// Configures whether requests should be pre-emptively blocked on mismatched origin.
471 ///
472 /// If `true`, a 400 Bad Request is returned immediately when a request fails origin validation.
473 ///
474 /// If `false`, the request will be processed as normal but relevant CORS headers will not be
475 /// appended to the response. In this case, the browser is trusted to validate CORS headers and
476 /// and block requests based on pre-flight requests. Use this setting to allow cURL and other
477 /// non-browser HTTP clients to function as normal, no matter what `Origin` the request has.
478 ///
479 /// Defaults to true.
480 pub fn block_on_origin_mismatch(mut self, block: bool) -> Cors {
481 if let Some(cors) = cors(&mut self.inner, &self.error) {
482 cors.block_on_origin_mismatch = block;
483 }
484
485 self
486 }
487}
488
489impl Default for Cors {
490 /// A restrictive (security paranoid) set of defaults.
491 ///
492 /// *No* allowed origins, methods, request headers or exposed headers. Credentials
493 /// not supported. No max age (will use browser's default).
494 fn default() -> Cors {
495 let inner = Inner {
496 allowed_origins: AllOrSome::Some(HashSet::with_capacity(8)),
497 allowed_origins_fns: smallvec![],
498
499 allowed_methods: HashSet::with_capacity(8),
500 allowed_methods_baked: None,
501
502 allowed_headers: AllOrSome::Some(HashSet::with_capacity(8)),
503 allowed_headers_baked: None,
504
505 expose_headers: AllOrSome::Some(HashSet::with_capacity(8)),
506 expose_headers_baked: None,
507
508 max_age: None,
509 preflight: true,
510 send_wildcard: false,
511 supports_credentials: false,
512 #[cfg(feature = "draft-private-network-access")]
513 allow_private_network_access: false,
514 vary_header: true,
515 block_on_origin_mismatch: true,
516 };
517
518 Cors {
519 inner: Rc::new(inner),
520 error: None,
521 }
522 }
523}
524
525impl<S, B> Transform<S, ServiceRequest> for Cors
526where
527 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
528 S::Future: 'static,
529
530 B: MessageBody + 'static,
531{
532 type Response = ServiceResponse<EitherBody<B>>;
533 type Error = Error;
534 type InitError = ();
535 type Transform = CorsMiddleware<S>;
536 type Future = Ready<Result<Self::Transform, Self::InitError>>;
537
538 fn new_transform(&self, service: S) -> Self::Future {
539 if let Some(ref err) = self.error {
540 match err {
541 Either::Left(err) => error!("{}", err),
542 Either::Right(err) => error!("{}", err),
543 }
544
545 return future::err(());
546 }
547
548 let mut inner = Rc::clone(&self.inner);
549
550 if inner.supports_credentials && inner.send_wildcard && inner.allowed_origins.is_all() {
551 error!(
552 "Illegal combination of CORS options: credentials can not be supported when all \
553 origins are allowed and `send_wildcard` is enabled."
554 );
555 return future::err(());
556 }
557
558 // bake allowed headers value if Some and not empty
559 match inner.allowed_headers.as_ref() {
560 Some(header_set) if !header_set.is_empty() => {
561 let allowed_headers_str = intersperse_header_values(header_set);
562 Rc::make_mut(&mut inner).allowed_headers_baked = Some(allowed_headers_str);
563 }
564 _ => {}
565 }
566
567 // bake allowed methods value if not empty
568 if !inner.allowed_methods.is_empty() {
569 let allowed_methods_str = intersperse_header_values(&inner.allowed_methods);
570 Rc::make_mut(&mut inner).allowed_methods_baked = Some(allowed_methods_str);
571 }
572
573 // bake exposed headers value if Some and not empty
574 match inner.expose_headers.as_ref() {
575 Some(header_set) if !header_set.is_empty() => {
576 let expose_headers_str = intersperse_header_values(header_set);
577 Rc::make_mut(&mut inner).expose_headers_baked = Some(expose_headers_str);
578 }
579 _ => {}
580 }
581
582 future::ok(CorsMiddleware { service, inner })
583 }
584}
585
586/// Only call when values are guaranteed to be valid header values and set is not empty.
587pub(crate) fn intersperse_header_values<T>(val_set: &HashSet<T>) -> HeaderValue
588where
589 T: AsRef<str>,
590{
591 debug_assert!(
592 !val_set.is_empty(),
593 "only call `intersperse_header_values` when set is not empty"
594 );
595
596 val_set
597 .iter()
598 .fold(String::with_capacity(64), |mut acc, val| {
599 acc.push_str(", ");
600 acc.push_str(val.as_ref());
601 acc
602 })
603 // set is not empty so string will always have leading ", " to trim
604 [2..]
605 .try_into()
606 // all method names are valid header values
607 .unwrap()
608}
609
610#[cfg(test)]
611mod test {
612 use std::convert::{Infallible, TryInto};
613
614 use actix_web::{
615 body,
616 dev::{fn_service, Transform},
617 http::{header::HeaderName, StatusCode},
618 test::{self, TestRequest},
619 HttpResponse,
620 };
621
622 use super::*;
623
624 #[test]
625 fn illegal_allow_credentials() {
626 // using the permissive defaults (all origins allowed) and adding send_wildcard
627 // and supports_credentials should error on construction
628
629 assert!(Cors::permissive()
630 .supports_credentials()
631 .send_wildcard()
632 .new_transform(test::ok_service())
633 .into_inner()
634 .is_err());
635 }
636
637 #[actix_web::test]
638 async fn restrictive_defaults() {
639 let cors = Cors::default()
640 .new_transform(test::ok_service())
641 .await
642 .unwrap();
643
644 let req = TestRequest::default()
645 .insert_header(("Origin", "https://www.example.com"))
646 .to_srv_request();
647
648 let resp = test::call_service(&cors, req).await;
649 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
650 }
651
652 #[actix_web::test]
653 async fn allowed_header_try_from() {
654 let _cors = Cors::default().allowed_header("Content-Type");
655 }
656
657 #[actix_web::test]
658 async fn allowed_header_try_into() {
659 struct ContentType;
660
661 impl TryInto<HeaderName> for ContentType {
662 type Error = Infallible;
663
664 fn try_into(self) -> Result<HeaderName, Self::Error> {
665 Ok(HeaderName::from_static("content-type"))
666 }
667 }
668
669 let _cors = Cors::default().allowed_header(ContentType);
670 }
671
672 #[actix_web::test]
673 async fn middleware_generic_over_body_type() {
674 let srv = fn_service(|req: ServiceRequest| async move {
675 Ok(req.into_response(HttpResponse::with_body(StatusCode::OK, body::None::new())))
676 });
677
678 Cors::default().new_transform(srv).await.unwrap();
679 }
680}