Source code for ddf.lossnode.lossnode

import time
from collections.abc import Callable
from dataclasses import dataclass
from uuid import UUID

import numpy as np
from loguru import logger

from ddf.ant_link.loss_channel import FuturePastProbe, LossChannel
from ddf.communicator import Communicator
from ddf.ddf import NodeId, ObjectId
from ddf.information import ExchangeVector
from ddf.message import ControlCommand, ControlMessage, GradientMessage, InfoMessage, Message
from ddf.node import ProcessingMode


[docs] @dataclass class Neighbor: node_id: NodeId # gradients in the form of a function (calculates the gradient at point) cost_fn_gradient: dict[LossChannel, Callable[[ExchangeVector], np.ndarray] | None]
[docs] class LossNode: communicator: Communicator mode: ProcessingMode node_id: NodeId neighbors: dict[NodeId, Neighbor] old_info_messages: list[InfoMessage] # other messages we received will be dumped here if we need them later __other_message_queue: list[Message] def __init__(self, node_id: UUID, communicator: Communicator, neighbors: list[Neighbor]) -> None: self.node_id = NodeId(node_id) self.communicator = communicator self.mode = ProcessingMode.INFO_MODE self.__other_message_queue = [] self.neighbors = dict((neighbor.node_id, neighbor) for neighbor in neighbors) self.old_info_messages = []
[docs] def run(self) -> None: """Main loop, do whatever a lossnode does.""" period_start = time.time() period_length = 3 while True: if time.time() >= period_start + period_length: switch = True period_start = time.time() else: switch = False self.step(switch) time.sleep(1)
[docs] def step(self, switch_modes: bool) -> None: if switch_modes: self._switch_modes() messages = self._receive_info_messages() if self.mode == ProcessingMode.INFO_MODE: # save the messages for later self.old_info_messages.extend(messages) return elif self.mode == ProcessingMode.GRADIENT_MODE: messages.extend(self.old_info_messages) self.old_info_messages.clear() # only keep newest message per Neighbor and ObjectId # filter for known neighbors as well most_current_messages: dict[tuple[NodeId, ObjectId], InfoMessage] = dict() for msg in messages: # filter out messages from unknown nodes if msg.sender not in self.neighbors: continue # no message for this key yet if (msg.sender, msg.object_id) not in most_current_messages: most_current_messages[(msg.sender, msg.object_id)] = msg # counter is higher -> message is newer -> overwrite elif msg.message_counter > most_current_messages[(msg.sender, msg.object_id)].message_counter: most_current_messages[(msg.sender, msg.object_id)] = msg for msg in most_current_messages.values(): assert msg.sender in self.neighbors, "we filtered for known neighbors before" neighbor = self.neighbors[msg.sender] exchange_vec = msg.state for loss_channel, cost_fn_grad in neighbor.cost_fn_gradient.items(): if cost_fn_grad is None: continue gradient = cost_fn_grad(exchange_vec) self.communicator.send_message( GradientMessage( self.node_id, neighbor.node_id, msg.object_id, np.atleast_1d(gradient), loss_channel ) )
[docs] def add_neighbor( self, node_id: NodeId, loss_channel: LossChannel, cost_fn_grad: Callable[[ExchangeVector], np.ndarray] | None = None, ) -> None: """Add this node as a neighbor to the lossnode or add a new cost_fn_grad. you will **get gradients from the lossnode iff you provide a cost_fn_grad** Args: node_id (NodeId): the id of the (new) neighbor loss_channel (LossChannel): loss channel this cost function gradient should be associated with cost_fn_grad (Callable[[ExchangeVector], np.ndarray] | None): supply a cost function gradient as in a function that calculates the gradient. Raises: ValueError: in case `cost_fn_grad is None` and `loss_channel.mode == Probe` """ if loss_channel.mode == FuturePastProbe.PROBE and cost_fn_grad is None: raise ValueError("cannot have Probe and no cost_fn_gradient") if node_id not in self.neighbors: self.neighbors[node_id] = Neighbor(node_id, dict()) self.neighbors[node_id].cost_fn_gradient[loss_channel] = cost_fn_grad
def _get_exchange_vector_mean(self, messages: list[InfoMessage]) -> dict[NodeId, ExchangeVector]: nodes = set(msg.sender for msg in messages) results = dict() for n in nodes: exchange_vectors = [msg.state for msg in messages if msg.sender == n] mean = np.mean(np.array([ev.mean for ev in exchange_vectors]), axis=0) cov = np.mean(np.array([ev.covariance for ev in exchange_vectors]), axis=0) results[n] = ExchangeVector(mean, cov) return results def _switch_modes(self) -> None: control_command = ControlCommand.GetNodeInfo match self.mode: case ProcessingMode.INFO_MODE: self.mode = ProcessingMode.GRADIENT_MODE control_command = ControlCommand.FZB case ProcessingMode.GRADIENT_MODE: self.mode = ProcessingMode.INFO_MODE control_command = ControlCommand.UUI case _: logger.error(f"invalid mode: {self.mode}") logger.info(f"switch to mode {self.mode}") for neighbor in self.neighbors.values(): self.communicator.send_message(ControlMessage(self.node_id, neighbor.node_id, UUID(int=0), control_command)) def _receive_info_messages(self) -> list[InfoMessage]: all_messages = self.communicator.get_messages() info_messages: list[InfoMessage] = [im for im in all_messages if isinstance(im, InfoMessage)] other_messages = [msg for msg in all_messages if not isinstance(msg, InfoMessage)] self.__other_message_queue.extend(other_messages) return info_messages