1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
//! Simple "poll function" future and factory.

use core::{
    fmt,
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};

/// Creates a future driven by the provided function that receives a task context.
///
/// # Examples
/// ```
/// # use std::task::Poll;
/// # use actix_utils::future::poll_fn;
/// # async fn test_poll_fn() {
/// let res = poll_fn(|_| Poll::Ready(42)).await;
/// assert_eq!(res, 42);
///
/// let mut i = 5;
/// let res = poll_fn(|cx| {
///     i -= 1;
///
///     if i > 0 {
///         cx.waker().wake_by_ref();
///         Poll::Pending
///     } else {
///         Poll::Ready(42)
///     }
/// })
/// .await;
/// assert_eq!(res, 42);
/// # }
/// # actix_rt::Runtime::new().unwrap().block_on(test_poll_fn());
/// ```
#[inline]
pub fn poll_fn<F, T>(f: F) -> PollFn<F>
where
    F: FnMut(&mut Context<'_>) -> Poll<T>,
{
    PollFn { f }
}

/// Future for the [`poll_fn`] function.
pub struct PollFn<F> {
    f: F,
}

impl<F> fmt::Debug for PollFn<F> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("PollFn").finish()
    }
}

impl<F, T> Future for PollFn<F>
where
    F: FnMut(&mut Context<'_>) -> Poll<T>,
{
    type Output = T;

    #[inline]
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        // SAFETY: we are not moving out of the pinned field
        // see https://github.com/rust-lang/rust/pull/102737
        (unsafe { &mut self.get_unchecked_mut().f })(cx)
    }
}

#[cfg(test)]
mod tests {
    use std::marker::PhantomPinned;

    use super::*;

    static_assertions::assert_impl_all!(PollFn<()>: Unpin);
    static_assertions::assert_not_impl_all!(PollFn<PhantomPinned>: Unpin);

    #[actix_rt::test]
    async fn test_poll_fn() {
        let res = poll_fn(|_| Poll::Ready(42)).await;
        assert_eq!(res, 42);

        let mut i = 5;
        let res = poll_fn(|cx| {
            i -= 1;

            if i > 0 {
                cx.waker().wake_by_ref();
                Poll::Pending
            } else {
                Poll::Ready(42)
            }
        })
        .await;
        assert_eq!(res, 42);
    }

    // following soundness tests taken from https://github.com/tokio-rs/tokio/pull/5087

    #[allow(dead_code)]
    fn require_send<T: Send>(_t: &T) {}
    #[allow(dead_code)]
    fn require_sync<T: Sync>(_t: &T) {}

    trait AmbiguousIfUnpin<A> {
        fn some_item(&self) {}
    }
    impl<T: ?Sized> AmbiguousIfUnpin<()> for T {}
    impl<T: ?Sized + Unpin> AmbiguousIfUnpin<[u8; 0]> for T {}

    const _: fn() = || {
        let pinned = std::marker::PhantomPinned;
        let f = poll_fn(move |_| {
            // Use `pinned` to take ownership of it.
            let _ = &pinned;
            std::task::Poll::Pending::<()>
        });
        require_send(&f);
        require_sync(&f);
        AmbiguousIfUnpin::some_item(&f);
    };
}