tower/buffer/
worker.rs

1use super::{
2    error::{Closed, ServiceError},
3    message::Message,
4};
5use std::sync::{Arc, Mutex};
6use std::{
7    future::Future,
8    pin::Pin,
9    task::{ready, Context, Poll},
10};
11use tokio::sync::mpsc;
12use tower_service::Service;
13
14pin_project_lite::pin_project! {
15    /// Task that handles processing the buffer. This type should not be used
16    /// directly, instead `Buffer` requires an `Executor` that can accept this task.
17    ///
18    /// The struct is `pub` in the private module and the type is *not* re-exported
19    /// as part of the public API. This is the "sealed" pattern to include "private"
20    /// types in public traits that are not meant for consumers of the library to
21    /// implement (only call).
22    #[derive(Debug)]
23    pub struct Worker<T, Request>
24    where
25        T: Service<Request>,
26    {
27        current_message: Option<Message<Request, T::Future>>,
28        rx: mpsc::Receiver<Message<Request, T::Future>>,
29        service: T,
30        finish: bool,
31        failed: Option<ServiceError>,
32        handle: Handle,
33    }
34}
35
36/// Get the error out
37#[derive(Debug)]
38pub(crate) struct Handle {
39    inner: Arc<Mutex<Option<ServiceError>>>,
40}
41
42impl<T, Request> Worker<T, Request>
43where
44    T: Service<Request>,
45    T::Error: Into<crate::BoxError>,
46{
47    pub(crate) fn new(
48        service: T,
49        rx: mpsc::Receiver<Message<Request, T::Future>>,
50    ) -> (Handle, Worker<T, Request>) {
51        let handle = Handle {
52            inner: Arc::new(Mutex::new(None)),
53        };
54
55        let worker = Worker {
56            current_message: None,
57            finish: false,
58            failed: None,
59            rx,
60            service,
61            handle: handle.clone(),
62        };
63
64        (handle, worker)
65    }
66
67    /// Return the next queued Message that hasn't been canceled.
68    ///
69    /// If a `Message` is returned, the `bool` is true if this is the first time we received this
70    /// message, and false otherwise (i.e., we tried to forward it to the backing service before).
71    fn poll_next_msg(
72        &mut self,
73        cx: &mut Context<'_>,
74    ) -> Poll<Option<(Message<Request, T::Future>, bool)>> {
75        if self.finish {
76            // We've already received None and are shutting down
77            return Poll::Ready(None);
78        }
79
80        tracing::trace!("worker polling for next message");
81        if let Some(msg) = self.current_message.take() {
82            // If the oneshot sender is closed, then the receiver is dropped,
83            // and nobody cares about the response. If this is the case, we
84            // should continue to the next request.
85            if !msg.tx.is_closed() {
86                tracing::trace!("resuming buffered request");
87                return Poll::Ready(Some((msg, false)));
88            }
89
90            tracing::trace!("dropping cancelled buffered request");
91        }
92
93        // Get the next request
94        while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
95            if !msg.tx.is_closed() {
96                tracing::trace!("processing new request");
97                return Poll::Ready(Some((msg, true)));
98            }
99            // Otherwise, request is canceled, so pop the next one.
100            tracing::trace!("dropping cancelled request");
101        }
102
103        Poll::Ready(None)
104    }
105
106    fn failed(&mut self, error: crate::BoxError) {
107        // The underlying service failed when we called `poll_ready` on it with the given `error`. We
108        // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
109        // an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
110        // requests will also fail with the same error.
111
112        // Note that we need to handle the case where some handle is concurrently trying to send us
113        // a request. We need to make sure that *either* the send of the request fails *or* it
114        // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
115        // case where we send errors to all outstanding requests, and *then* the caller sends its
116        // request. We do this by *first* exposing the error, *then* closing the channel used to
117        // send more requests (so the client will see the error when the send fails), and *then*
118        // sending the error to all outstanding requests.
119        let error = ServiceError::new(error);
120
121        let mut inner = self.handle.inner.lock().unwrap();
122
123        if inner.is_some() {
124            // Future::poll was called after we've already errored out!
125            return;
126        }
127
128        *inner = Some(error.clone());
129        drop(inner);
130
131        self.rx.close();
132
133        // By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None),
134        // which will trigger the `self.finish == true` phase. We just need to make sure that any
135        // requests that we receive before we've exhausted the receiver receive the error:
136        self.failed = Some(error);
137    }
138}
139
140impl<T, Request> Future for Worker<T, Request>
141where
142    T: Service<Request>,
143    T::Error: Into<crate::BoxError>,
144{
145    type Output = ();
146
147    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
148        if self.finish {
149            return Poll::Ready(());
150        }
151
152        loop {
153            match ready!(self.poll_next_msg(cx)) {
154                Some((msg, first)) => {
155                    let _guard = msg.span.enter();
156                    if let Some(ref failed) = self.failed {
157                        tracing::trace!("notifying caller about worker failure");
158                        let _ = msg.tx.send(Err(failed.clone()));
159                        continue;
160                    }
161
162                    // Wait for the service to be ready
163                    tracing::trace!(
164                        resumed = !first,
165                        message = "worker received request; waiting for service readiness"
166                    );
167                    match self.service.poll_ready(cx) {
168                        Poll::Ready(Ok(())) => {
169                            tracing::debug!(service.ready = true, message = "processing request");
170                            let response = self.service.call(msg.request);
171
172                            // Send the response future back to the sender.
173                            //
174                            // An error means the request had been canceled in-between
175                            // our calls, the response future will just be dropped.
176                            tracing::trace!("returning response future");
177                            let _ = msg.tx.send(Ok(response));
178                        }
179                        Poll::Pending => {
180                            tracing::trace!(service.ready = false, message = "delay");
181                            // Put out current message back in its slot.
182                            drop(_guard);
183                            self.current_message = Some(msg);
184                            return Poll::Pending;
185                        }
186                        Poll::Ready(Err(e)) => {
187                            let error = e.into();
188                            tracing::debug!({ %error }, "service failed");
189                            drop(_guard);
190                            self.failed(error);
191                            let _ = msg.tx.send(Err(self
192                                .failed
193                                .as_ref()
194                                .expect("Worker::failed did not set self.failed?")
195                                .clone()));
196                        }
197                    }
198                }
199                None => {
200                    // No more more requests _ever_.
201                    self.finish = true;
202                    return Poll::Ready(());
203                }
204            }
205        }
206    }
207}
208
209impl Handle {
210    pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
211        self.inner
212            .lock()
213            .unwrap()
214            .as_ref()
215            .map(|svc_err| svc_err.clone().into())
216            .unwrap_or_else(|| Closed::new().into())
217    }
218}
219
220impl Clone for Handle {
221    fn clone(&self) -> Handle {
222        Handle {
223            inner: self.inner.clone(),
224        }
225    }
226}