Commit e957b9cd authored by Benno Lossin's avatar Benno Lossin Committed by Miguel Ojeda
Browse files

rust: macros: refactor generics parsing of `#[pin_data]` into its own function



Other macros might also want to parse generics. Additionally this makes
the code easier to read, as the next commit will introduce more code in
`#[pin_data]`. Also add more comments to explain how parsing generics
work.

Signed-off-by: default avatarBenno Lossin <benno.lossin@proton.me>
Reviewed-by: default avatarAlice Ryhl <aliceryhl@google.com>
Reviewed-by: default avatarGary Guo <gary@garyguo.net>
Reviewed-by: default avatarMartin Rodriguez Reboredo <yakoyoku@gmail.com>
Link: https://lore.kernel.org/r/20230424081112.99890-2-benno.lossin@proton.me


Signed-off-by: default avatarMiguel Ojeda <ojeda@kernel.org>
parent b8342add
Loading
Loading
Loading
Loading
+85 −1
Original line number Diff line number Diff line
// SPDX-License-Identifier: GPL-2.0

use proc_macro::{token_stream, Group, TokenTree};
use proc_macro::{token_stream, Group, Punct, Spacing, TokenStream, TokenTree};

pub(crate) fn try_ident(it: &mut token_stream::IntoIter) -> Option<String> {
    if let Some(TokenTree::Ident(ident)) = it.next() {
@@ -69,3 +69,87 @@ pub(crate) fn expect_end(it: &mut token_stream::IntoIter) {
        panic!("Expected end");
    }
}

pub(crate) struct Generics {
    pub(crate) impl_generics: Vec<TokenTree>,
    pub(crate) ty_generics: Vec<TokenTree>,
}

/// Parses the given `TokenStream` into `Generics` and the rest.
///
/// The generics are not present in the rest, but a where clause might remain.
pub(crate) fn parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>) {
    // `impl_generics`, the declared generics with their bounds.
    let mut impl_generics = vec![];
    // Only the names of the generics, without any bounds.
    let mut ty_generics = vec![];
    // Tokens not related to the generics e.g. the `where` token and definition.
    let mut rest = vec![];
    // The current level of `<`.
    let mut nesting = 0;
    let mut toks = input.into_iter();
    // If we are at the beginning of a generic parameter.
    let mut at_start = true;
    for tt in &mut toks {
        match tt.clone() {
            TokenTree::Punct(p) if p.as_char() == '<' => {
                if nesting >= 1 {
                    // This is inside of the generics and part of some bound.
                    impl_generics.push(tt);
                }
                nesting += 1;
            }
            TokenTree::Punct(p) if p.as_char() == '>' => {
                // This is a parsing error, so we just end it here.
                if nesting == 0 {
                    break;
                } else {
                    nesting -= 1;
                    if nesting >= 1 {
                        // We are still inside of the generics and part of some bound.
                        impl_generics.push(tt);
                    }
                    if nesting == 0 {
                        break;
                    }
                }
            }
            tt => {
                if nesting == 1 {
                    // Here depending on the token, it might be a generic variable name.
                    match &tt {
                        // Ignore const.
                        TokenTree::Ident(i) if i.to_string() == "const" => {}
                        TokenTree::Ident(_) if at_start => {
                            ty_generics.push(tt.clone());
                            // We also already push the `,` token, this makes it easier to append
                            // generics.
                            ty_generics.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
                            at_start = false;
                        }
                        TokenTree::Punct(p) if p.as_char() == ',' => at_start = true,
                        // Lifetimes begin with `'`.
                        TokenTree::Punct(p) if p.as_char() == '\'' && at_start => {
                            ty_generics.push(tt.clone());
                        }
                        _ => {}
                    }
                }
                if nesting >= 1 {
                    impl_generics.push(tt);
                } else if nesting == 0 {
                    // If we haven't entered the generics yet, we still want to keep these tokens.
                    rest.push(tt);
                }
            }
        }
    }
    rest.extend(toks);
    (
        Generics {
            impl_generics,
            ty_generics,
        },
        rest,
    )
}
+9 −61
Original line number Diff line number Diff line
// SPDX-License-Identifier: Apache-2.0 OR MIT

use proc_macro::{Punct, Spacing, TokenStream, TokenTree};
use crate::helpers::{parse_generics, Generics};
use proc_macro::TokenStream;

pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream {
    // This proc-macro only does some pre-parsing and then delegates the actual parsing to
    // `kernel::__pin_data!`.
    //
    // In here we only collect the generics, since parsing them in declarative macros is very
    // elaborate. We also do not need to analyse their structure, we only need to collect them.

    // `impl_generics`, the declared generics with their bounds.
    let mut impl_generics = vec![];
    // Only the names of the generics, without any bounds.
    let mut ty_generics = vec![];
    // Tokens not related to the generics e.g. the `impl` token.
    let mut rest = vec![];
    // The current level of `<`.
    let mut nesting = 0;
    let mut toks = input.into_iter();
    // If we are at the beginning of a generic parameter.
    let mut at_start = true;
    for tt in &mut toks {
        match tt.clone() {
            TokenTree::Punct(p) if p.as_char() == '<' => {
                if nesting >= 1 {
                    impl_generics.push(tt);
                }
                nesting += 1;
            }
            TokenTree::Punct(p) if p.as_char() == '>' => {
                if nesting == 0 {
                    break;
                } else {
                    nesting -= 1;
                    if nesting >= 1 {
                        impl_generics.push(tt);
                    }
                    if nesting == 0 {
                        break;
                    }
                }
            }
            tt => {
                if nesting == 1 {
                    match &tt {
                        TokenTree::Ident(i) if i.to_string() == "const" => {}
                        TokenTree::Ident(_) if at_start => {
                            ty_generics.push(tt.clone());
                            ty_generics.push(TokenTree::Punct(Punct::new(',', Spacing::Alone)));
                            at_start = false;
                        }
                        TokenTree::Punct(p) if p.as_char() == ',' => at_start = true,
                        TokenTree::Punct(p) if p.as_char() == '\'' && at_start => {
                            ty_generics.push(tt.clone());
                        }
                        _ => {}
                    }
                }
                if nesting >= 1 {
                    impl_generics.push(tt);
                } else if nesting == 0 {
                    rest.push(tt);
                }
            }
        }
    }
    rest.extend(toks);
    let (
        Generics {
            impl_generics,
            ty_generics,
        },
        mut rest,
    ) = parse_generics(input);
    // This should be the body of the struct `{...}`.
    let last = rest.pop();
    quote!(::kernel::__pin_data! {