actix_web_lab/
body_limit.rs

1//! Body limit extractor.
2//!
3//! See [`BodyLimit`] docs.
4
5use std::{
6    fmt,
7    future::Future,
8    pin::Pin,
9    task::{ready, Context, Poll},
10};
11
12use actix_web::{
13    dev::{self, Payload},
14    FromRequest, HttpMessage as _, HttpRequest, ResponseError,
15};
16use derive_more::Display;
17use futures_core::Stream as _;
18
19use crate::header::ContentLength;
20
21/// Default body size limit of 2MiB.
22pub const DEFAULT_BODY_LIMIT: usize = 2_097_152;
23
24/// Extractor wrapper that limits size of payload used.
25///
26/// # Examples
27/// ```no_run
28/// use actix_web::{get, web::Bytes, Responder};
29/// use actix_web_lab::extract::BodyLimit;
30///
31/// const BODY_LIMIT: usize = 1_048_576; // 1MB
32///
33/// #[get("/")]
34/// async fn handler(body: BodyLimit<Bytes, BODY_LIMIT>) -> impl Responder {
35///     let body = body.into_inner();
36///     assert!(body.len() < BODY_LIMIT);
37///     body
38/// }
39/// ```
40#[derive(Debug, PartialEq, Eq)]
41pub struct BodyLimit<T, const LIMIT: usize = DEFAULT_BODY_LIMIT> {
42    inner: T,
43}
44
45mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
46    use super::*;
47
48    impl<T: std::fmt::Display, const LIMIT: usize> std::fmt::Display for BodyLimit<T, LIMIT> {
49        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50            std::fmt::Display::fmt(&self.inner, f)
51        }
52    }
53
54    impl<T, const LIMIT: usize> AsRef<T> for BodyLimit<T, LIMIT> {
55        fn as_ref(&self) -> &T {
56            &self.inner
57        }
58    }
59
60    impl<T, const LIMIT: usize> From<T> for BodyLimit<T, LIMIT> {
61        fn from(inner: T) -> Self {
62            Self { inner }
63        }
64    }
65}
66
67impl<T, const LIMIT: usize> BodyLimit<T, LIMIT> {
68    /// Returns inner extracted type.
69    pub fn into_inner(self) -> T {
70        self.inner
71    }
72}
73
74impl<T, const LIMIT: usize> FromRequest for BodyLimit<T, LIMIT>
75where
76    T: FromRequest + 'static,
77    T::Error: fmt::Debug + fmt::Display,
78{
79    type Error = BodyLimitError<T>;
80    type Future = BodyLimitFut<T, LIMIT>;
81
82    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
83        // fast check of Content-Length header
84        match req.get_header::<ContentLength>() {
85            // CL header indicated that payload would be too large
86            Some(len) if len > LIMIT => return BodyLimitFut::new_error(BodyLimitError::Overflow),
87            _ => {}
88        }
89
90        let counter = crate::util::fork_request_payload(payload);
91
92        BodyLimitFut {
93            inner: Inner::Body {
94                fut: Box::pin(T::from_request(req, payload)),
95                counter_pl: counter,
96                size: 0,
97            },
98        }
99    }
100}
101
102pub struct BodyLimitFut<T, const LIMIT: usize>
103where
104    T: FromRequest + 'static,
105    T::Error: fmt::Debug + fmt::Display,
106{
107    inner: Inner<T, LIMIT>,
108}
109
110impl<T, const LIMIT: usize> BodyLimitFut<T, LIMIT>
111where
112    T: FromRequest + 'static,
113    T::Error: fmt::Debug + fmt::Display,
114{
115    fn new_error(err: BodyLimitError<T>) -> Self {
116        Self {
117            inner: Inner::Error { err: Some(err) },
118        }
119    }
120}
121
122enum Inner<T, const LIMIT: usize>
123where
124    T: FromRequest + 'static,
125    T::Error: fmt::Debug + fmt::Display,
126{
127    Error {
128        err: Option<BodyLimitError<T>>,
129    },
130
131    Body {
132        /// Wrapped extractor future.
133        fut: Pin<Box<T::Future>>,
134
135        /// Forked request payload.
136        counter_pl: dev::Payload,
137
138        /// Running payload size count.
139        size: usize,
140    },
141}
142
143impl<T, const LIMIT: usize> Unpin for Inner<T, LIMIT>
144where
145    T: FromRequest + 'static,
146    T::Error: fmt::Debug + fmt::Display,
147{
148}
149
150impl<T, const LIMIT: usize> Future for BodyLimitFut<T, LIMIT>
151where
152    T: FromRequest + 'static,
153    T::Error: fmt::Debug + fmt::Display,
154{
155    type Output = Result<BodyLimit<T, LIMIT>, BodyLimitError<T>>;
156
157    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158        let this = &mut self.get_mut().inner;
159
160        match this {
161            Inner::Error { err } => Poll::Ready(Err(err.take().unwrap())),
162
163            Inner::Body {
164                fut,
165                counter_pl,
166                size,
167            } => {
168                // poll inner extractor first which also polls original payload stream
169                let res = ready!(fut.as_mut().poll(cx).map_err(BodyLimitError::Extractor)?);
170
171                // catch up with payload length counter checks
172                while let Poll::Ready(Some(Ok(chunk))) = Pin::new(&mut *counter_pl).poll_next(cx) {
173                    // update running size
174                    *size += chunk.len();
175
176                    if *size > LIMIT {
177                        return Poll::Ready(Err(BodyLimitError::Overflow));
178                    }
179                }
180
181                let ret = BodyLimit { inner: res };
182
183                Poll::Ready(Ok(ret))
184            }
185        }
186    }
187}
188
189#[derive(Display)]
190pub enum BodyLimitError<T>
191where
192    T: FromRequest + 'static,
193    T::Error: fmt::Debug + fmt::Display,
194{
195    #[display(fmt = "Wrapped extractor error: {_0}")]
196    Extractor(T::Error),
197
198    #[display(fmt = "Body was too large")]
199    Overflow,
200}
201
202impl<T> fmt::Debug for BodyLimitError<T>
203where
204    T: FromRequest + 'static,
205    T::Error: fmt::Debug + fmt::Display,
206{
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        match self {
209            Self::Extractor(err) => f
210                .debug_tuple("BodyLimitError::Extractor")
211                .field(err)
212                .finish(),
213
214            Self::Overflow => write!(f, "BodyLimitError::Overflow"),
215        }
216    }
217}
218
219impl<T> ResponseError for BodyLimitError<T>
220where
221    T: FromRequest + 'static,
222    T::Error: fmt::Debug + fmt::Display,
223{
224}
225
226#[cfg(test)]
227mod tests {
228    use actix_web::{http::header, test::TestRequest};
229    use bytes::Bytes;
230
231    use super::*;
232
233    static_assertions::assert_impl_all!(BodyLimitFut<(), 100>: Unpin);
234    static_assertions::assert_impl_all!(BodyLimitFut<Bytes, 100>: Unpin);
235
236    #[actix_web::test]
237    async fn within_limit() {
238        let (req, mut pl) = TestRequest::default()
239            .insert_header(header::ContentType::plaintext())
240            .insert_header((
241                header::CONTENT_LENGTH,
242                header::HeaderValue::from_static("9"),
243            ))
244            .set_payload(Bytes::from_static(b"123456789"))
245            .to_http_parts();
246
247        let body = BodyLimit::<Bytes, 10>::from_request(&req, &mut pl).await;
248        assert_eq!(
249            body.ok().unwrap().into_inner(),
250            Bytes::from_static(b"123456789")
251        );
252    }
253
254    #[actix_web::test]
255    async fn exceeds_limit() {
256        let (req, mut pl) = TestRequest::default()
257            .insert_header(header::ContentType::plaintext())
258            .insert_header((
259                header::CONTENT_LENGTH,
260                header::HeaderValue::from_static("10"),
261            ))
262            .set_payload(Bytes::from_static(b"0123456789"))
263            .to_http_parts();
264
265        let body = BodyLimit::<Bytes, 4>::from_request(&req, &mut pl).await;
266        assert!(matches!(body.unwrap_err(), BodyLimitError::Overflow));
267
268        let (req, mut pl) = TestRequest::default()
269            .insert_header(header::ContentType::plaintext())
270            .insert_header((
271                header::TRANSFER_ENCODING,
272                header::HeaderValue::from_static("chunked"),
273            ))
274            .set_payload(Bytes::from_static(b"10\r\n0123456789\r\n0"))
275            .to_http_parts();
276
277        let body = BodyLimit::<Bytes, 4>::from_request(&req, &mut pl).await;
278        assert!(matches!(body.unwrap_err(), BodyLimitError::Overflow));
279    }
280}