1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use std::{
    pin::Pin,
    task::{Context, Poll},
};

use actix_web::body::{BodySize, MessageBody};
use bytes::Bytes;
use tokio::sync::mpsc::{error::SendError, UnboundedReceiver, UnboundedSender};

use crate::BoxError;

/// Returns a sender half and a receiver half that can be used as a body type.
///
/// # Examples
/// ```
/// # use actix_web::{HttpResponse, web};
/// use std::convert::Infallible;
/// use actix_web_lab::body;
///
/// # async fn index() {
/// let (mut body_tx, body) = body::channel::<Infallible>();
///
/// let _ = web::block(move || {
///     body_tx.send(web::Bytes::from_static(b"body from another thread")).unwrap();
/// });
///
/// HttpResponse::Ok().body(body)
/// # ;}
/// ```
pub fn channel<E: Into<BoxError>>() -> (Sender<E>, impl MessageBody) {
    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
    (Sender::new(tx), Receiver::new(rx))
}

/// A channel-like sender for body chunks.
#[derive(Debug, Clone)]
pub struct Sender<E> {
    tx: UnboundedSender<Result<Bytes, E>>,
}

impl<E> Sender<E> {
    fn new(tx: UnboundedSender<Result<Bytes, E>>) -> Self {
        Self { tx }
    }

    /// Submits a chunk of bytes to the response body stream.
    ///
    /// # Errors
    /// Errors if other side of channel body was dropped, returning `chunk`.
    pub fn send(&mut self, chunk: Bytes) -> Result<(), Bytes> {
        self.tx.send(Ok(chunk)).map_err(|SendError(err)| match err {
            Ok(chunk) => chunk,
            Err(_) => unreachable!(),
        })
    }

    /// Closes the stream, optionally sending an error.
    ///
    /// # Errors
    /// Errors if closing with error and other side of channel body was dropped, returning `error`.
    pub fn close(self, error: Option<E>) -> Result<(), E> {
        if let Some(err) = error {
            return self.tx.send(Err(err)).map_err(|SendError(err)| match err {
                Ok(_) => unreachable!(),
                Err(err) => err,
            });
        }

        Ok(())
    }
}

#[derive(Debug)]
struct Receiver<E> {
    rx: UnboundedReceiver<Result<Bytes, E>>,
}

impl<E> Receiver<E> {
    fn new(rx: UnboundedReceiver<Result<Bytes, E>>) -> Self {
        Self { rx }
    }
}

impl<E> MessageBody for Receiver<E>
where
    E: Into<BoxError>,
{
    type Error = E;

    fn size(&self) -> BodySize {
        BodySize::Stream
    }

    fn poll_next(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Bytes, Self::Error>>> {
        self.rx.poll_recv(cx)
    }
}

#[cfg(test)]
mod tests {
    use std::io;

    use super::*;

    static_assertions::assert_impl_all!(Sender<io::Error>: Send, Sync, Unpin);
    static_assertions::assert_impl_all!(Receiver<io::Error>: Send, Sync, Unpin, MessageBody);
}