Source code for ddf.mqtt_communication.message

import dataclasses
from abc import abstractmethod
from dataclasses import Field, asdict, dataclass
from enum import Enum
from pathlib import Path
from typing import Any
from uuid import UUID

import asn1tools
import numpy as np

from ddf.ant_link.loss_channel import FuturePastProbe, LossChannel

SPEC = asn1tools.compile_files([Path(__file__).parent.resolve() / "message.asn"], "per")


[docs] @dataclass class Message: sender: UUID recipient: UUID def __eq__(self, other: object) -> bool: if not isinstance(other, Message): return False for attr in self.__dict__: if isinstance(getattr(self, attr), np.ndarray): if not np.all(getattr(self, attr) == getattr(other, attr)): return False elif getattr(self, attr) != getattr(other, attr): return False return True def __hash__(self) -> int: return hash(self.__dict__)
[docs] def asdict(self) -> dict[str, Any]: """Convert the dataclass to a dictionary using dataclass's `asdict` method. This method will on top of `dataclass.asdict` convert data types for ASN.1 encoding. """ def convert_types(d: dict[str, Any]) -> dict[str, Any]: # noqa: PLR0912 # convert types to ASN.1 types keys = list(d.keys()) for key in keys: value = d[key] match value: case UUID(): d[key] = value.bytes case np.ndarray(): d[key] = value.tolist() case ControlCommand(): match value: case ControlCommand.GET_NODE_INFO: d[key] = "getNodeInfo" case ControlCommand.GET_NODE_CONTEXT: d[key] = "getNodeContext" case ControlCommand.FREEZE_ZERO_BACKPROP: d[key] = "freezeZeroBackprop" case ControlCommand.UPDATE_UNFREEZE_INFORMATION: d[key] = "updateUnfreezeInformation" case ControlCommand.ACCEPT: d[key] = "accept" case ControlCommand.DECLINE: d[key] = "decline" case _: raise ValueError() case FuturePastProbe(): d[key] = value.name.lower() case dict(): d[key] = convert_types(value) return d d = asdict(self) d = convert_types(d) # convert python names to ASN.1 names translations = self.translate() for field, asn1_field in translations: if field in d: d[asn1_field] = d[field] del d[field] return d
[docs] @classmethod def fromdict(cls, data: dict[str, Any]) -> "Message": # noqa: PLR0912 """Create Message dataclass instance from dictionary. Reverses the process of `asdict` and converts the dictionary back to a dataclass, converting data types from ASN.1 encoding back to Python types including numpy arrays. """ # convert ASN.1 names to python names translations = cls.translate() for field, asn1_field in translations: if asn1_field in data: data[field] = data[asn1_field] del data[asn1_field] else: raise ValueError() # convert types from ASN.1 types to python types keys = list(data.keys()) fields: dict[str, Field] = dict((f.name, f) for f in dataclasses.fields(cls)) for key in keys: value = data[key] if key not in fields: raise ValueError(f"Unknown field {key} in {cls.__name__}") f = fields[key] if f.type is UUID: data[key] = UUID(bytes=value) if f.type is np.ndarray: data[key] = np.array(value, dtype=float) if f.type is ControlCommand: match value: case "getNodeInfo": data[key] = ControlCommand.GET_NODE_INFO case "getNodeContext": data[key] = ControlCommand.GET_NODE_CONTEXT case "freezeZeroBackprop": data[key] = ControlCommand.FREEZE_ZERO_BACKPROP case "updateUnfreezeInformation": data[key] = ControlCommand.UPDATE_UNFREEZE_INFORMATION case "accept": data[key] = ControlCommand.ACCEPT case "decline": data[key] = ControlCommand.DECLINE case _: raise ValueError(f"Unknown value {value} for field {key} in {cls.__name__}") if f.type is LossChannel: if not isinstance(value, dict): raise ValueError(f"Expected dict for LossChannel, got {type(value)}") else: mode_str = value["mode"] match mode_str: case "future": mode = FuturePastProbe.FUTURE case "past": mode = FuturePastProbe.PAST case "probe": mode = FuturePastProbe.PROBE case _: raise ValueError(f"Unknown mode {mode_str} for LossChannel") lossref = UUID(bytes=value["lossref"]) if "lossref" in value else None name = value["name"] data[key] = LossChannel(mode, lossref, name) return cls(**data)
[docs] @classmethod @abstractmethod def translate(cls) -> list[tuple[str, str]]: """Translation between naming conventions of python and ASN.1. Returns a list of tuples, where each tuple contains the name of the field in the dataclass and the name of the field in the ASN.1 specification. """ raise NotImplementedError()
[docs] @classmethod @abstractmethod def get_asn1_content_choice(cls) -> str: """Returns the name of the content choice in the ASN.1 specification. Should be one of info, gradient, nodeinfo, proberesponse, nodecontext, addneighbor, control. """ raise NotImplementedError()
[docs] @abstractmethod def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str: """Returns the topic to which the message should be sent.""" raise NotImplementedError()
[docs] @dataclass(eq=False) class InfoMessage(Message): reference: UUID mean: np.ndarray covariance: np.ndarray update_number: int
[docs] @classmethod def translate(cls) -> list[tuple[str, str]]: return [("update_number", "updateNumber")]
[docs] @classmethod def get_asn1_content_choice(cls) -> str: return "info"
[docs] def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str: sender = self.sender if not sender_is_lossnode else "lossnode" recipient = self.recipient if not recipient_is_lossnode else "lossnode" return f"{recipient}/info/{sender}/{self.reference}"
[docs] @dataclass(eq=False) class GradientMessage(Message): reference: UUID gradient: np.ndarray loss_channel: LossChannel update_number: int
[docs] @classmethod def translate(cls) -> list[tuple[str, str]]: return [("loss_channel", "lossChannel"), ("update_number", "updateNumber")]
[docs] @classmethod def get_asn1_content_choice(cls) -> str: return "gradient"
[docs] def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str: sender = self.sender if not sender_is_lossnode else "lossnode" recipient = self.recipient if not recipient_is_lossnode else "lossnode" return f"{recipient}/gradient/{sender}/{self.reference}"
[docs] def asdict(self) -> dict[str, Any]: d = super().asdict() if self.loss_channel.lossref is None: del d["lossChannel"]["lossref"] return d
[docs] @dataclass(eq=False) class NodeInfoMessage(Message): available_list: list[str] exchange_list: list[str] alias: str req_id: UUID
[docs] @classmethod def translate(cls) -> list[tuple[str, str]]: return [("available_list", "availableList"), ("exchange_list", "exchangeList"), ("req_id", "reqId")]
[docs] @classmethod def get_asn1_content_choice(cls) -> str: return "nodeinfo"
[docs] def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str: recipient = self.recipient if not recipient_is_lossnode else "lossnode" return f"{recipient}/responses"
[docs] @dataclass(eq=False) class ProbeResponseMessage(Message): reference: UUID gradient: np.ndarray loss_channel: LossChannel
[docs] @classmethod def translate(cls) -> list[tuple[str, str]]: return [("loss_channel", "lossChannel")]
[docs] @classmethod def get_asn1_content_choice(cls) -> str: return "proberesponse"
[docs] def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str: sender = self.sender if not sender_is_lossnode else "lossnode" recipient = self.recipient if not recipient_is_lossnode else "lossnode" return f"{recipient}/proberesponse/{sender}/{self.reference}"
[docs] @dataclass(eq=False) class NodeContextMessage(Message): fixed_context: str user_context: str req_id: UUID
[docs] @classmethod def translate(cls) -> list[tuple[str, str]]: return [("fixed_context", "fixedContext"), ("user_context", "userContext"), ("req_id", "reqId")]
[docs] @classmethod def get_asn1_content_choice(cls) -> str: return "nodecontext"
[docs] def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str: recipient = self.recipient if not recipient_is_lossnode else "lossnode" return f"{recipient}/responses"
[docs] @dataclass(eq=False) class AddNeighborMessage(Message): neighbor: UUID operator: np.ndarray req_id: UUID
[docs] @classmethod def translate(cls) -> list[tuple[str, str]]: return [("req_id", "reqId")]
[docs] @classmethod def get_asn1_content_choice(cls) -> str: return "addneighbor"
[docs] def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str: raise NotImplementedError()
[docs] class ControlCommand(Enum): GET_NODE_INFO = 0 GET_NODE_CONTEXT = 1 FREEZE_ZERO_BACKPROP = 2 UPDATE_UNFREEZE_INFORMATION = 3 ACCEPT = 4 DECLINE = 5
[docs] @dataclass(eq=False) class ControlMessage(Message): req_id: UUID command: ControlCommand
[docs] @classmethod def translate(cls) -> list[tuple[str, str]]: return [("req_id", "reqId")]
[docs] @classmethod def get_asn1_content_choice(cls) -> str: return "control"
[docs] def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str: if self.command in [ControlCommand.FREEZE_ZERO_BACKPROP, ControlCommand.UPDATE_UNFREEZE_INFORMATION]: return f"lossnode/control_msgs_royal/{self.recipient}" else: return f"lossnode/control_msgs_public/{self.recipient}"
[docs] def encode(message: Message) -> bytes: dict_data: dict[str, Any] = message.asdict() sender = dict_data["sender"] recipient = dict_data["recipient"] del dict_data["sender"] del dict_data["recipient"] data = {"sender": sender, "recipient": recipient, "content": (message.get_asn1_content_choice(), dict_data)} encoded = SPEC.encode("Message", data) assert isinstance(encoded, bytes), "Encoded message should be of type bytes" return encoded
[docs] def decode(encoded: bytes) -> Message | None: try: data = SPEC.decode("Message", encoded) except asn1tools.Error as e: print(f"Error decoding message: {e}") return None sender = data["sender"] recipient = data["recipient"] choice, content = data["content"] content["sender"] = sender content["recipient"] = recipient result = None message_types: list[type[Message]] = [ InfoMessage, GradientMessage, NodeInfoMessage, ProbeResponseMessage, NodeContextMessage, AddNeighborMessage, ControlMessage, ] for m_type in message_types: if choice == m_type.get_asn1_content_choice(): result = m_type.fromdict(content) break if result is None: raise ValueError(f"Unknown message type: {choice}") return result