Commit 071cedc8 authored by Benno Lossin's avatar Benno Lossin Committed by Miguel Ojeda
Browse files

rust: add derive macro for `Zeroable`



Add a derive proc-macro for the `Zeroable` trait. The macro supports
structs where every field implements the `Zeroable` trait. This way
`unsafe` implementations can be avoided.

The macro is split into two parts:
- a proc-macro to parse generics into impl and ty generics,
- a declarative macro that expands to the impl block.

Suggested-by: default avatarAsahi Lina <lina@asahilina.net>
Signed-off-by: default avatarBenno Lossin <benno.lossin@proton.me>
Reviewed-by: default avatarGary Guo <gary@garyguo.net>
Reviewed-by: default avatarMartin Rodriguez Reboredo <yakoyoku@gmail.com>
Link: https://lore.kernel.org/r/20230814084602.25699-4-benno.lossin@proton.me


[ Added `ignore` to the `lib.rs` example and cleaned trivial nit. ]
Signed-off-by: default avatarMiguel Ojeda <ojeda@kernel.org>
parent f8badd15
Loading
Loading
Loading
Loading
+35 −0
Original line number Diff line number Diff line
@@ -1215,3 +1215,38 @@ macro_rules! __init_internal {
        );
    };
}

#[doc(hidden)]
#[macro_export]
macro_rules! __derive_zeroable {
    (parse_input:
        @sig(
            $(#[$($struct_attr:tt)*])*
            $vis:vis struct $name:ident
            $(where $($whr:tt)*)?
        ),
        @impl_generics($($impl_generics:tt)*),
        @ty_generics($($ty_generics:tt)*),
        @body({
            $(
                $(#[$($field_attr:tt)*])*
                $field:ident : $field_ty:ty
            ),* $(,)?
        }),
    ) => {
        // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
        #[automatically_derived]
        unsafe impl<$($impl_generics)*> $crate::init::Zeroable for $name<$($ty_generics)*>
        where
            $($($whr)*)?
        {}
        const _: () = {
            fn assert_zeroable<T: ?::core::marker::Sized + $crate::init::Zeroable>() {}
            fn ensure_zeroable<$($impl_generics)*>()
                where $($($whr)*)?
            {
                $(assert_zeroable::<$field_ty>();)*
            }
        };
    };
}
+1 −1
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@ pub use core::pin::Pin;
pub use alloc::{boxed::Box, vec::Vec};

#[doc(no_inline)]
pub use macros::{module, pin_data, pinned_drop, vtable};
pub use macros::{module, pin_data, pinned_drop, vtable, Zeroable};

pub use super::build_assert;

+20 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ mod paste;
mod pin_data;
mod pinned_drop;
mod vtable;
mod zeroable;

use proc_macro::TokenStream;

@@ -343,3 +344,22 @@ pub fn paste(input: TokenStream) -> TokenStream {
    paste::expand(&mut tokens);
    tokens.into_iter().collect()
}

/// Derives the [`Zeroable`] trait for the given struct.
///
/// This can only be used for structs where every field implements the [`Zeroable`] trait.
///
/// # Examples
///
/// ```rust,ignore
/// #[derive(Zeroable)]
/// pub struct DriverData {
///     id: i64,
///     buf_ptr: *mut u8,
///     len: usize,
/// }
/// ```
#[proc_macro_derive(Zeroable)]
pub fn derive_zeroable(input: TokenStream) -> TokenStream {
    zeroable::derive(input)
}
+12 −0
Original line number Diff line number Diff line
@@ -124,6 +124,18 @@ macro_rules! quote_spanned {
        ));
        quote_spanned!(@proc $v $span $($tt)*);
    };
    (@proc $v:ident $span:ident ; $($tt:tt)*) => {
        $v.push(::proc_macro::TokenTree::Punct(
                ::proc_macro::Punct::new(';', ::proc_macro::Spacing::Alone)
        ));
        quote_spanned!(@proc $v $span $($tt)*);
    };
    (@proc $v:ident $span:ident + $($tt:tt)*) => {
        $v.push(::proc_macro::TokenTree::Punct(
                ::proc_macro::Punct::new('+', ::proc_macro::Spacing::Alone)
        ));
        quote_spanned!(@proc $v $span $($tt)*);
    };
    (@proc $v:ident $span:ident $id:ident $($tt:tt)*) => {
        $v.push(::proc_macro::TokenTree::Ident(::proc_macro::Ident::new(stringify!($id), $span)));
        quote_spanned!(@proc $v $span $($tt)*);
+72 −0
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0

use crate::helpers::{parse_generics, Generics};
use proc_macro::{TokenStream, TokenTree};

pub(crate) fn derive(input: TokenStream) -> TokenStream {
    let (
        Generics {
            impl_generics,
            ty_generics,
        },
        mut rest,
    ) = parse_generics(input);
    // This should be the body of the struct `{...}`.
    let last = rest.pop();
    // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
    let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
    // Are we inside of a generic where we want to add `Zeroable`?
    let mut in_generic = !impl_generics.is_empty();
    // Have we already inserted `Zeroable`?
    let mut inserted = false;
    // Level of `<>` nestings.
    let mut nested = 0;
    for tt in impl_generics {
        match &tt {
            // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
            TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
                if in_generic && !inserted {
                    new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
                }
                in_generic = true;
                inserted = false;
                new_impl_generics.push(tt);
            }
            // If we find `'`, then we are entering a lifetime.
            TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
                in_generic = false;
                new_impl_generics.push(tt);
            }
            TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
                new_impl_generics.push(tt);
                if in_generic {
                    new_impl_generics.extend(quote! { ::kernel::init::Zeroable + });
                    inserted = true;
                }
            }
            TokenTree::Punct(p) if p.as_char() == '<' => {
                nested += 1;
                new_impl_generics.push(tt);
            }
            TokenTree::Punct(p) if p.as_char() == '>' => {
                assert!(nested > 0);
                nested -= 1;
                new_impl_generics.push(tt);
            }
            _ => new_impl_generics.push(tt),
        }
    }
    assert_eq!(nested, 0);
    if in_generic && !inserted {
        new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
    }
    quote! {
        ::kernel::__derive_zeroable!(
            parse_input:
                @sig(#(#rest)*),
                @impl_generics(#(#new_impl_generics)*),
                @ty_generics(#(#ty_generics)*),
                @body(#last),
        );
    }
}