actix_web_lab/
json.rs

1//! JSON extractor with const-generic payload size limit.
2
3use std::{
4    future::Future,
5    marker::PhantomData,
6    pin::Pin,
7    task::{ready, Context, Poll},
8};
9
10// #[cfg(feature = "__compress")]
11// use crate::dev::Decompress;
12use actix_web::{
13    dev::Payload, error::JsonPayloadError, http::header, web, Error, FromRequest, HttpMessage,
14    HttpRequest,
15};
16use futures_core::Stream as _;
17use serde::de::DeserializeOwned;
18use tracing::debug;
19
20/// Default JSON payload size limit of 2MiB.
21pub const DEFAULT_JSON_LIMIT: usize = 2_097_152;
22
23/// JSON extractor with const-generic payload size limit.
24///
25/// `Json` is used to extract typed data from JSON request payloads.
26///
27/// # Extractor
28/// To extract typed data from a request body, the inner type `T` must implement the
29/// [`serde::Deserialize`] trait.
30///
31/// Use the `LIMIT` const generic parameter to control the payload size limit. The default limit
32/// that is exported (`DEFAULT_LIMIT`) is 2MiB.
33///
34/// ```
35/// use actix_web::{post, App};
36/// use actix_web_lab::extract::{Json, DEFAULT_JSON_LIMIT};
37/// use serde::Deserialize;
38///
39/// #[derive(Deserialize)]
40/// struct Info {
41///     username: String,
42/// }
43///
44/// /// Deserialize `Info` from request's body.
45/// #[post("/")]
46/// async fn index(info: Json<Info>) -> String {
47///     format!("Welcome {}!", info.username)
48/// }
49///
50/// const LIMIT_32_MB: usize = 33_554_432;
51///
52/// /// Deserialize payload with a higher 32MiB limit.
53/// #[post("/big-payload")]
54/// async fn big_payload(info: Json<Info, LIMIT_32_MB>) -> String {
55///     format!("Welcome {}!", info.username)
56/// }
57/// ```
58#[derive(Debug)]
59// #[derive(Debug, Deref, DerefMut, Display)]
60pub struct Json<T, const LIMIT: usize = DEFAULT_JSON_LIMIT>(pub T);
61
62mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
63    use super::*;
64
65    impl<T, const LIMIT: usize> std::ops::Deref for Json<T, LIMIT> {
66        type Target = T;
67
68        fn deref(&self) -> &Self::Target {
69            &self.0
70        }
71    }
72
73    impl<T, const LIMIT: usize> std::ops::DerefMut for Json<T, LIMIT> {
74        fn deref_mut(&mut self) -> &mut Self::Target {
75            &mut self.0
76        }
77    }
78
79    impl<T: std::fmt::Display, const LIMIT: usize> std::fmt::Display for Json<T, LIMIT> {
80        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81            std::fmt::Display::fmt(&self.0, f)
82        }
83    }
84}
85
86impl<T, const LIMIT: usize> Json<T, LIMIT> {
87    /// Unwraps into inner `T` value.
88    pub fn into_inner(self) -> T {
89        self.0
90    }
91}
92
93/// See [here](#extractor) for example of usage as an extractor.
94impl<T: DeserializeOwned, const LIMIT: usize> FromRequest for Json<T, LIMIT> {
95    type Error = Error;
96    type Future = JsonExtractFut<T, LIMIT>;
97
98    #[inline]
99    fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
100        JsonExtractFut {
101            req: Some(req.clone()),
102            fut: JsonBody::new(req, payload),
103        }
104    }
105}
106
107pub struct JsonExtractFut<T, const LIMIT: usize> {
108    req: Option<HttpRequest>,
109    fut: JsonBody<T, LIMIT>,
110}
111
112impl<T: DeserializeOwned, const LIMIT: usize> Future for JsonExtractFut<T, LIMIT> {
113    type Output = Result<Json<T, LIMIT>, Error>;
114
115    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
116        let this = self.get_mut();
117
118        let res = ready!(Pin::new(&mut this.fut).poll(cx));
119
120        let res = match res {
121            Err(err) => {
122                let req = this.req.take().unwrap();
123                debug!(
124                    "Failed to deserialize Json<{}> from payload in handler: {}",
125                    core::any::type_name::<T>(),
126                    req.match_name().unwrap_or_else(|| req.path())
127                );
128
129                Err(err.into())
130            }
131            Ok(data) => Ok(Json(data)),
132        };
133
134        Poll::Ready(res)
135    }
136}
137
138/// Future that resolves to some `T` when parsed from a JSON payload.
139///
140/// Can deserialize any type `T` that implements [`Deserialize`][serde::Deserialize].
141///
142/// Returns error if:
143/// - `Content-Type` is not `application/json`.
144/// - `Content-Length` is greater than `LIMIT`.
145/// - The payload, when consumed, is not valid JSON.
146pub enum JsonBody<T, const LIMIT: usize> {
147    Error(Option<JsonPayloadError>),
148    Body {
149        /// Length as reported by `Content-Length` header, if present.
150        length: Option<usize>,
151        // #[cfg(feature = "__compress")]
152        // payload: Decompress<Payload>,
153        // #[cfg(not(feature = "__compress"))]
154        payload: Payload,
155        buf: web::BytesMut,
156        _res: PhantomData<T>,
157    },
158}
159
160impl<T, const LIMIT: usize> Unpin for JsonBody<T, LIMIT> {}
161
162impl<T: DeserializeOwned, const LIMIT: usize> JsonBody<T, LIMIT> {
163    /// Create a new future to decode a JSON request payload.
164    // #[allow(clippy::borrow_interior_mutable_const)]
165    pub fn new(req: &HttpRequest, payload: &mut Payload) -> Self {
166        // check content-type
167        let can_parse_json = if let Ok(Some(mime)) = req.mime_type() {
168            mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
169        } else {
170            false
171        };
172
173        if !can_parse_json {
174            return JsonBody::Error(Some(JsonPayloadError::ContentType));
175        }
176
177        let length = req
178            .headers()
179            .get(&header::CONTENT_LENGTH)
180            .and_then(|l| l.to_str().ok())
181            .and_then(|s| s.parse::<usize>().ok());
182
183        // Notice the content-length is not checked against limit of json config here.
184        // As the internal usage always call JsonBody::limit after JsonBody::new.
185        // And limit check to return an error variant of JsonBody happens there.
186
187        let payload = {
188            // cfg_if::cfg_if! {
189            //     if #[cfg(feature = "__compress")] {
190            //         Decompress::from_headers(payload.take(), req.headers())
191            //     } else {
192            payload.take()
193            //     }
194            // }
195        };
196
197        if let Some(len) = length {
198            if len > LIMIT {
199                return JsonBody::Error(Some(JsonPayloadError::OverflowKnownLength {
200                    length: len,
201                    limit: LIMIT,
202                }));
203            }
204        }
205
206        JsonBody::Body {
207            length,
208            payload,
209            buf: web::BytesMut::with_capacity(8192),
210            _res: PhantomData,
211        }
212    }
213}
214
215impl<T: DeserializeOwned, const LIMIT: usize> Future for JsonBody<T, LIMIT> {
216    type Output = Result<T, JsonPayloadError>;
217
218    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219        let this = self.get_mut();
220
221        match this {
222            JsonBody::Body { buf, payload, .. } => loop {
223                let res = ready!(Pin::new(&mut *payload).poll_next(cx));
224
225                match res {
226                    Some(chunk) => {
227                        let chunk = chunk?;
228                        let buf_len = buf.len() + chunk.len();
229                        if buf_len > LIMIT {
230                            return Poll::Ready(Err(JsonPayloadError::Overflow { limit: LIMIT }));
231                        } else {
232                            buf.extend_from_slice(&chunk);
233                        }
234                    }
235
236                    None => {
237                        let json = serde_json::from_slice::<T>(buf)
238                            .map_err(JsonPayloadError::Deserialize)?;
239                        return Poll::Ready(Ok(json));
240                    }
241                }
242            },
243
244            JsonBody::Error(e) => Poll::Ready(Err(e.take().unwrap())),
245        }
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use actix_web::{http::header, test::TestRequest, web::Bytes};
252    use serde::{Deserialize, Serialize};
253
254    use super::*;
255
256    #[derive(Serialize, Deserialize, PartialEq, Debug)]
257    struct MyObject {
258        name: String,
259    }
260
261    fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool {
262        match err {
263            JsonPayloadError::Overflow { .. } => {
264                matches!(other, JsonPayloadError::Overflow { .. })
265            }
266            JsonPayloadError::OverflowKnownLength { .. } => {
267                matches!(other, JsonPayloadError::OverflowKnownLength { .. })
268            }
269            JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
270            _ => false,
271        }
272    }
273
274    #[actix_web::test]
275    async fn test_extract() {
276        let (req, mut pl) = TestRequest::default()
277            .insert_header(header::ContentType::json())
278            .insert_header((
279                header::CONTENT_LENGTH,
280                header::HeaderValue::from_static("16"),
281            ))
282            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
283            .to_http_parts();
284
285        let s = Json::<MyObject, DEFAULT_JSON_LIMIT>::from_request(&req, &mut pl)
286            .await
287            .unwrap();
288        assert_eq!(s.name, "test");
289        assert_eq!(
290            s.into_inner(),
291            MyObject {
292                name: "test".to_string()
293            }
294        );
295
296        let (req, mut pl) = TestRequest::default()
297            .insert_header(header::ContentType::json())
298            .insert_header((
299                header::CONTENT_LENGTH,
300                header::HeaderValue::from_static("16"),
301            ))
302            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
303            .to_http_parts();
304
305        let s = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
306        let err = format!("{}", s.unwrap_err());
307        assert!(
308            err.contains("JSON payload (16 bytes) is larger than allowed (limit: 10 bytes)."),
309            "unexpected error string: {err:?}"
310        );
311
312        let (req, mut pl) = TestRequest::default()
313            .insert_header(header::ContentType::json())
314            .insert_header((
315                header::CONTENT_LENGTH,
316                header::HeaderValue::from_static("16"),
317            ))
318            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
319            .to_http_parts();
320        let s = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
321        let err = format!("{}", s.unwrap_err());
322        assert!(
323            err.contains("larger than allowed"),
324            "unexpected error string: {err:?}"
325        );
326    }
327
328    #[actix_web::test]
329    async fn test_json_body() {
330        let (req, mut pl) = TestRequest::default().to_http_parts();
331        let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
332        assert!(json_eq(json.unwrap_err(), JsonPayloadError::ContentType));
333
334        let (req, mut pl) = TestRequest::default()
335            .insert_header((
336                header::CONTENT_TYPE,
337                header::HeaderValue::from_static("application/text"),
338            ))
339            .to_http_parts();
340        let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
341        assert!(json_eq(json.unwrap_err(), JsonPayloadError::ContentType));
342
343        let (req, mut pl) = TestRequest::default()
344            .insert_header(header::ContentType::json())
345            .insert_header((
346                header::CONTENT_LENGTH,
347                header::HeaderValue::from_static("10000"),
348            ))
349            .to_http_parts();
350
351        let json = JsonBody::<MyObject, 100>::new(&req, &mut pl).await;
352        assert!(json_eq(
353            json.unwrap_err(),
354            JsonPayloadError::OverflowKnownLength {
355                length: 10000,
356                limit: 100
357            }
358        ));
359
360        let (req, mut pl) = TestRequest::default()
361            .insert_header(header::ContentType::json())
362            .set_payload(Bytes::from_static(&[0u8; 1000]))
363            .to_http_parts();
364
365        let json = JsonBody::<MyObject, 100>::new(&req, &mut pl).await;
366
367        assert!(json_eq(
368            json.unwrap_err(),
369            JsonPayloadError::Overflow { limit: 100 }
370        ));
371
372        let (req, mut pl) = TestRequest::default()
373            .insert_header(header::ContentType::json())
374            .insert_header((
375                header::CONTENT_LENGTH,
376                header::HeaderValue::from_static("16"),
377            ))
378            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
379            .to_http_parts();
380
381        let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
382        assert_eq!(
383            json.ok().unwrap(),
384            MyObject {
385                name: "test".to_owned()
386            }
387        );
388    }
389
390    #[actix_web::test]
391    async fn test_with_json_and_bad_content_type() {
392        let (req, mut pl) = TestRequest::default()
393            .insert_header((
394                header::CONTENT_TYPE,
395                header::HeaderValue::from_static("text/plain"),
396            ))
397            .insert_header((
398                header::CONTENT_LENGTH,
399                header::HeaderValue::from_static("16"),
400            ))
401            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
402            .to_http_parts();
403
404        let s = Json::<MyObject, 4096>::from_request(&req, &mut pl).await;
405        assert!(s.is_err())
406    }
407
408    #[actix_web::test]
409    async fn test_with_config_in_data_wrapper() {
410        let (req, mut pl) = TestRequest::default()
411            .insert_header(header::ContentType::json())
412            .insert_header((header::CONTENT_LENGTH, 16))
413            .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
414            .to_http_parts();
415
416        let s = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
417        assert!(s.is_err());
418
419        let err_str = s.unwrap_err().to_string();
420        assert!(
421            err_str.contains("JSON payload (16 bytes) is larger than allowed (limit: 10 bytes).")
422        );
423    }
424}