Loading...
Note: File does not exist in v4.10.11.
1# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3from collections import namedtuple
4import functools
5import os
6import random
7import socket
8import struct
9from struct import Struct
10import sys
11import yaml
12import ipaddress
13import uuid
14
15from .nlspec import SpecFamily
16
17#
18# Generic Netlink code which should really be in some library, but I can't quickly find one.
19#
20
21
22class Netlink:
23 # Netlink socket
24 SOL_NETLINK = 270
25
26 NETLINK_ADD_MEMBERSHIP = 1
27 NETLINK_CAP_ACK = 10
28 NETLINK_EXT_ACK = 11
29 NETLINK_GET_STRICT_CHK = 12
30
31 # Netlink message
32 NLMSG_ERROR = 2
33 NLMSG_DONE = 3
34
35 NLM_F_REQUEST = 1
36 NLM_F_ACK = 4
37 NLM_F_ROOT = 0x100
38 NLM_F_MATCH = 0x200
39
40 NLM_F_REPLACE = 0x100
41 NLM_F_EXCL = 0x200
42 NLM_F_CREATE = 0x400
43 NLM_F_APPEND = 0x800
44
45 NLM_F_CAPPED = 0x100
46 NLM_F_ACK_TLVS = 0x200
47
48 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
49
50 NLA_F_NESTED = 0x8000
51 NLA_F_NET_BYTEORDER = 0x4000
52
53 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
54
55 # Genetlink defines
56 NETLINK_GENERIC = 16
57
58 GENL_ID_CTRL = 0x10
59
60 # nlctrl
61 CTRL_CMD_GETFAMILY = 3
62
63 CTRL_ATTR_FAMILY_ID = 1
64 CTRL_ATTR_FAMILY_NAME = 2
65 CTRL_ATTR_MAXATTR = 5
66 CTRL_ATTR_MCAST_GROUPS = 7
67
68 CTRL_ATTR_MCAST_GRP_NAME = 1
69 CTRL_ATTR_MCAST_GRP_ID = 2
70
71 # Extack types
72 NLMSGERR_ATTR_MSG = 1
73 NLMSGERR_ATTR_OFFS = 2
74 NLMSGERR_ATTR_COOKIE = 3
75 NLMSGERR_ATTR_POLICY = 4
76 NLMSGERR_ATTR_MISS_TYPE = 5
77 NLMSGERR_ATTR_MISS_NEST = 6
78
79
80class NlError(Exception):
81 def __init__(self, nl_msg):
82 self.nl_msg = nl_msg
83
84 def __str__(self):
85 return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
86
87
88class ConfigError(Exception):
89 pass
90
91
92class NlAttr:
93 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
94 type_formats = {
95 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")),
96 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")),
97 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
98 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
99 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
100 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
101 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
102 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
103 }
104
105 def __init__(self, raw, offset):
106 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
107 self.type = self._type & ~Netlink.NLA_TYPE_MASK
108 self.is_nest = self._type & Netlink.NLA_F_NESTED
109 self.payload_len = self._len
110 self.full_len = (self.payload_len + 3) & ~3
111 self.raw = raw[offset + 4 : offset + self.payload_len]
112
113 @classmethod
114 def get_format(cls, attr_type, byte_order=None):
115 format = cls.type_formats[attr_type]
116 if byte_order:
117 return format.big if byte_order == "big-endian" \
118 else format.little
119 return format.native
120
121 def as_scalar(self, attr_type, byte_order=None):
122 format = self.get_format(attr_type, byte_order)
123 return format.unpack(self.raw)[0]
124
125 def as_auto_scalar(self, attr_type, byte_order=None):
126 if len(self.raw) != 4 and len(self.raw) != 8:
127 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
128 real_type = attr_type[0] + str(len(self.raw) * 8)
129 format = self.get_format(real_type, byte_order)
130 return format.unpack(self.raw)[0]
131
132 def as_strz(self):
133 return self.raw.decode('ascii')[:-1]
134
135 def as_bin(self):
136 return self.raw
137
138 def as_c_array(self, type):
139 format = self.get_format(type)
140 return [ x[0] for x in format.iter_unpack(self.raw) ]
141
142 def __repr__(self):
143 return f"[type:{self.type} len:{self._len}] {self.raw}"
144
145
146class NlAttrs:
147 def __init__(self, msg, offset=0):
148 self.attrs = []
149
150 while offset < len(msg):
151 attr = NlAttr(msg, offset)
152 offset += attr.full_len
153 self.attrs.append(attr)
154
155 def __iter__(self):
156 yield from self.attrs
157
158 def __repr__(self):
159 msg = ''
160 for a in self.attrs:
161 if msg:
162 msg += '\n'
163 msg += repr(a)
164 return msg
165
166
167class NlMsg:
168 def __init__(self, msg, offset, attr_space=None):
169 self.hdr = msg[offset : offset + 16]
170
171 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
172 struct.unpack("IHHII", self.hdr)
173
174 self.raw = msg[offset + 16 : offset + self.nl_len]
175
176 self.error = 0
177 self.done = 0
178
179 extack_off = None
180 if self.nl_type == Netlink.NLMSG_ERROR:
181 self.error = struct.unpack("i", self.raw[0:4])[0]
182 self.done = 1
183 extack_off = 20
184 elif self.nl_type == Netlink.NLMSG_DONE:
185 self.error = struct.unpack("i", self.raw[0:4])[0]
186 self.done = 1
187 extack_off = 4
188
189 self.extack = None
190 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
191 self.extack = dict()
192 extack_attrs = NlAttrs(self.raw[extack_off:])
193 for extack in extack_attrs:
194 if extack.type == Netlink.NLMSGERR_ATTR_MSG:
195 self.extack['msg'] = extack.as_strz()
196 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
197 self.extack['miss-type'] = extack.as_scalar('u32')
198 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
199 self.extack['miss-nest'] = extack.as_scalar('u32')
200 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
201 self.extack['bad-attr-offs'] = extack.as_scalar('u32')
202 else:
203 if 'unknown' not in self.extack:
204 self.extack['unknown'] = []
205 self.extack['unknown'].append(extack)
206
207 if attr_space:
208 # We don't have the ability to parse nests yet, so only do global
209 if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
210 miss_type = self.extack['miss-type']
211 if miss_type in attr_space.attrs_by_val:
212 spec = attr_space.attrs_by_val[miss_type]
213 desc = spec['name']
214 if 'doc' in spec:
215 desc += f" ({spec['doc']})"
216 self.extack['miss-type'] = desc
217
218 def cmd(self):
219 return self.nl_type
220
221 def __repr__(self):
222 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
223 if self.error:
224 msg += '\n\terror: ' + str(self.error)
225 if self.extack:
226 msg += '\n\textack: ' + repr(self.extack)
227 return msg
228
229
230class NlMsgs:
231 def __init__(self, data, attr_space=None):
232 self.msgs = []
233
234 offset = 0
235 while offset < len(data):
236 msg = NlMsg(data, offset, attr_space=attr_space)
237 offset += msg.nl_len
238 self.msgs.append(msg)
239
240 def __iter__(self):
241 yield from self.msgs
242
243
244genl_family_name_to_id = None
245
246
247def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
248 # we prepend length in _genl_msg_finalize()
249 if seq is None:
250 seq = random.randint(1, 1024)
251 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
252 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
253 return nlmsg + genlmsg
254
255
256def _genl_msg_finalize(msg):
257 return struct.pack("I", len(msg) + 4) + msg
258
259
260def _genl_load_families():
261 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
262 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
263
264 msg = _genl_msg(Netlink.GENL_ID_CTRL,
265 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
266 Netlink.CTRL_CMD_GETFAMILY, 1)
267 msg = _genl_msg_finalize(msg)
268
269 sock.send(msg, 0)
270
271 global genl_family_name_to_id
272 genl_family_name_to_id = dict()
273
274 while True:
275 reply = sock.recv(128 * 1024)
276 nms = NlMsgs(reply)
277 for nl_msg in nms:
278 if nl_msg.error:
279 print("Netlink error:", nl_msg.error)
280 return
281 if nl_msg.done:
282 return
283
284 gm = GenlMsg(nl_msg)
285 fam = dict()
286 for attr in NlAttrs(gm.raw):
287 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
288 fam['id'] = attr.as_scalar('u16')
289 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
290 fam['name'] = attr.as_strz()
291 elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
292 fam['maxattr'] = attr.as_scalar('u32')
293 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
294 fam['mcast'] = dict()
295 for entry in NlAttrs(attr.raw):
296 mcast_name = None
297 mcast_id = None
298 for entry_attr in NlAttrs(entry.raw):
299 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
300 mcast_name = entry_attr.as_strz()
301 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
302 mcast_id = entry_attr.as_scalar('u32')
303 if mcast_name and mcast_id is not None:
304 fam['mcast'][mcast_name] = mcast_id
305 if 'name' in fam and 'id' in fam:
306 genl_family_name_to_id[fam['name']] = fam
307
308
309class GenlMsg:
310 def __init__(self, nl_msg):
311 self.nl = nl_msg
312 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
313 self.raw = nl_msg.raw[4:]
314
315 def cmd(self):
316 return self.genl_cmd
317
318 def __repr__(self):
319 msg = repr(self.nl)
320 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
321 for a in self.raw_attrs:
322 msg += '\t\t' + repr(a) + '\n'
323 return msg
324
325
326class NetlinkProtocol:
327 def __init__(self, family_name, proto_num):
328 self.family_name = family_name
329 self.proto_num = proto_num
330
331 def _message(self, nl_type, nl_flags, seq=None):
332 if seq is None:
333 seq = random.randint(1, 1024)
334 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
335 return nlmsg
336
337 def message(self, flags, command, version, seq=None):
338 return self._message(command, flags, seq)
339
340 def _decode(self, nl_msg):
341 return nl_msg
342
343 def decode(self, ynl, nl_msg):
344 msg = self._decode(nl_msg)
345 fixed_header_size = 0
346 if ynl:
347 op = ynl.rsp_by_value[msg.cmd()]
348 fixed_header_size = ynl._struct_size(op.fixed_header)
349 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
350 return msg
351
352 def get_mcast_id(self, mcast_name, mcast_groups):
353 if mcast_name not in mcast_groups:
354 raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
355 return mcast_groups[mcast_name].value
356
357 def msghdr_size(self):
358 return 16
359
360
361class GenlProtocol(NetlinkProtocol):
362 def __init__(self, family_name):
363 super().__init__(family_name, Netlink.NETLINK_GENERIC)
364
365 global genl_family_name_to_id
366 if genl_family_name_to_id is None:
367 _genl_load_families()
368
369 self.genl_family = genl_family_name_to_id[family_name]
370 self.family_id = genl_family_name_to_id[family_name]['id']
371
372 def message(self, flags, command, version, seq=None):
373 nlmsg = self._message(self.family_id, flags, seq)
374 genlmsg = struct.pack("BBH", command, version, 0)
375 return nlmsg + genlmsg
376
377 def _decode(self, nl_msg):
378 return GenlMsg(nl_msg)
379
380 def get_mcast_id(self, mcast_name, mcast_groups):
381 if mcast_name not in self.genl_family['mcast']:
382 raise Exception(f'Multicast group "{mcast_name}" not present in the family')
383 return self.genl_family['mcast'][mcast_name]
384
385 def msghdr_size(self):
386 return super().msghdr_size() + 4
387
388
389class SpaceAttrs:
390 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
391
392 def __init__(self, attr_space, attrs, outer = None):
393 outer_scopes = outer.scopes if outer else []
394 inner_scope = self.SpecValuesPair(attr_space, attrs)
395 self.scopes = [inner_scope] + outer_scopes
396
397 def lookup(self, name):
398 for scope in self.scopes:
399 if name in scope.spec:
400 if name in scope.values:
401 return scope.values[name]
402 spec_name = scope.spec.yaml['name']
403 raise Exception(
404 f"No value for '{name}' in attribute space '{spec_name}'")
405 raise Exception(f"Attribute '{name}' not defined in any attribute-set")
406
407
408#
409# YNL implementation details.
410#
411
412
413class YnlFamily(SpecFamily):
414 def __init__(self, def_path, schema=None, process_unknown=False,
415 recv_size=0):
416 super().__init__(def_path, schema)
417
418 self.include_raw = False
419 self.process_unknown = process_unknown
420
421 try:
422 if self.proto == "netlink-raw":
423 self.nlproto = NetlinkProtocol(self.yaml['name'],
424 self.yaml['protonum'])
425 else:
426 self.nlproto = GenlProtocol(self.yaml['name'])
427 except KeyError:
428 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
429
430 self._recv_dbg = False
431 # Note that netlink will use conservative (min) message size for
432 # the first dump recv() on the socket, our setting will only matter
433 # from the second recv() on.
434 self._recv_size = recv_size if recv_size else 131072
435 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
436 # for a message, so smaller receive sizes will lead to truncation.
437 # Note that the min size for other families may be larger than 4k!
438 if self._recv_size < 4000:
439 raise ConfigError()
440
441 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
442 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
443 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
444 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
445
446 self.async_msg_ids = set()
447 self.async_msg_queue = []
448
449 for msg in self.msgs.values():
450 if msg.is_async:
451 self.async_msg_ids.add(msg.rsp_value)
452
453 for op_name, op in self.ops.items():
454 bound_f = functools.partial(self._op, op_name)
455 setattr(self, op.ident_name, bound_f)
456
457
458 def ntf_subscribe(self, mcast_name):
459 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
460 self.sock.bind((0, 0))
461 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
462 mcast_id)
463
464 def set_recv_dbg(self, enabled):
465 self._recv_dbg = enabled
466
467 def _recv_dbg_print(self, reply, nl_msgs):
468 if not self._recv_dbg:
469 return
470 print("Recv: read", len(reply), "bytes,",
471 len(nl_msgs.msgs), "messages", file=sys.stderr)
472 for nl_msg in nl_msgs:
473 print(" ", nl_msg, file=sys.stderr)
474
475 def _encode_enum(self, attr_spec, value):
476 enum = self.consts[attr_spec['enum']]
477 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
478 scalar = 0
479 if isinstance(value, str):
480 value = [value]
481 for single_value in value:
482 scalar += enum.entries[single_value].user_value(as_flags = True)
483 return scalar
484 else:
485 return enum.entries[value].user_value()
486
487 def _get_scalar(self, attr_spec, value):
488 try:
489 return int(value)
490 except (ValueError, TypeError) as e:
491 if 'enum' not in attr_spec:
492 raise e
493 return self._encode_enum(attr_spec, value)
494
495 def _add_attr(self, space, name, value, search_attrs):
496 try:
497 attr = self.attr_sets[space][name]
498 except KeyError:
499 raise Exception(f"Space '{space}' has no attribute '{name}'")
500 nl_type = attr.value
501
502 if attr.is_multi and isinstance(value, list):
503 attr_payload = b''
504 for subvalue in value:
505 attr_payload += self._add_attr(space, name, subvalue, search_attrs)
506 return attr_payload
507
508 if attr["type"] == 'nest':
509 nl_type |= Netlink.NLA_F_NESTED
510 attr_payload = b''
511 sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
512 for subname, subvalue in value.items():
513 attr_payload += self._add_attr(attr['nested-attributes'],
514 subname, subvalue, sub_attrs)
515 elif attr["type"] == 'flag':
516 if not value:
517 # If value is absent or false then skip attribute creation.
518 return b''
519 attr_payload = b''
520 elif attr["type"] == 'string':
521 attr_payload = str(value).encode('ascii') + b'\x00'
522 elif attr["type"] == 'binary':
523 if isinstance(value, bytes):
524 attr_payload = value
525 elif isinstance(value, str):
526 attr_payload = bytes.fromhex(value)
527 elif isinstance(value, dict) and attr.struct_name:
528 attr_payload = self._encode_struct(attr.struct_name, value)
529 else:
530 raise Exception(f'Unknown type for binary attribute, value: {value}')
531 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
532 scalar = self._get_scalar(attr, value)
533 if attr.is_auto_scalar:
534 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
535 else:
536 attr_type = attr["type"]
537 format = NlAttr.get_format(attr_type, attr.byte_order)
538 attr_payload = format.pack(scalar)
539 elif attr['type'] in "bitfield32":
540 scalar_value = self._get_scalar(attr, value["value"])
541 scalar_selector = self._get_scalar(attr, value["selector"])
542 attr_payload = struct.pack("II", scalar_value, scalar_selector)
543 elif attr['type'] == 'sub-message':
544 msg_format = self._resolve_selector(attr, search_attrs)
545 attr_payload = b''
546 if msg_format.fixed_header:
547 attr_payload += self._encode_struct(msg_format.fixed_header, value)
548 if msg_format.attr_set:
549 if msg_format.attr_set in self.attr_sets:
550 nl_type |= Netlink.NLA_F_NESTED
551 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
552 for subname, subvalue in value.items():
553 attr_payload += self._add_attr(msg_format.attr_set,
554 subname, subvalue, sub_attrs)
555 else:
556 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
557 else:
558 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
559
560 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
561 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
562
563 def _decode_enum(self, raw, attr_spec):
564 enum = self.consts[attr_spec['enum']]
565 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
566 i = 0
567 value = set()
568 while raw:
569 if raw & 1:
570 value.add(enum.entries_by_val[i].name)
571 raw >>= 1
572 i += 1
573 else:
574 value = enum.entries_by_val[raw].name
575 return value
576
577 def _decode_binary(self, attr, attr_spec):
578 if attr_spec.struct_name:
579 decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
580 elif attr_spec.sub_type:
581 decoded = attr.as_c_array(attr_spec.sub_type)
582 else:
583 decoded = attr.as_bin()
584 if attr_spec.display_hint:
585 decoded = self._formatted_string(decoded, attr_spec.display_hint)
586 return decoded
587
588 def _decode_array_nest(self, attr, attr_spec):
589 decoded = []
590 offset = 0
591 while offset < len(attr.raw):
592 item = NlAttr(attr.raw, offset)
593 offset += item.full_len
594
595 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
596 decoded.append({ item.type: subattrs })
597 return decoded
598
599 def _decode_nest_type_value(self, attr, attr_spec):
600 decoded = {}
601 value = attr
602 for name in attr_spec['type-value']:
603 value = NlAttr(value.raw, 0)
604 decoded[name] = value.type
605 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
606 decoded.update(subattrs)
607 return decoded
608
609 def _decode_unknown(self, attr):
610 if attr.is_nest:
611 return self._decode(NlAttrs(attr.raw), None)
612 else:
613 return attr.as_bin()
614
615 def _rsp_add(self, rsp, name, is_multi, decoded):
616 if is_multi == None:
617 if name in rsp and type(rsp[name]) is not list:
618 rsp[name] = [rsp[name]]
619 is_multi = True
620 else:
621 is_multi = False
622
623 if not is_multi:
624 rsp[name] = decoded
625 elif name in rsp:
626 rsp[name].append(decoded)
627 else:
628 rsp[name] = [decoded]
629
630 def _resolve_selector(self, attr_spec, search_attrs):
631 sub_msg = attr_spec.sub_message
632 if sub_msg not in self.sub_msgs:
633 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
634 sub_msg_spec = self.sub_msgs[sub_msg]
635
636 selector = attr_spec.selector
637 value = search_attrs.lookup(selector)
638 if value not in sub_msg_spec.formats:
639 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
640
641 spec = sub_msg_spec.formats[value]
642 return spec
643
644 def _decode_sub_msg(self, attr, attr_spec, search_attrs):
645 msg_format = self._resolve_selector(attr_spec, search_attrs)
646 decoded = {}
647 offset = 0
648 if msg_format.fixed_header:
649 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
650 offset = self._struct_size(msg_format.fixed_header)
651 if msg_format.attr_set:
652 if msg_format.attr_set in self.attr_sets:
653 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
654 decoded.update(subdict)
655 else:
656 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
657 return decoded
658
659 def _decode(self, attrs, space, outer_attrs = None):
660 rsp = dict()
661 if space:
662 attr_space = self.attr_sets[space]
663 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
664
665 for attr in attrs:
666 try:
667 attr_spec = attr_space.attrs_by_val[attr.type]
668 except (KeyError, UnboundLocalError):
669 if not self.process_unknown:
670 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
671 attr_name = f"UnknownAttr({attr.type})"
672 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
673 continue
674
675 if attr_spec["type"] == 'nest':
676 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
677 decoded = subdict
678 elif attr_spec["type"] == 'string':
679 decoded = attr.as_strz()
680 elif attr_spec["type"] == 'binary':
681 decoded = self._decode_binary(attr, attr_spec)
682 elif attr_spec["type"] == 'flag':
683 decoded = True
684 elif attr_spec.is_auto_scalar:
685 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
686 elif attr_spec["type"] in NlAttr.type_formats:
687 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
688 if 'enum' in attr_spec:
689 decoded = self._decode_enum(decoded, attr_spec)
690 elif attr_spec["type"] == 'array-nest':
691 decoded = self._decode_array_nest(attr, attr_spec)
692 elif attr_spec["type"] == 'bitfield32':
693 value, selector = struct.unpack("II", attr.raw)
694 if 'enum' in attr_spec:
695 value = self._decode_enum(value, attr_spec)
696 selector = self._decode_enum(selector, attr_spec)
697 decoded = {"value": value, "selector": selector}
698 elif attr_spec["type"] == 'sub-message':
699 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
700 elif attr_spec["type"] == 'nest-type-value':
701 decoded = self._decode_nest_type_value(attr, attr_spec)
702 else:
703 if not self.process_unknown:
704 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
705 decoded = self._decode_unknown(attr)
706
707 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
708
709 return rsp
710
711 def _decode_extack_path(self, attrs, attr_set, offset, target):
712 for attr in attrs:
713 try:
714 attr_spec = attr_set.attrs_by_val[attr.type]
715 except KeyError:
716 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
717 if offset > target:
718 break
719 if offset == target:
720 return '.' + attr_spec.name
721
722 if offset + attr.full_len <= target:
723 offset += attr.full_len
724 continue
725 if attr_spec['type'] != 'nest':
726 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
727 offset += 4
728 subpath = self._decode_extack_path(NlAttrs(attr.raw),
729 self.attr_sets[attr_spec['nested-attributes']],
730 offset, target)
731 if subpath is None:
732 return None
733 return '.' + attr_spec.name + subpath
734
735 return None
736
737 def _decode_extack(self, request, op, extack):
738 if 'bad-attr-offs' not in extack:
739 return
740
741 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set))
742 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
743 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
744 extack['bad-attr-offs'])
745 if path:
746 del extack['bad-attr-offs']
747 extack['bad-attr'] = path
748
749 def _struct_size(self, name):
750 if name:
751 members = self.consts[name].members
752 size = 0
753 for m in members:
754 if m.type in ['pad', 'binary']:
755 if m.struct:
756 size += self._struct_size(m.struct)
757 else:
758 size += m.len
759 else:
760 format = NlAttr.get_format(m.type, m.byte_order)
761 size += format.size
762 return size
763 else:
764 return 0
765
766 def _decode_struct(self, data, name):
767 members = self.consts[name].members
768 attrs = dict()
769 offset = 0
770 for m in members:
771 value = None
772 if m.type == 'pad':
773 offset += m.len
774 elif m.type == 'binary':
775 if m.struct:
776 len = self._struct_size(m.struct)
777 value = self._decode_struct(data[offset : offset + len],
778 m.struct)
779 offset += len
780 else:
781 value = data[offset : offset + m.len]
782 offset += m.len
783 else:
784 format = NlAttr.get_format(m.type, m.byte_order)
785 [ value ] = format.unpack_from(data, offset)
786 offset += format.size
787 if value is not None:
788 if m.enum:
789 value = self._decode_enum(value, m)
790 elif m.display_hint:
791 value = self._formatted_string(value, m.display_hint)
792 attrs[m.name] = value
793 return attrs
794
795 def _encode_struct(self, name, vals):
796 members = self.consts[name].members
797 attr_payload = b''
798 for m in members:
799 value = vals.pop(m.name) if m.name in vals else None
800 if m.type == 'pad':
801 attr_payload += bytearray(m.len)
802 elif m.type == 'binary':
803 if m.struct:
804 if value is None:
805 value = dict()
806 attr_payload += self._encode_struct(m.struct, value)
807 else:
808 if value is None:
809 attr_payload += bytearray(m.len)
810 else:
811 attr_payload += bytes.fromhex(value)
812 else:
813 if value is None:
814 value = 0
815 format = NlAttr.get_format(m.type, m.byte_order)
816 attr_payload += format.pack(value)
817 return attr_payload
818
819 def _formatted_string(self, raw, display_hint):
820 if display_hint == 'mac':
821 formatted = ':'.join('%02x' % b for b in raw)
822 elif display_hint == 'hex':
823 formatted = bytes.hex(raw, ' ')
824 elif display_hint in [ 'ipv4', 'ipv6' ]:
825 formatted = format(ipaddress.ip_address(raw))
826 elif display_hint == 'uuid':
827 formatted = str(uuid.UUID(bytes=raw))
828 else:
829 formatted = raw
830 return formatted
831
832 def handle_ntf(self, decoded):
833 msg = dict()
834 if self.include_raw:
835 msg['raw'] = decoded
836 op = self.rsp_by_value[decoded.cmd()]
837 attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
838 if op.fixed_header:
839 attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
840
841 msg['name'] = op['name']
842 msg['msg'] = attrs
843 self.async_msg_queue.append(msg)
844
845 def check_ntf(self):
846 while True:
847 try:
848 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
849 except BlockingIOError:
850 return
851
852 nms = NlMsgs(reply)
853 self._recv_dbg_print(reply, nms)
854 for nl_msg in nms:
855 if nl_msg.error:
856 print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
857 print(nl_msg)
858 continue
859 if nl_msg.done:
860 print("Netlink done while checking for ntf!?")
861 continue
862
863 decoded = self.nlproto.decode(self, nl_msg)
864 if decoded.cmd() not in self.async_msg_ids:
865 print("Unexpected msg id done while checking for ntf", decoded)
866 continue
867
868 self.handle_ntf(decoded)
869
870 def operation_do_attributes(self, name):
871 """
872 For a given operation name, find and return a supported
873 set of attributes (as a dict).
874 """
875 op = self.find_operation(name)
876 if not op:
877 return None
878
879 return op['do']['request']['attributes'].copy()
880
881 def _op(self, method, vals, flags=None, dump=False):
882 op = self.ops[method]
883
884 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
885 for flag in flags or []:
886 nl_flags |= flag
887 if dump:
888 nl_flags |= Netlink.NLM_F_DUMP
889
890 req_seq = random.randint(1024, 65535)
891 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
892 if op.fixed_header:
893 msg += self._encode_struct(op.fixed_header, vals)
894 search_attrs = SpaceAttrs(op.attr_set, vals)
895 for name, value in vals.items():
896 msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
897 msg = _genl_msg_finalize(msg)
898
899 self.sock.send(msg, 0)
900
901 done = False
902 rsp = []
903 while not done:
904 reply = self.sock.recv(self._recv_size)
905 nms = NlMsgs(reply, attr_space=op.attr_set)
906 self._recv_dbg_print(reply, nms)
907 for nl_msg in nms:
908 if nl_msg.extack:
909 self._decode_extack(msg, op, nl_msg.extack)
910
911 if nl_msg.error:
912 raise NlError(nl_msg)
913 if nl_msg.done:
914 if nl_msg.extack:
915 print("Netlink warning:")
916 print(nl_msg)
917 done = True
918 break
919
920 decoded = self.nlproto.decode(self, nl_msg)
921
922 # Check if this is a reply to our request
923 if nl_msg.nl_seq != req_seq or decoded.cmd() != op.rsp_value:
924 if decoded.cmd() in self.async_msg_ids:
925 self.handle_ntf(decoded)
926 continue
927 else:
928 print('Unexpected message: ' + repr(decoded))
929 continue
930
931 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
932 if op.fixed_header:
933 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
934 rsp.append(rsp_msg)
935
936 if not rsp:
937 return None
938 if not dump and len(rsp) == 1:
939 return rsp[0]
940 return rsp
941
942 def do(self, method, vals, flags=None):
943 return self._op(method, vals, flags)
944
945 def dump(self, method, vals):
946 return self._op(method, vals, [], dump=True)