Commit e46dd903 authored by Donald Hunter's avatar Donald Hunter Committed by Jakub Kicinski
Browse files

tools/net/ynl: Add support for netlink-raw families



Refactor the ynl code to encapsulate protocol specifics into
NetlinkProtocol and GenlProtocol.

Signed-off-by: default avatarDonald Hunter <donald.hunter@gmail.com>
Link: https://lore.kernel.org/r/20230825122756.7603-8-donald.hunter@gmail.com


Signed-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parent fb0a06d4
Loading
Loading
Loading
Loading
+91 −33
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ class Netlink:
    NETLINK_ADD_MEMBERSHIP = 1
    NETLINK_CAP_ACK = 10
    NETLINK_EXT_ACK = 11
    NETLINK_GET_STRICT_CHK = 12

    # Netlink message
    NLMSG_ERROR = 2
@@ -228,6 +229,9 @@ class NlMsg:
                            desc += f" ({spec['doc']})"
                        self.extack['miss-type'] = desc

    def cmd(self):
        return self.nl_type

    def __repr__(self):
        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
        if self.error:
@@ -322,6 +326,9 @@ class GenlMsg:
        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
        self.raw = nl_msg.raw[4:]

    def cmd(self):
        return self.genl_cmd

    def __repr__(self):
        msg = repr(self.nl)
        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
@@ -330,9 +337,41 @@ class GenlMsg:
        return msg


class GenlFamily:
    def __init__(self, family_name):
class NetlinkProtocol:
    def __init__(self, family_name, proto_num):
        self.family_name = family_name
        self.proto_num = proto_num

    def _message(self, nl_type, nl_flags, seq=None):
        if seq is None:
            seq = random.randint(1, 1024)
        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
        return nlmsg

    def message(self, flags, command, version, seq=None):
        return self._message(command, flags, seq)

    def _decode(self, nl_msg):
        return nl_msg

    def decode(self, ynl, nl_msg):
        msg = self._decode(nl_msg)
        fixed_header_size = 0
        if ynl:
            op = ynl.rsp_by_value[msg.cmd()]
            fixed_header_size = ynl._fixed_header_size(op)
        msg.raw_attrs = NlAttrs(msg.raw[fixed_header_size:])
        return msg

    def get_mcast_id(self, mcast_name, mcast_groups):
        if mcast_name not in mcast_groups:
            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
        return mcast_groups[mcast_name].value


class GenlProtocol(NetlinkProtocol):
    def __init__(self, family_name):
        super().__init__(family_name, Netlink.NETLINK_GENERIC)

        global genl_family_name_to_id
        if genl_family_name_to_id is None:
@@ -341,6 +380,19 @@ class GenlFamily:
        self.genl_family = genl_family_name_to_id[family_name]
        self.family_id = genl_family_name_to_id[family_name]['id']

    def message(self, flags, command, version, seq=None):
        nlmsg = self._message(self.family_id, flags, seq)
        genlmsg = struct.pack("BBH", command, version, 0)
        return nlmsg + genlmsg

    def _decode(self, nl_msg):
        return GenlMsg(nl_msg)

    def get_mcast_id(self, mcast_name, mcast_groups):
        if mcast_name not in self.genl_family['mcast']:
            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
        return self.genl_family['mcast'][mcast_name]


#
# YNL implementation details.
@@ -353,9 +405,19 @@ class YnlFamily(SpecFamily):

        self.include_raw = False

        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
        try:
            if self.proto == "netlink-raw":
                self.nlproto = NetlinkProtocol(self.yaml['name'],
                                               self.yaml['protonum'])
            else:
                self.nlproto = GenlProtocol(self.yaml['name'])
        except KeyError:
            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")

        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)

        self.async_msg_ids = set()
        self.async_msg_queue = []
@@ -368,18 +430,12 @@ class YnlFamily(SpecFamily):
            bound_f = functools.partial(self._op, op_name)
            setattr(self, op.ident_name, bound_f)

        try:
            self.family = GenlFamily(self.yaml['name'])
        except KeyError:
            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")

    def ntf_subscribe(self, mcast_name):
        if mcast_name not in self.family.genl_family['mcast']:
            raise Exception(f'Multicast group "{mcast_name}" not present in the family')

        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
        self.sock.bind((0, 0))
        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
                             self.family.genl_family['mcast'][mcast_name])
                             mcast_id)

    def _add_attr(self, space, name, value):
        try:
@@ -505,11 +561,9 @@ class YnlFamily(SpecFamily):
        if 'bad-attr-offs' not in extack:
            return

        genl_req = GenlMsg(NlMsg(request, 0, op.attr_set))
        fixed_header_size = self._fixed_header_size(op)
        offset = 20 + fixed_header_size
        path = self._decode_extack_path(NlAttrs(genl_req.raw[fixed_header_size:]),
                                        op.attr_set, offset,
        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
        offset = 20 + self._fixed_header_size(op)
        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
                                        extack['bad-attr-offs'])
        if path:
            del extack['bad-attr-offs']
@@ -539,14 +593,17 @@ class YnlFamily(SpecFamily):
            fixed_header_attrs[m.name] = value
        return fixed_header_attrs

    def handle_ntf(self, nl_msg, genl_msg):
    def handle_ntf(self, decoded):
        msg = dict()
        if self.include_raw:
            msg['nlmsg'] = nl_msg
            msg['genlmsg'] = genl_msg
        op = self.rsp_by_value[genl_msg.genl_cmd]
            msg['raw'] = decoded
        op = self.rsp_by_value[decoded.cmd()]
        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
        if op.fixed_header:
            attrs.update(self._decode_fixed_header(decoded, op.fixed_header))

        msg['name'] = op['name']
        msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
        msg['msg'] = attrs
        self.async_msg_queue.append(msg)

    def check_ntf(self):
@@ -566,12 +623,12 @@ class YnlFamily(SpecFamily):
                    print("Netlink done while checking for ntf!?")
                    continue

                gm = GenlMsg(nl_msg)
                if gm.genl_cmd not in self.async_msg_ids:
                    print("Unexpected msg id done while checking for ntf", gm)
                decoded = self.nlproto.decode(self, nl_msg)
                if decoded.cmd() not in self.async_msg_ids:
                    print("Unexpected msg id done while checking for ntf", decoded)
                    continue

                self.handle_ntf(nl_msg, gm)
                self.handle_ntf(decoded)

    def operation_do_attributes(self, name):
      """
@@ -592,7 +649,7 @@ class YnlFamily(SpecFamily):
            nl_flags |= Netlink.NLM_F_DUMP

        req_seq = random.randint(1024, 65535)
        msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
        fixed_header_members = []
        if op.fixed_header:
            fixed_header_members = self.consts[op.fixed_header].members
@@ -624,19 +681,20 @@ class YnlFamily(SpecFamily):
                    done = True
                    break

                gm = GenlMsg(nl_msg)
                decoded = self.nlproto.decode(self, nl_msg)

                # Check if this is a reply to our request
                if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
                    if gm.genl_cmd in self.async_msg_ids:
                        self.handle_ntf(nl_msg, gm)
                if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
                    if decoded.cmd() in self.async_msg_ids:
                        self.handle_ntf(decoded)
                        continue
                    else:
                        print('Unexpected message: ' + repr(gm))
                        print('Unexpected message: ' + repr(decoded))
                        continue

                rsp_msg = self._decode(NlAttrs(gm.raw), op.attr_set.name)
                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
                if op.fixed_header:
                    rsp_msg.update(self._decode_fixed_header(gm, op.fixed_header))
                    rsp_msg.update(self._decode_fixed_header(decoded, op.fixed_header))
                rsp.append(rsp_msg)

        if not rsp: