actix_web_lab/
load_shed.rs1use std::{
7 cell::Cell,
8 error::Error as StdError,
9 fmt,
10 future::Future,
11 pin::Pin,
12 task::{ready, Context, Poll},
13};
14
15use actix_service::{Service, Transform};
16use actix_utils::future::{ok, Ready};
17use actix_web::ResponseError;
18use pin_project_lite::pin_project;
19
20#[derive(Debug, Clone, Default)]
22#[non_exhaustive]
23pub struct LoadShed;
24
25impl LoadShed {
26 pub fn new() -> Self {
28 LoadShed
29 }
30}
31
32impl<S: Service<Req>, Req> Transform<S, Req> for LoadShed {
33 type Response = S::Response;
34 type Error = Overloaded<S::Error>;
35 type Transform = LoadShedService<S>;
36 type InitError = ();
37 type Future = Ready<Result<Self::Transform, Self::InitError>>;
38
39 fn new_transform(&self, service: S) -> Self::Future {
40 ok(LoadShedService::new(service))
41 }
42}
43
44#[derive(Debug)]
46pub struct LoadShedService<S> {
47 inner: S,
48 is_ready: Cell<bool>,
49}
50
51impl<S> LoadShedService<S> {
52 pub(crate) fn new(inner: S) -> Self {
54 Self {
55 inner,
56 is_ready: Cell::new(false),
57 }
58 }
59}
60
61impl<S, Req> Service<Req> for LoadShedService<S>
62where
63 S: Service<Req>,
64{
65 type Response = S::Response;
66 type Error = Overloaded<S::Error>;
67 type Future = LoadShedFuture<S::Future>;
68
69 fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
70 let is_ready = match self.inner.poll_ready(cx) {
73 Poll::Ready(Err(err)) => return Poll::Ready(Err(Overloaded::Service(err))),
74 res => res.is_ready(),
75 };
76
77 self.is_ready.set(is_ready);
78
79 Poll::Ready(Ok(()))
82 }
83
84 fn call(&self, req: Req) -> Self::Future {
85 if self.is_ready.get() {
86 self.is_ready.set(false);
88 LoadShedFuture::called(self.inner.call(req))
89 } else {
90 LoadShedFuture::overloaded()
91 }
92 }
93}
94
95pin_project! {
96 pub struct LoadShedFuture<F> {
98 #[pin]
99 state: LoadShedFutureState<F>,
100 }
101}
102
103pin_project! {
104 #[project = LoadShedFutureStateProj]
105 enum LoadShedFutureState<F> {
106 Called { #[pin] fut: F },
107 Overloaded,
108 }
109}
110
111impl<F> LoadShedFuture<F> {
112 pub(crate) fn called(fut: F) -> Self {
113 LoadShedFuture {
114 state: LoadShedFutureState::Called { fut },
115 }
116 }
117
118 pub(crate) fn overloaded() -> Self {
119 LoadShedFuture {
120 state: LoadShedFutureState::Overloaded,
121 }
122 }
123}
124
125impl<F, T, E> Future for LoadShedFuture<F>
126where
127 F: Future<Output = Result<T, E>>,
128{
129 type Output = Result<T, Overloaded<E>>;
130
131 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132 match self.project().state.project() {
133 LoadShedFutureStateProj::Called { fut } => {
134 Poll::Ready(ready!(fut.poll(cx)).map_err(Overloaded::Service))
135 }
136 LoadShedFutureStateProj::Overloaded => Poll::Ready(Err(Overloaded::Overloaded)),
137 }
138 }
139}
140
141impl<F> fmt::Debug for LoadShedFuture<F>
142where
143 F: fmt::Debug,
145{
146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147 f.write_str("LoadShedFuture")
148 }
149}
150
151#[derive(Debug)]
154#[non_exhaustive]
155pub enum Overloaded<E> {
156 Service(E),
158
159 Overloaded,
161}
162
163impl<E: fmt::Display> fmt::Display for Overloaded<E> {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 match self {
166 Overloaded::Service(err) => write!(f, "{err}"),
167 Overloaded::Overloaded => f.write_str("service overloaded"),
168 }
169 }
170}
171
172impl<E: StdError + 'static> StdError for Overloaded<E> {
173 fn source(&self) -> Option<&(dyn StdError + 'static)> {
174 match self {
175 Overloaded::Service(err) => Some(err),
176 Overloaded::Overloaded => None,
177 }
178 }
179}
180
181impl<E> ResponseError for Overloaded<E>
182where
183 E: fmt::Debug + fmt::Display,
184{
185 fn status_code(&self) -> actix_http::StatusCode {
186 actix_web::http::StatusCode::SERVICE_UNAVAILABLE
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use actix_web::middleware::{Compat, Logger};
193
194 use super::*;
195
196 #[test]
197 fn integration() {
198 actix_web::App::new()
199 .wrap(Compat::new(LoadShed::new()))
200 .wrap(Logger::default());
201 }
202}