1use std::{collections::HashSet, rc::Rc};
2
3use actix_utils::future::ok;
4use actix_web::{
5 body::{EitherBody, MessageBody},
6 dev::{forward_ready, Service, ServiceRequest, ServiceResponse},
7 http::{
8 header::{self, HeaderValue},
9 Method,
10 },
11 Error, HttpResponse, Result,
12};
13use futures_util::future::{FutureExt as _, LocalBoxFuture};
14use log::debug;
15
16use crate::{
17 builder::intersperse_header_values,
18 inner::{add_vary_header, header_value_try_into_method},
19 AllOrSome, CorsError, Inner,
20};
21
22#[doc(hidden)]
27#[derive(Debug, Clone)]
28pub struct CorsMiddleware<S> {
29 pub(crate) service: S,
30 pub(crate) inner: Rc<Inner>,
31}
32
33impl<S> CorsMiddleware<S> {
34 fn is_request_preflight(req: &ServiceRequest) -> bool {
36 if req.method() != Method::OPTIONS {
38 return false;
39 }
40
41 if req
43 .headers()
44 .get(header::ACCESS_CONTROL_REQUEST_METHOD)
45 .and_then(header_value_try_into_method)
46 .is_none()
47 {
48 return false;
49 }
50
51 true
52 }
53
54 fn handle_preflight(&self, req: ServiceRequest) -> ServiceResponse {
61 let inner = Rc::clone(&self.inner);
62
63 match inner.validate_origin(req.head()) {
64 Ok(true) => {}
65 Ok(false) => return req.error_response(CorsError::OriginNotAllowed),
66 Err(err) => return req.error_response(err),
67 };
68
69 if let Err(err) = inner
70 .validate_allowed_method(req.head())
71 .and_then(|_| inner.validate_allowed_headers(req.head()))
72 {
73 return req.error_response(err);
74 }
75
76 let mut res = HttpResponse::Ok();
77
78 if let Some(origin) = inner.access_control_allow_origin(req.head()) {
79 res.insert_header((header::ACCESS_CONTROL_ALLOW_ORIGIN, origin));
80 }
81
82 if let Some(ref allowed_methods) = inner.allowed_methods_baked {
83 res.insert_header((
84 header::ACCESS_CONTROL_ALLOW_METHODS,
85 allowed_methods.clone(),
86 ));
87 }
88
89 if let Some(ref headers) = inner.allowed_headers_baked {
90 res.insert_header((header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()));
91 } else if let Some(headers) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
92 res.insert_header((header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()));
94 }
95
96 #[cfg(feature = "draft-private-network-access")]
97 if inner.allow_private_network_access
98 && req
99 .headers()
100 .contains_key("access-control-request-private-network")
101 {
102 res.insert_header((
103 header::HeaderName::from_static("access-control-allow-private-network"),
104 HeaderValue::from_static("true"),
105 ));
106 }
107
108 if inner.supports_credentials {
109 res.insert_header((
110 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
111 HeaderValue::from_static("true"),
112 ));
113 }
114
115 if let Some(max_age) = inner.max_age {
116 res.insert_header((header::ACCESS_CONTROL_MAX_AGE, max_age.to_string()));
117 }
118
119 let mut res = res.finish();
120
121 if inner.vary_header {
122 add_vary_header(res.headers_mut());
123 }
124
125 req.into_response(res)
126 }
127
128 fn augment_response<B>(
129 inner: &Inner,
130 origin_allowed: bool,
131 mut res: ServiceResponse<B>,
132 ) -> ServiceResponse<B> {
133 if origin_allowed {
134 if let Some(origin) = inner.access_control_allow_origin(res.request().head()) {
135 res.headers_mut()
136 .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
137 };
138 }
139
140 if let Some(ref expose) = inner.expose_headers_baked {
141 log::trace!("exposing selected headers: {:?}", expose);
142
143 res.headers_mut()
144 .insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
145 } else if matches!(inner.expose_headers, AllOrSome::All) {
146 if !res.headers().is_empty() {
148 let expose_all_request_headers = res
150 .headers()
151 .keys()
152 .map(|name| name.as_str())
153 .collect::<HashSet<_>>();
154
155 let expose_headers_value = intersperse_header_values(&expose_all_request_headers);
157
158 log::trace!(
159 "exposing all headers from request: {:?}",
160 expose_headers_value
161 );
162
163 res.headers_mut()
165 .insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers_value);
166 }
167 }
168
169 if inner.supports_credentials {
170 res.headers_mut().insert(
171 header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
172 HeaderValue::from_static("true"),
173 );
174 }
175
176 #[cfg(feature = "draft-private-network-access")]
177 if inner.allow_private_network_access
178 && res
179 .request()
180 .headers()
181 .contains_key("access-control-request-private-network")
182 {
183 res.headers_mut().insert(
184 header::HeaderName::from_static("access-control-allow-private-network"),
185 HeaderValue::from_static("true"),
186 );
187 }
188
189 if inner.vary_header {
190 add_vary_header(res.headers_mut());
191 }
192
193 res
194 }
195}
196
197impl<S, B> Service<ServiceRequest> for CorsMiddleware<S>
198where
199 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
200 S::Future: 'static,
201
202 B: MessageBody + 'static,
203{
204 type Response = ServiceResponse<EitherBody<B>>;
205 type Error = Error;
206 type Future = LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>;
207
208 forward_ready!(service);
209
210 fn call(&self, req: ServiceRequest) -> Self::Future {
211 let origin = req.headers().get(header::ORIGIN);
212
213 if self.inner.preflight && Self::is_request_preflight(&req) {
215 let res = self.handle_preflight(req);
216 return ok(res.map_into_right_body()).boxed_local();
217 }
218
219 let origin_allowed = match (origin, self.inner.validate_origin(req.head())) {
221 (None, _) => false,
222 (_, Ok(origin_allowed)) => origin_allowed,
223 (_, Err(err)) => {
224 debug!("origin validation failed; inner service is not called");
225 let mut res = req.error_response(err);
226
227 if self.inner.vary_header {
228 add_vary_header(res.headers_mut());
229 }
230
231 return ok(res.map_into_right_body()).boxed_local();
232 }
233 };
234
235 let inner = Rc::clone(&self.inner);
236 let fut = self.service.call(req);
237
238 Box::pin(async move {
239 let res = fut.await;
240 Ok(Self::augment_response(&inner, origin_allowed, res?).map_into_left_body())
241 })
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use actix_web::{
248 dev::Transform,
249 middleware::Compat,
250 test::{self, TestRequest},
251 App,
252 };
253
254 use super::*;
255 use crate::Cors;
256
257 #[test]
258 fn compat_compat() {
259 let _ = App::new().wrap(Compat::new(Cors::default()));
260 }
261
262 #[actix_web::test]
263 async fn test_options_no_origin() {
264 let cors = Cors::default()
268 .allow_any_origin()
269 .allowed_origin_fn(|origin, req_head| {
270 assert_eq!(&origin, req_head.headers.get(header::ORIGIN).unwrap());
271 req_head.headers().contains_key(header::DNT)
272 })
273 .new_transform(test::ok_service())
274 .await
275 .unwrap();
276
277 let req = TestRequest::get()
278 .insert_header((header::ORIGIN, "http://example.com"))
279 .to_srv_request();
280 let res = cors.call(req).await.unwrap();
281 assert_eq!(
282 None,
283 res.headers()
284 .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
285 .map(HeaderValue::as_bytes)
286 );
287
288 let req = TestRequest::get()
289 .insert_header((header::ORIGIN, "http://example.com"))
290 .insert_header((header::DNT, "1"))
291 .to_srv_request();
292 let res = cors.call(req).await.unwrap();
293 assert_eq!(
294 Some(&b"http://example.com"[..]),
295 res.headers()
296 .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
297 .map(HeaderValue::as_bytes)
298 );
299 }
300}