1use std::{
6 future::Future,
7 pin::Pin,
8 task::{ready, Context, Poll},
9};
10
11use actix_web::{
12 dev, http::StatusCode, web, Error, FromRequest, HttpMessage, HttpRequest, ResponseError,
13};
14use derive_more::{Display, Error};
15use futures_core::Stream as _;
16use tracing::debug;
17
18pub const DEFAULT_BYTES_LIMIT: usize = 4_194_304;
20
21#[derive(Debug)]
54pub struct Bytes<const LIMIT: usize = DEFAULT_BYTES_LIMIT>(pub web::Bytes);
56
57mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
58 use super::*;
59
60 impl<const LIMIT: usize> std::ops::Deref for Bytes<LIMIT> {
61 type Target = web::Bytes;
62
63 fn deref(&self) -> &Self::Target {
64 &self.0
65 }
66 }
67
68 impl<const LIMIT: usize> std::ops::DerefMut for Bytes<LIMIT> {
69 fn deref_mut(&mut self) -> &mut Self::Target {
70 &mut self.0
71 }
72 }
73
74 impl<const LIMIT: usize> AsRef<web::Bytes> for Bytes<LIMIT> {
75 fn as_ref(&self) -> &web::Bytes {
76 &self.0
77 }
78 }
79
80 impl<const LIMIT: usize> AsMut<web::Bytes> for Bytes<LIMIT> {
81 fn as_mut(&mut self) -> &mut web::Bytes {
82 &mut self.0
83 }
84 }
85}
86
87impl<const LIMIT: usize> Bytes<LIMIT> {
88 pub fn into_inner(self) -> web::Bytes {
90 self.0
91 }
92}
93
94impl<const LIMIT: usize> FromRequest for Bytes<LIMIT> {
96 type Error = Error;
97 type Future = BytesExtractFut<LIMIT>;
98
99 #[inline]
100 fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
101 BytesExtractFut {
102 req: Some(req.clone()),
103 fut: BytesBody::new(req, payload),
104 }
105 }
106}
107
108pub struct BytesExtractFut<const LIMIT: usize> {
109 req: Option<HttpRequest>,
110 fut: BytesBody<LIMIT>,
111}
112
113impl<const LIMIT: usize> Future for BytesExtractFut<LIMIT> {
114 type Output = Result<Bytes<LIMIT>, Error>;
115
116 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
117 let this = self.get_mut();
118
119 let res = ready!(Pin::new(&mut this.fut).poll(cx));
120
121 let res = match res {
122 Err(err) => {
123 let req = this.req.take().unwrap();
124
125 debug!(
126 "Failed to extract Bytes from payload in handler: {}",
127 req.match_name().unwrap_or_else(|| req.path())
128 );
129
130 Err(err.into())
131 }
132 Ok(data) => Ok(Bytes(data)),
133 };
134
135 Poll::Ready(res)
136 }
137}
138
139pub enum BytesBody<const LIMIT: usize> {
144 Error(Option<BytesPayloadError>),
145 Body {
146 length: Option<usize>,
148 payload: dev::Payload,
149 buf: web::BytesMut,
150 },
151}
152
153impl<const LIMIT: usize> Unpin for BytesBody<LIMIT> {}
154
155impl<const LIMIT: usize> BytesBody<LIMIT> {
156 pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> Self {
158 let payload = payload.take();
159
160 let length = req
161 .get_header::<crate::header::ContentLength>()
162 .map(|cl| cl.into_inner());
163
164 if let Some(len) = length {
169 if len > LIMIT {
170 return BytesBody::Error(Some(BytesPayloadError::OverflowKnownLength {
171 length: len,
172 limit: LIMIT,
173 }));
174 }
175 }
176
177 BytesBody::Body {
178 length,
179 payload,
180 buf: web::BytesMut::with_capacity(8192),
181 }
182 }
183}
184
185impl<const LIMIT: usize> Future for BytesBody<LIMIT> {
186 type Output = Result<web::Bytes, BytesPayloadError>;
187
188 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189 let this = self.get_mut();
190
191 match this {
192 BytesBody::Body { buf, payload, .. } => loop {
193 let res = ready!(Pin::new(&mut *payload).poll_next(cx));
194
195 match res {
196 Some(chunk) => {
197 let chunk = chunk?;
198 let buf_len = buf.len() + chunk.len();
199 if buf_len > LIMIT {
200 return Poll::Ready(Err(BytesPayloadError::Overflow { limit: LIMIT }));
201 } else {
202 buf.extend_from_slice(&chunk);
203 }
204 }
205
206 None => return Poll::Ready(Ok(buf.split().freeze())),
207 }
208 },
209
210 BytesBody::Error(err) => Poll::Ready(Err(err.take().unwrap())),
211 }
212 }
213}
214
215#[derive(Debug, Display, Error)]
217#[non_exhaustive]
218pub enum BytesPayloadError {
219 #[display(fmt = "Payload ({length} bytes) is larger than allowed (limit: {limit} bytes).")]
221 OverflowKnownLength { length: usize, limit: usize },
222
223 #[display(fmt = "Payload has exceeded limit ({limit} bytes).")]
225 Overflow { limit: usize },
226
227 #[display(fmt = "Error that occur during reading payload: {_0}")]
229 Payload(actix_web::error::PayloadError),
230}
231
232impl From<actix_web::error::PayloadError> for BytesPayloadError {
233 fn from(err: actix_web::error::PayloadError) -> Self {
234 Self::Payload(err)
235 }
236}
237
238impl ResponseError for BytesPayloadError {
239 fn status_code(&self) -> StatusCode {
240 match self {
241 Self::OverflowKnownLength { .. } => StatusCode::PAYLOAD_TOO_LARGE,
242 Self::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE,
243 Self::Payload(err) => err.status_code(),
244 }
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use actix_web::{http::header, test::TestRequest, web};
251
252 use super::*;
253
254 #[cfg(test)]
255 impl PartialEq for BytesPayloadError {
256 fn eq(&self, other: &Self) -> bool {
257 match (self, other) {
258 (
259 Self::OverflowKnownLength {
260 length: l_length,
261 limit: l_limit,
262 },
263 Self::OverflowKnownLength {
264 length: r_length,
265 limit: r_limit,
266 },
267 ) => l_length == r_length && l_limit == r_limit,
268
269 (Self::Overflow { limit: l_limit }, Self::Overflow { limit: r_limit }) => {
270 l_limit == r_limit
271 }
272
273 _ => false,
274 }
275 }
276 }
277
278 #[actix_web::test]
279 async fn extract() {
280 let (req, mut pl) = TestRequest::default()
281 .insert_header(header::ContentType::json())
282 .insert_header(crate::header::ContentLength::from(3))
283 .set_payload(web::Bytes::from_static(b"foo"))
284 .to_http_parts();
285
286 let s = Bytes::<DEFAULT_BYTES_LIMIT>::from_request(&req, &mut pl)
287 .await
288 .unwrap();
289 assert_eq!(s.as_ref(), "foo");
290
291 let (req, mut pl) = TestRequest::default()
292 .insert_header(header::ContentType::json())
293 .insert_header(crate::header::ContentLength::from(16))
294 .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
295 .to_http_parts();
296
297 let s = Bytes::<10>::from_request(&req, &mut pl).await;
298 let err_str = s.unwrap_err().to_string();
299 assert_eq!(
300 err_str,
301 "Payload (16 bytes) is larger than allowed (limit: 10 bytes).",
302 );
303
304 let (req, mut pl) = TestRequest::default()
305 .insert_header(header::ContentType::json())
306 .insert_header(crate::header::ContentLength::from(16))
307 .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
308 .to_http_parts();
309 let s = Bytes::<10>::from_request(&req, &mut pl).await;
310 let err = format!("{}", s.unwrap_err());
311 assert!(
312 err.contains("larger than allowed"),
313 "unexpected error string: {err:?}",
314 );
315 }
316
317 #[actix_web::test]
318 async fn body() {
319 let (req, mut pl) = TestRequest::default().to_http_parts();
320 let _bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
321 .await
322 .unwrap();
323
324 let (req, mut pl) = TestRequest::default()
325 .insert_header(header::ContentType("application/text".parse().unwrap()))
326 .to_http_parts();
327 BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
329 .await
330 .unwrap();
331
332 let (req, mut pl) = TestRequest::default()
333 .insert_header(header::ContentType::json())
334 .insert_header(crate::header::ContentLength::from(10000))
335 .to_http_parts();
336
337 let bytes = BytesBody::<100>::new(&req, &mut pl).await;
338 assert_eq!(
339 bytes.unwrap_err(),
340 BytesPayloadError::OverflowKnownLength {
341 length: 10000,
342 limit: 100
343 }
344 );
345
346 let (req, mut pl) = TestRequest::default()
347 .insert_header(header::ContentType::json())
348 .set_payload(web::Bytes::from_static(&[0u8; 1000]))
349 .to_http_parts();
350
351 let bytes = BytesBody::<100>::new(&req, &mut pl).await;
352
353 assert_eq!(
354 bytes.unwrap_err(),
355 BytesPayloadError::Overflow { limit: 100 }
356 );
357
358 let (req, mut pl) = TestRequest::default()
359 .insert_header(header::ContentType::json())
360 .insert_header(crate::header::ContentLength::from(16))
361 .set_payload(web::Bytes::from_static(b"foo foo foo foo"))
362 .to_http_parts();
363
364 let bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl).await;
365 assert_eq!(bytes.ok().unwrap(), "foo foo foo foo");
366 }
367
368 #[actix_web::test]
369 async fn test_with_config_in_data_wrapper() {
370 let (req, mut pl) = TestRequest::default()
371 .app_data(web::Data::new(web::PayloadConfig::default().limit(8)))
372 .insert_header(header::ContentType::json())
373 .insert_header((header::CONTENT_LENGTH, 16))
374 .set_payload(web::Bytes::from_static(b"{\"name\": \"test\"}"))
375 .to_http_parts();
376
377 let s = Bytes::<10>::from_request(&req, &mut pl).await;
378 assert!(s.is_err());
379
380 let err_str = s.unwrap_err().to_string();
381 assert_eq!(
382 err_str,
383 "Payload (16 bytes) is larger than allowed (limit: 10 bytes).",
384 );
385 }
386}