1use std::{
4 future::Future,
5 marker::PhantomData,
6 pin::Pin,
7 task::{ready, Context, Poll},
8};
9
10use actix_web::{
13 dev::Payload, error::JsonPayloadError, http::header, web, Error, FromRequest, HttpMessage,
14 HttpRequest,
15};
16use futures_core::Stream as _;
17use serde::de::DeserializeOwned;
18use tracing::debug;
19
20pub const DEFAULT_JSON_LIMIT: usize = 2_097_152;
22
23#[derive(Debug)]
59pub struct Json<T, const LIMIT: usize = DEFAULT_JSON_LIMIT>(pub T);
61
62mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
63 use super::*;
64
65 impl<T, const LIMIT: usize> std::ops::Deref for Json<T, LIMIT> {
66 type Target = T;
67
68 fn deref(&self) -> &Self::Target {
69 &self.0
70 }
71 }
72
73 impl<T, const LIMIT: usize> std::ops::DerefMut for Json<T, LIMIT> {
74 fn deref_mut(&mut self) -> &mut Self::Target {
75 &mut self.0
76 }
77 }
78
79 impl<T: std::fmt::Display, const LIMIT: usize> std::fmt::Display for Json<T, LIMIT> {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 std::fmt::Display::fmt(&self.0, f)
82 }
83 }
84}
85
86impl<T, const LIMIT: usize> Json<T, LIMIT> {
87 pub fn into_inner(self) -> T {
89 self.0
90 }
91}
92
93impl<T: DeserializeOwned, const LIMIT: usize> FromRequest for Json<T, LIMIT> {
95 type Error = Error;
96 type Future = JsonExtractFut<T, LIMIT>;
97
98 #[inline]
99 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
100 JsonExtractFut {
101 req: Some(req.clone()),
102 fut: JsonBody::new(req, payload),
103 }
104 }
105}
106
107pub struct JsonExtractFut<T, const LIMIT: usize> {
108 req: Option<HttpRequest>,
109 fut: JsonBody<T, LIMIT>,
110}
111
112impl<T: DeserializeOwned, const LIMIT: usize> Future for JsonExtractFut<T, LIMIT> {
113 type Output = Result<Json<T, LIMIT>, Error>;
114
115 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
116 let this = self.get_mut();
117
118 let res = ready!(Pin::new(&mut this.fut).poll(cx));
119
120 let res = match res {
121 Err(err) => {
122 let req = this.req.take().unwrap();
123 debug!(
124 "Failed to deserialize Json<{}> from payload in handler: {}",
125 core::any::type_name::<T>(),
126 req.match_name().unwrap_or_else(|| req.path())
127 );
128
129 Err(err.into())
130 }
131 Ok(data) => Ok(Json(data)),
132 };
133
134 Poll::Ready(res)
135 }
136}
137
138pub enum JsonBody<T, const LIMIT: usize> {
147 Error(Option<JsonPayloadError>),
148 Body {
149 length: Option<usize>,
151 payload: Payload,
155 buf: web::BytesMut,
156 _res: PhantomData<T>,
157 },
158}
159
160impl<T, const LIMIT: usize> Unpin for JsonBody<T, LIMIT> {}
161
162impl<T: DeserializeOwned, const LIMIT: usize> JsonBody<T, LIMIT> {
163 pub fn new(req: &HttpRequest, payload: &mut Payload) -> Self {
166 let can_parse_json = if let Ok(Some(mime)) = req.mime_type() {
168 mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
169 } else {
170 false
171 };
172
173 if !can_parse_json {
174 return JsonBody::Error(Some(JsonPayloadError::ContentType));
175 }
176
177 let length = req
178 .headers()
179 .get(&header::CONTENT_LENGTH)
180 .and_then(|l| l.to_str().ok())
181 .and_then(|s| s.parse::<usize>().ok());
182
183 let payload = {
188 payload.take()
193 };
196
197 if let Some(len) = length {
198 if len > LIMIT {
199 return JsonBody::Error(Some(JsonPayloadError::OverflowKnownLength {
200 length: len,
201 limit: LIMIT,
202 }));
203 }
204 }
205
206 JsonBody::Body {
207 length,
208 payload,
209 buf: web::BytesMut::with_capacity(8192),
210 _res: PhantomData,
211 }
212 }
213}
214
215impl<T: DeserializeOwned, const LIMIT: usize> Future for JsonBody<T, LIMIT> {
216 type Output = Result<T, JsonPayloadError>;
217
218 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
219 let this = self.get_mut();
220
221 match this {
222 JsonBody::Body { buf, payload, .. } => loop {
223 let res = ready!(Pin::new(&mut *payload).poll_next(cx));
224
225 match res {
226 Some(chunk) => {
227 let chunk = chunk?;
228 let buf_len = buf.len() + chunk.len();
229 if buf_len > LIMIT {
230 return Poll::Ready(Err(JsonPayloadError::Overflow { limit: LIMIT }));
231 } else {
232 buf.extend_from_slice(&chunk);
233 }
234 }
235
236 None => {
237 let json = serde_json::from_slice::<T>(buf)
238 .map_err(JsonPayloadError::Deserialize)?;
239 return Poll::Ready(Ok(json));
240 }
241 }
242 },
243
244 JsonBody::Error(e) => Poll::Ready(Err(e.take().unwrap())),
245 }
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use actix_web::{http::header, test::TestRequest, web::Bytes};
252 use serde::{Deserialize, Serialize};
253
254 use super::*;
255
256 #[derive(Serialize, Deserialize, PartialEq, Debug)]
257 struct MyObject {
258 name: String,
259 }
260
261 fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool {
262 match err {
263 JsonPayloadError::Overflow { .. } => {
264 matches!(other, JsonPayloadError::Overflow { .. })
265 }
266 JsonPayloadError::OverflowKnownLength { .. } => {
267 matches!(other, JsonPayloadError::OverflowKnownLength { .. })
268 }
269 JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
270 _ => false,
271 }
272 }
273
274 #[actix_web::test]
275 async fn test_extract() {
276 let (req, mut pl) = TestRequest::default()
277 .insert_header(header::ContentType::json())
278 .insert_header((
279 header::CONTENT_LENGTH,
280 header::HeaderValue::from_static("16"),
281 ))
282 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
283 .to_http_parts();
284
285 let s = Json::<MyObject, DEFAULT_JSON_LIMIT>::from_request(&req, &mut pl)
286 .await
287 .unwrap();
288 assert_eq!(s.name, "test");
289 assert_eq!(
290 s.into_inner(),
291 MyObject {
292 name: "test".to_string()
293 }
294 );
295
296 let (req, mut pl) = TestRequest::default()
297 .insert_header(header::ContentType::json())
298 .insert_header((
299 header::CONTENT_LENGTH,
300 header::HeaderValue::from_static("16"),
301 ))
302 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
303 .to_http_parts();
304
305 let s = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
306 let err = format!("{}", s.unwrap_err());
307 assert!(
308 err.contains("JSON payload (16 bytes) is larger than allowed (limit: 10 bytes)."),
309 "unexpected error string: {err:?}"
310 );
311
312 let (req, mut pl) = TestRequest::default()
313 .insert_header(header::ContentType::json())
314 .insert_header((
315 header::CONTENT_LENGTH,
316 header::HeaderValue::from_static("16"),
317 ))
318 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
319 .to_http_parts();
320 let s = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
321 let err = format!("{}", s.unwrap_err());
322 assert!(
323 err.contains("larger than allowed"),
324 "unexpected error string: {err:?}"
325 );
326 }
327
328 #[actix_web::test]
329 async fn test_json_body() {
330 let (req, mut pl) = TestRequest::default().to_http_parts();
331 let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
332 assert!(json_eq(json.unwrap_err(), JsonPayloadError::ContentType));
333
334 let (req, mut pl) = TestRequest::default()
335 .insert_header((
336 header::CONTENT_TYPE,
337 header::HeaderValue::from_static("application/text"),
338 ))
339 .to_http_parts();
340 let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
341 assert!(json_eq(json.unwrap_err(), JsonPayloadError::ContentType));
342
343 let (req, mut pl) = TestRequest::default()
344 .insert_header(header::ContentType::json())
345 .insert_header((
346 header::CONTENT_LENGTH,
347 header::HeaderValue::from_static("10000"),
348 ))
349 .to_http_parts();
350
351 let json = JsonBody::<MyObject, 100>::new(&req, &mut pl).await;
352 assert!(json_eq(
353 json.unwrap_err(),
354 JsonPayloadError::OverflowKnownLength {
355 length: 10000,
356 limit: 100
357 }
358 ));
359
360 let (req, mut pl) = TestRequest::default()
361 .insert_header(header::ContentType::json())
362 .set_payload(Bytes::from_static(&[0u8; 1000]))
363 .to_http_parts();
364
365 let json = JsonBody::<MyObject, 100>::new(&req, &mut pl).await;
366
367 assert!(json_eq(
368 json.unwrap_err(),
369 JsonPayloadError::Overflow { limit: 100 }
370 ));
371
372 let (req, mut pl) = TestRequest::default()
373 .insert_header(header::ContentType::json())
374 .insert_header((
375 header::CONTENT_LENGTH,
376 header::HeaderValue::from_static("16"),
377 ))
378 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
379 .to_http_parts();
380
381 let json = JsonBody::<MyObject, DEFAULT_JSON_LIMIT>::new(&req, &mut pl).await;
382 assert_eq!(
383 json.ok().unwrap(),
384 MyObject {
385 name: "test".to_owned()
386 }
387 );
388 }
389
390 #[actix_web::test]
391 async fn test_with_json_and_bad_content_type() {
392 let (req, mut pl) = TestRequest::default()
393 .insert_header((
394 header::CONTENT_TYPE,
395 header::HeaderValue::from_static("text/plain"),
396 ))
397 .insert_header((
398 header::CONTENT_LENGTH,
399 header::HeaderValue::from_static("16"),
400 ))
401 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
402 .to_http_parts();
403
404 let s = Json::<MyObject, 4096>::from_request(&req, &mut pl).await;
405 assert!(s.is_err())
406 }
407
408 #[actix_web::test]
409 async fn test_with_config_in_data_wrapper() {
410 let (req, mut pl) = TestRequest::default()
411 .insert_header(header::ContentType::json())
412 .insert_header((header::CONTENT_LENGTH, 16))
413 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
414 .to_http_parts();
415
416 let s = Json::<MyObject, 10>::from_request(&req, &mut pl).await;
417 assert!(s.is_err());
418
419 let err_str = s.unwrap_err().to_string();
420 assert!(
421 err_str.contains("JSON payload (16 bytes) is larger than allowed (limit: 10 bytes).")
422 );
423 }
424}