Linux Audio

Check our new training course

Loading...
Note: File does not exist in v4.10.11.
   1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
   2
   3from collections import namedtuple
   4from enum import Enum
   5import functools
   6import os
   7import random
   8import socket
   9import struct
  10from struct import Struct
  11import sys
  12import yaml
  13import ipaddress
  14import uuid
  15import queue
  16import selectors
  17import time
  18
  19from .nlspec import SpecFamily
  20
  21#
  22# Generic Netlink code which should really be in some library, but I can't quickly find one.
  23#
  24
  25
  26class Netlink:
  27    # Netlink socket
  28    SOL_NETLINK = 270
  29
  30    NETLINK_ADD_MEMBERSHIP = 1
  31    NETLINK_CAP_ACK = 10
  32    NETLINK_EXT_ACK = 11
  33    NETLINK_GET_STRICT_CHK = 12
  34
  35    # Netlink message
  36    NLMSG_ERROR = 2
  37    NLMSG_DONE = 3
  38
  39    NLM_F_REQUEST = 1
  40    NLM_F_ACK = 4
  41    NLM_F_ROOT = 0x100
  42    NLM_F_MATCH = 0x200
  43
  44    NLM_F_REPLACE = 0x100
  45    NLM_F_EXCL = 0x200
  46    NLM_F_CREATE = 0x400
  47    NLM_F_APPEND = 0x800
  48
  49    NLM_F_CAPPED = 0x100
  50    NLM_F_ACK_TLVS = 0x200
  51
  52    NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
  53
  54    NLA_F_NESTED = 0x8000
  55    NLA_F_NET_BYTEORDER = 0x4000
  56
  57    NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
  58
  59    # Genetlink defines
  60    NETLINK_GENERIC = 16
  61
  62    GENL_ID_CTRL = 0x10
  63
  64    # nlctrl
  65    CTRL_CMD_GETFAMILY = 3
  66
  67    CTRL_ATTR_FAMILY_ID = 1
  68    CTRL_ATTR_FAMILY_NAME = 2
  69    CTRL_ATTR_MAXATTR = 5
  70    CTRL_ATTR_MCAST_GROUPS = 7
  71
  72    CTRL_ATTR_MCAST_GRP_NAME = 1
  73    CTRL_ATTR_MCAST_GRP_ID = 2
  74
  75    # Extack types
  76    NLMSGERR_ATTR_MSG = 1
  77    NLMSGERR_ATTR_OFFS = 2
  78    NLMSGERR_ATTR_COOKIE = 3
  79    NLMSGERR_ATTR_POLICY = 4
  80    NLMSGERR_ATTR_MISS_TYPE = 5
  81    NLMSGERR_ATTR_MISS_NEST = 6
  82
  83    # Policy types
  84    NL_POLICY_TYPE_ATTR_TYPE = 1
  85    NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
  86    NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
  87    NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
  88    NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
  89    NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
  90    NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
  91    NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
  92    NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
  93    NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
  94    NL_POLICY_TYPE_ATTR_PAD = 11
  95    NL_POLICY_TYPE_ATTR_MASK = 12
  96
  97    AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
  98                                  's8', 's16', 's32', 's64',
  99                                  'binary', 'string', 'nul-string',
 100                                  'nested', 'nested-array',
 101                                  'bitfield32', 'sint', 'uint'])
 102
 103class NlError(Exception):
 104  def __init__(self, nl_msg):
 105    self.nl_msg = nl_msg
 106    self.error = -nl_msg.error
 107
 108  def __str__(self):
 109    return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
 110
 111
 112class ConfigError(Exception):
 113    pass
 114
 115
 116class NlAttr:
 117    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
 118    type_formats = {
 119        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
 120        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
 121        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
 122        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
 123        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
 124        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
 125        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
 126        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
 127    }
 128
 129    def __init__(self, raw, offset):
 130        self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
 131        self.type = self._type & ~Netlink.NLA_TYPE_MASK
 132        self.is_nest = self._type & Netlink.NLA_F_NESTED
 133        self.payload_len = self._len
 134        self.full_len = (self.payload_len + 3) & ~3
 135        self.raw = raw[offset + 4 : offset + self.payload_len]
 136
 137    @classmethod
 138    def get_format(cls, attr_type, byte_order=None):
 139        format = cls.type_formats[attr_type]
 140        if byte_order:
 141            return format.big if byte_order == "big-endian" \
 142                else format.little
 143        return format.native
 144
 145    def as_scalar(self, attr_type, byte_order=None):
 146        format = self.get_format(attr_type, byte_order)
 147        return format.unpack(self.raw)[0]
 148
 149    def as_auto_scalar(self, attr_type, byte_order=None):
 150        if len(self.raw) != 4 and len(self.raw) != 8:
 151            raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
 152        real_type = attr_type[0] + str(len(self.raw) * 8)
 153        format = self.get_format(real_type, byte_order)
 154        return format.unpack(self.raw)[0]
 155
 156    def as_strz(self):
 157        return self.raw.decode('ascii')[:-1]
 158
 159    def as_bin(self):
 160        return self.raw
 161
 162    def as_c_array(self, type):
 163        format = self.get_format(type)
 164        return [ x[0] for x in format.iter_unpack(self.raw) ]
 165
 166    def __repr__(self):
 167        return f"[type:{self.type} len:{self._len}] {self.raw}"
 168
 169
 170class NlAttrs:
 171    def __init__(self, msg, offset=0):
 172        self.attrs = []
 173
 174        while offset < len(msg):
 175            attr = NlAttr(msg, offset)
 176            offset += attr.full_len
 177            self.attrs.append(attr)
 178
 179    def __iter__(self):
 180        yield from self.attrs
 181
 182    def __repr__(self):
 183        msg = ''
 184        for a in self.attrs:
 185            if msg:
 186                msg += '\n'
 187            msg += repr(a)
 188        return msg
 189
 190
 191class NlMsg:
 192    def __init__(self, msg, offset, attr_space=None):
 193        self.hdr = msg[offset : offset + 16]
 194
 195        self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
 196            struct.unpack("IHHII", self.hdr)
 197
 198        self.raw = msg[offset + 16 : offset + self.nl_len]
 199
 200        self.error = 0
 201        self.done = 0
 202
 203        extack_off = None
 204        if self.nl_type == Netlink.NLMSG_ERROR:
 205            self.error = struct.unpack("i", self.raw[0:4])[0]
 206            self.done = 1
 207            extack_off = 20
 208        elif self.nl_type == Netlink.NLMSG_DONE:
 209            self.error = struct.unpack("i", self.raw[0:4])[0]
 210            self.done = 1
 211            extack_off = 4
 212
 213        self.extack = None
 214        if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
 215            self.extack = dict()
 216            extack_attrs = NlAttrs(self.raw[extack_off:])
 217            for extack in extack_attrs:
 218                if extack.type == Netlink.NLMSGERR_ATTR_MSG:
 219                    self.extack['msg'] = extack.as_strz()
 220                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
 221                    self.extack['miss-type'] = extack.as_scalar('u32')
 222                elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
 223                    self.extack['miss-nest'] = extack.as_scalar('u32')
 224                elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
 225                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
 226                elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
 227                    self.extack['policy'] = self._decode_policy(extack.raw)
 228                else:
 229                    if 'unknown' not in self.extack:
 230                        self.extack['unknown'] = []
 231                    self.extack['unknown'].append(extack)
 232
 233            if attr_space:
 234                # We don't have the ability to parse nests yet, so only do global
 235                if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
 236                    miss_type = self.extack['miss-type']
 237                    if miss_type in attr_space.attrs_by_val:
 238                        spec = attr_space.attrs_by_val[miss_type]
 239                        self.extack['miss-type'] = spec['name']
 240                        if 'doc' in spec:
 241                            self.extack['miss-type-doc'] = spec['doc']
 242
 243    def _decode_policy(self, raw):
 244        policy = {}
 245        for attr in NlAttrs(raw):
 246            if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
 247                type = attr.as_scalar('u32')
 248                policy['type'] = Netlink.AttrType(type).name
 249            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
 250                policy['min-value'] = attr.as_scalar('s64')
 251            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
 252                policy['max-value'] = attr.as_scalar('s64')
 253            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
 254                policy['min-value'] = attr.as_scalar('u64')
 255            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
 256                policy['max-value'] = attr.as_scalar('u64')
 257            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
 258                policy['min-length'] = attr.as_scalar('u32')
 259            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
 260                policy['max-length'] = attr.as_scalar('u32')
 261            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
 262                policy['bitfield32-mask'] = attr.as_scalar('u32')
 263            elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
 264                policy['mask'] = attr.as_scalar('u64')
 265        return policy
 266
 267    def cmd(self):
 268        return self.nl_type
 269
 270    def __repr__(self):
 271        msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
 272        if self.error:
 273            msg += '\n\terror: ' + str(self.error)
 274        if self.extack:
 275            msg += '\n\textack: ' + repr(self.extack)
 276        return msg
 277
 278
 279class NlMsgs:
 280    def __init__(self, data, attr_space=None):
 281        self.msgs = []
 282
 283        offset = 0
 284        while offset < len(data):
 285            msg = NlMsg(data, offset, attr_space=attr_space)
 286            offset += msg.nl_len
 287            self.msgs.append(msg)
 288
 289    def __iter__(self):
 290        yield from self.msgs
 291
 292
 293genl_family_name_to_id = None
 294
 295
 296def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
 297    # we prepend length in _genl_msg_finalize()
 298    if seq is None:
 299        seq = random.randint(1, 1024)
 300    nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
 301    genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
 302    return nlmsg + genlmsg
 303
 304
 305def _genl_msg_finalize(msg):
 306    return struct.pack("I", len(msg) + 4) + msg
 307
 308
 309def _genl_load_families():
 310    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
 311        sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
 312
 313        msg = _genl_msg(Netlink.GENL_ID_CTRL,
 314                        Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
 315                        Netlink.CTRL_CMD_GETFAMILY, 1)
 316        msg = _genl_msg_finalize(msg)
 317
 318        sock.send(msg, 0)
 319
 320        global genl_family_name_to_id
 321        genl_family_name_to_id = dict()
 322
 323        while True:
 324            reply = sock.recv(128 * 1024)
 325            nms = NlMsgs(reply)
 326            for nl_msg in nms:
 327                if nl_msg.error:
 328                    print("Netlink error:", nl_msg.error)
 329                    return
 330                if nl_msg.done:
 331                    return
 332
 333                gm = GenlMsg(nl_msg)
 334                fam = dict()
 335                for attr in NlAttrs(gm.raw):
 336                    if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
 337                        fam['id'] = attr.as_scalar('u16')
 338                    elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
 339                        fam['name'] = attr.as_strz()
 340                    elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
 341                        fam['maxattr'] = attr.as_scalar('u32')
 342                    elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
 343                        fam['mcast'] = dict()
 344                        for entry in NlAttrs(attr.raw):
 345                            mcast_name = None
 346                            mcast_id = None
 347                            for entry_attr in NlAttrs(entry.raw):
 348                                if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
 349                                    mcast_name = entry_attr.as_strz()
 350                                elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
 351                                    mcast_id = entry_attr.as_scalar('u32')
 352                            if mcast_name and mcast_id is not None:
 353                                fam['mcast'][mcast_name] = mcast_id
 354                if 'name' in fam and 'id' in fam:
 355                    genl_family_name_to_id[fam['name']] = fam
 356
 357
 358class GenlMsg:
 359    def __init__(self, nl_msg):
 360        self.nl = nl_msg
 361        self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
 362        self.raw = nl_msg.raw[4:]
 363
 364    def cmd(self):
 365        return self.genl_cmd
 366
 367    def __repr__(self):
 368        msg = repr(self.nl)
 369        msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
 370        for a in self.raw_attrs:
 371            msg += '\t\t' + repr(a) + '\n'
 372        return msg
 373
 374
 375class NetlinkProtocol:
 376    def __init__(self, family_name, proto_num):
 377        self.family_name = family_name
 378        self.proto_num = proto_num
 379
 380    def _message(self, nl_type, nl_flags, seq=None):
 381        if seq is None:
 382            seq = random.randint(1, 1024)
 383        nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
 384        return nlmsg
 385
 386    def message(self, flags, command, version, seq=None):
 387        return self._message(command, flags, seq)
 388
 389    def _decode(self, nl_msg):
 390        return nl_msg
 391
 392    def decode(self, ynl, nl_msg, op):
 393        msg = self._decode(nl_msg)
 394        if op is None:
 395            op = ynl.rsp_by_value[msg.cmd()]
 396        fixed_header_size = ynl._struct_size(op.fixed_header)
 397        msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
 398        return msg
 399
 400    def get_mcast_id(self, mcast_name, mcast_groups):
 401        if mcast_name not in mcast_groups:
 402            raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
 403        return mcast_groups[mcast_name].value
 404
 405    def msghdr_size(self):
 406        return 16
 407
 408
 409class GenlProtocol(NetlinkProtocol):
 410    def __init__(self, family_name):
 411        super().__init__(family_name, Netlink.NETLINK_GENERIC)
 412
 413        global genl_family_name_to_id
 414        if genl_family_name_to_id is None:
 415            _genl_load_families()
 416
 417        self.genl_family = genl_family_name_to_id[family_name]
 418        self.family_id = genl_family_name_to_id[family_name]['id']
 419
 420    def message(self, flags, command, version, seq=None):
 421        nlmsg = self._message(self.family_id, flags, seq)
 422        genlmsg = struct.pack("BBH", command, version, 0)
 423        return nlmsg + genlmsg
 424
 425    def _decode(self, nl_msg):
 426        return GenlMsg(nl_msg)
 427
 428    def get_mcast_id(self, mcast_name, mcast_groups):
 429        if mcast_name not in self.genl_family['mcast']:
 430            raise Exception(f'Multicast group "{mcast_name}" not present in the family')
 431        return self.genl_family['mcast'][mcast_name]
 432
 433    def msghdr_size(self):
 434        return super().msghdr_size() + 4
 435
 436
 437class SpaceAttrs:
 438    SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
 439
 440    def __init__(self, attr_space, attrs, outer = None):
 441        outer_scopes = outer.scopes if outer else []
 442        inner_scope = self.SpecValuesPair(attr_space, attrs)
 443        self.scopes = [inner_scope] + outer_scopes
 444
 445    def lookup(self, name):
 446        for scope in self.scopes:
 447            if name in scope.spec:
 448                if name in scope.values:
 449                    return scope.values[name]
 450                spec_name = scope.spec.yaml['name']
 451                raise Exception(
 452                    f"No value for '{name}' in attribute space '{spec_name}'")
 453        raise Exception(f"Attribute '{name}' not defined in any attribute-set")
 454
 455
 456#
 457# YNL implementation details.
 458#
 459
 460
 461class YnlFamily(SpecFamily):
 462    def __init__(self, def_path, schema=None, process_unknown=False,
 463                 recv_size=0):
 464        super().__init__(def_path, schema)
 465
 466        self.include_raw = False
 467        self.process_unknown = process_unknown
 468
 469        try:
 470            if self.proto == "netlink-raw":
 471                self.nlproto = NetlinkProtocol(self.yaml['name'],
 472                                               self.yaml['protonum'])
 473            else:
 474                self.nlproto = GenlProtocol(self.yaml['name'])
 475        except KeyError:
 476            raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
 477
 478        self._recv_dbg = False
 479        # Note that netlink will use conservative (min) message size for
 480        # the first dump recv() on the socket, our setting will only matter
 481        # from the second recv() on.
 482        self._recv_size = recv_size if recv_size else 131072
 483        # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
 484        # for a message, so smaller receive sizes will lead to truncation.
 485        # Note that the min size for other families may be larger than 4k!
 486        if self._recv_size < 4000:
 487            raise ConfigError()
 488
 489        self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
 490        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
 491        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
 492        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
 493
 494        self.async_msg_ids = set()
 495        self.async_msg_queue = queue.Queue()
 496
 497        for msg in self.msgs.values():
 498            if msg.is_async:
 499                self.async_msg_ids.add(msg.rsp_value)
 500
 501        for op_name, op in self.ops.items():
 502            bound_f = functools.partial(self._op, op_name)
 503            setattr(self, op.ident_name, bound_f)
 504
 505
 506    def ntf_subscribe(self, mcast_name):
 507        mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
 508        self.sock.bind((0, 0))
 509        self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
 510                             mcast_id)
 511
 512    def set_recv_dbg(self, enabled):
 513        self._recv_dbg = enabled
 514
 515    def _recv_dbg_print(self, reply, nl_msgs):
 516        if not self._recv_dbg:
 517            return
 518        print("Recv: read", len(reply), "bytes,",
 519              len(nl_msgs.msgs), "messages", file=sys.stderr)
 520        for nl_msg in nl_msgs:
 521            print("  ", nl_msg, file=sys.stderr)
 522
 523    def _encode_enum(self, attr_spec, value):
 524        enum = self.consts[attr_spec['enum']]
 525        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
 526            scalar = 0
 527            if isinstance(value, str):
 528                value = [value]
 529            for single_value in value:
 530                scalar += enum.entries[single_value].user_value(as_flags = True)
 531            return scalar
 532        else:
 533            return enum.entries[value].user_value()
 534
 535    def _get_scalar(self, attr_spec, value):
 536        try:
 537            return int(value)
 538        except (ValueError, TypeError) as e:
 539            if 'enum' not in attr_spec:
 540                raise e
 541        return self._encode_enum(attr_spec, value)
 542
 543    def _add_attr(self, space, name, value, search_attrs):
 544        try:
 545            attr = self.attr_sets[space][name]
 546        except KeyError:
 547            raise Exception(f"Space '{space}' has no attribute '{name}'")
 548        nl_type = attr.value
 549
 550        if attr.is_multi and isinstance(value, list):
 551            attr_payload = b''
 552            for subvalue in value:
 553                attr_payload += self._add_attr(space, name, subvalue, search_attrs)
 554            return attr_payload
 555
 556        if attr["type"] == 'nest':
 557            nl_type |= Netlink.NLA_F_NESTED
 558            attr_payload = b''
 559            sub_space = attr['nested-attributes']
 560            sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
 561            for subname, subvalue in value.items():
 562                attr_payload += self._add_attr(sub_space, subname, subvalue, sub_attrs)
 563        elif attr["type"] == 'flag':
 564            if not value:
 565                # If value is absent or false then skip attribute creation.
 566                return b''
 567            attr_payload = b''
 568        elif attr["type"] == 'string':
 569            attr_payload = str(value).encode('ascii') + b'\x00'
 570        elif attr["type"] == 'binary':
 571            if isinstance(value, bytes):
 572                attr_payload = value
 573            elif isinstance(value, str):
 574                attr_payload = bytes.fromhex(value)
 575            elif isinstance(value, dict) and attr.struct_name:
 576                attr_payload = self._encode_struct(attr.struct_name, value)
 577            else:
 578                raise Exception(f'Unknown type for binary attribute, value: {value}')
 579        elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
 580            scalar = self._get_scalar(attr, value)
 581            if attr.is_auto_scalar:
 582                attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
 583            else:
 584                attr_type = attr["type"]
 585            format = NlAttr.get_format(attr_type, attr.byte_order)
 586            attr_payload = format.pack(scalar)
 587        elif attr['type'] in "bitfield32":
 588            scalar_value = self._get_scalar(attr, value["value"])
 589            scalar_selector = self._get_scalar(attr, value["selector"])
 590            attr_payload = struct.pack("II", scalar_value, scalar_selector)
 591        elif attr['type'] == 'sub-message':
 592            msg_format = self._resolve_selector(attr, search_attrs)
 593            attr_payload = b''
 594            if msg_format.fixed_header:
 595                attr_payload += self._encode_struct(msg_format.fixed_header, value)
 596            if msg_format.attr_set:
 597                if msg_format.attr_set in self.attr_sets:
 598                    nl_type |= Netlink.NLA_F_NESTED
 599                    sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
 600                    for subname, subvalue in value.items():
 601                        attr_payload += self._add_attr(msg_format.attr_set,
 602                                                       subname, subvalue, sub_attrs)
 603                else:
 604                    raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
 605        else:
 606            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
 607
 608        pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
 609        return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
 610
 611    def _decode_enum(self, raw, attr_spec):
 612        enum = self.consts[attr_spec['enum']]
 613        if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
 614            i = 0
 615            value = set()
 616            while raw:
 617                if raw & 1:
 618                    value.add(enum.entries_by_val[i].name)
 619                raw >>= 1
 620                i += 1
 621        else:
 622            value = enum.entries_by_val[raw].name
 623        return value
 624
 625    def _decode_binary(self, attr, attr_spec):
 626        if attr_spec.struct_name:
 627            decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
 628        elif attr_spec.sub_type:
 629            decoded = attr.as_c_array(attr_spec.sub_type)
 630        else:
 631            decoded = attr.as_bin()
 632            if attr_spec.display_hint:
 633                decoded = self._formatted_string(decoded, attr_spec.display_hint)
 634        return decoded
 635
 636    def _decode_array_attr(self, attr, attr_spec):
 637        decoded = []
 638        offset = 0
 639        while offset < len(attr.raw):
 640            item = NlAttr(attr.raw, offset)
 641            offset += item.full_len
 642
 643            if attr_spec["sub-type"] == 'nest':
 644                subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
 645                decoded.append({ item.type: subattrs })
 646            elif attr_spec["sub-type"] == 'binary':
 647                subattrs = item.as_bin()
 648                if attr_spec.display_hint:
 649                    subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
 650                decoded.append(subattrs)
 651            elif attr_spec["sub-type"] in NlAttr.type_formats:
 652                subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
 653                if attr_spec.display_hint:
 654                    subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
 655                decoded.append(subattrs)
 656            else:
 657                raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
 658        return decoded
 659
 660    def _decode_nest_type_value(self, attr, attr_spec):
 661        decoded = {}
 662        value = attr
 663        for name in attr_spec['type-value']:
 664            value = NlAttr(value.raw, 0)
 665            decoded[name] = value.type
 666        subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
 667        decoded.update(subattrs)
 668        return decoded
 669
 670    def _decode_unknown(self, attr):
 671        if attr.is_nest:
 672            return self._decode(NlAttrs(attr.raw), None)
 673        else:
 674            return attr.as_bin()
 675
 676    def _rsp_add(self, rsp, name, is_multi, decoded):
 677        if is_multi == None:
 678            if name in rsp and type(rsp[name]) is not list:
 679                rsp[name] = [rsp[name]]
 680                is_multi = True
 681            else:
 682                is_multi = False
 683
 684        if not is_multi:
 685            rsp[name] = decoded
 686        elif name in rsp:
 687            rsp[name].append(decoded)
 688        else:
 689            rsp[name] = [decoded]
 690
 691    def _resolve_selector(self, attr_spec, search_attrs):
 692        sub_msg = attr_spec.sub_message
 693        if sub_msg not in self.sub_msgs:
 694            raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
 695        sub_msg_spec = self.sub_msgs[sub_msg]
 696
 697        selector = attr_spec.selector
 698        value = search_attrs.lookup(selector)
 699        if value not in sub_msg_spec.formats:
 700            raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
 701
 702        spec = sub_msg_spec.formats[value]
 703        return spec
 704
 705    def _decode_sub_msg(self, attr, attr_spec, search_attrs):
 706        msg_format = self._resolve_selector(attr_spec, search_attrs)
 707        decoded = {}
 708        offset = 0
 709        if msg_format.fixed_header:
 710            decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
 711            offset = self._struct_size(msg_format.fixed_header)
 712        if msg_format.attr_set:
 713            if msg_format.attr_set in self.attr_sets:
 714                subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
 715                decoded.update(subdict)
 716            else:
 717                raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
 718        return decoded
 719
 720    def _decode(self, attrs, space, outer_attrs = None):
 721        rsp = dict()
 722        if space:
 723            attr_space = self.attr_sets[space]
 724            search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
 725
 726        for attr in attrs:
 727            try:
 728                attr_spec = attr_space.attrs_by_val[attr.type]
 729            except (KeyError, UnboundLocalError):
 730                if not self.process_unknown:
 731                    raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
 732                attr_name = f"UnknownAttr({attr.type})"
 733                self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
 734                continue
 735
 736            if attr_spec["type"] == 'nest':
 737                subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
 738                decoded = subdict
 739            elif attr_spec["type"] == 'string':
 740                decoded = attr.as_strz()
 741            elif attr_spec["type"] == 'binary':
 742                decoded = self._decode_binary(attr, attr_spec)
 743            elif attr_spec["type"] == 'flag':
 744                decoded = True
 745            elif attr_spec.is_auto_scalar:
 746                decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
 747            elif attr_spec["type"] in NlAttr.type_formats:
 748                decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
 749                if 'enum' in attr_spec:
 750                    decoded = self._decode_enum(decoded, attr_spec)
 751                elif attr_spec.display_hint:
 752                    decoded = self._formatted_string(decoded, attr_spec.display_hint)
 753            elif attr_spec["type"] == 'indexed-array':
 754                decoded = self._decode_array_attr(attr, attr_spec)
 755            elif attr_spec["type"] == 'bitfield32':
 756                value, selector = struct.unpack("II", attr.raw)
 757                if 'enum' in attr_spec:
 758                    value = self._decode_enum(value, attr_spec)
 759                    selector = self._decode_enum(selector, attr_spec)
 760                decoded = {"value": value, "selector": selector}
 761            elif attr_spec["type"] == 'sub-message':
 762                decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
 763            elif attr_spec["type"] == 'nest-type-value':
 764                decoded = self._decode_nest_type_value(attr, attr_spec)
 765            else:
 766                if not self.process_unknown:
 767                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
 768                decoded = self._decode_unknown(attr)
 769
 770            self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
 771
 772        return rsp
 773
 774    def _decode_extack_path(self, attrs, attr_set, offset, target):
 775        for attr in attrs:
 776            try:
 777                attr_spec = attr_set.attrs_by_val[attr.type]
 778            except KeyError:
 779                raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
 780            if offset > target:
 781                break
 782            if offset == target:
 783                return '.' + attr_spec.name
 784
 785            if offset + attr.full_len <= target:
 786                offset += attr.full_len
 787                continue
 788            if attr_spec['type'] != 'nest':
 789                raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
 790            offset += 4
 791            subpath = self._decode_extack_path(NlAttrs(attr.raw),
 792                                               self.attr_sets[attr_spec['nested-attributes']],
 793                                               offset, target)
 794            if subpath is None:
 795                return None
 796            return '.' + attr_spec.name + subpath
 797
 798        return None
 799
 800    def _decode_extack(self, request, op, extack):
 801        if 'bad-attr-offs' not in extack:
 802            return
 803
 804        msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
 805        offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
 806        path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
 807                                        extack['bad-attr-offs'])
 808        if path:
 809            del extack['bad-attr-offs']
 810            extack['bad-attr'] = path
 811
 812    def _struct_size(self, name):
 813        if name:
 814            members = self.consts[name].members
 815            size = 0
 816            for m in members:
 817                if m.type in ['pad', 'binary']:
 818                    if m.struct:
 819                        size += self._struct_size(m.struct)
 820                    else:
 821                        size += m.len
 822                else:
 823                    format = NlAttr.get_format(m.type, m.byte_order)
 824                    size += format.size
 825            return size
 826        else:
 827            return 0
 828
 829    def _decode_struct(self, data, name):
 830        members = self.consts[name].members
 831        attrs = dict()
 832        offset = 0
 833        for m in members:
 834            value = None
 835            if m.type == 'pad':
 836                offset += m.len
 837            elif m.type == 'binary':
 838                if m.struct:
 839                    len = self._struct_size(m.struct)
 840                    value = self._decode_struct(data[offset : offset + len],
 841                                                m.struct)
 842                    offset += len
 843                else:
 844                    value = data[offset : offset + m.len]
 845                    offset += m.len
 846            else:
 847                format = NlAttr.get_format(m.type, m.byte_order)
 848                [ value ] = format.unpack_from(data, offset)
 849                offset += format.size
 850            if value is not None:
 851                if m.enum:
 852                    value = self._decode_enum(value, m)
 853                elif m.display_hint:
 854                    value = self._formatted_string(value, m.display_hint)
 855                attrs[m.name] = value
 856        return attrs
 857
 858    def _encode_struct(self, name, vals):
 859        members = self.consts[name].members
 860        attr_payload = b''
 861        for m in members:
 862            value = vals.pop(m.name) if m.name in vals else None
 863            if m.type == 'pad':
 864                attr_payload += bytearray(m.len)
 865            elif m.type == 'binary':
 866                if m.struct:
 867                    if value is None:
 868                        value = dict()
 869                    attr_payload += self._encode_struct(m.struct, value)
 870                else:
 871                    if value is None:
 872                        attr_payload += bytearray(m.len)
 873                    else:
 874                        attr_payload += bytes.fromhex(value)
 875            else:
 876                if value is None:
 877                    value = 0
 878                format = NlAttr.get_format(m.type, m.byte_order)
 879                attr_payload += format.pack(value)
 880        return attr_payload
 881
 882    def _formatted_string(self, raw, display_hint):
 883        if display_hint == 'mac':
 884            formatted = ':'.join('%02x' % b for b in raw)
 885        elif display_hint == 'hex':
 886            if isinstance(raw, int):
 887                formatted = hex(raw)
 888            else:
 889                formatted = bytes.hex(raw, ' ')
 890        elif display_hint in [ 'ipv4', 'ipv6' ]:
 891            formatted = format(ipaddress.ip_address(raw))
 892        elif display_hint == 'uuid':
 893            formatted = str(uuid.UUID(bytes=raw))
 894        else:
 895            formatted = raw
 896        return formatted
 897
 898    def handle_ntf(self, decoded):
 899        msg = dict()
 900        if self.include_raw:
 901            msg['raw'] = decoded
 902        op = self.rsp_by_value[decoded.cmd()]
 903        attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
 904        if op.fixed_header:
 905            attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
 906
 907        msg['name'] = op['name']
 908        msg['msg'] = attrs
 909        self.async_msg_queue.put(msg)
 910
 911    def check_ntf(self):
 912        while True:
 913            try:
 914                reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
 915            except BlockingIOError:
 916                return
 917
 918            nms = NlMsgs(reply)
 919            self._recv_dbg_print(reply, nms)
 920            for nl_msg in nms:
 921                if nl_msg.error:
 922                    print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
 923                    print(nl_msg)
 924                    continue
 925                if nl_msg.done:
 926                    print("Netlink done while checking for ntf!?")
 927                    continue
 928
 929                decoded = self.nlproto.decode(self, nl_msg, None)
 930                if decoded.cmd() not in self.async_msg_ids:
 931                    print("Unexpected msg id while checking for ntf", decoded)
 932                    continue
 933
 934                self.handle_ntf(decoded)
 935
 936    def poll_ntf(self, duration=None):
 937        start_time = time.time()
 938        selector = selectors.DefaultSelector()
 939        selector.register(self.sock, selectors.EVENT_READ)
 940
 941        while True:
 942            try:
 943                yield self.async_msg_queue.get_nowait()
 944            except queue.Empty:
 945                if duration is not None:
 946                    timeout = start_time + duration - time.time()
 947                    if timeout <= 0:
 948                        return
 949                else:
 950                    timeout = None
 951                events = selector.select(timeout)
 952                if events:
 953                    self.check_ntf()
 954
 955    def operation_do_attributes(self, name):
 956      """
 957      For a given operation name, find and return a supported
 958      set of attributes (as a dict).
 959      """
 960      op = self.find_operation(name)
 961      if not op:
 962        return None
 963
 964      return op['do']['request']['attributes'].copy()
 965
 966    def _encode_message(self, op, vals, flags, req_seq):
 967        nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
 968        for flag in flags or []:
 969            nl_flags |= flag
 970
 971        msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
 972        if op.fixed_header:
 973            msg += self._encode_struct(op.fixed_header, vals)
 974        search_attrs = SpaceAttrs(op.attr_set, vals)
 975        for name, value in vals.items():
 976            msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
 977        msg = _genl_msg_finalize(msg)
 978        return msg
 979
 980    def _ops(self, ops):
 981        reqs_by_seq = {}
 982        req_seq = random.randint(1024, 65535)
 983        payload = b''
 984        for (method, vals, flags) in ops:
 985            op = self.ops[method]
 986            msg = self._encode_message(op, vals, flags, req_seq)
 987            reqs_by_seq[req_seq] = (op, msg, flags)
 988            payload += msg
 989            req_seq += 1
 990
 991        self.sock.send(payload, 0)
 992
 993        done = False
 994        rsp = []
 995        op_rsp = []
 996        while not done:
 997            reply = self.sock.recv(self._recv_size)
 998            nms = NlMsgs(reply, attr_space=op.attr_set)
 999            self._recv_dbg_print(reply, nms)
1000            for nl_msg in nms:
1001                if nl_msg.nl_seq in reqs_by_seq:
1002                    (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1003                    if nl_msg.extack:
1004                        self._decode_extack(req_msg, op, nl_msg.extack)
1005                else:
1006                    op = None
1007                    req_flags = []
1008
1009                if nl_msg.error:
1010                    raise NlError(nl_msg)
1011                if nl_msg.done:
1012                    if nl_msg.extack:
1013                        print("Netlink warning:")
1014                        print(nl_msg)
1015
1016                    if Netlink.NLM_F_DUMP in req_flags:
1017                        rsp.append(op_rsp)
1018                    elif not op_rsp:
1019                        rsp.append(None)
1020                    elif len(op_rsp) == 1:
1021                        rsp.append(op_rsp[0])
1022                    else:
1023                        rsp.append(op_rsp)
1024                    op_rsp = []
1025
1026                    del reqs_by_seq[nl_msg.nl_seq]
1027                    done = len(reqs_by_seq) == 0
1028                    break
1029
1030                decoded = self.nlproto.decode(self, nl_msg, op)
1031
1032                # Check if this is a reply to our request
1033                if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1034                    if decoded.cmd() in self.async_msg_ids:
1035                        self.handle_ntf(decoded)
1036                        continue
1037                    else:
1038                        print('Unexpected message: ' + repr(decoded))
1039                        continue
1040
1041                rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1042                if op.fixed_header:
1043                    rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1044                op_rsp.append(rsp_msg)
1045
1046        return rsp
1047
1048    def _op(self, method, vals, flags=None, dump=False):
1049        req_flags = flags or []
1050        if dump:
1051            req_flags.append(Netlink.NLM_F_DUMP)
1052
1053        ops = [(method, vals, req_flags)]
1054        return self._ops(ops)[0]
1055
1056    def do(self, method, vals, flags=None):
1057        return self._op(method, vals, flags)
1058
1059    def dump(self, method, vals):
1060        return self._op(method, vals, dump=True)
1061
1062    def do_multi(self, ops):
1063        return self._ops(ops)