use std::{
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};
use actix_web::{
dev, http::StatusCode, web, Error, FromRequest, HttpMessage, HttpRequest, ResponseError,
};
use derive_more::{AsMut, AsRef, Deref, DerefMut, Display, Error};
use futures_core::Stream as _;
use tracing::debug;
pub const DEFAULT_BYTES_LIMIT: usize = 4_194_304;
#[derive(Debug, Deref, DerefMut, AsRef, AsMut)]
pub struct Bytes<const LIMIT: usize = DEFAULT_BYTES_LIMIT>(pub web::Bytes);
impl<const LIMIT: usize> Bytes<LIMIT> {
pub fn into_inner(self) -> web::Bytes {
self.0
}
}
impl<const LIMIT: usize> FromRequest for Bytes<LIMIT> {
type Error = Error;
type Future = BytesExtractFut<LIMIT>;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
BytesExtractFut {
req: Some(req.clone()),
fut: BytesBody::new(req, payload),
}
}
}
pub struct BytesExtractFut<const LIMIT: usize> {
req: Option<HttpRequest>,
fut: BytesBody<LIMIT>,
}
impl<const LIMIT: usize> Future for BytesExtractFut<LIMIT> {
type Output = Result<Bytes<LIMIT>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let res = ready!(Pin::new(&mut this.fut).poll(cx));
let res = match res {
Err(err) => {
let req = this.req.take().unwrap();
debug!(
"Failed to extract Bytes from payload in handler: {}",
req.match_name().unwrap_or_else(|| req.path())
);
Err(err.into())
}
Ok(data) => Ok(Bytes(data)),
};
Poll::Ready(res)
}
}
pub enum BytesBody<const LIMIT: usize> {
Error(Option<BytesPayloadError>),
Body {
length: Option<usize>,
payload: dev::Payload,
buf: web::BytesMut,
},
}
impl<const LIMIT: usize> Unpin for BytesBody<LIMIT> {}
impl<const LIMIT: usize> BytesBody<LIMIT> {
pub fn new(req: &HttpRequest, payload: &mut dev::Payload) -> Self {
let payload = payload.take();
let length = req
.get_header::<crate::header::ContentLength>()
.map(|cl| cl.into_inner());
if let Some(len) = length {
if len > LIMIT {
return BytesBody::Error(Some(BytesPayloadError::OverflowKnownLength {
length: len,
limit: LIMIT,
}));
}
}
BytesBody::Body {
length,
payload,
buf: web::BytesMut::with_capacity(8192),
}
}
}
impl<const LIMIT: usize> Future for BytesBody<LIMIT> {
type Output = Result<web::Bytes, BytesPayloadError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match this {
BytesBody::Body { buf, payload, .. } => loop {
let res = ready!(Pin::new(&mut *payload).poll_next(cx));
match res {
Some(chunk) => {
let chunk = chunk?;
let buf_len = buf.len() + chunk.len();
if buf_len > LIMIT {
return Poll::Ready(Err(BytesPayloadError::Overflow { limit: LIMIT }));
} else {
buf.extend_from_slice(&chunk);
}
}
None => return Poll::Ready(Ok(buf.split().freeze())),
}
},
BytesBody::Error(err) => Poll::Ready(Err(err.take().unwrap())),
}
}
}
#[derive(Debug, Display, Error)]
#[non_exhaustive]
pub enum BytesPayloadError {
#[display(
fmt = "Payload ({length} bytes) is larger than allowed (limit: {limit} bytes)."
)]
OverflowKnownLength { length: usize, limit: usize },
#[display(fmt = "Payload has exceeded limit ({limit} bytes).")]
Overflow { limit: usize },
#[display(fmt = "Error that occur during reading payload: {_0}")]
Payload(actix_web::error::PayloadError),
}
impl From<actix_web::error::PayloadError> for BytesPayloadError {
fn from(err: actix_web::error::PayloadError) -> Self {
Self::Payload(err)
}
}
impl ResponseError for BytesPayloadError {
fn status_code(&self) -> StatusCode {
match self {
Self::OverflowKnownLength { .. } => StatusCode::PAYLOAD_TOO_LARGE,
Self::Overflow { .. } => StatusCode::PAYLOAD_TOO_LARGE,
Self::Payload(err) => err.status_code(),
}
}
}
#[cfg(test)]
mod tests {
use actix_web::{http::header, test::TestRequest, web};
use super::*;
#[cfg(test)]
impl PartialEq for BytesPayloadError {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(
Self::OverflowKnownLength {
length: l_length,
limit: l_limit,
},
Self::OverflowKnownLength {
length: r_length,
limit: r_limit,
},
) => l_length == r_length && l_limit == r_limit,
(Self::Overflow { limit: l_limit }, Self::Overflow { limit: r_limit }) => {
l_limit == r_limit
}
_ => false,
}
}
}
#[actix_web::test]
async fn extract() {
let (req, mut pl) = TestRequest::default()
.insert_header(header::ContentType::json())
.insert_header(crate::header::ContentLength::from(3))
.set_payload(web::Bytes::from_static(b"foo"))
.to_http_parts();
let s = Bytes::<DEFAULT_BYTES_LIMIT>::from_request(&req, &mut pl)
.await
.unwrap();
assert_eq!(s.as_ref(), "foo");
let (req, mut pl) = TestRequest::default()
.insert_header(header::ContentType::json())
.insert_header(crate::header::ContentLength::from(16))
.set_payload(web::Bytes::from_static(b"foo foo foo foo"))
.to_http_parts();
let s = Bytes::<10>::from_request(&req, &mut pl).await;
let err_str = s.unwrap_err().to_string();
assert_eq!(
err_str,
"Payload (16 bytes) is larger than allowed (limit: 10 bytes).",
);
let (req, mut pl) = TestRequest::default()
.insert_header(header::ContentType::json())
.insert_header(crate::header::ContentLength::from(16))
.set_payload(web::Bytes::from_static(b"foo foo foo foo"))
.to_http_parts();
let s = Bytes::<10>::from_request(&req, &mut pl).await;
let err = format!("{}", s.unwrap_err());
assert!(
err.contains("larger than allowed"),
"unexpected error string: {err:?}",
);
}
#[actix_web::test]
async fn body() {
let (req, mut pl) = TestRequest::default().to_http_parts();
let _bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
.await
.unwrap();
let (req, mut pl) = TestRequest::default()
.insert_header(header::ContentType("application/text".parse().unwrap()))
.to_http_parts();
BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl)
.await
.unwrap();
let (req, mut pl) = TestRequest::default()
.insert_header(header::ContentType::json())
.insert_header(crate::header::ContentLength::from(10000))
.to_http_parts();
let bytes = BytesBody::<100>::new(&req, &mut pl).await;
assert_eq!(
bytes.unwrap_err(),
BytesPayloadError::OverflowKnownLength {
length: 10000,
limit: 100
}
);
let (req, mut pl) = TestRequest::default()
.insert_header(header::ContentType::json())
.set_payload(web::Bytes::from_static(&[0u8; 1000]))
.to_http_parts();
let bytes = BytesBody::<100>::new(&req, &mut pl).await;
assert_eq!(
bytes.unwrap_err(),
BytesPayloadError::Overflow { limit: 100 }
);
let (req, mut pl) = TestRequest::default()
.insert_header(header::ContentType::json())
.insert_header(crate::header::ContentLength::from(16))
.set_payload(web::Bytes::from_static(b"foo foo foo foo"))
.to_http_parts();
let bytes = BytesBody::<DEFAULT_BYTES_LIMIT>::new(&req, &mut pl).await;
assert_eq!(bytes.ok().unwrap(), "foo foo foo foo");
}
#[actix_web::test]
async fn test_with_config_in_data_wrapper() {
let (req, mut pl) = TestRequest::default()
.app_data(web::Data::new(web::PayloadConfig::default().limit(8)))
.insert_header(header::ContentType::json())
.insert_header((header::CONTENT_LENGTH, 16))
.set_payload(web::Bytes::from_static(b"{\"name\": \"test\"}"))
.to_http_parts();
let s = Bytes::<10>::from_request(&req, &mut pl).await;
assert!(s.is_err());
let err_str = s.unwrap_err().to_string();
assert_eq!(
err_str,
"Payload (16 bytes) is larger than allowed (limit: 10 bytes).",
);
}
}