Commit 2d8cb0bf authored by Jakub Kicinski's avatar Jakub Kicinski
Browse files

Merge branch 'tools-ynl-fix-enum-as-flags-in-the-generic-cli'

Jakub Kicinski says:

====================
tools: ynl: fix enum-as-flags in the generic CLI

The CLI needs to use proper classes when looking at Enum definitions
rather than interpreting the YAML spec ad-hoc, because we have more
than on format of the definition supported.
====================

Link: https://lore.kernel.org/r/20230308003923.445268-1-kuba@kernel.org


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 649c15c7 c311aaa7
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause

from .nlspec import SpecAttr, SpecAttrSet, SpecFamily, SpecOperation
from .nlspec import SpecAttr, SpecAttrSet, SpecEnumEntry, SpecEnumSet, \
    SpecFamily, SpecOperation
from .ynl import YnlFamily

__all__ = ["SpecAttr", "SpecAttrSet", "SpecFamily", "SpecOperation",
           "YnlFamily"]
__all__ = ["SpecAttr", "SpecAttrSet", "SpecEnumEntry", "SpecEnumSet",
           "SpecFamily", "SpecOperation", "YnlFamily"]
+99 −0
Original line number Diff line number Diff line
@@ -57,6 +57,94 @@ class SpecElement:
        pass


class SpecEnumEntry(SpecElement):
    """ Entry within an enum declared in the Netlink spec.

    Attributes:
        doc         documentation string
        enum_set    back reference to the enum
        value       numerical value of this enum (use accessors in most situations!)

    Methods:
        raw_value   raw value, i.e. the id in the enum, unlike user value which is a mask for flags
        user_value   user value, same as raw value for enums, for flags it's the mask
    """
    def __init__(self, enum_set, yaml, prev, value_start):
        if isinstance(yaml, str):
            yaml = {'name': yaml}
        super().__init__(enum_set.family, yaml)

        self.doc = yaml.get('doc', '')
        self.enum_set = enum_set

        if 'value' in yaml:
            self.value = yaml['value']
        elif prev:
            self.value = prev.value + 1
        else:
            self.value = value_start

    def has_doc(self):
        return bool(self.doc)

    def raw_value(self):
        return self.value

    def user_value(self):
        if self.enum_set['type'] == 'flags':
            return 1 << self.value
        else:
            return self.value


class SpecEnumSet(SpecElement):
    """ Enum type

    Represents an enumeration (list of numerical constants)
    as declared in the "definitions" section of the spec.

    Attributes:
        type            enum or flags
        entries         entries by name
        entries_by_val  entries by value
    Methods:
        get_mask      for flags compute the mask of all defined values
    """
    def __init__(self, family, yaml):
        super().__init__(family, yaml)

        self.type = yaml['type']

        prev_entry = None
        value_start = self.yaml.get('value-start', 0)
        self.entries = dict()
        self.entries_by_val = dict()
        for entry in self.yaml['entries']:
            e = self.new_entry(entry, prev_entry, value_start)
            self.entries[e.name] = e
            self.entries_by_val[e.raw_value()] = e
            prev_entry = e

    def new_entry(self, entry, prev_entry, value_start):
        return SpecEnumEntry(self, entry, prev_entry, value_start)

    def has_doc(self):
        if 'doc' in self.yaml:
            return True
        for entry in self.entries.values():
            if entry.has_doc():
                return True
        return False

    def get_mask(self):
        mask = 0
        idx = self.yaml.get('value-start', 0)
        for _ in self.entries.values():
            mask |= 1 << idx
            idx += 1
        return mask


class SpecAttr(SpecElement):
    """ Single Netlink atttribute type

@@ -193,6 +281,7 @@ class SpecFamily(SpecElement):
        msgs       dict of all messages (index by name)
        msgs_by_value  dict of all messages (indexed by name)
        ops        dict of all valid requests / responses
        consts     dict of all constants/enums
    """
    def __init__(self, spec_path, schema_path=None):
        with open(spec_path, "r") as stream:
@@ -222,6 +311,7 @@ class SpecFamily(SpecElement):
        self.req_by_value = collections.OrderedDict()
        self.rsp_by_value = collections.OrderedDict()
        self.ops = collections.OrderedDict()
        self.consts = collections.OrderedDict()

        last_exception = None
        while len(self._resolution_list) > 0:
@@ -242,6 +332,9 @@ class SpecFamily(SpecElement):
            if len(resolved) == 0:
                raise last_exception

    def new_enum(self, elem):
        return SpecEnumSet(self, elem)

    def new_attr_set(self, elem):
        return SpecAttrSet(self, elem)

@@ -296,6 +389,12 @@ class SpecFamily(SpecElement):
    def resolve(self):
        self.resolve_up(super())

        for elem in self.yaml['definitions']:
            if elem['type'] == 'enum' or elem['type'] == 'flags':
                self.consts[elem['name']] = self.new_enum(elem)
            else:
                self.consts[elem['name']] = elem

        for elem in self.yaml['attribute-sets']:
            attr_set = self.new_attr_set(elem)
            self.attr_sets[elem['name']] = attr_set
+2 −7
Original line number Diff line number Diff line
@@ -303,11 +303,6 @@ class YnlFamily(SpecFamily):
        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)

        self._types = dict()

        for elem in self.yaml.get('definitions', []):
            self._types[elem['name']] = elem

        self.async_msg_ids = set()
        self.async_msg_queue = []

@@ -353,13 +348,13 @@ class YnlFamily(SpecFamily):

    def _decode_enum(self, rsp, attr_spec):
        raw = rsp[attr_spec['name']]
        enum = self._types[attr_spec['enum']]
        enum = self.consts[attr_spec['enum']]
        i = attr_spec.get('value-start', 0)
        if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
            value = set()
            while raw:
                if raw & 1:
                    value.add(enum['entries'][i])
                    value.add(enum.entries_by_val[i].name)
                raw >>= 1
                i += 1
        else:
+21 −86
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ import collections
import os
import yaml

from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation
from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry


def c_upper(name):
@@ -567,97 +567,37 @@ class Struct:
        self.inherited = [c_lower(x) for x in sorted(self._inherited)]


class EnumEntry:
class EnumEntry(SpecEnumEntry):
    def __init__(self, enum_set, yaml, prev, value_start):
        if isinstance(yaml, str):
            self.name = yaml
            yaml = {}
            self.doc = ''
        else:
            self.name = yaml['name']
            self.doc = yaml.get('doc', '')

        self.yaml = yaml
        self.enum_set = enum_set
        self.c_name = c_upper(enum_set.value_pfx + self.name)
        super().__init__(enum_set, yaml, prev, value_start)

        if 'value' in yaml:
            self.value = yaml['value']
        if prev:
            self.value_change = (self.value != prev.value + 1)
        elif prev:
            self.value_change = False
            self.value = prev.value + 1
        else:
            self.value = value_start
            self.value_change = (self.value != 0)

        self.value_change = self.value_change or self.enum_set['type'] == 'flags'

    def __getitem__(self, key):
        return self.yaml[key]

    def __contains__(self, key):
        return key in self.yaml

    def has_doc(self):
        return bool(self.doc)
        # Added by resolve:
        self.c_name = None
        delattr(self, "c_name")

    # raw value, i.e. the id in the enum, unlike user value which is a mask for flags
    def raw_value(self):
        return self.value
    def resolve(self):
        self.resolve_up(super())

    # user value, same as raw value for enums, for flags it's the mask
    def user_value(self):
        if self.enum_set['type'] == 'flags':
            return 1 << self.value
        else:
            return self.value
        self.c_name = c_upper(self.enum_set.value_pfx + self.name)


class EnumSet:
class EnumSet(SpecEnumSet):
    def __init__(self, family, yaml):
        self.yaml = yaml
        self.family = family

        self.render_name = c_lower(family.name + '-' + yaml['name'])
        self.enum_name = 'enum ' + self.render_name

        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")

        self.type = yaml['type']

        prev_entry = None
        value_start = self.yaml.get('value-start', 0)
        self.entries = {}
        self.entry_list = []
        for entry in self.yaml['entries']:
            e = EnumEntry(self, entry, prev_entry, value_start)
            self.entries[e.name] = e
            self.entry_list.append(e)
            prev_entry = e

    def __getitem__(self, key):
        return self.yaml[key]

    def __contains__(self, key):
        return key in self.yaml

    def has_doc(self):
        if 'doc' in self.yaml:
            return True
        for entry in self.entry_list:
            if entry.has_doc():
                return True
        return False
        super().__init__(family, yaml)

    def get_mask(self):
        mask = 0
        idx = self.yaml.get('value-start', 0)
        for _ in self.entry_list:
            mask |= 1 << idx
            idx += 1
        return mask
    def new_entry(self, entry, prev_entry, value_start):
        return EnumEntry(self, entry, prev_entry, value_start)


class AttrSet(SpecAttrSet):
@@ -792,8 +732,6 @@ class Family(SpecFamily):

        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})

        self.consts = dict()

        self.hooks = dict()
        for when in ['pre', 'post']:
            self.hooks[when] = dict()
@@ -820,6 +758,9 @@ class Family(SpecFamily):
        if self.kernel_policy == 'global':
            self._load_global_policy()

    def new_enum(self, elem):
        return EnumSet(self, elem)

    def new_attr_set(self, elem):
        return AttrSet(self, elem)

@@ -837,12 +778,6 @@ class Family(SpecFamily):
                }

    def _dictify(self):
        for elem in self.yaml['definitions']:
            if elem['type'] == 'enum' or elem['type'] == 'flags':
                self.consts[elem['name']] = EnumSet(self, elem)
            else:
                self.consts[elem['name']] = elem

        ntf = []
        for msg in self.msgs.values():
            if 'notify' in msg:
@@ -1980,7 +1915,7 @@ def render_uapi(family, cw):
                if 'doc' in enum:
                    doc = ' - ' + enum['doc']
                cw.write_doc_line(enum.enum_name + doc)
                for entry in enum.entry_list:
                for entry in enum.entries.values():
                    if entry.has_doc():
                        doc = '@' + entry.c_name + ': ' + entry['doc']
                        cw.write_doc_line(doc)
@@ -1988,7 +1923,7 @@ def render_uapi(family, cw):

            uapi_enum_start(family, cw, const, 'name')
            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
            for entry in enum.entry_list:
            for entry in enum.entries.values():
                suffix = ','
                if entry.value_change:
                    suffix = f" = {entry.user_value()}" + suffix