tokio/runtime/task/
list.rs

1//! This module has containers for storing the tasks spawned on a scheduler. The
2//! `OwnedTasks` container is thread-safe but can only store tasks that
3//! implement Send. The `LocalOwnedTasks` container is not thread safe, but can
4//! store non-Send tasks.
5//!
6//! The collections can be closed to prevent adding new tasks during shutdown of
7//! the scheduler with the collection.
8
9use crate::future::Future;
10use crate::loom::cell::UnsafeCell;
11use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, SpawnLocation, Task};
12use crate::util::linked_list::{Link, LinkedList};
13use crate::util::sharded_list;
14
15use crate::loom::sync::atomic::{AtomicBool, Ordering};
16use std::marker::PhantomData;
17use std::num::NonZeroU64;
18
19// The id from the module below is used to verify whether a given task is stored
20// in this OwnedTasks, or some other task. The counter starts at one so we can
21// use `None` for tasks not owned by any list.
22//
23// The safety checks in this file can technically be violated if the counter is
24// overflown, but the checks are not supposed to ever fail unless there is a
25// bug in Tokio, so we accept that certain bugs would not be caught if the two
26// mixed up runtimes happen to have the same id.
27
28cfg_has_atomic_u64! {
29    use std::sync::atomic::AtomicU64;
30
31    static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);
32
33    fn get_next_id() -> NonZeroU64 {
34        loop {
35            let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
36            if let Some(id) = NonZeroU64::new(id) {
37                return id;
38            }
39        }
40    }
41}
42
43cfg_not_has_atomic_u64! {
44    use std::sync::atomic::AtomicU32;
45
46    static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);
47
48    fn get_next_id() -> NonZeroU64 {
49        loop {
50            let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
51            if let Some(id) = NonZeroU64::new(u64::from(id)) {
52                return id;
53            }
54        }
55    }
56}
57
58pub(crate) struct OwnedTasks<S: 'static> {
59    list: List<S>,
60    pub(crate) id: NonZeroU64,
61    closed: AtomicBool,
62}
63
64type List<S> = sharded_list::ShardedList<Task<S>, <Task<S> as Link>::Target>;
65
66pub(crate) struct LocalOwnedTasks<S: 'static> {
67    inner: UnsafeCell<OwnedTasksInner<S>>,
68    pub(crate) id: NonZeroU64,
69    _not_send_or_sync: PhantomData<*const ()>,
70}
71
72struct OwnedTasksInner<S: 'static> {
73    list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
74    closed: bool,
75}
76
77impl<S: 'static> OwnedTasks<S> {
78    pub(crate) fn new(num_cores: usize) -> Self {
79        let shard_size = Self::gen_shared_list_size(num_cores);
80        Self {
81            list: List::new(shard_size),
82            closed: AtomicBool::new(false),
83            id: get_next_id(),
84        }
85    }
86
87    /// Binds the provided task to this `OwnedTasks` instance. This fails if the
88    /// `OwnedTasks` has been closed.
89    pub(crate) fn bind<T>(
90        &self,
91        task: T,
92        scheduler: S,
93        id: super::Id,
94        spawned_at: SpawnLocation,
95    ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
96    where
97        S: Schedule,
98        T: Future + Send + 'static,
99        T::Output: Send + 'static,
100    {
101        let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
102        let notified = unsafe { self.bind_inner(task, notified) };
103        (join, notified)
104    }
105
106    /// Bind a task that isn't safe to transfer across thread boundaries.
107    ///
108    /// # Safety
109    ///
110    /// Only use this in `LocalRuntime` where the task cannot move
111    pub(crate) unsafe fn bind_local<T>(
112        &self,
113        task: T,
114        scheduler: S,
115        id: super::Id,
116        spawned_at: SpawnLocation,
117    ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
118    where
119        S: Schedule,
120        T: Future + 'static,
121        T::Output: 'static,
122    {
123        let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
124        let notified = unsafe { self.bind_inner(task, notified) };
125        (join, notified)
126    }
127
128    /// The part of `bind` that's the same for every type of future.
129    unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>>
130    where
131        S: Schedule,
132    {
133        unsafe {
134            // safety: We just created the task, so we have exclusive access
135            // to the field.
136            task.header().set_owner_id(self.id);
137        }
138
139        let shard = self.list.lock_shard(&task);
140        // Check the closed flag in the lock for ensuring all that tasks
141        // will shut down after the OwnedTasks has been closed.
142        if self.closed.load(Ordering::Acquire) {
143            drop(shard);
144            task.shutdown();
145            return None;
146        }
147        shard.push(task);
148        Some(notified)
149    }
150
151    /// Asserts that the given task is owned by this `OwnedTasks` and convert it to
152    /// a `LocalNotified`, giving the thread permission to poll this task.
153    #[inline]
154    pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
155        debug_assert_eq!(task.header().get_owner_id(), Some(self.id));
156        // safety: All tasks bound to this OwnedTasks are Send, so it is safe
157        // to poll it on this thread no matter what thread we are on.
158        LocalNotified {
159            task: task.0,
160            _not_send: PhantomData,
161        }
162    }
163
164    /// Shuts down all tasks in the collection. This call also closes the
165    /// collection, preventing new items from being added.
166    ///
167    /// The parameter start determines which shard this method will start at.
168    /// Using different values for each worker thread reduces contention.
169    pub(crate) fn close_and_shutdown_all(&self, start: usize)
170    where
171        S: Schedule,
172    {
173        self.closed.store(true, Ordering::Release);
174        for i in start..self.get_shard_size() + start {
175            loop {
176                let task = self.list.pop_back(i);
177                match task {
178                    Some(task) => {
179                        task.shutdown();
180                    }
181                    None => break,
182                }
183            }
184        }
185    }
186
187    #[inline]
188    pub(crate) fn get_shard_size(&self) -> usize {
189        self.list.shard_size()
190    }
191
192    pub(crate) fn num_alive_tasks(&self) -> usize {
193        self.list.len()
194    }
195
196    cfg_unstable_metrics! {
197        cfg_64bit_metrics! {
198            pub(crate) fn spawned_tasks_count(&self) -> u64 {
199                self.list.added()
200            }
201        }
202    }
203
204    pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
205        // If the task's owner ID is `None` then it is not part of any list and
206        // doesn't need removing.
207        let task_id = task.header().get_owner_id()?;
208
209        assert_eq!(task_id, self.id);
210
211        // safety: We just checked that the provided task is not in some other
212        // linked list.
213        unsafe { self.list.remove(task.header_ptr()) }
214    }
215
216    pub(crate) fn is_empty(&self) -> bool {
217        self.list.is_empty()
218    }
219
220    /// Generates the size of the sharded list based on the number of worker threads.
221    ///
222    /// The sharded lock design can effectively alleviate
223    /// lock contention performance problems caused by high concurrency.
224    ///
225    /// However, as the number of shards increases, the memory continuity between
226    /// nodes in the intrusive linked list will diminish. Furthermore,
227    /// the construction time of the sharded list will also increase with a higher number of shards.
228    ///
229    /// Due to the above reasons, we set a maximum value for the shared list size,
230    /// denoted as `MAX_SHARED_LIST_SIZE`.
231    fn gen_shared_list_size(num_cores: usize) -> usize {
232        const MAX_SHARED_LIST_SIZE: usize = 1 << 16;
233        usize::min(MAX_SHARED_LIST_SIZE, num_cores.next_power_of_two() * 4)
234    }
235}
236
237cfg_taskdump! {
238    impl<S: 'static> OwnedTasks<S> {
239        /// Locks the tasks, and calls `f` on an iterator over them.
240        pub(crate) fn for_each<F>(&self, f: F)
241        where
242            F: FnMut(&Task<S>),
243        {
244            self.list.for_each(f);
245        }
246    }
247}
248
249impl<S: 'static> LocalOwnedTasks<S> {
250    pub(crate) fn new() -> Self {
251        Self {
252            inner: UnsafeCell::new(OwnedTasksInner {
253                list: LinkedList::new(),
254                closed: false,
255            }),
256            id: get_next_id(),
257            _not_send_or_sync: PhantomData,
258        }
259    }
260
261    pub(crate) fn bind<T>(
262        &self,
263        task: T,
264        scheduler: S,
265        id: super::Id,
266        spawned_at: SpawnLocation,
267    ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
268    where
269        S: Schedule,
270        T: Future + 'static,
271        T::Output: 'static,
272    {
273        let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
274
275        unsafe {
276            // safety: We just created the task, so we have exclusive access
277            // to the field.
278            task.header().set_owner_id(self.id);
279        }
280
281        if self.is_closed() {
282            drop(notified);
283            task.shutdown();
284            (join, None)
285        } else {
286            self.with_inner(|inner| {
287                inner.list.push_front(task);
288            });
289            (join, Some(notified))
290        }
291    }
292
293    /// Shuts down all tasks in the collection. This call also closes the
294    /// collection, preventing new items from being added.
295    pub(crate) fn close_and_shutdown_all(&self)
296    where
297        S: Schedule,
298    {
299        self.with_inner(|inner| inner.closed = true);
300
301        while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) {
302            task.shutdown();
303        }
304    }
305
306    pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
307        // If the task's owner ID is `None` then it is not part of any list and
308        // doesn't need removing.
309        let task_id = task.header().get_owner_id()?;
310
311        assert_eq!(task_id, self.id);
312
313        self.with_inner(|inner|
314            // safety: We just checked that the provided task is not in some
315            // other linked list.
316            unsafe { inner.list.remove(task.header_ptr()) })
317    }
318
319    /// Asserts that the given task is owned by this `LocalOwnedTasks` and convert
320    /// it to a `LocalNotified`, giving the thread permission to poll this task.
321    #[inline]
322    pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
323        assert_eq!(task.header().get_owner_id(), Some(self.id));
324
325        // safety: The task was bound to this LocalOwnedTasks, and the
326        // LocalOwnedTasks is not Send or Sync, so we are on the right thread
327        // for polling this task.
328        LocalNotified {
329            task: task.0,
330            _not_send: PhantomData,
331        }
332    }
333
334    #[inline]
335    fn with_inner<F, T>(&self, f: F) -> T
336    where
337        F: FnOnce(&mut OwnedTasksInner<S>) -> T,
338    {
339        // safety: This type is not Sync, so concurrent calls of this method
340        // can't happen.  Furthermore, all uses of this method in this file make
341        // sure that they don't call `with_inner` recursively.
342        self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) })
343    }
344
345    pub(crate) fn is_closed(&self) -> bool {
346        self.with_inner(|inner| inner.closed)
347    }
348
349    pub(crate) fn is_empty(&self) -> bool {
350        self.with_inner(|inner| inner.list.is_empty())
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    // This test may run in parallel with other tests, so we only test that ids
359    // come in increasing order.
360    #[test]
361    fn test_id_not_broken() {
362        let mut last_id = get_next_id();
363
364        for _ in 0..1000 {
365            let next_id = get_next_id();
366            assert!(last_id < next_id);
367            last_id = next_id;
368        }
369    }
370}