tokio/runtime/task/
core.rs

1//! Core task module.
2//!
3//! # Safety
4//!
5//! The functions in this module are private to the `task` module. All of them
6//! should be considered `unsafe` to use, but are not marked as such since it
7//! would be too noisy.
8//!
9//! Make sure to consult the relevant safety section of each function before
10//! use.
11
12// It doesn't make sense to enforce `unsafe_op_in_unsafe_fn` for this module because
13//
14// * This module is doing the low-level task management that requires tons of unsafe
15//   operations.
16// * Excessive `unsafe {}` blocks hurt readability significantly.
17// TODO: replace with `#[expect(unsafe_op_in_unsafe_fn)]` after bumpping
18// the MSRV to 1.81.0.
19#![allow(unsafe_op_in_unsafe_fn)]
20
21use crate::future::Future;
22use crate::loom::cell::UnsafeCell;
23use crate::runtime::context;
24use crate::runtime::task::raw::{self, Vtable};
25use crate::runtime::task::state::State;
26use crate::runtime::task::{Id, Schedule, TaskHarnessScheduleHooks};
27use crate::util::linked_list;
28
29use std::num::NonZeroU64;
30#[cfg(tokio_unstable)]
31use std::panic::Location;
32use std::pin::Pin;
33use std::ptr::NonNull;
34use std::task::{Context, Poll, Waker};
35
36/// The task cell. Contains the components of the task.
37///
38/// It is critical for `Header` to be the first field as the task structure will
39/// be referenced by both *mut Cell and *mut Header.
40///
41/// Any changes to the layout of this struct _must_ also be reflected in the
42/// `const` fns in raw.rs.
43///
44// # This struct should be cache padded to avoid false sharing. The cache padding rules are copied
45// from crossbeam-utils/src/cache_padded.rs
46//
47// Starting from Intel's Sandy Bridge, spatial prefetcher is now pulling pairs of 64-byte cache
48// lines at a time, so we have to align to 128 bytes rather than 64.
49//
50// Sources:
51// - https://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-optimization-manual.pdf
52// - https://github.com/facebook/folly/blob/1b5288e6eea6df074758f877c849b6e73bbb9fbb/folly/lang/Align.h#L107
53//
54// ARM's big.LITTLE architecture has asymmetric cores and "big" cores have 128-byte cache line size.
55//
56// Sources:
57// - https://www.mono-project.com/news/2016/09/12/arm64-icache/
58//
59// powerpc64 has 128-byte cache line size.
60//
61// Sources:
62// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_ppc64x.go#L9
63#[cfg_attr(
64    any(
65        target_arch = "x86_64",
66        target_arch = "aarch64",
67        target_arch = "powerpc64",
68    ),
69    repr(align(128))
70)]
71// arm, mips, mips64, sparc, and hexagon have 32-byte cache line size.
72//
73// Sources:
74// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_arm.go#L7
75// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mips.go#L7
76// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mipsle.go#L7
77// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mips64x.go#L9
78// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/sparc/include/asm/cache.h#L17
79// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/hexagon/include/asm/cache.h#L12
80#[cfg_attr(
81    any(
82        target_arch = "arm",
83        target_arch = "mips",
84        target_arch = "mips64",
85        target_arch = "sparc",
86        target_arch = "hexagon",
87    ),
88    repr(align(32))
89)]
90// m68k has 16-byte cache line size.
91//
92// Sources:
93// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/m68k/include/asm/cache.h#L9
94#[cfg_attr(target_arch = "m68k", repr(align(16)))]
95// s390x has 256-byte cache line size.
96//
97// Sources:
98// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_s390x.go#L7
99// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/s390/include/asm/cache.h#L13
100#[cfg_attr(target_arch = "s390x", repr(align(256)))]
101// x86, riscv, wasm, and sparc64 have 64-byte cache line size.
102//
103// Sources:
104// - https://github.com/golang/go/blob/dda2991c2ea0c5914714469c4defc2562a907230/src/internal/cpu/cpu_x86.go#L9
105// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_wasm.go#L7
106// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/sparc/include/asm/cache.h#L19
107// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/riscv/include/asm/cache.h#L10
108//
109// All others are assumed to have 64-byte cache line size.
110#[cfg_attr(
111    not(any(
112        target_arch = "x86_64",
113        target_arch = "aarch64",
114        target_arch = "powerpc64",
115        target_arch = "arm",
116        target_arch = "mips",
117        target_arch = "mips64",
118        target_arch = "sparc",
119        target_arch = "hexagon",
120        target_arch = "m68k",
121        target_arch = "s390x",
122    )),
123    repr(align(64))
124)]
125#[repr(C)]
126pub(super) struct Cell<T: Future, S> {
127    /// Hot task state data
128    pub(super) header: Header,
129
130    /// Either the future or output, depending on the execution stage.
131    pub(super) core: Core<T, S>,
132
133    /// Cold data
134    pub(super) trailer: Trailer,
135}
136
137pub(super) struct CoreStage<T: Future> {
138    stage: UnsafeCell<Stage<T>>,
139}
140
141/// The core of the task.
142///
143/// Holds the future or output, depending on the stage of execution.
144///
145/// Any changes to the layout of this struct _must_ also be reflected in the
146/// `const` fns in raw.rs.
147#[repr(C)]
148pub(super) struct Core<T: Future, S> {
149    /// Scheduler used to drive this future.
150    pub(super) scheduler: S,
151
152    /// The task's ID, used for populating `JoinError`s.
153    pub(super) task_id: Id,
154
155    /// The source code location where the task was spawned.
156    ///
157    /// This is used for populating the `TaskMeta` passed to the task runtime
158    /// hooks.
159    #[cfg(tokio_unstable)]
160    pub(super) spawned_at: &'static Location<'static>,
161
162    /// Either the future or the output.
163    pub(super) stage: CoreStage<T>,
164}
165
166/// Crate public as this is also needed by the pool.
167#[repr(C)]
168pub(crate) struct Header {
169    /// Task state.
170    pub(super) state: State,
171
172    /// Pointer to next task, used with the injection queue.
173    pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>,
174
175    /// Table of function pointers for executing actions on the task.
176    pub(super) vtable: &'static Vtable,
177
178    /// This integer contains the id of the `OwnedTasks` or `LocalOwnedTasks`
179    /// that this task is stored in. If the task is not in any list, should be
180    /// the id of the list that it was previously in, or `None` if it has never
181    /// been in any list.
182    ///
183    /// Once a task has been bound to a list, it can never be bound to another
184    /// list, even if removed from the first list.
185    ///
186    /// The id is not unset when removed from a list because we want to be able
187    /// to read the id without synchronization, even if it is concurrently being
188    /// removed from the list.
189    pub(super) owner_id: UnsafeCell<Option<NonZeroU64>>,
190
191    /// The tracing ID for this instrumented task.
192    #[cfg(all(tokio_unstable, feature = "tracing"))]
193    pub(super) tracing_id: Option<tracing::Id>,
194}
195
196unsafe impl Send for Header {}
197unsafe impl Sync for Header {}
198
199/// Cold data is stored after the future. Data is considered cold if it is only
200/// used during creation or shutdown of the task.
201pub(super) struct Trailer {
202    /// Pointers for the linked list in the `OwnedTasks` that owns this task.
203    pub(super) owned: linked_list::Pointers<Header>,
204    /// Consumer task waiting on completion of this task.
205    pub(super) waker: UnsafeCell<Option<Waker>>,
206    /// Optional hooks needed in the harness.
207    #[cfg_attr(not(tokio_unstable), allow(dead_code))] //TODO: remove when hooks are stabilized
208    pub(super) hooks: TaskHarnessScheduleHooks,
209}
210
211generate_addr_of_methods! {
212    impl<> Trailer {
213        pub(super) unsafe fn addr_of_owned(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Header>> {
214            &self.owned
215        }
216    }
217}
218
219/// Either the future or the output.
220#[repr(C)] // https://github.com/rust-lang/miri/issues/3780
221pub(super) enum Stage<T: Future> {
222    Running(T),
223    Finished(super::Result<T::Output>),
224    Consumed,
225}
226
227impl<T: Future, S: Schedule> Cell<T, S> {
228    /// Allocates a new task cell, containing the header, trailer, and core
229    /// structures.
230    pub(super) fn new(
231        future: T,
232        scheduler: S,
233        state: State,
234        task_id: Id,
235        #[cfg(tokio_unstable)] spawned_at: &'static Location<'static>,
236    ) -> Box<Cell<T, S>> {
237        // Separated into a non-generic function to reduce LLVM codegen
238        fn new_header(
239            state: State,
240            vtable: &'static Vtable,
241            #[cfg(all(tokio_unstable, feature = "tracing"))] tracing_id: Option<tracing::Id>,
242        ) -> Header {
243            Header {
244                state,
245                queue_next: UnsafeCell::new(None),
246                vtable,
247                owner_id: UnsafeCell::new(None),
248                #[cfg(all(tokio_unstable, feature = "tracing"))]
249                tracing_id,
250            }
251        }
252
253        #[cfg(all(tokio_unstable, feature = "tracing"))]
254        let tracing_id = future.id();
255        let vtable = raw::vtable::<T, S>();
256        let result = Box::new(Cell {
257            trailer: Trailer::new(scheduler.hooks()),
258            header: new_header(
259                state,
260                vtable,
261                #[cfg(all(tokio_unstable, feature = "tracing"))]
262                tracing_id,
263            ),
264            core: Core {
265                scheduler,
266                stage: CoreStage {
267                    stage: UnsafeCell::new(Stage::Running(future)),
268                },
269                task_id,
270                #[cfg(tokio_unstable)]
271                spawned_at,
272            },
273        });
274
275        #[cfg(debug_assertions)]
276        {
277            // Using a separate function for this code avoids instantiating it separately for every `T`.
278            unsafe fn check<S>(
279                header: &Header,
280                trailer: &Trailer,
281                scheduler: &S,
282                task_id: &Id,
283                #[cfg(tokio_unstable)] spawn_location: &&'static Location<'static>,
284            ) {
285                let trailer_addr = trailer as *const Trailer as usize;
286                let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(header)) };
287                assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize);
288
289                let scheduler_addr = scheduler as *const S as usize;
290                let scheduler_ptr = unsafe { Header::get_scheduler::<S>(NonNull::from(header)) };
291                assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize);
292
293                let id_addr = task_id as *const Id as usize;
294                let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(header)) };
295                assert_eq!(id_addr, id_ptr.as_ptr() as usize);
296
297                #[cfg(tokio_unstable)]
298                {
299                    let spawn_location_addr =
300                        spawn_location as *const &'static Location<'static> as usize;
301                    let spawn_location_ptr =
302                        unsafe { Header::get_spawn_location_ptr(NonNull::from(header)) };
303                    assert_eq!(spawn_location_addr, spawn_location_ptr.as_ptr() as usize);
304                }
305            }
306            unsafe {
307                check(
308                    &result.header,
309                    &result.trailer,
310                    &result.core.scheduler,
311                    &result.core.task_id,
312                    #[cfg(tokio_unstable)]
313                    &result.core.spawned_at,
314                );
315            }
316        }
317
318        result
319    }
320}
321
322impl<T: Future> CoreStage<T> {
323    pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R {
324        self.stage.with_mut(f)
325    }
326}
327
328/// Set and clear the task id in the context when the future is executed or
329/// dropped, or when the output produced by the future is dropped.
330pub(crate) struct TaskIdGuard {
331    parent_task_id: Option<Id>,
332}
333
334impl TaskIdGuard {
335    fn enter(id: Id) -> Self {
336        TaskIdGuard {
337            parent_task_id: context::set_current_task_id(Some(id)),
338        }
339    }
340}
341
342impl Drop for TaskIdGuard {
343    fn drop(&mut self) {
344        context::set_current_task_id(self.parent_task_id);
345    }
346}
347
348impl<T: Future, S: Schedule> Core<T, S> {
349    /// Polls the future.
350    ///
351    /// # Safety
352    ///
353    /// The caller must ensure it is safe to mutate the `state` field. This
354    /// requires ensuring mutual exclusion between any concurrent thread that
355    /// might modify the future or output field.
356    ///
357    /// The mutual exclusion is implemented by `Harness` and the `Lifecycle`
358    /// component of the task state.
359    ///
360    /// `self` must also be pinned. This is handled by storing the task on the
361    /// heap.
362    pub(super) fn poll(&self, mut cx: Context<'_>) -> Poll<T::Output> {
363        let res = {
364            self.stage.stage.with_mut(|ptr| {
365                // Safety: The caller ensures mutual exclusion to the field.
366                let future = match unsafe { &mut *ptr } {
367                    Stage::Running(future) => future,
368                    _ => unreachable!("unexpected stage"),
369                };
370
371                // Safety: The caller ensures the future is pinned.
372                let future = unsafe { Pin::new_unchecked(future) };
373
374                let _guard = TaskIdGuard::enter(self.task_id);
375                future.poll(&mut cx)
376            })
377        };
378
379        if res.is_ready() {
380            self.drop_future_or_output();
381        }
382
383        res
384    }
385
386    /// Drops the future.
387    ///
388    /// # Safety
389    ///
390    /// The caller must ensure it is safe to mutate the `stage` field.
391    pub(super) fn drop_future_or_output(&self) {
392        // Safety: the caller ensures mutual exclusion to the field.
393        unsafe {
394            self.set_stage(Stage::Consumed);
395        }
396    }
397
398    /// Stores the task output.
399    ///
400    /// # Safety
401    ///
402    /// The caller must ensure it is safe to mutate the `stage` field.
403    pub(super) fn store_output(&self, output: super::Result<T::Output>) {
404        // Safety: the caller ensures mutual exclusion to the field.
405        unsafe {
406            self.set_stage(Stage::Finished(output));
407        }
408    }
409
410    /// Takes the task output.
411    ///
412    /// # Safety
413    ///
414    /// The caller must ensure it is safe to mutate the `stage` field.
415    pub(super) fn take_output(&self) -> super::Result<T::Output> {
416        use std::mem;
417
418        self.stage.stage.with_mut(|ptr| {
419            // Safety:: the caller ensures mutual exclusion to the field.
420            match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) {
421                Stage::Finished(output) => output,
422                _ => panic!("JoinHandle polled after completion"),
423            }
424        })
425    }
426
427    unsafe fn set_stage(&self, stage: Stage<T>) {
428        let _guard = TaskIdGuard::enter(self.task_id);
429        self.stage.stage.with_mut(|ptr| *ptr = stage);
430    }
431}
432
433impl Header {
434    pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) {
435        self.queue_next.with_mut(|ptr| *ptr = next);
436    }
437
438    // safety: The caller must guarantee exclusive access to this field, and
439    // must ensure that the id is either `None` or the id of the OwnedTasks
440    // containing this task.
441    pub(super) unsafe fn set_owner_id(&self, owner: NonZeroU64) {
442        self.owner_id.with_mut(|ptr| *ptr = Some(owner));
443    }
444
445    pub(super) fn get_owner_id(&self) -> Option<NonZeroU64> {
446        // safety: If there are concurrent writes, then that write has violated
447        // the safety requirements on `set_owner_id`.
448        unsafe { self.owner_id.with(|ptr| *ptr) }
449    }
450
451    /// Gets a pointer to the `Trailer` of the task containing this `Header`.
452    ///
453    /// # Safety
454    ///
455    /// The provided raw pointer must point at the header of a task.
456    pub(super) unsafe fn get_trailer(me: NonNull<Header>) -> NonNull<Trailer> {
457        let offset = me.as_ref().vtable.trailer_offset;
458        let trailer = me.as_ptr().cast::<u8>().add(offset).cast::<Trailer>();
459        NonNull::new_unchecked(trailer)
460    }
461
462    /// Gets a pointer to the scheduler of the task containing this `Header`.
463    ///
464    /// # Safety
465    ///
466    /// The provided raw pointer must point at the header of a task.
467    ///
468    /// The generic type S must be set to the correct scheduler type for this
469    /// task.
470    pub(super) unsafe fn get_scheduler<S>(me: NonNull<Header>) -> NonNull<S> {
471        let offset = me.as_ref().vtable.scheduler_offset;
472        let scheduler = me.as_ptr().cast::<u8>().add(offset).cast::<S>();
473        NonNull::new_unchecked(scheduler)
474    }
475
476    /// Gets a pointer to the id of the task containing this `Header`.
477    ///
478    /// # Safety
479    ///
480    /// The provided raw pointer must point at the header of a task.
481    pub(super) unsafe fn get_id_ptr(me: NonNull<Header>) -> NonNull<Id> {
482        let offset = me.as_ref().vtable.id_offset;
483        let id = me.as_ptr().cast::<u8>().add(offset).cast::<Id>();
484        NonNull::new_unchecked(id)
485    }
486
487    /// Gets the id of the task containing this `Header`.
488    ///
489    /// # Safety
490    ///
491    /// The provided raw pointer must point at the header of a task.
492    pub(super) unsafe fn get_id(me: NonNull<Header>) -> Id {
493        let ptr = Header::get_id_ptr(me).as_ptr();
494        *ptr
495    }
496
497    /// Gets a pointer to the source code location where the task containing
498    /// this `Header` was spawned.
499    ///
500    /// # Safety
501    ///
502    /// The provided raw pointer must point at the header of a task.
503    #[cfg(tokio_unstable)]
504    pub(super) unsafe fn get_spawn_location_ptr(
505        me: NonNull<Header>,
506    ) -> NonNull<&'static Location<'static>> {
507        let offset = me.as_ref().vtable.spawn_location_offset;
508        let spawned_at = me
509            .as_ptr()
510            .cast::<u8>()
511            .add(offset)
512            .cast::<&'static Location<'static>>();
513        NonNull::new_unchecked(spawned_at)
514    }
515
516    /// Gets the source code location where the task containing
517    /// this `Header` was spawned
518    ///
519    /// # Safety
520    ///
521    /// The provided raw pointer must point at the header of a task.
522    #[cfg(tokio_unstable)]
523    pub(super) unsafe fn get_spawn_location(me: NonNull<Header>) -> &'static Location<'static> {
524        let ptr = Header::get_spawn_location_ptr(me).as_ptr();
525        *ptr
526    }
527
528    /// Gets the tracing id of the task containing this `Header`.
529    ///
530    /// # Safety
531    ///
532    /// The provided raw pointer must point at the header of a task.
533    #[cfg(all(tokio_unstable, feature = "tracing"))]
534    pub(super) unsafe fn get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id> {
535        me.as_ref().tracing_id.as_ref()
536    }
537}
538
539impl Trailer {
540    fn new(hooks: TaskHarnessScheduleHooks) -> Self {
541        Trailer {
542            waker: UnsafeCell::new(None),
543            owned: linked_list::Pointers::new(),
544            hooks,
545        }
546    }
547
548    pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) {
549        self.waker.with_mut(|ptr| {
550            *ptr = waker;
551        });
552    }
553
554    pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool {
555        self.waker
556            .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker))
557    }
558
559    pub(super) fn wake_join(&self) {
560        self.waker.with(|ptr| match unsafe { &*ptr } {
561            Some(waker) => waker.wake_by_ref(),
562            None => panic!("waker missing"),
563        });
564    }
565}
566
567#[test]
568#[cfg(not(loom))]
569fn header_lte_cache_line() {
570    assert!(std::mem::size_of::<Header>() <= 8 * std::mem::size_of::<*const ()>());
571}