use std::alloc::Layout;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::ptr::{self, NonNull};
use std::task::{Context, Poll};
use std::{fmt, panic};
pub(crate) struct ReusableBoxFuture<T> {
boxed: NonNull<dyn Future<Output = T> + Send>,
}
impl<T> ReusableBoxFuture<T> {
pub(crate) fn new<F>(future: F) -> Self
where
F: Future<Output = T> + Send + 'static,
{
let boxed: Box<dyn Future<Output = T> + Send> = Box::new(future);
let boxed = Box::into_raw(boxed);
let boxed = unsafe { NonNull::new_unchecked(boxed) };
Self { boxed }
}
pub(crate) fn set<F>(&mut self, future: F)
where
F: Future<Output = T> + Send + 'static,
{
if let Err(future) = self.try_set(future) {
*self = Self::new(future);
}
}
pub(crate) fn try_set<F>(&mut self, future: F) -> Result<(), F>
where
F: Future<Output = T> + Send + 'static,
{
let self_layout = {
let dyn_future: &(dyn Future<Output = T> + Send) = unsafe { self.boxed.as_ref() };
Layout::for_value(dyn_future)
};
if Layout::new::<F>() == self_layout {
unsafe {
self.set_same_layout(future);
}
Ok(())
} else {
Err(future)
}
}
unsafe fn set_same_layout<F>(&mut self, future: F)
where
F: Future<Output = T> + Send + 'static,
{
let result = panic::catch_unwind(AssertUnwindSafe(|| {
ptr::drop_in_place(self.boxed.as_ptr());
}));
let self_ptr: *mut F = self.boxed.as_ptr() as *mut F;
ptr::write(self_ptr, future);
self.boxed = NonNull::new_unchecked(self_ptr);
match result {
Ok(()) => {}
Err(payload) => {
panic::resume_unwind(payload);
}
}
}
pub(crate) fn get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)> {
unsafe { Pin::new_unchecked(self.boxed.as_mut()) }
}
pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> {
self.get_pin().poll(cx)
}
}
impl<T> Future for ReusableBoxFuture<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> {
Pin::into_inner(self).get_pin().poll(cx)
}
}
unsafe impl<T> Send for ReusableBoxFuture<T> {}
unsafe impl<T> Sync for ReusableBoxFuture<T> {}
impl<T> Unpin for ReusableBoxFuture<T> {}
impl<T> Drop for ReusableBoxFuture<T> {
fn drop(&mut self) {
unsafe {
drop(Box::from_raw(self.boxed.as_ptr()));
}
}
}
impl<T> fmt::Debug for ReusableBoxFuture<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReusableBoxFuture").finish()
}
}
#[cfg(test)]
mod test {
use super::ReusableBoxFuture;
use futures::future::FutureExt;
use std::alloc::Layout;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
#[test]
fn test_different_futures() {
let fut = async move { 10 };
assert_eq!(Layout::for_value(&fut).size(), 1);
let mut b = ReusableBoxFuture::new(fut);
assert_eq!(b.get_pin().now_or_never(), Some(10));
b.try_set(async move { 20 })
.unwrap_or_else(|_| panic!("incorrect size"));
assert_eq!(b.get_pin().now_or_never(), Some(20));
b.try_set(async move { 30 })
.unwrap_or_else(|_| panic!("incorrect size"));
assert_eq!(b.get_pin().now_or_never(), Some(30));
}
#[test]
fn test_different_sizes() {
let fut1 = async move { 10 };
let val = [0u32; 1000];
let fut2 = async move { val[0] };
let fut3 = ZeroSizedFuture {};
assert_eq!(Layout::for_value(&fut1).size(), 1);
assert_eq!(Layout::for_value(&fut2).size(), 4004);
assert_eq!(Layout::for_value(&fut3).size(), 0);
let mut b = ReusableBoxFuture::new(fut1);
assert_eq!(b.get_pin().now_or_never(), Some(10));
b.set(fut2);
assert_eq!(b.get_pin().now_or_never(), Some(0));
b.set(fut3);
assert_eq!(b.get_pin().now_or_never(), Some(5));
}
struct ZeroSizedFuture {}
impl Future for ZeroSizedFuture {
type Output = u32;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<u32> {
Poll::Ready(5)
}
}
#[test]
fn test_zero_sized() {
let fut = ZeroSizedFuture {};
assert_eq!(Layout::for_value(&fut).size(), 0);
let mut b = ReusableBoxFuture::new(fut);
assert_eq!(b.get_pin().now_or_never(), Some(5));
assert_eq!(b.get_pin().now_or_never(), Some(5));
b.try_set(ZeroSizedFuture {})
.unwrap_or_else(|_| panic!("incorrect size"));
assert_eq!(b.get_pin().now_or_never(), Some(5));
assert_eq!(b.get_pin().now_or_never(), Some(5));
}
}