actix_web_lab/
bytes.rs

1//! Bytes extractor with const-generic payload size limit.
2//!
3//! See docs for [`Bytes`].
4
5use std::{
6    future::Future,
7    pin::Pin,
8    task::{ready, Context, Poll},
9};
10
11use actix_web::{
12    dev, http::StatusCode, web, Error, FromRequest, HttpMessage, HttpRequest, ResponseError,
13};
14use derive_more::{Display, Error};
15use futures_core::Stream as _;
16use tracing::debug;
17
18/// Default bytes payload size limit of 4MiB.
19pub const DEFAULT_BYTES_LIMIT: usize = 4_194_304;
20
21/// Bytes extractor with const-generic payload size limit.
22///
23/// # Extractor
24/// Extracts raw bytes from a request body, even if it.
25///
26/// Use the `LIMIT` const generic parameter to control the payload size limit. The default limit
27/// that is exported (`DEFAULT_LIMIT`) is 4MiB.
28///
29/// # Differences from `actix_web::web::Bytes`
30/// - Does not read `PayloadConfig` from app data.
31/// - Supports const-generic size limits.
32/// - Will not automatically decompress request bodies.
33///
34/// # Examples
35/// ```
36/// use actix_web::{post, App};
37/// use actix_web_lab::extract::{Bytes, DEFAULT_BYTES_LIMIT};
38///
39/// /// Deserialize `Info` from request's body.
40/// #[post("/")]
41/// async fn index(info: Bytes) -> String {
42///     format!("Payload up to 4MiB: {info:?}!")
43/// }
44///
45/// const LIMIT_32_MB: usize = 33_554_432;
46///
47/// /// Deserialize payload with a higher 32MiB limit.
48/// #[post("/big-payload")]
49/// async fn big_payload(info: Bytes<LIMIT_32_MB>) -> String {
50///     format!("Payload up to 32MiB: {info:?}!")
51/// }
52/// ```
53#[derive(Debug)]
54// #[derive(Debug, Deref, DerefMut, AsRef, AsMut)]
55pub struct Bytes<const LIMIT: usize = DEFAULT_BYTES_LIMIT>(pub web::Bytes);
56
57mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
58    use super::*;
59
60    impl<const LIMIT: usize> std::ops::Deref for Bytes<LIMIT> {
61        type Target = web::Bytes;
62
63        fn deref(&self) -> &Self::Target {
64            &self.0
65        }
66    }
67
68    impl<const LIMIT: usize> std::ops::DerefMut for Bytes<LIMIT> {
69        fn deref_mut(&mut self) -> &mut Self::Target {
70            &mut self.0
71        }
72    }
73
74    impl<const LIMIT: usize> AsRef<web::Bytes> for Bytes<LIMIT> {
75        fn as_ref(&self) -> &web::Bytes {
76            &self.0
77        }
78    }
79
80    impl<const LIMIT: usize> AsMut<web::Bytes> for Bytes<LIMIT> {
81        fn as_mut(&mut self) -> &mut web::Bytes {
82            &mut self.0
83        }
84    }
85}
86
87impl<const LIMIT: usize> Bytes<LIMIT> {
88    /// Unwraps into inner `Bytes`.
89    pub fn into_inner(self) -> web::Bytes {
90        self.0
91    }
92}
93
94/// See [here](#extractor) for example of usage as an extractor.
95impl<const LIMIT: usize> FromRequest for Bytes<LIMIT> {
96    type Error = Error;
97    type Future = BytesExtractFut<LIMIT>;
98
99    #[inline]
100    fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
101        BytesExtractFut {
102            req: Some(req.clone()),
103            fut: BytesBody::new(req, payload),
104        }
105    }
106}
107
108pub struct BytesExtractFut<const LIMIT: usize> {
109    req: Option<HttpRequest>,
110    fut: BytesBody<LIMIT>,
111}
112
113impl<const LIMIT: usize> Future for BytesExtractFut<LIMIT> {
114    type Output = Result<Bytes<LIMIT>, Error>;
115
116    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
117        let this = self.get_mut();
118
119        let res = ready!(Pin::new(&mut this.fut).poll(cx));
120
121        let res = match res {
122            Err(err) => {
123                let req = this.req.take().unwrap();
124
125                debug!(
126                    "Failed to extract Bytes from payload in handler: {}",
127                    req.match_name().unwrap_or_else(|| req.path())
128                );
129
130                Err(err.into())
131            }
132            Ok(data) => Ok(Bytes(data)),
133        };
134
135        Poll::Ready(res)
136    }
137}
138
139/// Future that resolves to `Bytes` when the payload is been completely read.
140///
141/// Returns error if:
142/// - `Content-Length` is greater than `LIMIT`.
143pub enum BytesBody<const LIMIT: usize> {
144    Error(Option<BytesPayloadError>),
145    Body {
146        /// Length as reported by `Content-Length` header, if present.
147        length: Option<usize>,
148        payload: dev::Payload,
149        buf: web::BytesMut,
150    },
151}
152
153impl<const LIMIT: usize> Unpin for BytesBody<LIMIT> {}
154
155impl<const LIMIT: usize> BytesBody<LIMIT> {
156    /// Create a new future to decode a JSON request payload.
157    pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> Self {
158        let payload = payload.take();
159
160        let length = req
161            .get_header::<crate::header::ContentLength>()
162            .map(|cl| cl.into_inner());
163
164        // Notice the content-length is not checked against limit here as the internal usage always
165        // call BytesBody::limit after BytesBody::new and limit check to return an error variant of
166        // BytesBody happens there.
167
168        if let Some(len) = length {
169            if len > LIMIT {
170                return BytesBody::Error(Some(BytesPayloadError::OverflowKnownLength {
171                    length: len,
172                    limit: LIMIT,
173                }));
174            }
175        }
176
177        BytesBody::Body {
178            length,
179            payload,
180            buf: web::BytesMut::with_capacity(8192),
181        }
182    }
183}
184
185impl<const LIMIT: usize> Future for BytesBody<LIMIT> {
186    type Output = Result<web::Bytes, BytesPayloadError>;
187
188    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189        let this = self.get_mut();
190
191        match this {
192            BytesBody::Body { buf, payload, .. } => loop {
193                let res = ready!(Pin::new(&mut *payload).poll_next(cx));
194
195                match res {
196                    Some(chunk) => {
197                        let chunk = chunk?;
198                        let buf_len = buf.len() + chunk.len();
199                        if buf_len > LIMIT {
200                            return Poll::Ready(Err(BytesPayloadError::Overflow { limit: LIMIT }));
201                        } else {
202                            buf.extend_from_slice(&chunk);
203                        }
204                    }
205
206                    None => return Poll::Ready(Ok(buf.split().freeze())),
207                }
208            },
209
210            BytesBody::Error(err) => Poll::Ready(Err(err.take().unwrap())),
211        }
212    }
213}
214
215/// A set of errors that can occur during parsing json payloads
216#[derive(Debug, Display, Error)]
217#[non_exhaustive]
218pub enum BytesPayloadError {
219    /// Payload size is bigger than allowed & content length header set. (default: 4MiB)
220    #[display(fmt = "Payload ({length} bytes) is larger than allowed (limit: {limit} bytes).")]
221    OverflowKnownLength { length: usize, limit: usize },
222
223    /// Payload size is bigger than allowed but no content length header set. (default: 4MiB)
224    #[display(fmt = "Payload has exceeded limit ({limit} bytes).")]
225    Overflow { limit: usize },
226
227    /// Payload error.
228    #[display(fmt = "Error that occur during reading payload: {_0}")]
229    Payload(actix_web::error::PayloadError),
230}
231
232impl From<actix_web::error::PayloadError> for BytesPayloadError {
233    fn from(err: actix_web::error::PayloadError) -> Self {
234        Self::Payload(err)
235    }
236}
237
238impl ResponseError for BytesPayloadError {
239    fn status_code(&self) -> StatusCode {
240        match self {
241            Self::OverflowKnownLength { .. } => StatusCode::PAYLOAD_TOO_LARGE,
242            Self::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE,
243            Self::Payload(err) => err.status_code(),
244        }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use actix_web::{http::header, test::TestRequest, web};
251
252    use super::*;
253
254    #[cfg(test)]
255    impl PartialEq for BytesPayloadError {
256        fn eq(&self, other: &Self) -> bool {
257            match (self, other) {
258                (
259                    Self::OverflowKnownLength {
260                        length: l_length,
261                        limit: l_limit,
262                    },
263                    Self::OverflowKnownLength {
264                        length: r_length,
265                        limit: r_limit,
266                    },
267                ) => l_length == r_length && l_limit == r_limit,
268
269                (Self::Overflow { limit: l_limit }, Self::Overflow { limit: r_limit }) => {
270                    l_limit == r_limit
271                }
272
273                _ => false,
274            }
275        }
276    }
277
278    #[actix_web::test]
279    async fn extract() {
280        let (req, mut pl) = TestRequest::default()
281            .insert_header(header::ContentType::json())
282            .insert_header(crate::header::ContentLength::from(3))
283            .set_payload(web::Bytes::from_static(b"foo"))
284            .to_http_parts();
285
286        let s = Bytes::<DEFAULT_BYTES_LIMIT>::from_request(&req, &mut pl)
287            .await
288            .unwrap();
289        assert_eq!(s.as_ref(), "foo");
290
291        let (req, mut pl) = TestRequest::default()
292            .insert_header(header::ContentType::json())
293            .insert_header(crate::header::ContentLength::from(16))
294            .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
295            .to_http_parts();
296
297        let s = Bytes::<10>::from_request(&req, &mut pl).await;
298        let err_str = s.unwrap_err().to_string();
299        assert_eq!(
300            err_str,
301            "Payload (16 bytes) is larger than allowed (limit: 10 bytes).",
302        );
303
304        let (req, mut pl) = TestRequest::default()
305            .insert_header(header::ContentType::json())
306            .insert_header(crate::header::ContentLength::from(16))
307            .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
308            .to_http_parts();
309        let s = Bytes::<10>::from_request(&req, &mut pl).await;
310        let err = format!("{}", s.unwrap_err());
311        assert!(
312            err.contains("larger than allowed"),
313            "unexpected error string: {err:?}",
314        );
315    }
316
317    #[actix_web::test]
318    async fn body() {
319        let (req, mut pl) = TestRequest::default().to_http_parts();
320        let _bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
321            .await
322            .unwrap();
323
324        let (req, mut pl) = TestRequest::default()
325            .insert_header(header::ContentType("application/text".parse().unwrap()))
326            .to_http_parts();
327        // content-type doesn't matter
328        BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
329            .await
330            .unwrap();
331
332        let (req, mut pl) = TestRequest::default()
333            .insert_header(header::ContentType::json())
334            .insert_header(crate::header::ContentLength::from(10000))
335            .to_http_parts();
336
337        let bytes = BytesBody::<100>::new(&req, &mut pl).await;
338        assert_eq!(
339            bytes.unwrap_err(),
340            BytesPayloadError::OverflowKnownLength {
341                length: 10000,
342                limit: 100
343            }
344        );
345
346        let (req, mut pl) = TestRequest::default()
347            .insert_header(header::ContentType::json())
348            .set_payload(web::Bytes::from_static(&[0u8; 1000]))
349            .to_http_parts();
350
351        let bytes = BytesBody::<100>::new(&req, &mut pl).await;
352
353        assert_eq!(
354            bytes.unwrap_err(),
355            BytesPayloadError::Overflow { limit: 100 }
356        );
357
358        let (req, mut pl) = TestRequest::default()
359            .insert_header(header::ContentType::json())
360            .insert_header(crate::header::ContentLength::from(16))
361            .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
362            .to_http_parts();
363
364        let bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl).await;
365        assert_eq!(bytes.ok().unwrap(), "foo foo foo foo");
366    }
367
368    #[actix_web::test]
369    async fn test_with_config_in_data_wrapper() {
370        let (req, mut pl) = TestRequest::default()
371            .app_data(web::Data::new(web::PayloadConfig::default().limit(8)))
372            .insert_header(header::ContentType::json())
373            .insert_header((header::CONTENT_LENGTH, 16))
374            .set_payload(web::Bytes::from_static(b"{\"name\": \"test\"}"))
375            .to_http_parts();
376
377        let s = Bytes::<10>::from_request(&req, &mut pl).await;
378        assert!(s.is_err());
379
380        let err_str = s.unwrap_err().to_string();
381        assert_eq!(
382            err_str,
383            "Payload (16 bytes) is larger than allowed (limit: 10 bytes).",
384        );
385    }
386}