1use alloc::{collections::VecDeque, rc::Rc};
4use core::{
5 cell::RefCell,
6 fmt,
7 future::poll_fn,
8 pin::Pin,
9 task::{Context, Poll},
10};
11use std::error::Error;
12
13use futures_core::stream::Stream;
14use futures_sink::Sink;
15use local_waker::LocalWaker;
16
17pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
21 let shared = Rc::new(RefCell::new(Shared {
22 has_receiver: true,
23 buffer: VecDeque::new(),
24 blocked_recv: LocalWaker::new(),
25 }));
26
27 let sender = Sender {
28 shared: shared.clone(),
29 };
30
31 let receiver = Receiver { shared };
32
33 (sender, receiver)
34}
35
36#[derive(Debug)]
37struct Shared<T> {
38 buffer: VecDeque<T>,
39 blocked_recv: LocalWaker,
40 has_receiver: bool,
41}
42
43#[derive(Debug)]
47pub struct Sender<T> {
48 shared: Rc<RefCell<Shared<T>>>,
49}
50
51impl<T> Unpin for Sender<T> {}
52
53impl<T> Sender<T> {
54 pub fn send(&self, item: T) -> Result<(), SendError<T>> {
56 let mut shared = self.shared.borrow_mut();
57
58 if !shared.has_receiver {
59 return Err(SendError(item));
61 };
62
63 shared.buffer.push_back(item);
64 shared.blocked_recv.wake();
65
66 Ok(())
67 }
68
69 pub fn close(&mut self) {
74 self.shared.borrow_mut().has_receiver = false;
75 }
76}
77
78impl<T> Clone for Sender<T> {
79 fn clone(&self) -> Self {
80 Sender {
81 shared: self.shared.clone(),
82 }
83 }
84}
85
86impl<T> Sink<T> for Sender<T> {
87 type Error = SendError<T>;
88
89 fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90 Poll::Ready(Ok(()))
91 }
92
93 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), SendError<T>> {
94 self.send(item)
95 }
96
97 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), SendError<T>>> {
98 Poll::Ready(Ok(()))
99 }
100
101 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102 Poll::Ready(Ok(()))
103 }
104}
105
106impl<T> Drop for Sender<T> {
107 fn drop(&mut self) {
108 let count = Rc::strong_count(&self.shared);
109 let shared = self.shared.borrow_mut();
110
111 if shared.has_receiver && count == 2 {
113 shared.blocked_recv.wake();
115 }
116 }
117}
118
119#[derive(Debug)]
123pub struct Receiver<T> {
124 shared: Rc<RefCell<Shared<T>>>,
125}
126
127impl<T> Receiver<T> {
128 pub async fn recv(&mut self) -> Option<T> {
134 let mut this = Pin::new(self);
135 poll_fn(|cx| this.as_mut().poll_next(cx)).await
136 }
137
138 pub fn sender(&self) -> Sender<T> {
140 Sender {
141 shared: self.shared.clone(),
142 }
143 }
144}
145
146impl<T> Unpin for Receiver<T> {}
147
148impl<T> Stream for Receiver<T> {
149 type Item = T;
150
151 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
152 let mut shared = self.shared.borrow_mut();
153
154 if Rc::strong_count(&self.shared) == 1 {
155 return Poll::Ready(shared.buffer.pop_front());
157 }
158
159 if let Some(msg) = shared.buffer.pop_front() {
160 Poll::Ready(Some(msg))
161 } else {
162 shared.blocked_recv.register(cx.waker());
163 Poll::Pending
164 }
165 }
166}
167
168impl<T> Drop for Receiver<T> {
169 fn drop(&mut self) {
170 let mut shared = self.shared.borrow_mut();
171 shared.buffer.clear();
172 shared.has_receiver = false;
173 }
174}
175
176pub struct SendError<T>(pub T);
180
181impl<T> SendError<T> {
182 pub fn into_inner(self) -> T {
184 self.0
185 }
186}
187
188impl<T> fmt::Debug for SendError<T> {
189 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
190 fmt.debug_tuple("SendError").field(&"...").finish()
191 }
192}
193
194impl<T> fmt::Display for SendError<T> {
195 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
196 write!(fmt, "send failed because receiver is gone")
197 }
198}
199
200impl<T> Error for SendError<T> {}
201
202#[cfg(test)]
203mod tests {
204 use futures_util::{future::lazy, StreamExt as _};
205
206 use super::*;
207
208 #[tokio::test]
209 async fn test_mpsc() {
210 let (tx, mut rx) = channel();
211 tx.send("test").unwrap();
212 assert_eq!(rx.next().await.unwrap(), "test");
213
214 let tx2 = tx.clone();
215 tx2.send("test2").unwrap();
216 assert_eq!(rx.next().await.unwrap(), "test2");
217
218 assert_eq!(
219 lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
220 Poll::Pending
221 );
222 drop(tx2);
223 assert_eq!(
224 lazy(|cx| Pin::new(&mut rx).poll_next(cx)).await,
225 Poll::Pending
226 );
227 drop(tx);
228 assert_eq!(rx.next().await, None);
229
230 let (tx, rx) = channel();
231 tx.send("test").unwrap();
232 drop(rx);
233 assert!(tx.send("test").is_err());
234
235 let (mut tx, _) = channel();
236 let tx2 = tx.clone();
237 tx.close();
238 assert!(tx.send("test").is_err());
239 assert!(tx2.send("test").is_err());
240 }
241
242 #[tokio::test]
243 async fn test_recv() {
244 let (tx, mut rx) = channel();
245 tx.send("test").unwrap();
246 assert_eq!(rx.recv().await.unwrap(), "test");
247 drop(tx);
248
249 let (tx, mut rx) = channel();
250 tx.send("test").unwrap();
251 assert_eq!(rx.recv().await.unwrap(), "test");
252 drop(tx);
253 assert!(rx.recv().await.is_none());
254 }
255}