"""Communicator using ASN1 as transmission file format."""
from queue import Empty
from uuid import UUID
import numpy as np
from loguru import logger
from .. import message as ddf_message
from ..communicator import Communicator
from ..ddf import NodeId, ObjectId
from ..information import ExchangeVector
from . import message as asn1_message
from .internal import InternalCommunicator
[docs]
def translate_mode(mode: ddf_message.ControlCommand) -> asn1_message.ControlCommand:
"""Map a ControlCommand from ddf to asn1."""
match mode:
case ddf_message.ControlCommand.FZB:
return asn1_message.ControlCommand.UPDATE_UNFREEZE_INFORMATION
case ddf_message.ControlCommand.UUI:
return asn1_message.ControlCommand.FREEZE_ZERO_BACKPROP
case _:
raise ValueError(f"Unknown message processing mode: {mode}")
[docs]
def translate_mode_inv(command: asn1_message.ControlCommand) -> ddf_message.ControlCommand:
"""Map a ControlCommand from asn1 to ddf.
Inverse of `translate_mode`.
"""
match command:
case asn1_message.ControlCommand.UPDATE_UNFREEZE_INFORMATION:
return ddf_message.ControlCommand.FZB
case asn1_message.ControlCommand.FREEZE_ZERO_BACKPROP:
return ddf_message.ControlCommand.UUI
case _:
raise ValueError(f"Unknown control command: {command}")
[docs]
class AntCommunicatorMqtt(Communicator):
"""Implementation of `Communicator` for ASN1."""
_internal: InternalCommunicator
def __init__( # noqa: PLR0913
self, client_id: UUID, base_topic: str, alias: str, mqtt_host: str, mqtt_port: int = 1883, autorun: bool = True
) -> None:
self._internal = InternalCommunicator(client_id, base_topic, alias, mqtt_host, mqtt_port, autorun)
[docs]
def send_message(self, msg: ddf_message.Message) -> None:
logger.debug(f"Sending message: {msg}")
match msg:
case ddf_message.InfoMessage():
asn1_msg: asn1_message.Message = asn1_message.InfoMessage(
msg.sender.uuid,
msg.recipient.uuid,
msg.object_id.uuid,
np.atleast_1d(np.squeeze(msg.state.mean)),
msg.state.covariance,
msg.message_counter,
)
self._internal.send_message(asn1_msg)
case ddf_message.GradientMessage():
update_number = 0
asn1_msg = asn1_message.GradientMessage(
msg.sender.uuid,
msg.recipient.uuid,
msg.object_id.uuid,
np.atleast_1d(np.squeeze(msg.gradient)),
msg.loss_channel,
update_number,
)
self._internal.send_message(asn1_msg)
case ddf_message.ControlMessage():
cmd = translate_mode(msg.command)
if cmd is None:
logger.warning(f"Unknown command: {msg.command}")
return
asn1_msg = asn1_message.ControlMessage(
msg.sender.uuid,
msg.recipient.uuid,
msg.request_id,
cmd,
)
self._internal.send_message(asn1_msg)
case ddf_message.NodeInfoMessage():
asn1_msg = asn1_message.NodeInfoMessage(
msg.sender.uuid, msg.recipient.uuid, msg.available, msg.exchange, msg.alias, msg.req_id
)
self._internal.send_message(asn1_msg)
case ddf_message.ProbeResponseMessage():
asn1_msg = asn1_message.ProbeResponseMessage(
msg.sender.uuid,
msg.recipient.uuid,
msg.reference.uuid,
msg.gradient,
msg.loss_channel,
)
self._internal.send_message(asn1_msg)
case ddf_message.NodeContextMessage():
asn1_msg = asn1_message.NodeContextMessage(
msg.sender.uuid, msg.recipient.uuid, msg.fixed_context, msg.user_context, msg.req_id
)
self._internal.send_message(asn1_msg)
case _:
logger.warning(f"Unknown message type: {type(msg)}")
[docs]
def get_messages(self) -> list[ddf_message.Message]:
msgs: list[asn1_message.Message] = []
while True:
try:
msg = self._internal._queue.get_nowait()
msgs.append(msg)
except Empty:
break
ddf_msgs = []
for msg in msgs:
match msg:
case asn1_message.InfoMessage():
ddf_msg: ddf_message.Message = ddf_message.InfoMessage(
NodeId(msg.sender),
NodeId(msg.recipient),
ObjectId(msg.reference),
ExchangeVector(msg.mean, msg.covariance),
msg.update_number,
)
case asn1_message.GradientMessage():
ddf_msg = ddf_message.GradientMessage(
NodeId(msg.sender),
NodeId(msg.recipient),
ObjectId(msg.reference),
np.atleast_1d(msg.gradient),
msg.loss_channel,
)
case asn1_message.ControlMessage():
mode = translate_mode_inv(msg.command)
if mode is None:
logger.warning(f"Unknown command: {msg.command}")
continue
ddf_msg = ddf_message.ControlMessage(
NodeId(msg.sender),
NodeId(msg.recipient),
msg.req_id,
mode,
)
case asn1_message.NodeInfoMessage():
ddf_msg = ddf_message.NodeInfoMessage(
NodeId(msg.sender),
NodeId(msg.recipient),
msg.available_list,
msg.exchange_list,
msg.alias,
msg.req_id,
)
case asn1_message.ProbeResponseMessage():
ddf_msg = ddf_message.ProbeResponseMessage(
NodeId(msg.sender),
NodeId(msg.recipient),
ObjectId(msg.reference),
msg.gradient,
msg.loss_channel,
)
case asn1_message.NodeContextMessage():
ddf_msg = ddf_message.NodeContextMessage(
NodeId(msg.sender), NodeId(msg.recipient), msg.fixed_context, msg.user_context, msg.req_id
)
case asn1_message.AddNeighborMessage():
ddf_msg = ddf_message.AddNeighborMessage(
NodeId(msg.sender), NodeId(msg.recipient), NodeId(msg.neighbor), msg.req_id
)
case _:
logger.warning(f"Unknown message type: {type(msg)}")
continue
ddf_msgs.append(ddf_msg)
return ddf_msgs