use std::task::Poll;
use actix_rt::pin;
use actix_utils::future::poll_fn;
use bytes::{Bytes, BytesMut};
use derive_more::{Display, Error};
use futures_core::ready;
use super::{BodySize, MessageBody};
pub async fn to_bytes<B: MessageBody>(body: B) -> Result<Bytes, B::Error> {
to_bytes_limited(body, usize::MAX)
.await
.expect("body should never yield more than usize::MAX bytes")
}
#[derive(Debug, Display, Error)]
#[display(fmt = "limit exceeded while collecting body bytes")]
#[non_exhaustive]
pub struct BodyLimitExceeded;
pub async fn to_bytes_limited<B: MessageBody>(
body: B,
limit: usize,
) -> Result<Result<Bytes, B::Error>, BodyLimitExceeded> {
const INITIAL_ALLOC_BYTES: usize = 32 * 1024;
let cap = match body.size() {
BodySize::None | BodySize::Sized(0) => return Ok(Ok(Bytes::new())),
BodySize::Sized(size) if size as usize > limit => return Err(BodyLimitExceeded),
BodySize::Sized(size) => (size as usize).min(INITIAL_ALLOC_BYTES),
BodySize::Stream => INITIAL_ALLOC_BYTES,
};
let mut exceeded_limit = false;
let mut buf = BytesMut::with_capacity(cap);
pin!(body);
match poll_fn(|cx| loop {
let body = body.as_mut();
match ready!(body.poll_next(cx)) {
Some(Ok(bytes)) => {
if buf.len() + bytes.len() > limit {
exceeded_limit = true;
return Poll::Ready(Ok(()));
}
buf.extend_from_slice(&bytes)
}
None => return Poll::Ready(Ok(())),
Some(Err(err)) => return Poll::Ready(Err(err)),
}
})
.await
{
Err(err) => Ok(Err(err)),
Ok(()) if exceeded_limit => Err(BodyLimitExceeded),
Ok(()) => Ok(Ok(buf.freeze())),
}
}
#[cfg(test)]
mod tests {
use std::io;
use futures_util::{stream, StreamExt as _};
use super::*;
use crate::{
body::{BodyStream, SizedStream},
Error,
};
#[actix_rt::test]
async fn to_bytes_complete() {
let bytes = to_bytes(()).await.unwrap();
assert!(bytes.is_empty());
let body = Bytes::from_static(b"123");
let bytes = to_bytes(body).await.unwrap();
assert_eq!(bytes, b"123"[..]);
}
#[actix_rt::test]
async fn to_bytes_streams() {
let stream = stream::iter(vec![Bytes::from_static(b"123"), Bytes::from_static(b"abc")])
.map(Ok::<_, Error>);
let body = BodyStream::new(stream);
let bytes = to_bytes(body).await.unwrap();
assert_eq!(bytes, b"123abc"[..]);
}
#[actix_rt::test]
async fn to_bytes_limited_complete() {
let bytes = to_bytes_limited((), 0).await.unwrap().unwrap();
assert!(bytes.is_empty());
let bytes = to_bytes_limited((), 1).await.unwrap().unwrap();
assert!(bytes.is_empty());
assert!(to_bytes_limited(Bytes::from_static(b"12"), 0)
.await
.is_err());
assert!(to_bytes_limited(Bytes::from_static(b"12"), 1)
.await
.is_err());
assert!(to_bytes_limited(Bytes::from_static(b"12"), 2).await.is_ok());
assert!(to_bytes_limited(Bytes::from_static(b"12"), 3).await.is_ok());
}
#[actix_rt::test]
async fn to_bytes_limited_streams() {
let body = SizedStream::new(8, stream::empty().map(Ok::<_, Error>));
assert!(to_bytes_limited(body, 3).await.is_err());
let body = SizedStream::new(3, stream::empty().map(Ok::<_, Error>));
assert!(to_bytes_limited(body, 3).await.unwrap().unwrap().is_empty());
let stream = stream::iter(vec![Bytes::from_static(b"1234")]).map(Ok::<_, Error>);
let body = SizedStream::new(3, stream);
assert!(to_bytes_limited(body, 3).await.is_err());
let stream = stream::iter(vec![Bytes::from_static(b"123"), Bytes::from_static(b"abc")])
.map(Ok::<_, Error>);
let body = BodyStream::new(stream);
assert!(to_bytes_limited(body, 3).await.is_err());
}
#[actix_rt::test]
async fn to_body_limit_error() {
let err_stream = stream::once(async { Err(io::Error::new(io::ErrorKind::Other, "")) });
let body = SizedStream::new(8, err_stream);
assert!(to_bytes_limited(body, 10).await.unwrap().is_err());
}
}