Linux Audio

Check our new training course

Loading...
Note: File does not exist in v6.2.
   1#!/usr/bin/env python3
   2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
   3
   4import argparse
   5import collections
   6import filecmp
   7import os
   8import re
   9import shutil
  10import tempfile
  11import yaml
  12
  13from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
  14
  15
  16def c_upper(name):
  17    return name.upper().replace('-', '_')
  18
  19
  20def c_lower(name):
  21    return name.lower().replace('-', '_')
  22
  23
  24def limit_to_number(name):
  25    """
  26    Turn a string limit like u32-max or s64-min into its numerical value
  27    """
  28    if name[0] == 'u' and name.endswith('-min'):
  29        return 0
  30    width = int(name[1:-4])
  31    if name[0] == 's':
  32        width -= 1
  33    value = (1 << width) - 1
  34    if name[0] == 's' and name.endswith('-min'):
  35        value = -value - 1
  36    return value
  37
  38
  39class BaseNlLib:
  40    def get_family_id(self):
  41        return 'ys->family_id'
  42
  43    def parse_cb_run(self, cb, data, is_dump=False, indent=1):
  44        ind = '\n\t\t' + '\t' * indent + ' '
  45        if is_dump:
  46            return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
  47        else:
  48            return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
  49                   "ynl_cb_array, NLMSG_MIN_TYPE)"
  50
  51
  52class Type(SpecAttr):
  53    def __init__(self, family, attr_set, attr, value):
  54        super().__init__(family, attr_set, attr, value)
  55
  56        self.attr = attr
  57        self.attr_set = attr_set
  58        self.type = attr['type']
  59        self.checks = attr.get('checks', {})
  60
  61        self.request = False
  62        self.reply = False
  63
  64        if 'len' in attr:
  65            self.len = attr['len']
  66
  67        if 'nested-attributes' in attr:
  68            self.nested_attrs = attr['nested-attributes']
  69            if self.nested_attrs == family.name:
  70                self.nested_render_name = c_lower(f"{family.name}")
  71            else:
  72                self.nested_render_name = c_lower(f"{family.name}_{self.nested_attrs}")
  73
  74            if self.nested_attrs in self.family.consts:
  75                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
  76            else:
  77                self.nested_struct_type = 'struct ' + self.nested_render_name
  78
  79        self.c_name = c_lower(self.name)
  80        if self.c_name in _C_KW:
  81            self.c_name += '_'
  82
  83        # Added by resolve():
  84        self.enum_name = None
  85        delattr(self, "enum_name")
  86
  87    def get_limit(self, limit, default=None):
  88        value = self.checks.get(limit, default)
  89        if value is None:
  90            return value
  91        if not isinstance(value, int):
  92            value = limit_to_number(value)
  93        return value
  94
  95    def resolve(self):
  96        if 'name-prefix' in self.attr:
  97            enum_name = f"{self.attr['name-prefix']}{self.name}"
  98        else:
  99            enum_name = f"{self.attr_set.name_prefix}{self.name}"
 100        self.enum_name = c_upper(enum_name)
 101
 102    def is_multi_val(self):
 103        return None
 104
 105    def is_scalar(self):
 106        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
 107
 108    def is_recursive(self):
 109        return False
 110
 111    def is_recursive_for_op(self, ri):
 112        return self.is_recursive() and not ri.op
 113
 114    def presence_type(self):
 115        return 'bit'
 116
 117    def presence_member(self, space, type_filter):
 118        if self.presence_type() != type_filter:
 119            return
 120
 121        if self.presence_type() == 'bit':
 122            pfx = '__' if space == 'user' else ''
 123            return f"{pfx}u32 {self.c_name}:1;"
 124
 125        if self.presence_type() == 'len':
 126            pfx = '__' if space == 'user' else ''
 127            return f"{pfx}u32 {self.c_name}_len;"
 128
 129    def _complex_member_type(self, ri):
 130        return None
 131
 132    def free_needs_iter(self):
 133        return False
 134
 135    def free(self, ri, var, ref):
 136        if self.is_multi_val() or self.presence_type() == 'len':
 137            ri.cw.p(f'free({var}->{ref}{self.c_name});')
 138
 139    def arg_member(self, ri):
 140        member = self._complex_member_type(ri)
 141        if member:
 142            arg = [member + ' *' + self.c_name]
 143            if self.presence_type() == 'count':
 144                arg += ['unsigned int n_' + self.c_name]
 145            return arg
 146        raise Exception(f"Struct member not implemented for class type {self.type}")
 147
 148    def struct_member(self, ri):
 149        if self.is_multi_val():
 150            ri.cw.p(f"unsigned int n_{self.c_name};")
 151        member = self._complex_member_type(ri)
 152        if member:
 153            ptr = '*' if self.is_multi_val() else ''
 154            if self.is_recursive_for_op(ri):
 155                ptr = '*'
 156            ri.cw.p(f"{member} {ptr}{self.c_name};")
 157            return
 158        members = self.arg_member(ri)
 159        for one in members:
 160            ri.cw.p(one + ';')
 161
 162    def _attr_policy(self, policy):
 163        return '{ .type = ' + policy + ', }'
 164
 165    def attr_policy(self, cw):
 166        policy = c_upper('nla-' + self.attr['type'])
 167
 168        spec = self._attr_policy(policy)
 169        cw.p(f"\t[{self.enum_name}] = {spec},")
 170
 171    def _mnl_type(self):
 172        # mnl does not have helpers for signed integer types
 173        # turn signed type into unsigned
 174        # this only makes sense for scalar types
 175        t = self.type
 176        if t[0] == 's':
 177            t = 'u' + t[1:]
 178        return t
 179
 180    def _attr_typol(self):
 181        raise Exception(f"Type policy not implemented for class type {self.type}")
 182
 183    def attr_typol(self, cw):
 184        typol = self._attr_typol()
 185        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
 186
 187    def _attr_put_line(self, ri, var, line):
 188        if self.presence_type() == 'bit':
 189            ri.cw.p(f"if ({var}->_present.{self.c_name})")
 190        elif self.presence_type() == 'len':
 191            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
 192        ri.cw.p(f"{line};")
 193
 194    def _attr_put_simple(self, ri, var, put_type):
 195        line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
 196        self._attr_put_line(ri, var, line)
 197
 198    def attr_put(self, ri, var):
 199        raise Exception(f"Put not implemented for class type {self.type}")
 200
 201    def _attr_get(self, ri, var):
 202        raise Exception(f"Attr get not implemented for class type {self.type}")
 203
 204    def attr_get(self, ri, var, first):
 205        lines, init_lines, local_vars = self._attr_get(ri, var)
 206        if type(lines) is str:
 207            lines = [lines]
 208        if type(init_lines) is str:
 209            init_lines = [init_lines]
 210
 211        kw = 'if' if first else 'else if'
 212        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
 213        if local_vars:
 214            for local in local_vars:
 215                ri.cw.p(local)
 216            ri.cw.nl()
 217
 218        if not self.is_multi_val():
 219            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
 220            ri.cw.p("return MNL_CB_ERROR;")
 221            if self.presence_type() == 'bit':
 222                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
 223
 224        if init_lines:
 225            ri.cw.nl()
 226            for line in init_lines:
 227                ri.cw.p(line)
 228
 229        for line in lines:
 230            ri.cw.p(line)
 231        ri.cw.block_end()
 232        return True
 233
 234    def _setter_lines(self, ri, member, presence):
 235        raise Exception(f"Setter not implemented for class type {self.type}")
 236
 237    def setter(self, ri, space, direction, deref=False, ref=None):
 238        ref = (ref if ref else []) + [self.c_name]
 239        var = "req"
 240        member = f"{var}->{'.'.join(ref)}"
 241
 242        code = []
 243        presence = ''
 244        for i in range(0, len(ref)):
 245            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
 246            if self.presence_type() == 'bit':
 247                code.append(presence + ' = 1;')
 248        code += self._setter_lines(ri, member, presence)
 249
 250        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
 251        free = bool([x for x in code if 'free(' in x])
 252        alloc = bool([x for x in code if 'alloc(' in x])
 253        if free and not alloc:
 254            func_name = '__' + func_name
 255        ri.cw.write_func('static inline void', func_name, body=code,
 256                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
 257
 258
 259class TypeUnused(Type):
 260    def presence_type(self):
 261        return ''
 262
 263    def arg_member(self, ri):
 264        return []
 265
 266    def _attr_get(self, ri, var):
 267        return ['return MNL_CB_ERROR;'], None, None
 268
 269    def _attr_typol(self):
 270        return '.type = YNL_PT_REJECT, '
 271
 272    def attr_policy(self, cw):
 273        pass
 274
 275    def attr_put(self, ri, var):
 276        pass
 277
 278    def attr_get(self, ri, var, first):
 279        pass
 280
 281    def setter(self, ri, space, direction, deref=False, ref=None):
 282        pass
 283
 284
 285class TypePad(Type):
 286    def presence_type(self):
 287        return ''
 288
 289    def arg_member(self, ri):
 290        return []
 291
 292    def _attr_typol(self):
 293        return '.type = YNL_PT_IGNORE, '
 294
 295    def attr_put(self, ri, var):
 296        pass
 297
 298    def attr_get(self, ri, var, first):
 299        pass
 300
 301    def attr_policy(self, cw):
 302        pass
 303
 304    def setter(self, ri, space, direction, deref=False, ref=None):
 305        pass
 306
 307
 308class TypeScalar(Type):
 309    def __init__(self, family, attr_set, attr, value):
 310        super().__init__(family, attr_set, attr, value)
 311
 312        self.byte_order_comment = ''
 313        if 'byte-order' in attr:
 314            self.byte_order_comment = f" /* {attr['byte-order']} */"
 315
 316        if 'enum' in self.attr:
 317            enum = self.family.consts[self.attr['enum']]
 318            low, high = enum.value_range()
 319            if 'min' not in self.checks:
 320                if low != 0 or self.type[0] == 's':
 321                    self.checks['min'] = low
 322            if 'max' not in self.checks:
 323                self.checks['max'] = high
 324
 325        if 'min' in self.checks and 'max' in self.checks:
 326            if self.get_limit('min') > self.get_limit('max'):
 327                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
 328            self.checks['range'] = True
 329
 330        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
 331        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
 332        if low < 0 and self.type[0] == 'u':
 333            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
 334        if low < -32768 or high > 32767:
 335            self.checks['full-range'] = True
 336
 337        # Added by resolve():
 338        self.is_bitfield = None
 339        delattr(self, "is_bitfield")
 340        self.type_name = None
 341        delattr(self, "type_name")
 342
 343    def resolve(self):
 344        self.resolve_up(super())
 345
 346        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
 347            self.is_bitfield = True
 348        elif 'enum' in self.attr:
 349            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
 350        else:
 351            self.is_bitfield = False
 352
 353        if not self.is_bitfield and 'enum' in self.attr:
 354            self.type_name = self.family.consts[self.attr['enum']].user_type
 355        elif self.is_auto_scalar:
 356            self.type_name = '__' + self.type[0] + '64'
 357        else:
 358            self.type_name = '__' + self.type
 359
 360    def mnl_type(self):
 361        return self._mnl_type()
 362
 363    def _attr_policy(self, policy):
 364        if 'flags-mask' in self.checks or self.is_bitfield:
 365            if self.is_bitfield:
 366                enum = self.family.consts[self.attr['enum']]
 367                mask = enum.get_mask(as_flags=True)
 368            else:
 369                flags = self.family.consts[self.checks['flags-mask']]
 370                flag_cnt = len(flags['entries'])
 371                mask = (1 << flag_cnt) - 1
 372            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
 373        elif 'full-range' in self.checks:
 374            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
 375        elif 'range' in self.checks:
 376            return f"NLA_POLICY_RANGE({policy}, {self.get_limit('min')}, {self.get_limit('max')})"
 377        elif 'min' in self.checks:
 378            return f"NLA_POLICY_MIN({policy}, {self.get_limit('min')})"
 379        elif 'max' in self.checks:
 380            return f"NLA_POLICY_MAX({policy}, {self.get_limit('max')})"
 381        return super()._attr_policy(policy)
 382
 383    def _attr_typol(self):
 384        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
 385
 386    def arg_member(self, ri):
 387        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
 388
 389    def attr_put(self, ri, var):
 390        self._attr_put_simple(ri, var, self.mnl_type())
 391
 392    def _attr_get(self, ri, var):
 393        return f"{var}->{self.c_name} = mnl_attr_get_{self.mnl_type()}(attr);", None, None
 394
 395    def _setter_lines(self, ri, member, presence):
 396        return [f"{member} = {self.c_name};"]
 397
 398
 399class TypeFlag(Type):
 400    def arg_member(self, ri):
 401        return []
 402
 403    def _attr_typol(self):
 404        return '.type = YNL_PT_FLAG, '
 405
 406    def attr_put(self, ri, var):
 407        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
 408
 409    def _attr_get(self, ri, var):
 410        return [], None, None
 411
 412    def _setter_lines(self, ri, member, presence):
 413        return []
 414
 415
 416class TypeString(Type):
 417    def arg_member(self, ri):
 418        return [f"const char *{self.c_name}"]
 419
 420    def presence_type(self):
 421        return 'len'
 422
 423    def struct_member(self, ri):
 424        ri.cw.p(f"char *{self.c_name};")
 425
 426    def _attr_typol(self):
 427        return f'.type = YNL_PT_NUL_STR, '
 428
 429    def _attr_policy(self, policy):
 430        if 'exact-len' in self.checks:
 431            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
 432        else:
 433            mem = '{ .type = ' + policy
 434            if 'max-len' in self.checks:
 435                mem += ', .len = ' + str(self.get_limit('max-len'))
 436            mem += ', }'
 437        return mem
 438
 439    def attr_policy(self, cw):
 440        if self.checks.get('unterminated-ok', False):
 441            policy = 'NLA_STRING'
 442        else:
 443            policy = 'NLA_NUL_STRING'
 444
 445        spec = self._attr_policy(policy)
 446        cw.p(f"\t[{self.enum_name}] = {spec},")
 447
 448    def attr_put(self, ri, var):
 449        self._attr_put_simple(ri, var, 'strz')
 450
 451    def _attr_get(self, ri, var):
 452        len_mem = var + '->_present.' + self.c_name + '_len'
 453        return [f"{len_mem} = len;",
 454                f"{var}->{self.c_name} = malloc(len + 1);",
 455                f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
 456                f"{var}->{self.c_name}[len] = 0;"], \
 457               ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
 458               ['unsigned int len;']
 459
 460    def _setter_lines(self, ri, member, presence):
 461        return [f"free({member});",
 462                f"{presence}_len = strlen({self.c_name});",
 463                f"{member} = malloc({presence}_len + 1);",
 464                f'memcpy({member}, {self.c_name}, {presence}_len);',
 465                f'{member}[{presence}_len] = 0;']
 466
 467
 468class TypeBinary(Type):
 469    def arg_member(self, ri):
 470        return [f"const void *{self.c_name}", 'size_t len']
 471
 472    def presence_type(self):
 473        return 'len'
 474
 475    def struct_member(self, ri):
 476        ri.cw.p(f"void *{self.c_name};")
 477
 478    def _attr_typol(self):
 479        return f'.type = YNL_PT_BINARY,'
 480
 481    def _attr_policy(self, policy):
 482        if 'exact-len' in self.checks:
 483            mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
 484        else:
 485            mem = '{ '
 486            if len(self.checks) == 1 and 'min-len' in self.checks:
 487                mem += '.len = ' + str(self.get_limit('min-len'))
 488            elif len(self.checks) == 0:
 489                mem += '.type = NLA_BINARY'
 490            else:
 491                raise Exception('One or more of binary type checks not implemented, yet')
 492            mem += ', }'
 493        return mem
 494
 495    def attr_put(self, ri, var):
 496        self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
 497                            f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
 498
 499    def _attr_get(self, ri, var):
 500        len_mem = var + '->_present.' + self.c_name + '_len'
 501        return [f"{len_mem} = len;",
 502                f"{var}->{self.c_name} = malloc(len);",
 503                f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
 504               ['len = mnl_attr_get_payload_len(attr);'], \
 505               ['unsigned int len;']
 506
 507    def _setter_lines(self, ri, member, presence):
 508        return [f"free({member});",
 509                f"{presence}_len = len;",
 510                f"{member} = malloc({presence}_len);",
 511                f'memcpy({member}, {self.c_name}, {presence}_len);']
 512
 513
 514class TypeBitfield32(Type):
 515    def _complex_member_type(self, ri):
 516        return "struct nla_bitfield32"
 517
 518    def _attr_typol(self):
 519        return f'.type = YNL_PT_BITFIELD32, '
 520
 521    def _attr_policy(self, policy):
 522        if not 'enum' in self.attr:
 523            raise Exception('Enum required for bitfield32 attr')
 524        enum = self.family.consts[self.attr['enum']]
 525        mask = enum.get_mask(as_flags=True)
 526        return f"NLA_POLICY_BITFIELD32({mask})"
 527
 528    def attr_put(self, ri, var):
 529        line = f"mnl_attr_put(nlh, {self.enum_name}, sizeof(struct nla_bitfield32), &{var}->{self.c_name})"
 530        self._attr_put_line(ri, var, line)
 531
 532    def _attr_get(self, ri, var):
 533        return f"memcpy(&{var}->{self.c_name}, mnl_attr_get_payload(attr), sizeof(struct nla_bitfield32));", None, None
 534
 535    def _setter_lines(self, ri, member, presence):
 536        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
 537
 538
 539class TypeNest(Type):
 540    def is_recursive(self):
 541        return self.family.pure_nested_structs[self.nested_attrs].recursive
 542
 543    def _complex_member_type(self, ri):
 544        return self.nested_struct_type
 545
 546    def free(self, ri, var, ref):
 547        at = '&'
 548        if self.is_recursive_for_op(ri):
 549            at = ''
 550            ri.cw.p(f'if ({var}->{ref}{self.c_name})')
 551        ri.cw.p(f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});')
 552
 553    def _attr_typol(self):
 554        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
 555
 556    def _attr_policy(self, policy):
 557        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
 558
 559    def attr_put(self, ri, var):
 560        at = '' if self.is_recursive_for_op(ri) else '&'
 561        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
 562                            f"{self.enum_name}, {at}{var}->{self.c_name})")
 563
 564    def _attr_get(self, ri, var):
 565        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
 566                     "return MNL_CB_ERROR;"]
 567        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
 568                      f"parg.data = &{var}->{self.c_name};"]
 569        return get_lines, init_lines, None
 570
 571    def setter(self, ri, space, direction, deref=False, ref=None):
 572        ref = (ref if ref else []) + [self.c_name]
 573
 574        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
 575            if attr.is_recursive():
 576                continue
 577            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
 578
 579
 580class TypeMultiAttr(Type):
 581    def __init__(self, family, attr_set, attr, value, base_type):
 582        super().__init__(family, attr_set, attr, value)
 583
 584        self.base_type = base_type
 585
 586    def is_multi_val(self):
 587        return True
 588
 589    def presence_type(self):
 590        return 'count'
 591
 592    def mnl_type(self):
 593        return self._mnl_type()
 594
 595    def _complex_member_type(self, ri):
 596        if 'type' not in self.attr or self.attr['type'] == 'nest':
 597            return self.nested_struct_type
 598        elif self.attr['type'] in scalars:
 599            scalar_pfx = '__' if ri.ku_space == 'user' else ''
 600            return scalar_pfx + self.attr['type']
 601        else:
 602            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
 603
 604    def free_needs_iter(self):
 605        return 'type' not in self.attr or self.attr['type'] == 'nest'
 606
 607    def free(self, ri, var, ref):
 608        if self.attr['type'] in scalars:
 609            ri.cw.p(f"free({var}->{ref}{self.c_name});")
 610        elif 'type' not in self.attr or self.attr['type'] == 'nest':
 611            ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
 612            ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
 613            ri.cw.p(f"free({var}->{ref}{self.c_name});")
 614        else:
 615            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
 616
 617    def _attr_policy(self, policy):
 618        return self.base_type._attr_policy(policy)
 619
 620    def _attr_typol(self):
 621        return self.base_type._attr_typol()
 622
 623    def _attr_get(self, ri, var):
 624        return f'n_{self.c_name}++;', None, None
 625
 626    def attr_put(self, ri, var):
 627        if self.attr['type'] in scalars:
 628            put_type = self.mnl_type()
 629            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
 630            ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
 631        elif 'type' not in self.attr or self.attr['type'] == 'nest':
 632            ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
 633            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
 634                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
 635        else:
 636            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
 637
 638    def _setter_lines(self, ri, member, presence):
 639        # For multi-attr we have a count, not presence, hack up the presence
 640        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
 641        return [f"free({member});",
 642                f"{member} = {self.c_name};",
 643                f"{presence} = n_{self.c_name};"]
 644
 645
 646class TypeArrayNest(Type):
 647    def is_multi_val(self):
 648        return True
 649
 650    def presence_type(self):
 651        return 'count'
 652
 653    def _complex_member_type(self, ri):
 654        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
 655            return self.nested_struct_type
 656        elif self.attr['sub-type'] in scalars:
 657            scalar_pfx = '__' if ri.ku_space == 'user' else ''
 658            return scalar_pfx + self.attr['sub-type']
 659        else:
 660            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
 661
 662    def _attr_typol(self):
 663        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
 664
 665    def _attr_get(self, ri, var):
 666        local_vars = ['const struct nlattr *attr2;']
 667        get_lines = [f'attr_{self.c_name} = attr;',
 668                     'mnl_attr_for_each_nested(attr2, attr)',
 669                     f'\t{var}->n_{self.c_name}++;']
 670        return get_lines, None, local_vars
 671
 672
 673class TypeNestTypeValue(Type):
 674    def _complex_member_type(self, ri):
 675        return self.nested_struct_type
 676
 677    def _attr_typol(self):
 678        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
 679
 680    def _attr_get(self, ri, var):
 681        prev = 'attr'
 682        tv_args = ''
 683        get_lines = []
 684        local_vars = []
 685        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
 686                      f"parg.data = &{var}->{self.c_name};"]
 687        if 'type-value' in self.attr:
 688            tv_names = [c_lower(x) for x in self.attr["type-value"]]
 689            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
 690            local_vars += [f'__u32 {", ".join(tv_names)};']
 691            for level in self.attr["type-value"]:
 692                level = c_lower(level)
 693                get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
 694                get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
 695                prev = 'attr_' + level
 696
 697            tv_args = f", {', '.join(tv_names)}"
 698
 699        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
 700        return get_lines, init_lines, local_vars
 701
 702
 703class Struct:
 704    def __init__(self, family, space_name, type_list=None, inherited=None):
 705        self.family = family
 706        self.space_name = space_name
 707        self.attr_set = family.attr_sets[space_name]
 708        # Use list to catch comparisons with empty sets
 709        self._inherited = inherited if inherited is not None else []
 710        self.inherited = []
 711
 712        self.nested = type_list is None
 713        if family.name == c_lower(space_name):
 714            self.render_name = c_lower(family.name)
 715        else:
 716            self.render_name = c_lower(family.name + '-' + space_name)
 717        self.struct_name = 'struct ' + self.render_name
 718        if self.nested and space_name in family.consts:
 719            self.struct_name += '_'
 720        self.ptr_name = self.struct_name + ' *'
 721        # All attr sets this one contains, directly or multiple levels down
 722        self.child_nests = set()
 723
 724        self.request = False
 725        self.reply = False
 726        self.recursive = False
 727
 728        self.attr_list = []
 729        self.attrs = dict()
 730        if type_list is not None:
 731            for t in type_list:
 732                self.attr_list.append((t, self.attr_set[t]),)
 733        else:
 734            for t in self.attr_set:
 735                self.attr_list.append((t, self.attr_set[t]),)
 736
 737        max_val = 0
 738        self.attr_max_val = None
 739        for name, attr in self.attr_list:
 740            if attr.value >= max_val:
 741                max_val = attr.value
 742                self.attr_max_val = attr
 743            self.attrs[name] = attr
 744
 745    def __iter__(self):
 746        yield from self.attrs
 747
 748    def __getitem__(self, key):
 749        return self.attrs[key]
 750
 751    def member_list(self):
 752        return self.attr_list
 753
 754    def set_inherited(self, new_inherited):
 755        if self._inherited != new_inherited:
 756            raise Exception("Inheriting different members not supported")
 757        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
 758
 759
 760class EnumEntry(SpecEnumEntry):
 761    def __init__(self, enum_set, yaml, prev, value_start):
 762        super().__init__(enum_set, yaml, prev, value_start)
 763
 764        if prev:
 765            self.value_change = (self.value != prev.value + 1)
 766        else:
 767            self.value_change = (self.value != 0)
 768        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
 769
 770        # Added by resolve:
 771        self.c_name = None
 772        delattr(self, "c_name")
 773
 774    def resolve(self):
 775        self.resolve_up(super())
 776
 777        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
 778
 779
 780class EnumSet(SpecEnumSet):
 781    def __init__(self, family, yaml):
 782        self.render_name = c_lower(family.name + '-' + yaml['name'])
 783
 784        if 'enum-name' in yaml:
 785            if yaml['enum-name']:
 786                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
 787                self.user_type = self.enum_name
 788            else:
 789                self.enum_name = None
 790        else:
 791            self.enum_name = 'enum ' + self.render_name
 792
 793        if self.enum_name:
 794            self.user_type = self.enum_name
 795        else:
 796            self.user_type = 'int'
 797
 798        self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
 799
 800        super().__init__(family, yaml)
 801
 802    def new_entry(self, entry, prev_entry, value_start):
 803        return EnumEntry(self, entry, prev_entry, value_start)
 804
 805    def value_range(self):
 806        low = min([x.value for x in self.entries.values()])
 807        high = max([x.value for x in self.entries.values()])
 808
 809        if high - low + 1 != len(self.entries):
 810            raise Exception("Can't get value range for a noncontiguous enum")
 811
 812        return low, high
 813
 814
 815class AttrSet(SpecAttrSet):
 816    def __init__(self, family, yaml):
 817        super().__init__(family, yaml)
 818
 819        if self.subset_of is None:
 820            if 'name-prefix' in yaml:
 821                pfx = yaml['name-prefix']
 822            elif self.name == family.name:
 823                pfx = family.name + '-a-'
 824            else:
 825                pfx = f"{family.name}-a-{self.name}-"
 826            self.name_prefix = c_upper(pfx)
 827            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
 828            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
 829        else:
 830            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
 831            self.max_name = family.attr_sets[self.subset_of].max_name
 832            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
 833
 834        # Added by resolve:
 835        self.c_name = None
 836        delattr(self, "c_name")
 837
 838    def resolve(self):
 839        self.c_name = c_lower(self.name)
 840        if self.c_name in _C_KW:
 841            self.c_name += '_'
 842        if self.c_name == self.family.c_name:
 843            self.c_name = ''
 844
 845    def new_attr(self, elem, value):
 846        if elem['type'] in scalars:
 847            t = TypeScalar(self.family, self, elem, value)
 848        elif elem['type'] == 'unused':
 849            t = TypeUnused(self.family, self, elem, value)
 850        elif elem['type'] == 'pad':
 851            t = TypePad(self.family, self, elem, value)
 852        elif elem['type'] == 'flag':
 853            t = TypeFlag(self.family, self, elem, value)
 854        elif elem['type'] == 'string':
 855            t = TypeString(self.family, self, elem, value)
 856        elif elem['type'] == 'binary':
 857            t = TypeBinary(self.family, self, elem, value)
 858        elif elem['type'] == 'bitfield32':
 859            t = TypeBitfield32(self.family, self, elem, value)
 860        elif elem['type'] == 'nest':
 861            t = TypeNest(self.family, self, elem, value)
 862        elif elem['type'] == 'array-nest':
 863            t = TypeArrayNest(self.family, self, elem, value)
 864        elif elem['type'] == 'nest-type-value':
 865            t = TypeNestTypeValue(self.family, self, elem, value)
 866        else:
 867            raise Exception(f"No typed class for type {elem['type']}")
 868
 869        if 'multi-attr' in elem and elem['multi-attr']:
 870            t = TypeMultiAttr(self.family, self, elem, value, t)
 871
 872        return t
 873
 874
 875class Operation(SpecOperation):
 876    def __init__(self, family, yaml, req_value, rsp_value):
 877        super().__init__(family, yaml, req_value, rsp_value)
 878
 879        self.render_name = c_lower(family.name + '_' + self.name)
 880
 881        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
 882                         ('dump' in yaml and 'request' in yaml['dump'])
 883
 884        self.has_ntf = False
 885
 886        # Added by resolve:
 887        self.enum_name = None
 888        delattr(self, "enum_name")
 889
 890    def resolve(self):
 891        self.resolve_up(super())
 892
 893        if not self.is_async:
 894            self.enum_name = self.family.op_prefix + c_upper(self.name)
 895        else:
 896            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
 897
 898    def mark_has_ntf(self):
 899        self.has_ntf = True
 900
 901
 902class Family(SpecFamily):
 903    def __init__(self, file_name, exclude_ops):
 904        # Added by resolve:
 905        self.c_name = None
 906        delattr(self, "c_name")
 907        self.op_prefix = None
 908        delattr(self, "op_prefix")
 909        self.async_op_prefix = None
 910        delattr(self, "async_op_prefix")
 911        self.mcgrps = None
 912        delattr(self, "mcgrps")
 913        self.consts = None
 914        delattr(self, "consts")
 915        self.hooks = None
 916        delattr(self, "hooks")
 917
 918        super().__init__(file_name, exclude_ops=exclude_ops)
 919
 920        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
 921        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
 922
 923        if 'definitions' not in self.yaml:
 924            self.yaml['definitions'] = []
 925
 926        if 'uapi-header' in self.yaml:
 927            self.uapi_header = self.yaml['uapi-header']
 928        else:
 929            self.uapi_header = f"linux/{self.name}.h"
 930        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
 931            self.uapi_header_name = self.uapi_header[6:-2]
 932        else:
 933            self.uapi_header_name = self.name
 934
 935    def resolve(self):
 936        self.resolve_up(super())
 937
 938        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
 939            raise Exception("Codegen only supported for genetlink")
 940
 941        self.c_name = c_lower(self.name)
 942        if 'name-prefix' in self.yaml['operations']:
 943            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
 944        else:
 945            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
 946        if 'async-prefix' in self.yaml['operations']:
 947            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
 948        else:
 949            self.async_op_prefix = self.op_prefix
 950
 951        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
 952
 953        self.hooks = dict()
 954        for when in ['pre', 'post']:
 955            self.hooks[when] = dict()
 956            for op_mode in ['do', 'dump']:
 957                self.hooks[when][op_mode] = dict()
 958                self.hooks[when][op_mode]['set'] = set()
 959                self.hooks[when][op_mode]['list'] = []
 960
 961        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
 962        self.root_sets = dict()
 963        # dict space-name -> set('request', 'reply')
 964        self.pure_nested_structs = dict()
 965
 966        self._mark_notify()
 967        self._mock_up_events()
 968
 969        self._load_root_sets()
 970        self._load_nested_sets()
 971        self._load_attr_use()
 972        self._load_hooks()
 973
 974        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
 975        if self.kernel_policy == 'global':
 976            self._load_global_policy()
 977
 978    def new_enum(self, elem):
 979        return EnumSet(self, elem)
 980
 981    def new_attr_set(self, elem):
 982        return AttrSet(self, elem)
 983
 984    def new_operation(self, elem, req_value, rsp_value):
 985        return Operation(self, elem, req_value, rsp_value)
 986
 987    def _mark_notify(self):
 988        for op in self.msgs.values():
 989            if 'notify' in op:
 990                self.ops[op['notify']].mark_has_ntf()
 991
 992    # Fake a 'do' equivalent of all events, so that we can render their response parsing
 993    def _mock_up_events(self):
 994        for op in self.yaml['operations']['list']:
 995            if 'event' in op:
 996                op['do'] = {
 997                    'reply': {
 998                        'attributes': op['event']['attributes']
 999                    }
1000                }
1001
1002    def _load_root_sets(self):
1003        for op_name, op in self.msgs.items():
1004            if 'attribute-set' not in op:
1005                continue
1006
1007            req_attrs = set()
1008            rsp_attrs = set()
1009            for op_mode in ['do', 'dump']:
1010                if op_mode in op and 'request' in op[op_mode]:
1011                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1012                if op_mode in op and 'reply' in op[op_mode]:
1013                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1014            if 'event' in op:
1015                rsp_attrs.update(set(op['event']['attributes']))
1016
1017            if op['attribute-set'] not in self.root_sets:
1018                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1019            else:
1020                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1021                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1022
1023    def _sort_pure_types(self):
1024        # Try to reorder according to dependencies
1025        pns_key_list = list(self.pure_nested_structs.keys())
1026        pns_key_seen = set()
1027        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1028        for _ in range(rounds):
1029            if len(pns_key_list) == 0:
1030                break
1031            name = pns_key_list.pop(0)
1032            finished = True
1033            for _, spec in self.attr_sets[name].items():
1034                if 'nested-attributes' in spec:
1035                    nested = spec['nested-attributes']
1036                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1037                    if self.pure_nested_structs[nested].recursive:
1038                        continue
1039                    if nested not in pns_key_seen:
1040                        # Dicts are sorted, this will make struct last
1041                        struct = self.pure_nested_structs.pop(name)
1042                        self.pure_nested_structs[name] = struct
1043                        finished = False
1044                        break
1045            if finished:
1046                pns_key_seen.add(name)
1047            else:
1048                pns_key_list.append(name)
1049
1050    def _load_nested_sets(self):
1051        attr_set_queue = list(self.root_sets.keys())
1052        attr_set_seen = set(self.root_sets.keys())
1053
1054        while len(attr_set_queue):
1055            a_set = attr_set_queue.pop(0)
1056            for attr, spec in self.attr_sets[a_set].items():
1057                if 'nested-attributes' not in spec:
1058                    continue
1059
1060                nested = spec['nested-attributes']
1061                if nested not in attr_set_seen:
1062                    attr_set_queue.append(nested)
1063                    attr_set_seen.add(nested)
1064
1065                inherit = set()
1066                if nested not in self.root_sets:
1067                    if nested not in self.pure_nested_structs:
1068                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1069                else:
1070                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1071
1072                if 'type-value' in spec:
1073                    if nested in self.root_sets:
1074                        raise Exception("Inheriting members to a space used as root not supported")
1075                    inherit.update(set(spec['type-value']))
1076                elif spec['type'] == 'array-nest':
1077                    inherit.add('idx')
1078                self.pure_nested_structs[nested].set_inherited(inherit)
1079
1080        for root_set, rs_members in self.root_sets.items():
1081            for attr, spec in self.attr_sets[root_set].items():
1082                if 'nested-attributes' in spec:
1083                    nested = spec['nested-attributes']
1084                    if attr in rs_members['request']:
1085                        self.pure_nested_structs[nested].request = True
1086                    if attr in rs_members['reply']:
1087                        self.pure_nested_structs[nested].reply = True
1088
1089        self._sort_pure_types()
1090
1091        # Propagate the request / reply / recursive
1092        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1093            for _, spec in self.attr_sets[attr_set].items():
1094                if 'nested-attributes' in spec:
1095                    child_name = spec['nested-attributes']
1096                    struct.child_nests.add(child_name)
1097                    child = self.pure_nested_structs.get(child_name)
1098                    if child:
1099                        if not child.recursive:
1100                            struct.child_nests.update(child.child_nests)
1101                        child.request |= struct.request
1102                        child.reply |= struct.reply
1103                if attr_set in struct.child_nests:
1104                    struct.recursive = True
1105
1106        self._sort_pure_types()
1107
1108    def _load_attr_use(self):
1109        for _, struct in self.pure_nested_structs.items():
1110            if struct.request:
1111                for _, arg in struct.member_list():
1112                    arg.request = True
1113            if struct.reply:
1114                for _, arg in struct.member_list():
1115                    arg.reply = True
1116
1117        for root_set, rs_members in self.root_sets.items():
1118            for attr, spec in self.attr_sets[root_set].items():
1119                if attr in rs_members['request']:
1120                    spec.request = True
1121                if attr in rs_members['reply']:
1122                    spec.reply = True
1123
1124    def _load_global_policy(self):
1125        global_set = set()
1126        attr_set_name = None
1127        for op_name, op in self.ops.items():
1128            if not op:
1129                continue
1130            if 'attribute-set' not in op:
1131                continue
1132
1133            if attr_set_name is None:
1134                attr_set_name = op['attribute-set']
1135            if attr_set_name != op['attribute-set']:
1136                raise Exception('For a global policy all ops must use the same set')
1137
1138            for op_mode in ['do', 'dump']:
1139                if op_mode in op:
1140                    req = op[op_mode].get('request')
1141                    if req:
1142                        global_set.update(req.get('attributes', []))
1143
1144        self.global_policy = []
1145        self.global_policy_set = attr_set_name
1146        for attr in self.attr_sets[attr_set_name]:
1147            if attr in global_set:
1148                self.global_policy.append(attr)
1149
1150    def _load_hooks(self):
1151        for op in self.ops.values():
1152            for op_mode in ['do', 'dump']:
1153                if op_mode not in op:
1154                    continue
1155                for when in ['pre', 'post']:
1156                    if when not in op[op_mode]:
1157                        continue
1158                    name = op[op_mode][when]
1159                    if name in self.hooks[when][op_mode]['set']:
1160                        continue
1161                    self.hooks[when][op_mode]['set'].add(name)
1162                    self.hooks[when][op_mode]['list'].append(name)
1163
1164
1165class RenderInfo:
1166    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1167        self.family = family
1168        self.nl = cw.nlib
1169        self.ku_space = ku_space
1170        self.op_mode = op_mode
1171        self.op = op
1172
1173        self.fixed_hdr = None
1174        if op and op.fixed_header:
1175            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1176
1177        # 'do' and 'dump' response parsing is identical
1178        self.type_consistent = True
1179        if op_mode != 'do' and 'dump' in op:
1180            if 'do' in op:
1181                if ('reply' in op['do']) != ('reply' in op["dump"]):
1182                    self.type_consistent = False
1183                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1184                    self.type_consistent = False
1185            else:
1186                self.type_consistent = False
1187
1188        self.attr_set = attr_set
1189        if not self.attr_set:
1190            self.attr_set = op['attribute-set']
1191
1192        self.type_name_conflict = False
1193        if op:
1194            self.type_name = c_lower(op.name)
1195        else:
1196            self.type_name = c_lower(attr_set)
1197            if attr_set in family.consts:
1198                self.type_name_conflict = True
1199
1200        self.cw = cw
1201
1202        self.struct = dict()
1203        if op_mode == 'notify':
1204            op_mode = 'do'
1205        for op_dir in ['request', 'reply']:
1206            if op:
1207                type_list = []
1208                if op_dir in op[op_mode]:
1209                    type_list = op[op_mode][op_dir]['attributes']
1210                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1211        if op_mode == 'event':
1212            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1213
1214
1215class CodeWriter:
1216    def __init__(self, nlib, out_file=None, overwrite=True):
1217        self.nlib = nlib
1218        self._overwrite = overwrite
1219
1220        self._nl = False
1221        self._block_end = False
1222        self._silent_block = False
1223        self._ind = 0
1224        self._ifdef_block = None
1225        if out_file is None:
1226            self._out = os.sys.stdout
1227        else:
1228            self._out = tempfile.NamedTemporaryFile('w+')
1229            self._out_file = out_file
1230
1231    def __del__(self):
1232        self.close_out_file()
1233
1234    def close_out_file(self):
1235        if self._out == os.sys.stdout:
1236            return
1237        # Avoid modifying the file if contents didn't change
1238        self._out.flush()
1239        if not self._overwrite and os.path.isfile(self._out_file):
1240            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1241                return
1242        with open(self._out_file, 'w+') as out_file:
1243            self._out.seek(0)
1244            shutil.copyfileobj(self._out, out_file)
1245            self._out.close()
1246        self._out = os.sys.stdout
1247
1248    @classmethod
1249    def _is_cond(cls, line):
1250        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1251
1252    def p(self, line, add_ind=0):
1253        if self._block_end:
1254            self._block_end = False
1255            if line.startswith('else'):
1256                line = '} ' + line
1257            else:
1258                self._out.write('\t' * self._ind + '}\n')
1259
1260        if self._nl:
1261            self._out.write('\n')
1262            self._nl = False
1263
1264        ind = self._ind
1265        if line[-1] == ':':
1266            ind -= 1
1267        if self._silent_block:
1268            ind += 1
1269        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1270        if line[0] == '#':
1271            ind = 0
1272        if add_ind:
1273            ind += add_ind
1274        self._out.write('\t' * ind + line + '\n')
1275
1276    def nl(self):
1277        self._nl = True
1278
1279    def block_start(self, line=''):
1280        if line:
1281            line = line + ' '
1282        self.p(line + '{')
1283        self._ind += 1
1284
1285    def block_end(self, line=''):
1286        if line and line[0] not in {';', ','}:
1287            line = ' ' + line
1288        self._ind -= 1
1289        self._nl = False
1290        if not line:
1291            # Delay printing closing bracket in case "else" comes next
1292            if self._block_end:
1293                self._out.write('\t' * (self._ind + 1) + '}\n')
1294            self._block_end = True
1295        else:
1296            self.p('}' + line)
1297
1298    def write_doc_line(self, doc, indent=True):
1299        words = doc.split()
1300        line = ' *'
1301        for word in words:
1302            if len(line) + len(word) >= 79:
1303                self.p(line)
1304                line = ' *'
1305                if indent:
1306                    line += '  '
1307            line += ' ' + word
1308        self.p(line)
1309
1310    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1311        if not args:
1312            args = ['void']
1313
1314        if doc:
1315            self.p('/*')
1316            self.p(' * ' + doc)
1317            self.p(' */')
1318
1319        oneline = qual_ret
1320        if qual_ret[-1] != '*':
1321            oneline += ' '
1322        oneline += f"{name}({', '.join(args)}){suffix}"
1323
1324        if len(oneline) < 80:
1325            self.p(oneline)
1326            return
1327
1328        v = qual_ret
1329        if len(v) > 3:
1330            self.p(v)
1331            v = ''
1332        elif qual_ret[-1] != '*':
1333            v += ' '
1334        v += name + '('
1335        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1336        delta_ind = len(v) - len(ind)
1337        v += args[0]
1338        i = 1
1339        while i < len(args):
1340            next_len = len(v) + len(args[i])
1341            if v[0] == '\t':
1342                next_len += delta_ind
1343            if next_len > 76:
1344                self.p(v + ',')
1345                v = ind
1346            else:
1347                v += ', '
1348            v += args[i]
1349            i += 1
1350        self.p(v + ')' + suffix)
1351
1352    def write_func_lvar(self, local_vars):
1353        if not local_vars:
1354            return
1355
1356        if type(local_vars) is str:
1357            local_vars = [local_vars]
1358
1359        local_vars.sort(key=len, reverse=True)
1360        for var in local_vars:
1361            self.p(var)
1362        self.nl()
1363
1364    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1365        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1366        self.write_func_lvar(local_vars=local_vars)
1367
1368        self.block_start()
1369        for line in body:
1370            self.p(line)
1371        self.block_end()
1372
1373    def writes_defines(self, defines):
1374        longest = 0
1375        for define in defines:
1376            if len(define[0]) > longest:
1377                longest = len(define[0])
1378        longest = ((longest + 8) // 8) * 8
1379        for define in defines:
1380            line = '#define ' + define[0]
1381            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1382            if type(define[1]) is int:
1383                line += str(define[1])
1384            elif type(define[1]) is str:
1385                line += '"' + define[1] + '"'
1386            self.p(line)
1387
1388    def write_struct_init(self, members):
1389        longest = max([len(x[0]) for x in members])
1390        longest += 1  # because we prepend a .
1391        longest = ((longest + 8) // 8) * 8
1392        for one in members:
1393            line = '.' + one[0]
1394            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1395            line += '= ' + str(one[1]) + ','
1396            self.p(line)
1397
1398    def ifdef_block(self, config):
1399        config_option = None
1400        if config:
1401            config_option = 'CONFIG_' + c_upper(config)
1402        if self._ifdef_block == config_option:
1403            return
1404
1405        if self._ifdef_block:
1406            self.p('#endif /* ' + self._ifdef_block + ' */')
1407        if config_option:
1408            self.p('#ifdef ' + config_option)
1409        self._ifdef_block = config_option
1410
1411
1412scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1413
1414direction_to_suffix = {
1415    'reply': '_rsp',
1416    'request': '_req',
1417    '': ''
1418}
1419
1420op_mode_to_wrapper = {
1421    'do': '',
1422    'dump': '_list',
1423    'notify': '_ntf',
1424    'event': '',
1425}
1426
1427_C_KW = {
1428    'auto',
1429    'bool',
1430    'break',
1431    'case',
1432    'char',
1433    'const',
1434    'continue',
1435    'default',
1436    'do',
1437    'double',
1438    'else',
1439    'enum',
1440    'extern',
1441    'float',
1442    'for',
1443    'goto',
1444    'if',
1445    'inline',
1446    'int',
1447    'long',
1448    'register',
1449    'return',
1450    'short',
1451    'signed',
1452    'sizeof',
1453    'static',
1454    'struct',
1455    'switch',
1456    'typedef',
1457    'union',
1458    'unsigned',
1459    'void',
1460    'volatile',
1461    'while'
1462}
1463
1464
1465def rdir(direction):
1466    if direction == 'reply':
1467        return 'request'
1468    if direction == 'request':
1469        return 'reply'
1470    return direction
1471
1472
1473def op_prefix(ri, direction, deref=False):
1474    suffix = f"_{ri.type_name}"
1475
1476    if not ri.op_mode or ri.op_mode == 'do':
1477        suffix += f"{direction_to_suffix[direction]}"
1478    else:
1479        if direction == 'request':
1480            suffix += '_req_dump'
1481        else:
1482            if ri.type_consistent:
1483                if deref:
1484                    suffix += f"{direction_to_suffix[direction]}"
1485                else:
1486                    suffix += op_mode_to_wrapper[ri.op_mode]
1487            else:
1488                suffix += '_rsp'
1489                suffix += '_dump' if deref else '_list'
1490
1491    return f"{ri.family.c_name}{suffix}"
1492
1493
1494def type_name(ri, direction, deref=False):
1495    return f"struct {op_prefix(ri, direction, deref=deref)}"
1496
1497
1498def print_prototype(ri, direction, terminate=True, doc=None):
1499    suffix = ';' if terminate else ''
1500
1501    fname = ri.op.render_name
1502    if ri.op_mode == 'dump':
1503        fname += '_dump'
1504
1505    args = ['struct ynl_sock *ys']
1506    if 'request' in ri.op[ri.op_mode]:
1507        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1508
1509    ret = 'int'
1510    if 'reply' in ri.op[ri.op_mode]:
1511        ret = f"{type_name(ri, rdir(direction))} *"
1512
1513    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1514
1515
1516def print_req_prototype(ri):
1517    print_prototype(ri, "request", doc=ri.op['doc'])
1518
1519
1520def print_dump_prototype(ri):
1521    print_prototype(ri, "request")
1522
1523
1524def put_typol_fwd(cw, struct):
1525    cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;')
1526
1527
1528def put_typol(cw, struct):
1529    type_max = struct.attr_set.max_name
1530    cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1531
1532    for _, arg in struct.member_list():
1533        arg.attr_typol(cw)
1534
1535    cw.block_end(line=';')
1536    cw.nl()
1537
1538    cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1539    cw.p(f'.max_attr = {type_max},')
1540    cw.p(f'.table = {struct.render_name}_policy,')
1541    cw.block_end(line=';')
1542    cw.nl()
1543
1544
1545def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1546    args = [f'int {arg_name}']
1547    if enum:
1548        args = [enum.user_type + ' ' + arg_name]
1549    cw.write_func_prot('const char *', f'{render_name}_str', args)
1550    cw.block_start()
1551    if enum and enum.type == 'flags':
1552        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1553    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1554    cw.p('return NULL;')
1555    cw.p(f'return {map_name}[{arg_name}];')
1556    cw.block_end()
1557    cw.nl()
1558
1559
1560def put_op_name_fwd(family, cw):
1561    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1562
1563
1564def put_op_name(family, cw):
1565    map_name = f'{family.c_name}_op_strmap'
1566    cw.block_start(line=f"static const char * const {map_name}[] =")
1567    for op_name, op in family.msgs.items():
1568        if op.rsp_value:
1569            # Make sure we don't add duplicated entries, if multiple commands
1570            # produce the same response in legacy families.
1571            if family.rsp_by_value[op.rsp_value] != op:
1572                cw.p(f'// skip "{op_name}", duplicate reply value')
1573                continue
1574
1575            if op.req_value == op.rsp_value:
1576                cw.p(f'[{op.enum_name}] = "{op_name}",')
1577            else:
1578                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1579    cw.block_end(line=';')
1580    cw.nl()
1581
1582    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1583
1584
1585def put_enum_to_str_fwd(family, cw, enum):
1586    args = [enum.user_type + ' value']
1587    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1588
1589
1590def put_enum_to_str(family, cw, enum):
1591    map_name = f'{enum.render_name}_strmap'
1592    cw.block_start(line=f"static const char * const {map_name}[] =")
1593    for entry in enum.entries.values():
1594        cw.p(f'[{entry.value}] = "{entry.name}",')
1595    cw.block_end(line=';')
1596    cw.nl()
1597
1598    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1599
1600
1601def put_req_nested_prototype(ri, struct, suffix=';'):
1602    func_args = ['struct nlmsghdr *nlh',
1603                 'unsigned int attr_type',
1604                 f'{struct.ptr_name}obj']
1605
1606    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1607                          suffix=suffix)
1608
1609
1610def put_req_nested(ri, struct):
1611    put_req_nested_prototype(ri, struct, suffix='')
1612    ri.cw.block_start()
1613    ri.cw.write_func_lvar('struct nlattr *nest;')
1614
1615    ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1616
1617    for _, arg in struct.member_list():
1618        arg.attr_put(ri, "obj")
1619
1620    ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1621
1622    ri.cw.nl()
1623    ri.cw.p('return 0;')
1624    ri.cw.block_end()
1625    ri.cw.nl()
1626
1627
1628def _multi_parse(ri, struct, init_lines, local_vars):
1629    if struct.nested:
1630        iter_line = "mnl_attr_for_each_nested(attr, nested)"
1631    else:
1632        if ri.fixed_hdr:
1633            local_vars += ['void *hdr;']
1634        iter_line = "mnl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1635
1636    array_nests = set()
1637    multi_attrs = set()
1638    needs_parg = False
1639    for arg, aspec in struct.member_list():
1640        if aspec['type'] == 'array-nest':
1641            local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1642            array_nests.add(arg)
1643        if 'multi-attr' in aspec:
1644            multi_attrs.add(arg)
1645        needs_parg |= 'nested-attributes' in aspec
1646    if array_nests or multi_attrs:
1647        local_vars.append('int i;')
1648    if needs_parg:
1649        local_vars.append('struct ynl_parse_arg parg;')
1650        init_lines.append('parg.ys = yarg->ys;')
1651
1652    all_multi = array_nests | multi_attrs
1653
1654    for anest in sorted(all_multi):
1655        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1656
1657    ri.cw.block_start()
1658    ri.cw.write_func_lvar(local_vars)
1659
1660    for line in init_lines:
1661        ri.cw.p(line)
1662    ri.cw.nl()
1663
1664    for arg in struct.inherited:
1665        ri.cw.p(f'dst->{arg} = {arg};')
1666
1667    if ri.fixed_hdr:
1668        ri.cw.p('hdr = mnl_nlmsg_get_payload_offset(nlh, sizeof(struct genlmsghdr));')
1669        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1670    for anest in sorted(all_multi):
1671        aspec = struct[anest]
1672        ri.cw.p(f"if (dst->{aspec.c_name})")
1673        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1674
1675    ri.cw.nl()
1676    ri.cw.block_start(line=iter_line)
1677    ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1678    ri.cw.nl()
1679
1680    first = True
1681    for _, arg in struct.member_list():
1682        good = arg.attr_get(ri, 'dst', first=first)
1683        # First may be 'unused' or 'pad', ignore those
1684        first &= not good
1685
1686    ri.cw.block_end()
1687    ri.cw.nl()
1688
1689    for anest in sorted(array_nests):
1690        aspec = struct[anest]
1691
1692        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1693        ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1694        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1695        ri.cw.p('i = 0;')
1696        ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1697        ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1698        ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1699        ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1700        ri.cw.p('return MNL_CB_ERROR;')
1701        ri.cw.p('i++;')
1702        ri.cw.block_end()
1703        ri.cw.block_end()
1704    ri.cw.nl()
1705
1706    for anest in sorted(multi_attrs):
1707        aspec = struct[anest]
1708        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1709        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1710        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1711        ri.cw.p('i = 0;')
1712        if 'nested-attributes' in aspec:
1713            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1714        ri.cw.block_start(line=iter_line)
1715        ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1716        if 'nested-attributes' in aspec:
1717            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1718            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1719            ri.cw.p('return MNL_CB_ERROR;')
1720        elif aspec.type in scalars:
1721            ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{aspec.mnl_type()}(attr);")
1722        else:
1723            raise Exception('Nest parsing type not supported yet')
1724        ri.cw.p('i++;')
1725        ri.cw.block_end()
1726        ri.cw.block_end()
1727        ri.cw.block_end()
1728    ri.cw.nl()
1729
1730    if struct.nested:
1731        ri.cw.p('return 0;')
1732    else:
1733        ri.cw.p('return MNL_CB_OK;')
1734    ri.cw.block_end()
1735    ri.cw.nl()
1736
1737
1738def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1739    func_args = ['struct ynl_parse_arg *yarg',
1740                 'const struct nlattr *nested']
1741    for arg in struct.inherited:
1742        func_args.append('__u32 ' + arg)
1743
1744    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1745                          suffix=suffix)
1746
1747
1748def parse_rsp_nested(ri, struct):
1749    parse_rsp_nested_prototype(ri, struct, suffix='')
1750
1751    local_vars = ['const struct nlattr *attr;',
1752                  f'{struct.ptr_name}dst = yarg->data;']
1753    init_lines = []
1754
1755    _multi_parse(ri, struct, init_lines, local_vars)
1756
1757
1758def parse_rsp_msg(ri, deref=False):
1759    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1760        return
1761
1762    func_args = ['const struct nlmsghdr *nlh',
1763                 'void *data']
1764
1765    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1766                  'struct ynl_parse_arg *yarg = data;',
1767                  'const struct nlattr *attr;']
1768    init_lines = ['dst = yarg->data;']
1769
1770    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1771
1772    if ri.struct["reply"].member_list():
1773        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1774    else:
1775        # Empty reply
1776        ri.cw.block_start()
1777        ri.cw.p('return MNL_CB_OK;')
1778        ri.cw.block_end()
1779        ri.cw.nl()
1780
1781
1782def print_req(ri):
1783    ret_ok = '0'
1784    ret_err = '-1'
1785    direction = "request"
1786    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1787                  'struct nlmsghdr *nlh;',
1788                  'int err;']
1789
1790    if 'reply' in ri.op[ri.op_mode]:
1791        ret_ok = 'rsp'
1792        ret_err = 'NULL'
1793        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1794
1795    if ri.fixed_hdr:
1796        local_vars += ['size_t hdr_len;',
1797                       'void *hdr;']
1798
1799    print_prototype(ri, direction, terminate=False)
1800    ri.cw.block_start()
1801    ri.cw.write_func_lvar(local_vars)
1802
1803    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1804
1805    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1806    if 'reply' in ri.op[ri.op_mode]:
1807        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1808    ri.cw.nl()
1809
1810    if ri.fixed_hdr:
1811        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1812        ri.cw.p("hdr = mnl_nlmsg_put_extra_header(nlh, hdr_len);")
1813        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1814        ri.cw.nl()
1815
1816    for _, attr in ri.struct["request"].member_list():
1817        attr.attr_put(ri, "req")
1818    ri.cw.nl()
1819
1820    if 'reply' in ri.op[ri.op_mode]:
1821        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1822        ri.cw.p('yrs.yarg.data = rsp;')
1823        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1824        if ri.op.value is not None:
1825            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1826        else:
1827            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1828        ri.cw.nl()
1829    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1830    ri.cw.p('if (err < 0)')
1831    if 'reply' in ri.op[ri.op_mode]:
1832        ri.cw.p('goto err_free;')
1833    else:
1834        ri.cw.p('return -1;')
1835    ri.cw.nl()
1836
1837    ri.cw.p(f"return {ret_ok};")
1838    ri.cw.nl()
1839
1840    if 'reply' in ri.op[ri.op_mode]:
1841        ri.cw.p('err_free:')
1842        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1843        ri.cw.p(f"return {ret_err};")
1844
1845    ri.cw.block_end()
1846
1847
1848def print_dump(ri):
1849    direction = "request"
1850    print_prototype(ri, direction, terminate=False)
1851    ri.cw.block_start()
1852    local_vars = ['struct ynl_dump_state yds = {};',
1853                  'struct nlmsghdr *nlh;',
1854                  'int err;']
1855
1856    if ri.fixed_hdr:
1857        local_vars += ['size_t hdr_len;',
1858                       'void *hdr;']
1859
1860    ri.cw.write_func_lvar(local_vars)
1861
1862    ri.cw.p('yds.ys = ys;')
1863    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1864    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1865    if ri.op.value is not None:
1866        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1867    else:
1868        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1869    ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1870    ri.cw.nl()
1871    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1872
1873    if ri.fixed_hdr:
1874        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1875        ri.cw.p("hdr = mnl_nlmsg_put_extra_header(nlh, hdr_len);")
1876        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1877        ri.cw.nl()
1878
1879    if "request" in ri.op[ri.op_mode]:
1880        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1881        ri.cw.nl()
1882        for _, attr in ri.struct["request"].member_list():
1883            attr.attr_put(ri, "req")
1884    ri.cw.nl()
1885
1886    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1887    ri.cw.p('if (err < 0)')
1888    ri.cw.p('goto free_list;')
1889    ri.cw.nl()
1890
1891    ri.cw.p('return yds.first;')
1892    ri.cw.nl()
1893    ri.cw.p('free_list:')
1894    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1895    ri.cw.p('return NULL;')
1896    ri.cw.block_end()
1897
1898
1899def call_free(ri, direction, var):
1900    return f"{op_prefix(ri, direction)}_free({var});"
1901
1902
1903def free_arg_name(direction):
1904    if direction:
1905        return direction_to_suffix[direction][1:]
1906    return 'obj'
1907
1908
1909def print_alloc_wrapper(ri, direction):
1910    name = op_prefix(ri, direction)
1911    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1912    ri.cw.block_start()
1913    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1914    ri.cw.block_end()
1915
1916
1917def print_free_prototype(ri, direction, suffix=';'):
1918    name = op_prefix(ri, direction)
1919    struct_name = name
1920    if ri.type_name_conflict:
1921        struct_name += '_'
1922    arg = free_arg_name(direction)
1923    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1924
1925
1926def _print_type(ri, direction, struct):
1927    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1928    if not direction and ri.type_name_conflict:
1929        suffix += '_'
1930
1931    if ri.op_mode == 'dump':
1932        suffix += '_dump'
1933
1934    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
1935
1936    if ri.fixed_hdr:
1937        ri.cw.p(ri.fixed_hdr + ' _hdr;')
1938        ri.cw.nl()
1939
1940    meta_started = False
1941    for _, attr in struct.member_list():
1942        for type_filter in ['len', 'bit']:
1943            line = attr.presence_member(ri.ku_space, type_filter)
1944            if line:
1945                if not meta_started:
1946                    ri.cw.block_start(line=f"struct")
1947                    meta_started = True
1948                ri.cw.p(line)
1949    if meta_started:
1950        ri.cw.block_end(line='_present;')
1951        ri.cw.nl()
1952
1953    for arg in struct.inherited:
1954        ri.cw.p(f"__u32 {arg};")
1955
1956    for _, attr in struct.member_list():
1957        attr.struct_member(ri)
1958
1959    ri.cw.block_end(line=';')
1960    ri.cw.nl()
1961
1962
1963def print_type(ri, direction):
1964    _print_type(ri, direction, ri.struct[direction])
1965
1966
1967def print_type_full(ri, struct):
1968    _print_type(ri, "", struct)
1969
1970
1971def print_type_helpers(ri, direction, deref=False):
1972    print_free_prototype(ri, direction)
1973    ri.cw.nl()
1974
1975    if ri.ku_space == 'user' and direction == 'request':
1976        for _, attr in ri.struct[direction].member_list():
1977            attr.setter(ri, ri.attr_set, direction, deref=deref)
1978    ri.cw.nl()
1979
1980
1981def print_req_type_helpers(ri):
1982    if len(ri.struct["request"].attr_list) == 0:
1983        return
1984    print_alloc_wrapper(ri, "request")
1985    print_type_helpers(ri, "request")
1986
1987
1988def print_rsp_type_helpers(ri):
1989    if 'reply' not in ri.op[ri.op_mode]:
1990        return
1991    print_type_helpers(ri, "reply")
1992
1993
1994def print_parse_prototype(ri, direction, terminate=True):
1995    suffix = "_rsp" if direction == "reply" else "_req"
1996    term = ';' if terminate else ''
1997
1998    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1999                          ['const struct nlattr **tb',
2000                           f"struct {ri.op.render_name}{suffix} *req"],
2001                          suffix=term)
2002
2003
2004def print_req_type(ri):
2005    if len(ri.struct["request"].attr_list) == 0:
2006        return
2007    print_type(ri, "request")
2008
2009
2010def print_req_free(ri):
2011    if 'request' not in ri.op[ri.op_mode]:
2012        return
2013    _free_type(ri, 'request', ri.struct['request'])
2014
2015
2016def print_rsp_type(ri):
2017    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2018        direction = 'reply'
2019    elif ri.op_mode == 'event':
2020        direction = 'reply'
2021    else:
2022        return
2023    print_type(ri, direction)
2024
2025
2026def print_wrapped_type(ri):
2027    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2028    if ri.op_mode == 'dump':
2029        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2030    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2031        ri.cw.p('__u16 family;')
2032        ri.cw.p('__u8 cmd;')
2033        ri.cw.p('struct ynl_ntf_base_type *next;')
2034        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2035    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2036    ri.cw.block_end(line=';')
2037    ri.cw.nl()
2038    print_free_prototype(ri, 'reply')
2039    ri.cw.nl()
2040
2041
2042def _free_type_members_iter(ri, struct):
2043    for _, attr in struct.member_list():
2044        if attr.free_needs_iter():
2045            ri.cw.p('unsigned int i;')
2046            ri.cw.nl()
2047            break
2048
2049
2050def _free_type_members(ri, var, struct, ref=''):
2051    for _, attr in struct.member_list():
2052        attr.free(ri, var, ref)
2053
2054
2055def _free_type(ri, direction, struct):
2056    var = free_arg_name(direction)
2057
2058    print_free_prototype(ri, direction, suffix='')
2059    ri.cw.block_start()
2060    _free_type_members_iter(ri, struct)
2061    _free_type_members(ri, var, struct)
2062    if direction:
2063        ri.cw.p(f'free({var});')
2064    ri.cw.block_end()
2065    ri.cw.nl()
2066
2067
2068def free_rsp_nested_prototype(ri):
2069        print_free_prototype(ri, "")
2070
2071
2072def free_rsp_nested(ri, struct):
2073    _free_type(ri, "", struct)
2074
2075
2076def print_rsp_free(ri):
2077    if 'reply' not in ri.op[ri.op_mode]:
2078        return
2079    _free_type(ri, 'reply', ri.struct['reply'])
2080
2081
2082def print_dump_type_free(ri):
2083    sub_type = type_name(ri, 'reply')
2084
2085    print_free_prototype(ri, 'reply', suffix='')
2086    ri.cw.block_start()
2087    ri.cw.p(f"{sub_type} *next = rsp;")
2088    ri.cw.nl()
2089    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2090    _free_type_members_iter(ri, ri.struct['reply'])
2091    ri.cw.p('rsp = next;')
2092    ri.cw.p('next = rsp->next;')
2093    ri.cw.nl()
2094
2095    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2096    ri.cw.p(f'free(rsp);')
2097    ri.cw.block_end()
2098    ri.cw.block_end()
2099    ri.cw.nl()
2100
2101
2102def print_ntf_type_free(ri):
2103    print_free_prototype(ri, 'reply', suffix='')
2104    ri.cw.block_start()
2105    _free_type_members_iter(ri, ri.struct['reply'])
2106    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2107    ri.cw.p(f'free(rsp);')
2108    ri.cw.block_end()
2109    ri.cw.nl()
2110
2111
2112def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2113    if terminate and ri and policy_should_be_static(struct.family):
2114        return
2115
2116    if terminate:
2117        prefix = 'extern '
2118    else:
2119        if ri and policy_should_be_static(struct.family):
2120            prefix = 'static '
2121        else:
2122            prefix = ''
2123
2124    suffix = ';' if terminate else ' = {'
2125
2126    max_attr = struct.attr_max_val
2127    if ri:
2128        name = ri.op.render_name
2129        if ri.op.dual_policy:
2130            name += '_' + ri.op_mode
2131    else:
2132        name = struct.render_name
2133    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2134
2135
2136def print_req_policy(cw, struct, ri=None):
2137    if ri and ri.op:
2138        cw.ifdef_block(ri.op.get('config-cond', None))
2139    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2140    for _, arg in struct.member_list():
2141        arg.attr_policy(cw)
2142    cw.p("};")
2143    cw.ifdef_block(None)
2144    cw.nl()
2145
2146
2147def kernel_can_gen_family_struct(family):
2148    return family.proto == 'genetlink'
2149
2150
2151def policy_should_be_static(family):
2152    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2153
2154
2155def print_kernel_policy_ranges(family, cw):
2156    first = True
2157    for _, attr_set in family.attr_sets.items():
2158        if attr_set.subset_of:
2159            continue
2160
2161        for _, attr in attr_set.items():
2162            if not attr.request:
2163                continue
2164            if 'full-range' not in attr.checks:
2165                continue
2166
2167            if first:
2168                cw.p('/* Integer value ranges */')
2169                first = False
2170
2171            sign = '' if attr.type[0] == 'u' else '_signed'
2172            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2173            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2174            members = []
2175            if 'min' in attr.checks:
2176                members.append(('min', str(attr.get_limit('min')) + suffix))
2177            if 'max' in attr.checks:
2178                members.append(('max', str(attr.get_limit('max')) + suffix))
2179            cw.write_struct_init(members)
2180            cw.block_end(line=';')
2181            cw.nl()
2182
2183
2184def print_kernel_op_table_fwd(family, cw, terminate):
2185    exported = not kernel_can_gen_family_struct(family)
2186
2187    if not terminate or exported:
2188        cw.p(f"/* Ops table for {family.name} */")
2189
2190        pol_to_struct = {'global': 'genl_small_ops',
2191                         'per-op': 'genl_ops',
2192                         'split': 'genl_split_ops'}
2193        struct_type = pol_to_struct[family.kernel_policy]
2194
2195        if not exported:
2196            cnt = ""
2197        elif family.kernel_policy == 'split':
2198            cnt = 0
2199            for op in family.ops.values():
2200                if 'do' in op:
2201                    cnt += 1
2202                if 'dump' in op:
2203                    cnt += 1
2204        else:
2205            cnt = len(family.ops)
2206
2207        qual = 'static const' if not exported else 'const'
2208        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2209        if terminate:
2210            cw.p(f"extern {line};")
2211        else:
2212            cw.block_start(line=line + ' =')
2213
2214    if not terminate:
2215        return
2216
2217    cw.nl()
2218    for name in family.hooks['pre']['do']['list']:
2219        cw.write_func_prot('int', c_lower(name),
2220                           ['const struct genl_split_ops *ops',
2221                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2222    for name in family.hooks['post']['do']['list']:
2223        cw.write_func_prot('void', c_lower(name),
2224                           ['const struct genl_split_ops *ops',
2225                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2226    for name in family.hooks['pre']['dump']['list']:
2227        cw.write_func_prot('int', c_lower(name),
2228                           ['struct netlink_callback *cb'], suffix=';')
2229    for name in family.hooks['post']['dump']['list']:
2230        cw.write_func_prot('int', c_lower(name),
2231                           ['struct netlink_callback *cb'], suffix=';')
2232
2233    cw.nl()
2234
2235    for op_name, op in family.ops.items():
2236        if op.is_async:
2237            continue
2238
2239        if 'do' in op:
2240            name = c_lower(f"{family.name}-nl-{op_name}-doit")
2241            cw.write_func_prot('int', name,
2242                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2243
2244        if 'dump' in op:
2245            name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
2246            cw.write_func_prot('int', name,
2247                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2248    cw.nl()
2249
2250
2251def print_kernel_op_table_hdr(family, cw):
2252    print_kernel_op_table_fwd(family, cw, terminate=True)
2253
2254
2255def print_kernel_op_table(family, cw):
2256    print_kernel_op_table_fwd(family, cw, terminate=False)
2257    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2258        for op_name, op in family.ops.items():
2259            if op.is_async:
2260                continue
2261
2262            cw.ifdef_block(op.get('config-cond', None))
2263            cw.block_start()
2264            members = [('cmd', op.enum_name)]
2265            if 'dont-validate' in op:
2266                members.append(('validate',
2267                                ' | '.join([c_upper('genl-dont-validate-' + x)
2268                                            for x in op['dont-validate']])), )
2269            for op_mode in ['do', 'dump']:
2270                if op_mode in op:
2271                    name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2272                    members.append((op_mode + 'it', name))
2273            if family.kernel_policy == 'per-op':
2274                struct = Struct(family, op['attribute-set'],
2275                                type_list=op['do']['request']['attributes'])
2276
2277                name = c_lower(f"{family.name}-{op_name}-nl-policy")
2278                members.append(('policy', name))
2279                members.append(('maxattr', struct.attr_max_val.enum_name))
2280            if 'flags' in op:
2281                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2282            cw.write_struct_init(members)
2283            cw.block_end(line=',')
2284    elif family.kernel_policy == 'split':
2285        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2286                    'dump': {'pre': 'start', 'post': 'done'}}
2287
2288        for op_name, op in family.ops.items():
2289            for op_mode in ['do', 'dump']:
2290                if op.is_async or op_mode not in op:
2291                    continue
2292
2293                cw.ifdef_block(op.get('config-cond', None))
2294                cw.block_start()
2295                members = [('cmd', op.enum_name)]
2296                if 'dont-validate' in op:
2297                    dont_validate = []
2298                    for x in op['dont-validate']:
2299                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2300                            continue
2301                        if op_mode == "dump" and x == 'strict':
2302                            continue
2303                        dont_validate.append(x)
2304
2305                    if dont_validate:
2306                        members.append(('validate',
2307                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2308                                                    for x in dont_validate])), )
2309                name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2310                if 'pre' in op[op_mode]:
2311                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2312                members.append((op_mode + 'it', name))
2313                if 'post' in op[op_mode]:
2314                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2315                if 'request' in op[op_mode]:
2316                    struct = Struct(family, op['attribute-set'],
2317                                    type_list=op[op_mode]['request']['attributes'])
2318
2319                    if op.dual_policy:
2320                        name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2321                    else:
2322                        name = c_lower(f"{family.name}-{op_name}-nl-policy")
2323                    members.append(('policy', name))
2324                    members.append(('maxattr', struct.attr_max_val.enum_name))
2325                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2326                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2327                cw.write_struct_init(members)
2328                cw.block_end(line=',')
2329    cw.ifdef_block(None)
2330
2331    cw.block_end(line=';')
2332    cw.nl()
2333
2334
2335def print_kernel_mcgrp_hdr(family, cw):
2336    if not family.mcgrps['list']:
2337        return
2338
2339    cw.block_start('enum')
2340    for grp in family.mcgrps['list']:
2341        grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2342        cw.p(grp_id)
2343    cw.block_end(';')
2344    cw.nl()
2345
2346
2347def print_kernel_mcgrp_src(family, cw):
2348    if not family.mcgrps['list']:
2349        return
2350
2351    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2352    for grp in family.mcgrps['list']:
2353        name = grp['name']
2354        grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2355        cw.p('[' + grp_id + '] = { "' + name + '", },')
2356    cw.block_end(';')
2357    cw.nl()
2358
2359
2360def print_kernel_family_struct_hdr(family, cw):
2361    if not kernel_can_gen_family_struct(family):
2362        return
2363
2364    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2365    cw.nl()
2366
2367
2368def print_kernel_family_struct_src(family, cw):
2369    if not kernel_can_gen_family_struct(family):
2370        return
2371
2372    cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2373    cw.p('.name\t\t= ' + family.fam_key + ',')
2374    cw.p('.version\t= ' + family.ver_key + ',')
2375    cw.p('.netnsok\t= true,')
2376    cw.p('.parallel_ops\t= true,')
2377    cw.p('.module\t\t= THIS_MODULE,')
2378    if family.kernel_policy == 'per-op':
2379        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2380        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2381    elif family.kernel_policy == 'split':
2382        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2383        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2384    if family.mcgrps['list']:
2385        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2386        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2387    cw.block_end(';')
2388
2389
2390def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2391    start_line = 'enum'
2392    if enum_name in obj:
2393        if obj[enum_name]:
2394            start_line = 'enum ' + c_lower(obj[enum_name])
2395    elif ckey and ckey in obj:
2396        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2397    cw.block_start(line=start_line)
2398
2399
2400def render_uapi(family, cw):
2401    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2402    cw.p('#ifndef ' + hdr_prot)
2403    cw.p('#define ' + hdr_prot)
2404    cw.nl()
2405
2406    defines = [(family.fam_key, family["name"]),
2407               (family.ver_key, family.get('version', 1))]
2408    cw.writes_defines(defines)
2409    cw.nl()
2410
2411    defines = []
2412    for const in family['definitions']:
2413        if const['type'] != 'const':
2414            cw.writes_defines(defines)
2415            defines = []
2416            cw.nl()
2417
2418        # Write kdoc for enum and flags (one day maybe also structs)
2419        if const['type'] == 'enum' or const['type'] == 'flags':
2420            enum = family.consts[const['name']]
2421
2422            if enum.has_doc():
2423                cw.p('/**')
2424                doc = ''
2425                if 'doc' in enum:
2426                    doc = ' - ' + enum['doc']
2427                cw.write_doc_line(enum.enum_name + doc)
2428                for entry in enum.entries.values():
2429                    if entry.has_doc():
2430                        doc = '@' + entry.c_name + ': ' + entry['doc']
2431                        cw.write_doc_line(doc)
2432                cw.p(' */')
2433
2434            uapi_enum_start(family, cw, const, 'name')
2435            name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2436            for entry in enum.entries.values():
2437                suffix = ','
2438                if entry.value_change:
2439                    suffix = f" = {entry.user_value()}" + suffix
2440                cw.p(entry.c_name + suffix)
2441
2442            if const.get('render-max', False):
2443                cw.nl()
2444                cw.p('/* private: */')
2445                if const['type'] == 'flags':
2446                    max_name = c_upper(name_pfx + 'mask')
2447                    max_val = f' = {enum.get_mask()},'
2448                    cw.p(max_name + max_val)
2449                else:
2450                    max_name = c_upper(name_pfx + 'max')
2451                    cw.p('__' + max_name + ',')
2452                    cw.p(max_name + ' = (__' + max_name + ' - 1)')
2453            cw.block_end(line=';')
2454            cw.nl()
2455        elif const['type'] == 'const':
2456            defines.append([c_upper(family.get('c-define-name',
2457                                               f"{family.name}-{const['name']}")),
2458                            const['value']])
2459
2460    if defines:
2461        cw.writes_defines(defines)
2462        cw.nl()
2463
2464    max_by_define = family.get('max-by-define', False)
2465
2466    for _, attr_set in family.attr_sets.items():
2467        if attr_set.subset_of:
2468            continue
2469
2470        max_value = f"({attr_set.cnt_name} - 1)"
2471
2472        val = 0
2473        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2474        for _, attr in attr_set.items():
2475            suffix = ','
2476            if attr.value != val:
2477                suffix = f" = {attr.value},"
2478                val = attr.value
2479            val += 1
2480            cw.p(attr.enum_name + suffix)
2481        cw.nl()
2482        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2483        if not max_by_define:
2484            cw.p(f"{attr_set.max_name} = {max_value}")
2485        cw.block_end(line=';')
2486        if max_by_define:
2487            cw.p(f"#define {attr_set.max_name} {max_value}")
2488        cw.nl()
2489
2490    # Commands
2491    separate_ntf = 'async-prefix' in family['operations']
2492
2493    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2494    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2495    max_value = f"({cnt_name} - 1)"
2496
2497    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2498    val = 0
2499    for op in family.msgs.values():
2500        if separate_ntf and ('notify' in op or 'event' in op):
2501            continue
2502
2503        suffix = ','
2504        if op.value != val:
2505            suffix = f" = {op.value},"
2506            val = op.value
2507        cw.p(op.enum_name + suffix)
2508        val += 1
2509    cw.nl()
2510    cw.p(cnt_name + ('' if max_by_define else ','))
2511    if not max_by_define:
2512        cw.p(f"{max_name} = {max_value}")
2513    cw.block_end(line=';')
2514    if max_by_define:
2515        cw.p(f"#define {max_name} {max_value}")
2516    cw.nl()
2517
2518    if separate_ntf:
2519        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2520        for op in family.msgs.values():
2521            if separate_ntf and not ('notify' in op or 'event' in op):
2522                continue
2523
2524            suffix = ','
2525            if 'value' in op:
2526                suffix = f" = {op['value']},"
2527            cw.p(op.enum_name + suffix)
2528        cw.block_end(line=';')
2529        cw.nl()
2530
2531    # Multicast
2532    defines = []
2533    for grp in family.mcgrps['list']:
2534        name = grp['name']
2535        defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2536                        f'{name}'])
2537    cw.nl()
2538    if defines:
2539        cw.writes_defines(defines)
2540        cw.nl()
2541
2542    cw.p(f'#endif /* {hdr_prot} */')
2543
2544
2545def _render_user_ntf_entry(ri, op):
2546    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2547    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2548    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2549    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2550    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2551    ri.cw.block_end(line=',')
2552
2553
2554def render_user_family(family, cw, prototype):
2555    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2556    if prototype:
2557        cw.p(f'extern {symbol};')
2558        return
2559
2560    if family.ntfs:
2561        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2562        for ntf_op_name, ntf_op in family.ntfs.items():
2563            if 'notify' in ntf_op:
2564                op = family.ops[ntf_op['notify']]
2565                ri = RenderInfo(cw, family, "user", op, "notify")
2566            elif 'event' in ntf_op:
2567                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2568            else:
2569                raise Exception('Invalid notification ' + ntf_op_name)
2570            _render_user_ntf_entry(ri, ntf_op)
2571        for op_name, op in family.ops.items():
2572            if 'event' not in op:
2573                continue
2574            ri = RenderInfo(cw, family, "user", op, "event")
2575            _render_user_ntf_entry(ri, op)
2576        cw.block_end(line=";")
2577        cw.nl()
2578
2579    cw.block_start(f'{symbol} = ')
2580    cw.p(f'.name\t\t= "{family.c_name}",')
2581    if family.fixed_header:
2582        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2583    else:
2584        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2585    if family.ntfs:
2586        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2587        cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2588    cw.block_end(line=';')
2589
2590
2591def family_contains_bitfield32(family):
2592    for _, attr_set in family.attr_sets.items():
2593        if attr_set.subset_of:
2594            continue
2595        for _, attr in attr_set.items():
2596            if attr.type == "bitfield32":
2597                return True
2598    return False
2599
2600
2601def find_kernel_root(full_path):
2602    sub_path = ''
2603    while True:
2604        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2605        full_path = os.path.dirname(full_path)
2606        maintainers = os.path.join(full_path, "MAINTAINERS")
2607        if os.path.exists(maintainers):
2608            return full_path, sub_path[:-1]
2609
2610
2611def main():
2612    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2613    parser.add_argument('--mode', dest='mode', type=str, required=True)
2614    parser.add_argument('--spec', dest='spec', type=str, required=True)
2615    parser.add_argument('--header', dest='header', action='store_true', default=None)
2616    parser.add_argument('--source', dest='header', action='store_false')
2617    parser.add_argument('--user-header', nargs='+', default=[])
2618    parser.add_argument('--cmp-out', action='store_true', default=None,
2619                        help='Do not overwrite the output file if the new output is identical to the old')
2620    parser.add_argument('--exclude-op', action='append', default=[])
2621    parser.add_argument('-o', dest='out_file', type=str, default=None)
2622    args = parser.parse_args()
2623
2624    if args.header is None:
2625        parser.error("--header or --source is required")
2626
2627    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2628
2629    try:
2630        parsed = Family(args.spec, exclude_ops)
2631        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2632            print('Spec license:', parsed.license)
2633            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2634            os.sys.exit(1)
2635    except yaml.YAMLError as exc:
2636        print(exc)
2637        os.sys.exit(1)
2638        return
2639
2640    supported_models = ['unified']
2641    if args.mode in ['user', 'kernel']:
2642        supported_models += ['directional']
2643    if parsed.msg_id_model not in supported_models:
2644        print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2645        os.sys.exit(1)
2646
2647    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2648
2649    _, spec_kernel = find_kernel_root(args.spec)
2650    if args.mode == 'uapi' or args.header:
2651        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2652    else:
2653        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2654    cw.p("/* Do not edit directly, auto-generated from: */")
2655    cw.p(f"/*\t{spec_kernel} */")
2656    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2657    if args.exclude_op or args.user_header:
2658        line = ''
2659        line += ' --user-header '.join([''] + args.user_header)
2660        line += ' --exclude-op '.join([''] + args.exclude_op)
2661        cw.p(f'/* YNL-ARG{line} */')
2662    cw.nl()
2663
2664    if args.mode == 'uapi':
2665        render_uapi(parsed, cw)
2666        return
2667
2668    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2669    if args.header:
2670        cw.p('#ifndef ' + hdr_prot)
2671        cw.p('#define ' + hdr_prot)
2672        cw.nl()
2673
2674    if args.mode == 'kernel':
2675        cw.p('#include <net/netlink.h>')
2676        cw.p('#include <net/genetlink.h>')
2677        cw.nl()
2678        if not args.header:
2679            if args.out_file:
2680                cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2681            cw.nl()
2682        headers = ['uapi/' + parsed.uapi_header]
2683    else:
2684        cw.p('#include <stdlib.h>')
2685        cw.p('#include <string.h>')
2686        if args.header:
2687            cw.p('#include <linux/types.h>')
2688            if family_contains_bitfield32(parsed):
2689                cw.p('#include <linux/netlink.h>')
2690        else:
2691            cw.p(f'#include "{parsed.name}-user.h"')
2692            cw.p('#include "ynl.h"')
2693        headers = [parsed.uapi_header]
2694    for definition in parsed['definitions']:
2695        if 'header' in definition:
2696            headers.append(definition['header'])
2697    for one in headers:
2698        cw.p(f"#include <{one}>")
2699    cw.nl()
2700
2701    if args.mode == "user":
2702        if not args.header:
2703            cw.p("#include <libmnl/libmnl.h>")
2704            cw.p("#include <linux/genetlink.h>")
2705            cw.nl()
2706            for one in args.user_header:
2707                cw.p(f'#include "{one}"')
2708        else:
2709            cw.p('struct ynl_sock;')
2710            cw.nl()
2711            render_user_family(parsed, cw, True)
2712        cw.nl()
2713
2714    if args.mode == "kernel":
2715        if args.header:
2716            for _, struct in sorted(parsed.pure_nested_structs.items()):
2717                if struct.request:
2718                    cw.p('/* Common nested types */')
2719                    break
2720            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2721                if struct.request:
2722                    print_req_policy_fwd(cw, struct)
2723            cw.nl()
2724
2725            if parsed.kernel_policy == 'global':
2726                cw.p(f"/* Global operation policy for {parsed.name} */")
2727
2728                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2729                print_req_policy_fwd(cw, struct)
2730                cw.nl()
2731
2732            if parsed.kernel_policy in {'per-op', 'split'}:
2733                for op_name, op in parsed.ops.items():
2734                    if 'do' in op and 'event' not in op:
2735                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2736                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2737                        cw.nl()
2738
2739            print_kernel_op_table_hdr(parsed, cw)
2740            print_kernel_mcgrp_hdr(parsed, cw)
2741            print_kernel_family_struct_hdr(parsed, cw)
2742        else:
2743            print_kernel_policy_ranges(parsed, cw)
2744
2745            for _, struct in sorted(parsed.pure_nested_structs.items()):
2746                if struct.request:
2747                    cw.p('/* Common nested types */')
2748                    break
2749            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2750                if struct.request:
2751                    print_req_policy(cw, struct)
2752            cw.nl()
2753
2754            if parsed.kernel_policy == 'global':
2755                cw.p(f"/* Global operation policy for {parsed.name} */")
2756
2757                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2758                print_req_policy(cw, struct)
2759                cw.nl()
2760
2761            for op_name, op in parsed.ops.items():
2762                if parsed.kernel_policy in {'per-op', 'split'}:
2763                    for op_mode in ['do', 'dump']:
2764                        if op_mode in op and 'request' in op[op_mode]:
2765                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2766                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2767                            print_req_policy(cw, ri.struct['request'], ri=ri)
2768                            cw.nl()
2769
2770            print_kernel_op_table(parsed, cw)
2771            print_kernel_mcgrp_src(parsed, cw)
2772            print_kernel_family_struct_src(parsed, cw)
2773
2774    if args.mode == "user":
2775        if args.header:
2776            cw.p('/* Enums */')
2777            put_op_name_fwd(parsed, cw)
2778
2779            for name, const in parsed.consts.items():
2780                if isinstance(const, EnumSet):
2781                    put_enum_to_str_fwd(parsed, cw, const)
2782            cw.nl()
2783
2784            cw.p('/* Common nested types */')
2785            for attr_set, struct in parsed.pure_nested_structs.items():
2786                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2787                print_type_full(ri, struct)
2788
2789            for op_name, op in parsed.ops.items():
2790                cw.p(f"/* ============== {op.enum_name} ============== */")
2791
2792                if 'do' in op and 'event' not in op:
2793                    cw.p(f"/* {op.enum_name} - do */")
2794                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2795                    print_req_type(ri)
2796                    print_req_type_helpers(ri)
2797                    cw.nl()
2798                    print_rsp_type(ri)
2799                    print_rsp_type_helpers(ri)
2800                    cw.nl()
2801                    print_req_prototype(ri)
2802                    cw.nl()
2803
2804                if 'dump' in op:
2805                    cw.p(f"/* {op.enum_name} - dump */")
2806                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2807                    print_req_type(ri)
2808                    print_req_type_helpers(ri)
2809                    if not ri.type_consistent:
2810                        print_rsp_type(ri)
2811                    print_wrapped_type(ri)
2812                    print_dump_prototype(ri)
2813                    cw.nl()
2814
2815                if op.has_ntf:
2816                    cw.p(f"/* {op.enum_name} - notify */")
2817                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2818                    if not ri.type_consistent:
2819                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2820                    print_wrapped_type(ri)
2821
2822            for op_name, op in parsed.ntfs.items():
2823                if 'event' in op:
2824                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2825                    cw.p(f"/* {op.enum_name} - event */")
2826                    print_rsp_type(ri)
2827                    cw.nl()
2828                    print_wrapped_type(ri)
2829            cw.nl()
2830        else:
2831            cw.p('/* Enums */')
2832            put_op_name(parsed, cw)
2833
2834            for name, const in parsed.consts.items():
2835                if isinstance(const, EnumSet):
2836                    put_enum_to_str(parsed, cw, const)
2837            cw.nl()
2838
2839            has_recursive_nests = False
2840            cw.p('/* Policies */')
2841            for struct in parsed.pure_nested_structs.values():
2842                if struct.recursive:
2843                    put_typol_fwd(cw, struct)
2844                    has_recursive_nests = True
2845            if has_recursive_nests:
2846                cw.nl()
2847            for name in parsed.pure_nested_structs:
2848                struct = Struct(parsed, name)
2849                put_typol(cw, struct)
2850            for name in parsed.root_sets:
2851                struct = Struct(parsed, name)
2852                put_typol(cw, struct)
2853
2854            cw.p('/* Common nested types */')
2855            if has_recursive_nests:
2856                for attr_set, struct in parsed.pure_nested_structs.items():
2857                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2858                    free_rsp_nested_prototype(ri)
2859                    if struct.request:
2860                        put_req_nested_prototype(ri, struct)
2861                    if struct.reply:
2862                        parse_rsp_nested_prototype(ri, struct)
2863                cw.nl()
2864            for attr_set, struct in parsed.pure_nested_structs.items():
2865                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2866
2867                free_rsp_nested(ri, struct)
2868                if struct.request:
2869                    put_req_nested(ri, struct)
2870                if struct.reply:
2871                    parse_rsp_nested(ri, struct)
2872
2873            for op_name, op in parsed.ops.items():
2874                cw.p(f"/* ============== {op.enum_name} ============== */")
2875                if 'do' in op and 'event' not in op:
2876                    cw.p(f"/* {op.enum_name} - do */")
2877                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2878                    print_req_free(ri)
2879                    print_rsp_free(ri)
2880                    parse_rsp_msg(ri)
2881                    print_req(ri)
2882                    cw.nl()
2883
2884                if 'dump' in op:
2885                    cw.p(f"/* {op.enum_name} - dump */")
2886                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2887                    if not ri.type_consistent:
2888                        parse_rsp_msg(ri, deref=True)
2889                    print_req_free(ri)
2890                    print_dump_type_free(ri)
2891                    print_dump(ri)
2892                    cw.nl()
2893
2894                if op.has_ntf:
2895                    cw.p(f"/* {op.enum_name} - notify */")
2896                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2897                    if not ri.type_consistent:
2898                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
2899                    print_ntf_type_free(ri)
2900
2901            for op_name, op in parsed.ntfs.items():
2902                if 'event' in op:
2903                    cw.p(f"/* {op.enum_name} - event */")
2904
2905                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2906                    parse_rsp_msg(ri)
2907
2908                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
2909                    print_ntf_type_free(ri)
2910            cw.nl()
2911            render_user_family(parsed, cw, False)
2912
2913    if args.header:
2914        cw.p(f'#endif /* {hdr_prot} */')
2915
2916
2917if __name__ == "__main__":
2918    main()