import dataclasses
from abc import abstractmethod
from dataclasses import Field, asdict, dataclass
from enum import Enum
from pathlib import Path
from typing import Any
from uuid import UUID
import asn1tools
import numpy as np
from ddf.ant_link.loss_channel import FuturePastProbe, LossChannel
SPEC = asn1tools.compile_files([Path(__file__).parent.resolve() / "message.asn"], "per")
[docs]
@dataclass
class Message:
sender: UUID
recipient: UUID
def __eq__(self, other: object) -> bool:
if not isinstance(other, Message):
return False
for attr in self.__dict__:
if isinstance(getattr(self, attr), np.ndarray):
if not np.all(getattr(self, attr) == getattr(other, attr)):
return False
elif getattr(self, attr) != getattr(other, attr):
return False
return True
def __hash__(self) -> int:
return hash(self.__dict__)
[docs]
def asdict(self) -> dict[str, Any]:
"""Convert the dataclass to a dictionary using dataclass's `asdict` method.
This method will on top of `dataclass.asdict` convert data types for
ASN.1 encoding.
"""
def convert_types(d: dict[str, Any]) -> dict[str, Any]: # noqa: PLR0912
# convert types to ASN.1 types
keys = list(d.keys())
for key in keys:
value = d[key]
match value:
case UUID():
d[key] = value.bytes
case np.ndarray():
d[key] = value.tolist()
case ControlCommand():
match value:
case ControlCommand.GET_NODE_INFO:
d[key] = "getNodeInfo"
case ControlCommand.GET_NODE_CONTEXT:
d[key] = "getNodeContext"
case ControlCommand.FREEZE_ZERO_BACKPROP:
d[key] = "freezeZeroBackprop"
case ControlCommand.UPDATE_UNFREEZE_INFORMATION:
d[key] = "updateUnfreezeInformation"
case ControlCommand.ACCEPT:
d[key] = "accept"
case ControlCommand.DECLINE:
d[key] = "decline"
case _:
raise ValueError()
case FuturePastProbe():
d[key] = value.name.lower()
case dict():
d[key] = convert_types(value)
return d
d = asdict(self)
d = convert_types(d)
# convert python names to ASN.1 names
translations = self.translate()
for field, asn1_field in translations:
if field in d:
d[asn1_field] = d[field]
del d[field]
return d
[docs]
@classmethod
def fromdict(cls, data: dict[str, Any]) -> "Message": # noqa: PLR0912
"""Create Message dataclass instance from dictionary.
Reverses the process of `asdict` and converts the dictionary back to a dataclass,
converting data types from ASN.1 encoding back to Python types including numpy arrays.
"""
# convert ASN.1 names to python names
translations = cls.translate()
for field, asn1_field in translations:
if asn1_field in data:
data[field] = data[asn1_field]
del data[asn1_field]
else:
raise ValueError()
# convert types from ASN.1 types to python types
keys = list(data.keys())
fields: dict[str, Field] = dict((f.name, f) for f in dataclasses.fields(cls))
for key in keys:
value = data[key]
if key not in fields:
raise ValueError(f"Unknown field {key} in {cls.__name__}")
f = fields[key]
if f.type is UUID:
data[key] = UUID(bytes=value)
if f.type is np.ndarray:
data[key] = np.array(value, dtype=float)
if f.type is ControlCommand:
match value:
case "getNodeInfo":
data[key] = ControlCommand.GET_NODE_INFO
case "getNodeContext":
data[key] = ControlCommand.GET_NODE_CONTEXT
case "freezeZeroBackprop":
data[key] = ControlCommand.FREEZE_ZERO_BACKPROP
case "updateUnfreezeInformation":
data[key] = ControlCommand.UPDATE_UNFREEZE_INFORMATION
case "accept":
data[key] = ControlCommand.ACCEPT
case "decline":
data[key] = ControlCommand.DECLINE
case _:
raise ValueError(f"Unknown value {value} for field {key} in {cls.__name__}")
if f.type is LossChannel:
if not isinstance(value, dict):
raise ValueError(f"Expected dict for LossChannel, got {type(value)}")
else:
mode_str = value["mode"]
match mode_str:
case "future":
mode = FuturePastProbe.FUTURE
case "past":
mode = FuturePastProbe.PAST
case "probe":
mode = FuturePastProbe.PROBE
case _:
raise ValueError(f"Unknown mode {mode_str} for LossChannel")
lossref = UUID(bytes=value["lossref"]) if "lossref" in value else None
name = value["name"]
data[key] = LossChannel(mode, lossref, name)
return cls(**data)
[docs]
@classmethod
@abstractmethod
def translate(cls) -> list[tuple[str, str]]:
"""Translation between naming conventions of python and ASN.1.
Returns a list of tuples, where each tuple contains the name of the field in the
dataclass and the name of the field in the ASN.1 specification.
"""
raise NotImplementedError()
[docs]
@classmethod
@abstractmethod
def get_asn1_content_choice(cls) -> str:
"""Returns the name of the content choice in the ASN.1 specification.
Should be one of info, gradient, nodeinfo, proberesponse, nodecontext, addneighbor, control.
"""
raise NotImplementedError()
[docs]
@abstractmethod
def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str:
"""Returns the topic to which the message should be sent."""
raise NotImplementedError()
[docs]
@dataclass(eq=False)
class InfoMessage(Message):
reference: UUID
mean: np.ndarray
covariance: np.ndarray
update_number: int
[docs]
@classmethod
def translate(cls) -> list[tuple[str, str]]:
return [("update_number", "updateNumber")]
[docs]
@classmethod
def get_asn1_content_choice(cls) -> str:
return "info"
[docs]
def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str:
sender = self.sender if not sender_is_lossnode else "lossnode"
recipient = self.recipient if not recipient_is_lossnode else "lossnode"
return f"{recipient}/info/{sender}/{self.reference}"
[docs]
@dataclass(eq=False)
class GradientMessage(Message):
reference: UUID
gradient: np.ndarray
loss_channel: LossChannel
update_number: int
[docs]
@classmethod
def translate(cls) -> list[tuple[str, str]]:
return [("loss_channel", "lossChannel"), ("update_number", "updateNumber")]
[docs]
@classmethod
def get_asn1_content_choice(cls) -> str:
return "gradient"
[docs]
def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str:
sender = self.sender if not sender_is_lossnode else "lossnode"
recipient = self.recipient if not recipient_is_lossnode else "lossnode"
return f"{recipient}/gradient/{sender}/{self.reference}"
[docs]
def asdict(self) -> dict[str, Any]:
d = super().asdict()
if self.loss_channel.lossref is None:
del d["lossChannel"]["lossref"]
return d
[docs]
@dataclass(eq=False)
class NodeInfoMessage(Message):
available_list: list[str]
exchange_list: list[str]
alias: str
req_id: UUID
[docs]
@classmethod
def translate(cls) -> list[tuple[str, str]]:
return [("available_list", "availableList"), ("exchange_list", "exchangeList"), ("req_id", "reqId")]
[docs]
@classmethod
def get_asn1_content_choice(cls) -> str:
return "nodeinfo"
[docs]
def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str:
recipient = self.recipient if not recipient_is_lossnode else "lossnode"
return f"{recipient}/responses"
[docs]
@dataclass(eq=False)
class ProbeResponseMessage(Message):
reference: UUID
gradient: np.ndarray
loss_channel: LossChannel
[docs]
@classmethod
def translate(cls) -> list[tuple[str, str]]:
return [("loss_channel", "lossChannel")]
[docs]
@classmethod
def get_asn1_content_choice(cls) -> str:
return "proberesponse"
[docs]
def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str:
sender = self.sender if not sender_is_lossnode else "lossnode"
recipient = self.recipient if not recipient_is_lossnode else "lossnode"
return f"{recipient}/proberesponse/{sender}/{self.reference}"
[docs]
@dataclass(eq=False)
class NodeContextMessage(Message):
fixed_context: str
user_context: str
req_id: UUID
[docs]
@classmethod
def translate(cls) -> list[tuple[str, str]]:
return [("fixed_context", "fixedContext"), ("user_context", "userContext"), ("req_id", "reqId")]
[docs]
@classmethod
def get_asn1_content_choice(cls) -> str:
return "nodecontext"
[docs]
def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str:
recipient = self.recipient if not recipient_is_lossnode else "lossnode"
return f"{recipient}/responses"
[docs]
@dataclass(eq=False)
class AddNeighborMessage(Message):
neighbor: UUID
operator: np.ndarray
req_id: UUID
[docs]
@classmethod
def translate(cls) -> list[tuple[str, str]]:
return [("req_id", "reqId")]
[docs]
@classmethod
def get_asn1_content_choice(cls) -> str:
return "addneighbor"
[docs]
def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str:
raise NotImplementedError()
[docs]
class ControlCommand(Enum):
GET_NODE_INFO = 0
GET_NODE_CONTEXT = 1
FREEZE_ZERO_BACKPROP = 2
UPDATE_UNFREEZE_INFORMATION = 3
ACCEPT = 4
DECLINE = 5
[docs]
@dataclass(eq=False)
class ControlMessage(Message):
req_id: UUID
command: ControlCommand
[docs]
@classmethod
def translate(cls) -> list[tuple[str, str]]:
return [("req_id", "reqId")]
[docs]
@classmethod
def get_asn1_content_choice(cls) -> str:
return "control"
[docs]
def get_topic(self, sender_is_lossnode: bool, recipient_is_lossnode: bool) -> str:
if self.command in [ControlCommand.FREEZE_ZERO_BACKPROP, ControlCommand.UPDATE_UNFREEZE_INFORMATION]:
return f"lossnode/control_msgs_royal/{self.recipient}"
else:
return f"lossnode/control_msgs_public/{self.recipient}"
[docs]
def encode(message: Message) -> bytes:
dict_data: dict[str, Any] = message.asdict()
sender = dict_data["sender"]
recipient = dict_data["recipient"]
del dict_data["sender"]
del dict_data["recipient"]
data = {"sender": sender, "recipient": recipient, "content": (message.get_asn1_content_choice(), dict_data)}
encoded = SPEC.encode("Message", data)
assert isinstance(encoded, bytes), "Encoded message should be of type bytes"
return encoded
[docs]
def decode(encoded: bytes) -> Message | None:
try:
data = SPEC.decode("Message", encoded)
except asn1tools.Error as e:
print(f"Error decoding message: {e}")
return None
sender = data["sender"]
recipient = data["recipient"]
choice, content = data["content"]
content["sender"] = sender
content["recipient"] = recipient
result = None
message_types: list[type[Message]] = [
InfoMessage,
GradientMessage,
NodeInfoMessage,
ProbeResponseMessage,
NodeContextMessage,
AddNeighborMessage,
ControlMessage,
]
for m_type in message_types:
if choice == m_type.get_asn1_content_choice():
result = m_type.fromdict(content)
break
if result is None:
raise ValueError(f"Unknown message type: {choice}")
return result