Source code for ddf.mqtt_communication.communicator

"""Communicator using ASN1 as transmission file format."""

from queue import Empty
from uuid import UUID

import numpy as np
from loguru import logger


from .. import message as ddf_message
from ..communicator import Communicator
from ..ddf import NodeId, ObjectId
from ..information import ExchangeVector
from . import message as asn1_message
from .internal import InternalCommunicator


[docs] def translate_mode(mode: ddf_message.ControlCommand) -> asn1_message.ControlCommand: """Map a ControlCommand from ddf to asn1.""" match mode: case ddf_message.ControlCommand.FZB: return asn1_message.ControlCommand.UPDATE_UNFREEZE_INFORMATION case ddf_message.ControlCommand.UUI: return asn1_message.ControlCommand.FREEZE_ZERO_BACKPROP case _: raise ValueError(f"Unknown message processing mode: {mode}")
[docs] def translate_mode_inv(command: asn1_message.ControlCommand) -> ddf_message.ControlCommand: """Map a ControlCommand from asn1 to ddf. Inverse of `translate_mode`. """ match command: case asn1_message.ControlCommand.UPDATE_UNFREEZE_INFORMATION: return ddf_message.ControlCommand.FZB case asn1_message.ControlCommand.FREEZE_ZERO_BACKPROP: return ddf_message.ControlCommand.UUI case _: raise ValueError(f"Unknown control command: {command}")
[docs] class AntCommunicatorMqtt(Communicator): """Implementation of `Communicator` for ASN1.""" _internal: InternalCommunicator def __init__( # noqa: PLR0913 self, client_id: UUID, base_topic: str, alias: str, mqtt_host: str, mqtt_port: int = 1883, autorun: bool = True ) -> None: self._internal = InternalCommunicator(client_id, base_topic, alias, mqtt_host, mqtt_port, autorun)
[docs] def send_message(self, msg: ddf_message.Message) -> None: logger.debug(f"Sending message: {msg}") match msg: case ddf_message.InfoMessage(): asn1_msg: asn1_message.Message = asn1_message.InfoMessage( msg.sender.uuid, msg.recipient.uuid, msg.object_id.uuid, np.atleast_1d(np.squeeze(msg.state.mean)), msg.state.covariance, msg.message_counter, ) self._internal.send_message(asn1_msg) case ddf_message.GradientMessage(): update_number = 0 asn1_msg = asn1_message.GradientMessage( msg.sender.uuid, msg.recipient.uuid, msg.object_id.uuid, np.atleast_1d(np.squeeze(msg.gradient)), msg.loss_channel, update_number, ) self._internal.send_message(asn1_msg) case ddf_message.ControlMessage(): cmd = translate_mode(msg.command) if cmd is None: logger.warning(f"Unknown command: {msg.command}") return asn1_msg = asn1_message.ControlMessage( msg.sender.uuid, msg.recipient.uuid, msg.request_id, cmd, ) self._internal.send_message(asn1_msg) case ddf_message.NodeInfoMessage(): asn1_msg = asn1_message.NodeInfoMessage( msg.sender.uuid, msg.recipient.uuid, msg.available, msg.exchange, msg.alias, msg.req_id ) self._internal.send_message(asn1_msg) case ddf_message.ProbeResponseMessage(): asn1_msg = asn1_message.ProbeResponseMessage( msg.sender.uuid, msg.recipient.uuid, msg.reference.uuid, msg.gradient, msg.loss_channel, ) self._internal.send_message(asn1_msg) case ddf_message.NodeContextMessage(): asn1_msg = asn1_message.NodeContextMessage( msg.sender.uuid, msg.recipient.uuid, msg.fixed_context, msg.user_context, msg.req_id ) self._internal.send_message(asn1_msg) case _: logger.warning(f"Unknown message type: {type(msg)}")
[docs] def get_messages(self) -> list[ddf_message.Message]: msgs: list[asn1_message.Message] = [] while True: try: msg = self._internal._queue.get_nowait() msgs.append(msg) except Empty: break ddf_msgs = [] for msg in msgs: match msg: case asn1_message.InfoMessage(): ddf_msg: ddf_message.Message = ddf_message.InfoMessage( NodeId(msg.sender), NodeId(msg.recipient), ObjectId(msg.reference), ExchangeVector(msg.mean, msg.covariance), msg.update_number, ) case asn1_message.GradientMessage(): ddf_msg = ddf_message.GradientMessage( NodeId(msg.sender), NodeId(msg.recipient), ObjectId(msg.reference), np.atleast_1d(msg.gradient), msg.loss_channel, ) case asn1_message.ControlMessage(): mode = translate_mode_inv(msg.command) if mode is None: logger.warning(f"Unknown command: {msg.command}") continue ddf_msg = ddf_message.ControlMessage( NodeId(msg.sender), NodeId(msg.recipient), msg.req_id, mode, ) case asn1_message.NodeInfoMessage(): ddf_msg = ddf_message.NodeInfoMessage( NodeId(msg.sender), NodeId(msg.recipient), msg.available_list, msg.exchange_list, msg.alias, msg.req_id, ) case asn1_message.ProbeResponseMessage(): ddf_msg = ddf_message.ProbeResponseMessage( NodeId(msg.sender), NodeId(msg.recipient), ObjectId(msg.reference), msg.gradient, msg.loss_channel, ) case asn1_message.NodeContextMessage(): ddf_msg = ddf_message.NodeContextMessage( NodeId(msg.sender), NodeId(msg.recipient), msg.fixed_context, msg.user_context, msg.req_id ) case asn1_message.AddNeighborMessage(): ddf_msg = ddf_message.AddNeighborMessage( NodeId(msg.sender), NodeId(msg.recipient), NodeId(msg.neighbor), msg.req_id ) case _: logger.warning(f"Unknown message type: {type(msg)}") continue ddf_msgs.append(ddf_msg) return ddf_msgs