actix_web_lab/
load_shed.rs

1// Code mostly copied from `tower`:
2// https://github.com/tower-rs/tower/tree/5064987f/tower/src/load_shed
3
4//! Load-shedding middleware.
5
6use std::{
7    cell::Cell,
8    error::Error as StdError,
9    fmt,
10    future::Future,
11    pin::Pin,
12    task::{ready, Context, Poll},
13};
14
15use actix_service::{Service, Transform};
16use actix_utils::future::{ok, Ready};
17use actix_web::ResponseError;
18use pin_project_lite::pin_project;
19
20/// A middleware that sheds load when the inner service isn't ready.
21#[derive(Debug, Clone, Default)]
22#[non_exhaustive]
23pub struct LoadShed;
24
25impl LoadShed {
26    /// Creates a new load-shedding middleware.
27    pub fn new() -> Self {
28        LoadShed
29    }
30}
31
32impl<S: Service<Req>, Req> Transform<S, Req> for LoadShed {
33    type Response = S::Response;
34    type Error = Overloaded<S::Error>;
35    type Transform = LoadShedService<S>;
36    type InitError = ();
37    type Future = Ready<Result<Self::Transform, Self::InitError>>;
38
39    fn new_transform(&self, service: S) -> Self::Future {
40        ok(LoadShedService::new(service))
41    }
42}
43
44/// A service wrapper that sheds load when the inner service isn't ready.
45#[derive(Debug)]
46pub struct LoadShedService<S> {
47    inner: S,
48    is_ready: Cell<bool>,
49}
50
51impl<S> LoadShedService<S> {
52    /// Wraps a service in [`LoadShedService`] middleware.
53    pub(crate) fn new(inner: S) -> Self {
54        Self {
55            inner,
56            is_ready: Cell::new(false),
57        }
58    }
59}
60
61impl<S, Req> Service<Req> for LoadShedService<S>
62where
63    S: Service<Req>,
64{
65    type Response = S::Response;
66    type Error = Overloaded<S::Error>;
67    type Future = LoadShedFuture<S::Future>;
68
69    fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
70        // We check for readiness here, so that we can know in `call` if
71        // the inner service is overloaded or not.
72        let is_ready = match self.inner.poll_ready(cx) {
73            Poll::Ready(Err(err)) => return Poll::Ready(Err(Overloaded::Service(err))),
74            res => res.is_ready(),
75        };
76
77        self.is_ready.set(is_ready);
78
79        // But we always report Ready, so that layers above don't wait until
80        // the inner service is ready (the entire point of this layer!)
81        Poll::Ready(Ok(()))
82    }
83
84    fn call(&self, req: Req) -> Self::Future {
85        if self.is_ready.get() {
86            // readiness only counts once, you need to check again!
87            self.is_ready.set(false);
88            LoadShedFuture::called(self.inner.call(req))
89        } else {
90            LoadShedFuture::overloaded()
91        }
92    }
93}
94
95pin_project! {
96    /// Future for [`LoadShedService`].
97    pub struct LoadShedFuture<F> {
98        #[pin]
99        state: LoadShedFutureState<F>,
100    }
101}
102
103pin_project! {
104    #[project = LoadShedFutureStateProj]
105    enum LoadShedFutureState<F> {
106        Called { #[pin] fut: F },
107        Overloaded,
108    }
109}
110
111impl<F> LoadShedFuture<F> {
112    pub(crate) fn called(fut: F) -> Self {
113        LoadShedFuture {
114            state: LoadShedFutureState::Called { fut },
115        }
116    }
117
118    pub(crate) fn overloaded() -> Self {
119        LoadShedFuture {
120            state: LoadShedFutureState::Overloaded,
121        }
122    }
123}
124
125impl<F, T, E> Future for LoadShedFuture<F>
126where
127    F: Future<Output = Result<T, E>>,
128{
129    type Output = Result<T, Overloaded<E>>;
130
131    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132        match self.project().state.project() {
133            LoadShedFutureStateProj::Called { fut } => {
134                Poll::Ready(ready!(fut.poll(cx)).map_err(Overloaded::Service))
135            }
136            LoadShedFutureStateProj::Overloaded => Poll::Ready(Err(Overloaded::Overloaded)),
137        }
138    }
139}
140
141impl<F> fmt::Debug for LoadShedFuture<F>
142where
143    // bounds for future-proofing...
144    F: fmt::Debug,
145{
146    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147        f.write_str("LoadShedFuture")
148    }
149}
150
151/// An error returned by [`LoadShed`] service when the inner service is not ready to handle any
152/// requests at the time of being called.
153#[derive(Debug)]
154#[non_exhaustive]
155pub enum Overloaded<E> {
156    /// Service error.
157    Service(E),
158
159    /// Service overloaded.
160    Overloaded,
161}
162
163impl<E: fmt::Display> fmt::Display for Overloaded<E> {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        match self {
166            Overloaded::Service(err) => write!(f, "{err}"),
167            Overloaded::Overloaded => f.write_str("service overloaded"),
168        }
169    }
170}
171
172impl<E: StdError + 'static> StdError for Overloaded<E> {
173    fn source(&self) -> Option<&(dyn StdError + 'static)> {
174        match self {
175            Overloaded::Service(err) => Some(err),
176            Overloaded::Overloaded => None,
177        }
178    }
179}
180
181impl<E> ResponseError for Overloaded<E>
182where
183    E: fmt::Debug + fmt::Display,
184{
185    fn status_code(&self) -> actix_http::StatusCode {
186        actix_web::http::StatusCode::SERVICE_UNAVAILABLE
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use actix_web::middleware::{Compat, Logger};
193
194    use super::*;
195
196    #[test]
197    fn integration() {
198        actix_web::App::new()
199            .wrap(Compat::new(LoadShed::new()))
200            .wrap(Logger::default());
201    }
202}