use std::{
pin::Pin,
task::{Context, Poll},
};
use actix_web::body::{BodySize, MessageBody};
use bytes::Bytes;
use tokio::sync::mpsc::{error::SendError, UnboundedReceiver, UnboundedSender};
use crate::BoxError;
pub fn channel<E: Into<BoxError>>() -> (Sender<E>, impl MessageBody) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
(Sender::new(tx), Receiver::new(rx))
}
#[derive(Debug, Clone)]
pub struct Sender<E> {
tx: UnboundedSender<Result<Bytes, E>>,
}
impl<E> Sender<E> {
fn new(tx: UnboundedSender<Result<Bytes, E>>) -> Self {
Self { tx }
}
pub fn send(&mut self, chunk: Bytes) -> Result<(), Bytes> {
self.tx.send(Ok(chunk)).map_err(|SendError(err)| match err {
Ok(chunk) => chunk,
Err(_) => unreachable!(),
})
}
pub fn close(self, error: Option<E>) -> Result<(), E> {
if let Some(err) = error {
return self.tx.send(Err(err)).map_err(|SendError(err)| match err {
Ok(_) => unreachable!(),
Err(err) => err,
});
}
Ok(())
}
}
#[derive(Debug)]
struct Receiver<E> {
rx: UnboundedReceiver<Result<Bytes, E>>,
}
impl<E> Receiver<E> {
fn new(rx: UnboundedReceiver<Result<Bytes, E>>) -> Self {
Self { rx }
}
}
impl<E> MessageBody for Receiver<E>
where
E: Into<BoxError>,
{
type Error = E;
fn size(&self) -> BodySize {
BodySize::Stream
}
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Self::Error>>> {
self.rx.poll_recv(cx)
}
}
#[cfg(test)]
mod tests {
use std::io;
use super::*;
static_assertions::assert_impl_all!(Sender<io::Error>: Send, Sync, Unpin);
static_assertions::assert_impl_all!(Receiver<io::Error>: Send, Sync, Unpin, MessageBody);
}