Commit 8da7a2b7 authored by Wedson Almeida Filho's avatar Wedson Almeida Filho Committed by Miguel Ojeda
Browse files

rust: introduce `current`



This allows Rust code to get a reference to the current task without
having to increment the refcount, but still guaranteeing memory safety.

Cc: Ingo Molnar <mingo@redhat.com>
Cc: Peter Zijlstra <peterz@infradead.org>
Reviewed-by: default avatarMartin Rodriguez Reboredo <yakoyoku@gmail.com>
Signed-off-by: default avatarWedson Almeida Filho <walmeida@microsoft.com>
Link: https://lore.kernel.org/r/20230411054543.21278-10-wedsonaf@gmail.com


Signed-off-by: default avatarMiguel Ojeda <ojeda@kernel.org>
parent 313c4281
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -103,6 +103,12 @@ long rust_helper_PTR_ERR(__force const void *ptr)
}
EXPORT_SYMBOL_GPL(rust_helper_PTR_ERR);

struct task_struct *rust_helper_get_current(void)
{
	return current;
}
EXPORT_SYMBOL_GPL(rust_helper_get_current);

void rust_helper_get_task_struct(struct task_struct *t)
{
	get_task_struct(t);
+2 −0
Original line number Diff line number Diff line
@@ -36,3 +36,5 @@ pub use super::error::{code::*, Error, Result};
pub use super::{str::CStr, ThisModule};

pub use super::init::{InPlaceInit, Init, PinInit};

pub use super::current;
+81 −1
Original line number Diff line number Diff line
@@ -5,7 +5,17 @@
//! C header: [`include/linux/sched.h`](../../../../include/linux/sched.h).

use crate::{bindings, types::Opaque};
use core::ptr;
use core::{marker::PhantomData, ops::Deref, ptr};

/// Returns the currently running task.
#[macro_export]
macro_rules! current {
    () => {
        // SAFETY: Deref + addr-of below create a temporary `TaskRef` that cannot outlive the
        // caller.
        unsafe { &*$crate::task::Task::current() }
    };
}

/// Wraps the kernel's `struct task_struct`.
///
@@ -15,6 +25,42 @@ use core::ptr;
///
/// Instances of this type are always ref-counted, that is, a call to `get_task_struct` ensures
/// that the allocation remains valid at least until the matching call to `put_task_struct`.
///
/// # Examples
///
/// The following is an example of getting the PID of the current thread with zero additional cost
/// when compared to the C version:
///
/// ```
/// let pid = current!().pid();
/// ```
///
/// Getting the PID of the current process, also zero additional cost:
///
/// ```
/// let pid = current!().group_leader().pid();
/// ```
///
/// Getting the current task and storing it in some struct. The reference count is automatically
/// incremented when creating `State` and decremented when it is dropped:
///
/// ```
/// use kernel::{task::Task, types::ARef};
///
/// struct State {
///     creator: ARef<Task>,
///     index: u32,
/// }
///
/// impl State {
///     fn new() -> Self {
///         Self {
///             creator: current!().into(),
///             index: 0,
///         }
///     }
/// }
/// ```
#[repr(transparent)]
pub struct Task(pub(crate) Opaque<bindings::task_struct>);

@@ -27,6 +73,40 @@ unsafe impl Sync for Task {}
type Pid = bindings::pid_t;

impl Task {
    /// Returns a task reference for the currently executing task/thread.
    ///
    /// The recommended way to get the current task/thread is to use the
    /// [`current`](crate::current) macro because it is safe.
    ///
    /// # Safety
    ///
    /// Callers must ensure that the returned object doesn't outlive the current task/thread.
    pub unsafe fn current() -> impl Deref<Target = Task> {
        struct TaskRef<'a> {
            task: &'a Task,
            _not_send: PhantomData<*mut ()>,
        }

        impl Deref for TaskRef<'_> {
            type Target = Task;

            fn deref(&self) -> &Self::Target {
                self.task
            }
        }

        // SAFETY: Just an FFI call with no additional safety requirements.
        let ptr = unsafe { bindings::get_current() };

        TaskRef {
            // SAFETY: If the current thread is still running, the current task is valid. Given
            // that `TaskRef` is not `Send`, we know it cannot be transferred to another thread
            // (where it could potentially outlive the caller).
            task: unsafe { &*ptr.cast() },
            _not_send: PhantomData,
        }
    }

    /// Returns the group leader of the given task.
    pub fn group_leader(&self) -> &Task {
        // SAFETY: By the type invariant, we know that `self.0` is a valid task. Valid tasks always