Commit 9e494390 authored by Benno Lossin's avatar Benno Lossin Committed by Miguel Ojeda
Browse files

rust: init: add functions to create array initializers



Add two functions `pin_init_array_from_fn` and `init_array_from_fn` that
take a function that generates initializers for `T` from `usize`, the added
functions then return an initializer for `[T; N]` where every element is
initialized by an element returned from the generator function.

Suggested-by: default avatarAsahi Lina <lina@asahilina.net>
Reviewed-by: default avatarBjörn Roy Baron <bjorn3_gh@protonmail.com>
Reviewed-by: default avatarAlice Ryhl <aliceryhl@google.com>
Reviewed-by: default avatarMartin Rodriguez Reboredo <yakoyoku@gmail.com>
Signed-off-by: default avatarBenno Lossin <benno.lossin@proton.me>
Link: https://lore.kernel.org/r/20230814084602.25699-9-benno.lossin@proton.me


[ Cleaned a couple trivial nits. ]
Signed-off-by: default avatarMiguel Ojeda <ojeda@kernel.org>
parent 35e7fca2
Loading
Loading
Loading
Loading
+88 −0
Original line number Diff line number Diff line
@@ -202,6 +202,7 @@
use crate::{
    error::{self, Error},
    sync::UniqueArc,
    types::ScopeGuard,
};
use alloc::boxed::Box;
use core::{
@@ -867,6 +868,93 @@ pub fn uninit<T, E>() -> impl Init<MaybeUninit<T>, E> {
    unsafe { init_from_closure(|_| Ok(())) }
}

/// Initializes an array by initializing each element via the provided initializer.
///
/// # Examples
///
/// ```rust
/// use kernel::{error::Error, init::init_array_from_fn};
/// let array: Box<[usize; 1_000]>= Box::init::<Error>(init_array_from_fn(|i| i)).unwrap();
/// assert_eq!(array.len(), 1_000);
/// ```
pub fn init_array_from_fn<I, const N: usize, T, E>(
    mut make_init: impl FnMut(usize) -> I,
) -> impl Init<[T; N], E>
where
    I: Init<T, E>,
{
    let init = move |slot: *mut [T; N]| {
        let slot = slot.cast::<T>();
        // Counts the number of initialized elements and when dropped drops that many elements from
        // `slot`.
        let mut init_count = ScopeGuard::new_with_data(0, |i| {
            // We now free every element that has been initialized before:
            // SAFETY: The loop initialized exactly the values from 0..i and since we
            // return `Err` below, the caller will consider the memory at `slot` as
            // uninitialized.
            unsafe { ptr::drop_in_place(ptr::slice_from_raw_parts_mut(slot, i)) };
        });
        for i in 0..N {
            let init = make_init(i);
            // SAFETY: Since 0 <= `i` < N, it is still in bounds of `[T; N]`.
            let ptr = unsafe { slot.add(i) };
            // SAFETY: The pointer is derived from `slot` and thus satisfies the `__init`
            // requirements.
            unsafe { init.__init(ptr) }?;
            *init_count += 1;
        }
        init_count.dismiss();
        Ok(())
    };
    // SAFETY: The initializer above initializes every element of the array. On failure it drops
    // any initialized elements and returns `Err`.
    unsafe { init_from_closure(init) }
}

/// Initializes an array by initializing each element via the provided initializer.
///
/// # Examples
///
/// ```rust
/// use kernel::{sync::{Arc, Mutex}, init::pin_init_array_from_fn, new_mutex};
/// let array: Arc<[Mutex<usize>; 1_000]>=
///     Arc::pin_init(pin_init_array_from_fn(|i| new_mutex!(i))).unwrap();
/// assert_eq!(array.len(), 1_000);
/// ```
pub fn pin_init_array_from_fn<I, const N: usize, T, E>(
    mut make_init: impl FnMut(usize) -> I,
) -> impl PinInit<[T; N], E>
where
    I: PinInit<T, E>,
{
    let init = move |slot: *mut [T; N]| {
        let slot = slot.cast::<T>();
        // Counts the number of initialized elements and when dropped drops that many elements from
        // `slot`.
        let mut init_count = ScopeGuard::new_with_data(0, |i| {
            // We now free every element that has been initialized before:
            // SAFETY: The loop initialized exactly the values from 0..i and since we
            // return `Err` below, the caller will consider the memory at `slot` as
            // uninitialized.
            unsafe { ptr::drop_in_place(ptr::slice_from_raw_parts_mut(slot, i)) };
        });
        for i in 0..N {
            let init = make_init(i);
            // SAFETY: Since 0 <= `i` < N, it is still in bounds of `[T; N]`.
            let ptr = unsafe { slot.add(i) };
            // SAFETY: The pointer is derived from `slot` and thus satisfies the `__init`
            // requirements.
            unsafe { init.__pinned_init(ptr) }?;
            *init_count += 1;
        }
        init_count.dismiss();
        Ok(())
    };
    // SAFETY: The initializer above initializes every element of the array. On failure it drops
    // any initialized elements and returns `Err`.
    unsafe { pin_init_from_closure(init) }
}

// SAFETY: Every type can be initialized by-value.
unsafe impl<T, E> Init<T, E> for T {
    unsafe fn __init(self, slot: *mut T) -> Result<(), E> {