from dataclasses import dataclass, field
from queue import Queue, ShutDown
from time import sleep
from uuid import UUID
import paho.mqtt.client as mqtt
from loguru import logger
from paho.mqtt.enums import CallbackAPIVersion
from paho.mqtt.properties import Properties
from paho.mqtt.reasoncodes import ReasonCode
from . import message as asn1_message
# _all_ messages should be qos=2
QOS_REQUIRED = 2
PUBLISH_TIMEOUT = 30 # seconds
[docs]
@dataclass
class InternalCommunicator:
client_id: UUID
"""Node Id."""
base_topic: str
"""Base topic for all connected nodes."""
alias: str
"""Human readable name for this node."""
mqtt_host: str
mqtt_port: int = 1883
autorun: bool = True
_lossnode_id: UUID | None = None
_queue: Queue[asn1_message.Message] = field(default_factory=lambda: Queue())
_client: mqtt.Client = field(init=False, repr=False)
def __post_init__(self) -> None:
self._client = mqtt.Client(callback_api_version=CallbackAPIVersion.VERSION2, client_id=str(self.client_id))
self._client.on_connect = self._on_connect # type: ignore
self._client.on_connect_fail = self._on_connect_fail
self._client.on_message = self._on_message
self._client.on_subscribe = self._on_subscribe
self._client.on_unsubscribe = self._on_unsubscribe # type: ignore
if self.autorun:
self.start()
wait_max = 5
while not self._client.is_connected() and wait_max > 0:
logger.info("Waiting for MQTT client to connect...")
sleep(1)
wait_max -= 1
else:
logger.warning("Autorun is disabled. Call start() manually.")
def __str__(self) -> str:
return f"{self.alias} ({self.client_id})"
def _on_connect(
self,
client: mqtt.Client,
userdata: None,
connect_flags: mqtt.ConnectFlags,
reason_code: ReasonCode,
properties: Properties,
) -> None:
if reason_code.is_failure:
logger.error(f"{self}: failed to connect. {reason_code=}.")
else:
logger.info(f"{self}: connected to MQTT broker. ({connect_flags.session_present=})")
# we should always subscribe from _on_connect callback to be sure that
# our subscriptions are the same for any (re)connection.
topics = [
f"{self.client_id}/#",
"lossnode/control_msgs_royal",
f"lossnode/control_msgs_royal/{self.client_id}",
"lossnode/control_msgs_public",
f"lossnode/control_msgs_public/{self.client_id}",
]
subs: list[tuple[str, int]] = [(f"{self.base_topic}/{topic}", QOS_REQUIRED) for topic in topics]
client.subscribe(subs)
def _on_connect_fail(self, client: mqtt.Client, userdata: None) -> None:
logger.error(f"Connection failed for client {self.alias} ({self.client_id})")
def _on_message(self, client: mqtt.Client, userdata: None, msg: mqtt.MQTTMessage) -> None:
"""Callback function to handle incoming messages.
This will deserialize the message into its specific kind of Message subclass and put it into the queue.
"""
logger.debug(f"{self.alias} received message: {msg.topic=}, {msg.payload=}, {msg.qos=}")
topic = msg.topic
assert topic.startswith(self.base_topic), f"Topic does not start with {self.base_topic}: {topic=}"
payload = msg.payload
if msg.qos != QOS_REQUIRED:
warning = f"QoS is not {QOS_REQUIRED}: {topic=}, {payload=}"
logger.warning(warning)
try:
m = asn1_message.decode(payload)
if m is None:
logger.warning(f"Could not parse message: {topic=}, {payload=}")
return
# check if the message is for us, TODO: we should do this via subscriptions
if m.recipient != self.client_id:
return
logger.debug(f"{self.alias}: queue message: {m=}")
self._queue.put(m)
except NotImplementedError:
pass
except ValueError as error:
logger.warning(f"Could not parse message: {topic=}, {payload=}, error: {error}")
except ShutDown:
logger.error("queue shut down")
def _on_subscribe(
self, client: mqtt.Client, userdata: None, mid: int, reason_code_list: list[ReasonCode], properties: Properties
) -> None:
# Since we subscribed only for a single channel, reason_code_list contains
# a single entry
if reason_code_list[0].is_failure:
logger.warning(f"{self}: broker rejected your subscription: {reason_code_list[0]}")
else:
logger.debug(f"{self}: broker granted the following QoS: {reason_code_list[0].value}")
def _on_unsubscribe(
self, client: mqtt.Client, userdata: None, mid: int, reason_code_list: list[ReasonCode], properties: Properties
) -> None:
# Be careful, the reason_code_list is only present in MQTTv5.
# In MQTTv3 it will always be empty
if len(reason_code_list) == 0 or not reason_code_list[0].is_failure:
logger.debug("unsubscribe succeeded (if SUBACK is received in MQTTv3 it success)")
else:
logger.warning(f"Broker replied with failure: {reason_code_list[0]}")
client.disconnect()
[docs]
def start(self) -> None:
"""Connect and start client loop."""
self._client.connect(host=self.mqtt_host, port=self.mqtt_port)
logger.debug(f"Starting MQTT client loop for {self}")
self._client.loop_start()
[docs]
def stop(self) -> None:
self._client.loop_stop()
[docs]
def send_message(self, msg: asn1_message.Message) -> None:
if msg.sender != self.client_id:
logger.error(f"Sender {msg.sender} does not match client id {self.client_id}")
return
sender_is_lossnode = self.client_id == self._lossnode_id
recipient_is_lossnode = msg.recipient == self._lossnode_id
topic = f"{self.base_topic}/{msg.get_topic(sender_is_lossnode, recipient_is_lossnode)}"
payload = asn1_message.encode(msg)
info = self._client.publish(topic=topic, payload=payload, qos=QOS_REQUIRED)
info.wait_for_publish(PUBLISH_TIMEOUT)
[docs]
def test_send(self, msg_text: str, msg_topic: str) -> None:
topic = msg_topic
payload = msg_text
info = self._client.publish(topic=topic, payload=payload, qos=QOS_REQUIRED)
info.wait_for_publish(PUBLISH_TIMEOUT)