actix_cors/
middleware.rs

1use std::{collections::HashSet, rc::Rc};
2
3use actix_utils::future::ok;
4use actix_web::{
5    body::{EitherBody, MessageBody},
6    dev::{forward_ready, Service, ServiceRequest, ServiceResponse},
7    http::{
8        header::{self, HeaderValue},
9        Method,
10    },
11    Error, HttpResponse, Result,
12};
13use futures_util::future::{FutureExt as _, LocalBoxFuture};
14use log::debug;
15
16use crate::{
17    builder::intersperse_header_values,
18    inner::{add_vary_header, header_value_try_into_method},
19    AllOrSome, CorsError, Inner,
20};
21
22/// Service wrapper for Cross-Origin Resource Sharing support.
23///
24/// This struct contains the settings for CORS requests to be validated and for responses to
25/// be generated.
26#[doc(hidden)]
27#[derive(Debug, Clone)]
28pub struct CorsMiddleware<S> {
29    pub(crate) service: S,
30    pub(crate) inner: Rc<Inner>,
31}
32
33impl<S> CorsMiddleware<S> {
34    /// Returns true if request is `OPTIONS` and contains an `Access-Control-Request-Method` header.
35    fn is_request_preflight(req: &ServiceRequest) -> bool {
36        // check request method is OPTIONS
37        if req.method() != Method::OPTIONS {
38            return false;
39        }
40
41        // check follow-up request method is present and valid
42        if req
43            .headers()
44            .get(header::ACCESS_CONTROL_REQUEST_METHOD)
45            .and_then(header_value_try_into_method)
46            .is_none()
47        {
48            return false;
49        }
50
51        true
52    }
53
54    /// Validates preflight request headers against configuration and constructs preflight response.
55    ///
56    /// Checks:
57    /// - `Origin` header is acceptable;
58    /// - `Access-Control-Request-Method` header is acceptable;
59    /// - `Access-Control-Request-Headers` header is acceptable.
60    fn handle_preflight(&self, req: ServiceRequest) -> ServiceResponse {
61        let inner = Rc::clone(&self.inner);
62
63        match inner.validate_origin(req.head()) {
64            Ok(true) => {}
65            Ok(false) => return req.error_response(CorsError::OriginNotAllowed),
66            Err(err) => return req.error_response(err),
67        };
68
69        if let Err(err) = inner
70            .validate_allowed_method(req.head())
71            .and_then(|_| inner.validate_allowed_headers(req.head()))
72        {
73            return req.error_response(err);
74        }
75
76        let mut res = HttpResponse::Ok();
77
78        if let Some(origin) = inner.access_control_allow_origin(req.head()) {
79            res.insert_header((header::ACCESS_CONTROL_ALLOW_ORIGIN, origin));
80        }
81
82        if let Some(ref allowed_methods) = inner.allowed_methods_baked {
83            res.insert_header((
84                header::ACCESS_CONTROL_ALLOW_METHODS,
85                allowed_methods.clone(),
86            ));
87        }
88
89        if let Some(ref headers) = inner.allowed_headers_baked {
90            res.insert_header((header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()));
91        } else if let Some(headers) = req.headers().get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
92            // all headers allowed, return
93            res.insert_header((header::ACCESS_CONTROL_ALLOW_HEADERS, headers.clone()));
94        }
95
96        #[cfg(feature = "draft-private-network-access")]
97        if inner.allow_private_network_access
98            && req
99                .headers()
100                .contains_key("access-control-request-private-network")
101        {
102            res.insert_header((
103                header::HeaderName::from_static("access-control-allow-private-network"),
104                HeaderValue::from_static("true"),
105            ));
106        }
107
108        if inner.supports_credentials {
109            res.insert_header((
110                header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
111                HeaderValue::from_static("true"),
112            ));
113        }
114
115        if let Some(max_age) = inner.max_age {
116            res.insert_header((header::ACCESS_CONTROL_MAX_AGE, max_age.to_string()));
117        }
118
119        let mut res = res.finish();
120
121        if inner.vary_header {
122            add_vary_header(res.headers_mut());
123        }
124
125        req.into_response(res)
126    }
127
128    fn augment_response<B>(
129        inner: &Inner,
130        origin_allowed: bool,
131        mut res: ServiceResponse<B>,
132    ) -> ServiceResponse<B> {
133        if origin_allowed {
134            if let Some(origin) = inner.access_control_allow_origin(res.request().head()) {
135                res.headers_mut()
136                    .insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin);
137            };
138        }
139
140        if let Some(ref expose) = inner.expose_headers_baked {
141            log::trace!("exposing selected headers: {:?}", expose);
142
143            res.headers_mut()
144                .insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose.clone());
145        } else if matches!(inner.expose_headers, AllOrSome::All) {
146            // intersperse_header_values requires that argument is non-empty
147            if !res.headers().is_empty() {
148                // extract header names from request
149                let expose_all_request_headers = res
150                    .headers()
151                    .keys()
152                    .map(|name| name.as_str())
153                    .collect::<HashSet<_>>();
154
155                // create comma separated string of header names
156                let expose_headers_value = intersperse_header_values(&expose_all_request_headers);
157
158                log::trace!(
159                    "exposing all headers from request: {:?}",
160                    expose_headers_value
161                );
162
163                // add header names to expose response header
164                res.headers_mut()
165                    .insert(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers_value);
166            }
167        }
168
169        if inner.supports_credentials {
170            res.headers_mut().insert(
171                header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
172                HeaderValue::from_static("true"),
173            );
174        }
175
176        #[cfg(feature = "draft-private-network-access")]
177        if inner.allow_private_network_access
178            && res
179                .request()
180                .headers()
181                .contains_key("access-control-request-private-network")
182        {
183            res.headers_mut().insert(
184                header::HeaderName::from_static("access-control-allow-private-network"),
185                HeaderValue::from_static("true"),
186            );
187        }
188
189        if inner.vary_header {
190            add_vary_header(res.headers_mut());
191        }
192
193        res
194    }
195}
196
197impl<S, B> Service<ServiceRequest> for CorsMiddleware<S>
198where
199    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
200    S::Future: 'static,
201
202    B: MessageBody + 'static,
203{
204    type Response = ServiceResponse<EitherBody<B>>;
205    type Error = Error;
206    type Future = LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>;
207
208    forward_ready!(service);
209
210    fn call(&self, req: ServiceRequest) -> Self::Future {
211        let origin = req.headers().get(header::ORIGIN);
212
213        // handle preflight requests
214        if self.inner.preflight && Self::is_request_preflight(&req) {
215            let res = self.handle_preflight(req);
216            return ok(res.map_into_right_body()).boxed_local();
217        }
218
219        // only check actual requests with a origin header
220        let origin_allowed = match (origin, self.inner.validate_origin(req.head())) {
221            (None, _) => false,
222            (_, Ok(origin_allowed)) => origin_allowed,
223            (_, Err(err)) => {
224                debug!("origin validation failed; inner service is not called");
225                let mut res = req.error_response(err);
226
227                if self.inner.vary_header {
228                    add_vary_header(res.headers_mut());
229                }
230
231                return ok(res.map_into_right_body()).boxed_local();
232            }
233        };
234
235        let inner = Rc::clone(&self.inner);
236        let fut = self.service.call(req);
237
238        Box::pin(async move {
239            let res = fut.await;
240            Ok(Self::augment_response(&inner, origin_allowed, res?).map_into_left_body())
241        })
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use actix_web::{
248        dev::Transform,
249        middleware::Compat,
250        test::{self, TestRequest},
251        App,
252    };
253
254    use super::*;
255    use crate::Cors;
256
257    #[test]
258    fn compat_compat() {
259        let _ = App::new().wrap(Compat::new(Cors::default()));
260    }
261
262    #[actix_web::test]
263    async fn test_options_no_origin() {
264        // Tests case where allowed_origins is All but there are validate functions to run incase.
265        // In this case, origins are only allowed when the DNT header is sent.
266
267        let cors = Cors::default()
268            .allow_any_origin()
269            .allowed_origin_fn(|origin, req_head| {
270                assert_eq!(&origin, req_head.headers.get(header::ORIGIN).unwrap());
271                req_head.headers().contains_key(header::DNT)
272            })
273            .new_transform(test::ok_service())
274            .await
275            .unwrap();
276
277        let req = TestRequest::get()
278            .insert_header((header::ORIGIN, "http://example.com"))
279            .to_srv_request();
280        let res = cors.call(req).await.unwrap();
281        assert_eq!(
282            None,
283            res.headers()
284                .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
285                .map(HeaderValue::as_bytes)
286        );
287
288        let req = TestRequest::get()
289            .insert_header((header::ORIGIN, "http://example.com"))
290            .insert_header((header::DNT, "1"))
291            .to_srv_request();
292        let res = cors.call(req).await.unwrap();
293        assert_eq!(
294            Some(&b"http://example.com"[..]),
295            res.headers()
296                .get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
297                .map(HeaderValue::as_bytes)
298        );
299    }
300}