tower/buffer/worker.rs
1use super::{
2 error::{Closed, ServiceError},
3 message::Message,
4};
5use std::sync::{Arc, Mutex};
6use std::{
7 future::Future,
8 pin::Pin,
9 task::{ready, Context, Poll},
10};
11use tokio::sync::mpsc;
12use tower_service::Service;
13
14pin_project_lite::pin_project! {
15 /// Task that handles processing the buffer. This type should not be used
16 /// directly, instead `Buffer` requires an `Executor` that can accept this task.
17 ///
18 /// The struct is `pub` in the private module and the type is *not* re-exported
19 /// as part of the public API. This is the "sealed" pattern to include "private"
20 /// types in public traits that are not meant for consumers of the library to
21 /// implement (only call).
22 #[derive(Debug)]
23 pub struct Worker<T, Request>
24 where
25 T: Service<Request>,
26 {
27 current_message: Option<Message<Request, T::Future>>,
28 rx: mpsc::Receiver<Message<Request, T::Future>>,
29 service: T,
30 finish: bool,
31 failed: Option<ServiceError>,
32 handle: Handle,
33 }
34}
35
36/// Get the error out
37#[derive(Debug)]
38pub(crate) struct Handle {
39 inner: Arc<Mutex<Option<ServiceError>>>,
40}
41
42impl<T, Request> Worker<T, Request>
43where
44 T: Service<Request>,
45 T::Error: Into<crate::BoxError>,
46{
47 pub(crate) fn new(
48 service: T,
49 rx: mpsc::Receiver<Message<Request, T::Future>>,
50 ) -> (Handle, Worker<T, Request>) {
51 let handle = Handle {
52 inner: Arc::new(Mutex::new(None)),
53 };
54
55 let worker = Worker {
56 current_message: None,
57 finish: false,
58 failed: None,
59 rx,
60 service,
61 handle: handle.clone(),
62 };
63
64 (handle, worker)
65 }
66
67 /// Return the next queued Message that hasn't been canceled.
68 ///
69 /// If a `Message` is returned, the `bool` is true if this is the first time we received this
70 /// message, and false otherwise (i.e., we tried to forward it to the backing service before).
71 fn poll_next_msg(
72 &mut self,
73 cx: &mut Context<'_>,
74 ) -> Poll<Option<(Message<Request, T::Future>, bool)>> {
75 if self.finish {
76 // We've already received None and are shutting down
77 return Poll::Ready(None);
78 }
79
80 tracing::trace!("worker polling for next message");
81 if let Some(msg) = self.current_message.take() {
82 // If the oneshot sender is closed, then the receiver is dropped,
83 // and nobody cares about the response. If this is the case, we
84 // should continue to the next request.
85 if !msg.tx.is_closed() {
86 tracing::trace!("resuming buffered request");
87 return Poll::Ready(Some((msg, false)));
88 }
89
90 tracing::trace!("dropping cancelled buffered request");
91 }
92
93 // Get the next request
94 while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
95 if !msg.tx.is_closed() {
96 tracing::trace!("processing new request");
97 return Poll::Ready(Some((msg, true)));
98 }
99 // Otherwise, request is canceled, so pop the next one.
100 tracing::trace!("dropping cancelled request");
101 }
102
103 Poll::Ready(None)
104 }
105
106 fn failed(&mut self, error: crate::BoxError) {
107 // The underlying service failed when we called `poll_ready` on it with the given `error`. We
108 // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
109 // an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
110 // requests will also fail with the same error.
111
112 // Note that we need to handle the case where some handle is concurrently trying to send us
113 // a request. We need to make sure that *either* the send of the request fails *or* it
114 // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
115 // case where we send errors to all outstanding requests, and *then* the caller sends its
116 // request. We do this by *first* exposing the error, *then* closing the channel used to
117 // send more requests (so the client will see the error when the send fails), and *then*
118 // sending the error to all outstanding requests.
119 let error = ServiceError::new(error);
120
121 let mut inner = self.handle.inner.lock().unwrap();
122
123 if inner.is_some() {
124 // Future::poll was called after we've already errored out!
125 return;
126 }
127
128 *inner = Some(error.clone());
129 drop(inner);
130
131 self.rx.close();
132
133 // By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None),
134 // which will trigger the `self.finish == true` phase. We just need to make sure that any
135 // requests that we receive before we've exhausted the receiver receive the error:
136 self.failed = Some(error);
137 }
138}
139
140impl<T, Request> Future for Worker<T, Request>
141where
142 T: Service<Request>,
143 T::Error: Into<crate::BoxError>,
144{
145 type Output = ();
146
147 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
148 if self.finish {
149 return Poll::Ready(());
150 }
151
152 loop {
153 match ready!(self.poll_next_msg(cx)) {
154 Some((msg, first)) => {
155 let _guard = msg.span.enter();
156 if let Some(ref failed) = self.failed {
157 tracing::trace!("notifying caller about worker failure");
158 let _ = msg.tx.send(Err(failed.clone()));
159 continue;
160 }
161
162 // Wait for the service to be ready
163 tracing::trace!(
164 resumed = !first,
165 message = "worker received request; waiting for service readiness"
166 );
167 match self.service.poll_ready(cx) {
168 Poll::Ready(Ok(())) => {
169 tracing::debug!(service.ready = true, message = "processing request");
170 let response = self.service.call(msg.request);
171
172 // Send the response future back to the sender.
173 //
174 // An error means the request had been canceled in-between
175 // our calls, the response future will just be dropped.
176 tracing::trace!("returning response future");
177 let _ = msg.tx.send(Ok(response));
178 }
179 Poll::Pending => {
180 tracing::trace!(service.ready = false, message = "delay");
181 // Put out current message back in its slot.
182 drop(_guard);
183 self.current_message = Some(msg);
184 return Poll::Pending;
185 }
186 Poll::Ready(Err(e)) => {
187 let error = e.into();
188 tracing::debug!({ %error }, "service failed");
189 drop(_guard);
190 self.failed(error);
191 let _ = msg.tx.send(Err(self
192 .failed
193 .as_ref()
194 .expect("Worker::failed did not set self.failed?")
195 .clone()));
196 }
197 }
198 }
199 None => {
200 // No more more requests _ever_.
201 self.finish = true;
202 return Poll::Ready(());
203 }
204 }
205 }
206 }
207}
208
209impl Handle {
210 pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
211 self.inner
212 .lock()
213 .unwrap()
214 .as_ref()
215 .map(|svc_err| svc_err.clone().into())
216 .unwrap_or_else(|| Closed::new().into())
217 }
218}
219
220impl Clone for Handle {
221 fn clone(&self) -> Handle {
222 Handle {
223 inner: self.inner.clone(),
224 }
225 }
226}