1use std::{
6 fmt,
7 future::Future,
8 pin::Pin,
9 task::{ready, Context, Poll},
10};
11
12use actix_web::{
13 dev::{self, Payload},
14 FromRequest, HttpMessage as _, HttpRequest, ResponseError,
15};
16use derive_more::Display;
17use futures_core::Stream as _;
18
19use crate::header::ContentLength;
20
21pub const DEFAULT_BODY_LIMIT: usize = 2_097_152;
23
24#[derive(Debug, PartialEq, Eq)]
41pub struct BodyLimit<T, const LIMIT: usize = DEFAULT_BODY_LIMIT> {
42 inner: T,
43}
44
45mod waiting_on_derive_more_to_start_using_syn_2_due_to_proc_macro_panic {
46 use super::*;
47
48 impl<T: std::fmt::Display, const LIMIT: usize> std::fmt::Display for BodyLimit<T, LIMIT> {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 std::fmt::Display::fmt(&self.inner, f)
51 }
52 }
53
54 impl<T, const LIMIT: usize> AsRef<T> for BodyLimit<T, LIMIT> {
55 fn as_ref(&self) -> &T {
56 &self.inner
57 }
58 }
59
60 impl<T, const LIMIT: usize> From<T> for BodyLimit<T, LIMIT> {
61 fn from(inner: T) -> Self {
62 Self { inner }
63 }
64 }
65}
66
67impl<T, const LIMIT: usize> BodyLimit<T, LIMIT> {
68 pub fn into_inner(self) -> T {
70 self.inner
71 }
72}
73
74impl<T, const LIMIT: usize> FromRequest for BodyLimit<T, LIMIT>
75where
76 T: FromRequest + 'static,
77 T::Error: fmt::Debug + fmt::Display,
78{
79 type Error = BodyLimitError<T>;
80 type Future = BodyLimitFut<T, LIMIT>;
81
82 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
83 match req.get_header::<ContentLength>() {
85 Some(len) if len > LIMIT => return BodyLimitFut::new_error(BodyLimitError::Overflow),
87 _ => {}
88 }
89
90 let counter = crate::util::fork_request_payload(payload);
91
92 BodyLimitFut {
93 inner: Inner::Body {
94 fut: Box::pin(T::from_request(req, payload)),
95 counter_pl: counter,
96 size: 0,
97 },
98 }
99 }
100}
101
102pub struct BodyLimitFut<T, const LIMIT: usize>
103where
104 T: FromRequest + 'static,
105 T::Error: fmt::Debug + fmt::Display,
106{
107 inner: Inner<T, LIMIT>,
108}
109
110impl<T, const LIMIT: usize> BodyLimitFut<T, LIMIT>
111where
112 T: FromRequest + 'static,
113 T::Error: fmt::Debug + fmt::Display,
114{
115 fn new_error(err: BodyLimitError<T>) -> Self {
116 Self {
117 inner: Inner::Error { err: Some(err) },
118 }
119 }
120}
121
122enum Inner<T, const LIMIT: usize>
123where
124 T: FromRequest + 'static,
125 T::Error: fmt::Debug + fmt::Display,
126{
127 Error {
128 err: Option<BodyLimitError<T>>,
129 },
130
131 Body {
132 fut: Pin<Box<T::Future>>,
134
135 counter_pl: dev::Payload,
137
138 size: usize,
140 },
141}
142
143impl<T, const LIMIT: usize> Unpin for Inner<T, LIMIT>
144where
145 T: FromRequest + 'static,
146 T::Error: fmt::Debug + fmt::Display,
147{
148}
149
150impl<T, const LIMIT: usize> Future for BodyLimitFut<T, LIMIT>
151where
152 T: FromRequest + 'static,
153 T::Error: fmt::Debug + fmt::Display,
154{
155 type Output = Result<BodyLimit<T, LIMIT>, BodyLimitError<T>>;
156
157 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158 let this = &mut self.get_mut().inner;
159
160 match this {
161 Inner::Error { err } => Poll::Ready(Err(err.take().unwrap())),
162
163 Inner::Body {
164 fut,
165 counter_pl,
166 size,
167 } => {
168 let res = ready!(fut.as_mut().poll(cx).map_err(BodyLimitError::Extractor)?);
170
171 while let Poll::Ready(Some(Ok(chunk))) = Pin::new(&mut *counter_pl).poll_next(cx) {
173 *size += chunk.len();
175
176 if *size > LIMIT {
177 return Poll::Ready(Err(BodyLimitError::Overflow));
178 }
179 }
180
181 let ret = BodyLimit { inner: res };
182
183 Poll::Ready(Ok(ret))
184 }
185 }
186 }
187}
188
189#[derive(Display)]
190pub enum BodyLimitError<T>
191where
192 T: FromRequest + 'static,
193 T::Error: fmt::Debug + fmt::Display,
194{
195 #[display(fmt = "Wrapped extractor error: {_0}")]
196 Extractor(T::Error),
197
198 #[display(fmt = "Body was too large")]
199 Overflow,
200}
201
202impl<T> fmt::Debug for BodyLimitError<T>
203where
204 T: FromRequest + 'static,
205 T::Error: fmt::Debug + fmt::Display,
206{
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 match self {
209 Self::Extractor(err) => f
210 .debug_tuple("BodyLimitError::Extractor")
211 .field(err)
212 .finish(),
213
214 Self::Overflow => write!(f, "BodyLimitError::Overflow"),
215 }
216 }
217}
218
219impl<T> ResponseError for BodyLimitError<T>
220where
221 T: FromRequest + 'static,
222 T::Error: fmt::Debug + fmt::Display,
223{
224}
225
226#[cfg(test)]
227mod tests {
228 use actix_web::{http::header, test::TestRequest};
229 use bytes::Bytes;
230
231 use super::*;
232
233 static_assertions::assert_impl_all!(BodyLimitFut<(), 100>: Unpin);
234 static_assertions::assert_impl_all!(BodyLimitFut<Bytes, 100>: Unpin);
235
236 #[actix_web::test]
237 async fn within_limit() {
238 let (req, mut pl) = TestRequest::default()
239 .insert_header(header::ContentType::plaintext())
240 .insert_header((
241 header::CONTENT_LENGTH,
242 header::HeaderValue::from_static("9"),
243 ))
244 .set_payload(Bytes::from_static(b"123456789"))
245 .to_http_parts();
246
247 let body = BodyLimit::<Bytes, 10>::from_request(&req, &mut pl).await;
248 assert_eq!(
249 body.ok().unwrap().into_inner(),
250 Bytes::from_static(b"123456789")
251 );
252 }
253
254 #[actix_web::test]
255 async fn exceeds_limit() {
256 let (req, mut pl) = TestRequest::default()
257 .insert_header(header::ContentType::plaintext())
258 .insert_header((
259 header::CONTENT_LENGTH,
260 header::HeaderValue::from_static("10"),
261 ))
262 .set_payload(Bytes::from_static(b"0123456789"))
263 .to_http_parts();
264
265 let body = BodyLimit::<Bytes, 4>::from_request(&req, &mut pl).await;
266 assert!(matches!(body.unwrap_err(), BodyLimitError::Overflow));
267
268 let (req, mut pl) = TestRequest::default()
269 .insert_header(header::ContentType::plaintext())
270 .insert_header((
271 header::TRANSFER_ENCODING,
272 header::HeaderValue::from_static("chunked"),
273 ))
274 .set_payload(Bytes::from_static(b"10\r\n0123456789\r\n0"))
275 .to_http_parts();
276
277 let body = BodyLimit::<Bytes, 4>::from_request(&req, &mut pl).await;
278 assert!(matches!(body.unwrap_err(), BodyLimitError::Overflow));
279 }
280}