local_channel/
mpsc.rs

1//! A non-thread-safe multi-producer, single-consumer, futures-aware, FIFO queue.
2
3use alloc::{collections::VecDeque, rc::Rc};
4use core::{
5    cell::RefCell,
6    fmt,
7    future::poll_fn,
8    pin::Pin,
9    task::{Context, Poll},
10};
11use std::error::Error;
12
13use futures_core::stream::Stream;
14use futures_sink::Sink;
15use local_waker::LocalWaker;
16
17/// Creates a unbounded in-memory channel with buffered storage.
18///
19/// [Sender]s and [Receiver]s are `!Send`.
20pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
21    let shared = Rc::new(RefCell::new(Shared {
22        has_receiver: true,
23        buffer: VecDeque::new(),
24        blocked_recv: LocalWaker::new(),
25    }));
26
27    let sender = Sender {
28        shared: shared.clone(),
29    };
30
31    let receiver = Receiver { shared };
32
33    (sender, receiver)
34}
35
36#[derive(Debug)]
37struct Shared<T> {
38    buffer: VecDeque<T>,
39    blocked_recv: LocalWaker,
40    has_receiver: bool,
41}
42
43/// The transmission end of a channel.
44///
45/// This is created by the `channel` function.
46#[derive(Debug)]
47pub struct Sender<T> {
48    shared: Rc<RefCell<Shared<T>>>,
49}
50
51impl<T> Unpin for Sender<T> {}
52
53impl<T> Sender<T> {
54    /// Sends the provided message along this channel.
55    pub fn send(&self, item: T) -> Result<(), SendError<T>> {
56        let mut shared = self.shared.borrow_mut();
57
58        if !shared.has_receiver {
59            // receiver was dropped
60            return Err(SendError(item));
61        };
62
63        shared.buffer.push_back(item);
64        shared.blocked_recv.wake();
65
66        Ok(())
67    }
68
69    /// Closes the sender half.
70    ///
71    /// This prevents any further messages from being sent on the channel, by any sender, while
72    /// still enabling the receiver to drain messages that are already buffered.
73    pub fn close(&mut self) {
74        self.shared.borrow_mut().has_receiver = false;
75    }
76}
77
78impl<T> Clone for Sender<T> {
79    fn clone(&self) -> Self {
80        Sender {
81            shared: self.shared.clone(),
82        }
83    }
84}
85
86impl<T> Sink<T> for Sender<T> {
87    type Error = SendError<T>;
88
89    fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90        Poll::Ready(Ok(()))
91    }
92
93    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), SendError<T>> {
94        self.send(item)
95    }
96
97    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), SendError<T>>> {
98        Poll::Ready(Ok(()))
99    }
100
101    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102        Poll::Ready(Ok(()))
103    }
104}
105
106impl<T> Drop for Sender<T> {
107    fn drop(&mut self) {
108        let count = Rc::strong_count(&self.shared);
109        let shared = self.shared.borrow_mut();
110
111        // check is last sender is about to drop
112        if shared.has_receiver && count == 2 {
113            // Wake up receiver as its stream has ended
114            shared.blocked_recv.wake();
115        }
116    }
117}
118
119/// The receiving end of a channel which implements the `Stream` trait.
120///
121/// This is created by the [`channel`] function.
122#[derive(Debug)]
123pub struct Receiver<T> {
124    shared: Rc<RefCell<Shared<T>>>,
125}
126
127impl<T> Receiver<T> {
128    /// Receive the next value.
129    ///
130    /// Returns `None` if the channel is empty and has been [closed](Sender::close) explicitly or
131    /// when all senders have been dropped and, therefore, no more values can ever be sent though
132    /// this channel.
133    pub async fn recv(&mut self) -> Option<T> {
134        let mut this = Pin::new(self);
135        poll_fn(|cx| this.as_mut().poll_next(cx)).await
136    }
137
138    /// Create an associated [Sender].
139    pub fn sender(&self) -> Sender<T> {
140        Sender {
141            shared: self.shared.clone(),
142        }
143    }
144}
145
146impl<T> Unpin for Receiver<T> {}
147
148impl<T> Stream for Receiver<T> {
149    type Item = T;
150
151    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
152        let mut shared = self.shared.borrow_mut();
153
154        if Rc::strong_count(&self.shared) == 1 {
155            // All senders have been dropped, so drain the buffer and end the stream.
156            return Poll::Ready(shared.buffer.pop_front());
157        }
158
159        if let Some(msg) = shared.buffer.pop_front() {
160            Poll::Ready(Some(msg))
161        } else {
162            shared.blocked_recv.register(cx.waker());
163            Poll::Pending
164        }
165    }
166}
167
168impl<T> Drop for Receiver<T> {
169    fn drop(&mut self) {
170        let mut shared = self.shared.borrow_mut();
171        shared.buffer.clear();
172        shared.has_receiver = false;
173    }
174}
175
176/// Error returned when attempting to send after the channels' [Receiver] is dropped or closed.
177///
178/// Allows access to message that failed to send with [`into_inner`](Self::into_inner).
179pub struct SendError<T>(pub T);
180
181impl<T> SendError<T> {
182    /// Returns the message that was attempted to be sent but failed.
183    pub fn into_inner(self) -> T {
184        self.0
185    }
186}
187
188impl<T> fmt::Debug for SendError<T> {
189    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
190        fmt.debug_tuple("SendError").field(&"...").finish()
191    }
192}
193
194impl<T> fmt::Display for SendError<T> {
195    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
196        write!(fmt, "send failed because receiver is gone")
197    }
198}
199
200impl<T> Error for SendError<T> {}
201
202#[cfg(test)]
203mod tests {
204    use futures_util::{future::lazy, StreamExt as _};
205
206    use super::*;
207
208    #[tokio::test]
209    async fn test_mpsc() {
210        let (tx, mut rx) = channel();
211        tx.send("test").unwrap();
212        assert_eq!(rx.next().await.unwrap(), "test");
213
214        let tx2 = tx.clone();
215        tx2.send("test2").unwrap();
216        assert_eq!(rx.next().await.unwrap(), "test2");
217
218        assert_eq!(
219            lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
220            Poll::Pending
221        );
222        drop(tx2);
223        assert_eq!(
224            lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
225            Poll::Pending
226        );
227        drop(tx);
228        assert_eq!(rx.next().await, None);
229
230        let (tx, rx) = channel();
231        tx.send("test").unwrap();
232        drop(rx);
233        assert!(tx.send("test").is_err());
234
235        let (mut tx, _) = channel();
236        let tx2 = tx.clone();
237        tx.close();
238        assert!(tx.send("test").is_err());
239        assert!(tx2.send("test").is_err());
240    }
241
242    #[tokio::test]
243    async fn test_recv() {
244        let (tx, mut rx) = channel();
245        tx.send("test").unwrap();
246        assert_eq!(rx.recv().await.unwrap(), "test");
247        drop(tx);
248
249        let (tx, mut rx) = channel();
250        tx.send("test").unwrap();
251        assert_eq!(rx.recv().await.unwrap(), "test");
252        drop(tx);
253        assert!(rx.recv().await.is_none());
254    }
255}