Source code for ddf.mqtt_communication.internal

from dataclasses import dataclass, field
from queue import Queue, ShutDown
from time import sleep
from uuid import UUID

import paho.mqtt.client as mqtt
from loguru import logger
from paho.mqtt.enums import CallbackAPIVersion
from paho.mqtt.properties import Properties
from paho.mqtt.reasoncodes import ReasonCode

from . import message as asn1_message

# _all_ messages should be qos=2
QOS_REQUIRED = 2
PUBLISH_TIMEOUT = 30  # seconds


[docs] @dataclass class InternalCommunicator: client_id: UUID """Node Id.""" base_topic: str """Base topic for all connected nodes.""" alias: str """Human readable name for this node.""" mqtt_host: str mqtt_port: int = 1883 autorun: bool = True _lossnode_id: UUID | None = None _queue: Queue[asn1_message.Message] = field(default_factory=lambda: Queue()) _client: mqtt.Client = field(init=False, repr=False) def __post_init__(self) -> None: self._client = mqtt.Client(callback_api_version=CallbackAPIVersion.VERSION2, client_id=str(self.client_id)) self._client.on_connect = self._on_connect # type: ignore self._client.on_connect_fail = self._on_connect_fail self._client.on_message = self._on_message self._client.on_subscribe = self._on_subscribe self._client.on_unsubscribe = self._on_unsubscribe # type: ignore if self.autorun: self.start() wait_max = 5 while not self._client.is_connected() and wait_max > 0: logger.info("Waiting for MQTT client to connect...") sleep(1) wait_max -= 1 else: logger.warning("Autorun is disabled. Call start() manually.") def __str__(self) -> str: return f"{self.alias} ({self.client_id})" def _on_connect( self, client: mqtt.Client, userdata: None, connect_flags: mqtt.ConnectFlags, reason_code: ReasonCode, properties: Properties, ) -> None: if reason_code.is_failure: logger.error(f"{self}: failed to connect. {reason_code=}.") else: logger.info(f"{self}: connected to MQTT broker. ({connect_flags.session_present=})") # we should always subscribe from _on_connect callback to be sure that # our subscriptions are the same for any (re)connection. topics = [ f"{self.client_id}/#", "lossnode/control_msgs_royal", f"lossnode/control_msgs_royal/{self.client_id}", "lossnode/control_msgs_public", f"lossnode/control_msgs_public/{self.client_id}", ] subs: list[tuple[str, int]] = [(f"{self.base_topic}/{topic}", QOS_REQUIRED) for topic in topics] client.subscribe(subs) def _on_connect_fail(self, client: mqtt.Client, userdata: None) -> None: logger.error(f"Connection failed for client {self.alias} ({self.client_id})") def _on_message(self, client: mqtt.Client, userdata: None, msg: mqtt.MQTTMessage) -> None: """Callback function to handle incoming messages. This will deserialize the message into its specific kind of Message subclass and put it into the queue. """ logger.debug(f"{self.alias} received message: {msg.topic=}, {msg.payload=}, {msg.qos=}") topic = msg.topic assert topic.startswith(self.base_topic), f"Topic does not start with {self.base_topic}: {topic=}" payload = msg.payload if msg.qos != QOS_REQUIRED: warning = f"QoS is not {QOS_REQUIRED}: {topic=}, {payload=}" logger.warning(warning) try: m = asn1_message.decode(payload) if m is None: logger.warning(f"Could not parse message: {topic=}, {payload=}") return # check if the message is for us, TODO: we should do this via subscriptions if m.recipient != self.client_id: return logger.debug(f"{self.alias}: queue message: {m=}") self._queue.put(m) except NotImplementedError: pass except ValueError as error: logger.warning(f"Could not parse message: {topic=}, {payload=}, error: {error}") except ShutDown: logger.error("queue shut down") def _on_subscribe( self, client: mqtt.Client, userdata: None, mid: int, reason_code_list: list[ReasonCode], properties: Properties ) -> None: # Since we subscribed only for a single channel, reason_code_list contains # a single entry if reason_code_list[0].is_failure: logger.warning(f"{self}: broker rejected your subscription: {reason_code_list[0]}") else: logger.debug(f"{self}: broker granted the following QoS: {reason_code_list[0].value}") def _on_unsubscribe( self, client: mqtt.Client, userdata: None, mid: int, reason_code_list: list[ReasonCode], properties: Properties ) -> None: # Be careful, the reason_code_list is only present in MQTTv5. # In MQTTv3 it will always be empty if len(reason_code_list) == 0 or not reason_code_list[0].is_failure: logger.debug("unsubscribe succeeded (if SUBACK is received in MQTTv3 it success)") else: logger.warning(f"Broker replied with failure: {reason_code_list[0]}") client.disconnect()
[docs] def start(self) -> None: """Connect and start client loop.""" self._client.connect(host=self.mqtt_host, port=self.mqtt_port) logger.debug(f"Starting MQTT client loop for {self}") self._client.loop_start()
[docs] def stop(self) -> None: self._client.loop_stop()
[docs] def send_message(self, msg: asn1_message.Message) -> None: if msg.sender != self.client_id: logger.error(f"Sender {msg.sender} does not match client id {self.client_id}") return sender_is_lossnode = self.client_id == self._lossnode_id recipient_is_lossnode = msg.recipient == self._lossnode_id topic = f"{self.base_topic}/{msg.get_topic(sender_is_lossnode, recipient_is_lossnode)}" payload = asn1_message.encode(msg) info = self._client.publish(topic=topic, payload=payload, qos=QOS_REQUIRED) info.wait_for_publish(PUBLISH_TIMEOUT)
[docs] def test_send(self, msg_text: str, msg_topic: str) -> None: topic = msg_topic payload = msg_text info = self._client.publish(topic=topic, payload=payload, qos=QOS_REQUIRED) info.wait_for_publish(PUBLISH_TIMEOUT)