use std::{
cell::Cell,
error::Error as StdError,
fmt,
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};
use actix_service::{Service, Transform};
use actix_utils::future::{ok, Ready};
use actix_web::ResponseError;
use pin_project_lite::pin_project;
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct LoadShed;
impl LoadShed {
pub fn new() -> Self {
LoadShed
}
}
impl<S: Service<Req>, Req> Transform<S, Req> for LoadShed {
type Response = S::Response;
type Error = Overloaded<S::Error>;
type Transform = LoadShedService<S>;
type InitError = ();
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(LoadShedService::new(service))
}
}
#[derive(Debug)]
pub struct LoadShedService<S> {
inner: S,
is_ready: Cell<bool>,
}
impl<S> LoadShedService<S> {
pub(crate) fn new(inner: S) -> Self {
Self {
inner,
is_ready: Cell::new(false),
}
}
}
impl<S, Req> Service<Req> for LoadShedService<S>
where
S: Service<Req>,
{
type Response = S::Response;
type Error = Overloaded<S::Error>;
type Future = LoadShedFuture<S::Future>;
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let is_ready = match self.inner.poll_ready(cx) {
Poll::Ready(Err(err)) => return Poll::Ready(Err(Overloaded::Service(err))),
res => res.is_ready(),
};
self.is_ready.set(is_ready);
Poll::Ready(Ok(()))
}
fn call(&self, req: Req) -> Self::Future {
if self.is_ready.get() {
self.is_ready.set(false);
LoadShedFuture::called(self.inner.call(req))
} else {
LoadShedFuture::overloaded()
}
}
}
pin_project! {
pub struct LoadShedFuture<F> {
#[pin]
state: LoadShedFutureState<F>,
}
}
pin_project! {
#[project = LoadShedFutureStateProj]
enum LoadShedFutureState<F> {
Called { #[pin] fut: F },
Overloaded,
}
}
impl<F> LoadShedFuture<F> {
pub(crate) fn called(fut: F) -> Self {
LoadShedFuture {
state: LoadShedFutureState::Called { fut },
}
}
pub(crate) fn overloaded() -> Self {
LoadShedFuture {
state: LoadShedFutureState::Overloaded,
}
}
}
impl<F, T, E> Future for LoadShedFuture<F>
where
F: Future<Output = Result<T, E>>,
{
type Output = Result<T, Overloaded<E>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().state.project() {
LoadShedFutureStateProj::Called { fut } => {
Poll::Ready(ready!(fut.poll(cx)).map_err(Overloaded::Service))
}
LoadShedFutureStateProj::Overloaded => Poll::Ready(Err(Overloaded::Overloaded)),
}
}
}
impl<F> fmt::Debug for LoadShedFuture<F>
where
F: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("LoadShedFuture")
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum Overloaded<E> {
Service(E),
Overloaded,
}
impl<E: fmt::Display> fmt::Display for Overloaded<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Overloaded::Service(err) => write!(f, "{err}"),
Overloaded::Overloaded => f.write_str("service overloaded"),
}
}
}
impl<E: StdError + 'static> StdError for Overloaded<E> {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match self {
Overloaded::Service(err) => Some(err),
Overloaded::Overloaded => None,
}
}
}
impl<E> ResponseError for Overloaded<E>
where
E: fmt::Debug + fmt::Display,
{
fn status_code(&self) -> actix_http::StatusCode {
actix_web::http::StatusCode::SERVICE_UNAVAILABLE
}
}
#[cfg(test)]
mod tests {
use actix_web::middleware::{Compat, Logger};
use super::*;
#[test]
fn integration() {
actix_web::App::new()
.wrap(Compat::new(LoadShed::new()))
.wrap(Logger::default());
}
}