import time
from collections.abc import Callable
from dataclasses import dataclass
from uuid import UUID
import numpy as np
from loguru import logger
from ddf.ant_link.loss_channel import FuturePastProbe, LossChannel
from ddf.communicator import Communicator
from ddf.ddf import NodeId, ObjectId
from ddf.information import ExchangeVector
from ddf.message import ControlCommand, ControlMessage, GradientMessage, InfoMessage, Message
from ddf.node import ProcessingMode
[docs]
@dataclass
class Neighbor:
node_id: NodeId
# gradients in the form of a function (calculates the gradient at point)
cost_fn_gradient: dict[LossChannel, Callable[[ExchangeVector], np.ndarray] | None]
[docs]
class LossNode:
communicator: Communicator
mode: ProcessingMode
node_id: NodeId
neighbors: dict[NodeId, Neighbor]
old_info_messages: list[InfoMessage]
# other messages we received will be dumped here if we need them later
__other_message_queue: list[Message]
def __init__(self, node_id: UUID, communicator: Communicator, neighbors: list[Neighbor]) -> None:
self.node_id = NodeId(node_id)
self.communicator = communicator
self.mode = ProcessingMode.INFO_MODE
self.__other_message_queue = []
self.neighbors = dict((neighbor.node_id, neighbor) for neighbor in neighbors)
self.old_info_messages = []
[docs]
def run(self) -> None:
"""Main loop, do whatever a lossnode does."""
period_start = time.time()
period_length = 3
while True:
if time.time() >= period_start + period_length:
switch = True
period_start = time.time()
else:
switch = False
self.step(switch)
time.sleep(1)
[docs]
def step(self, switch_modes: bool) -> None:
if switch_modes:
self._switch_modes()
messages = self._receive_info_messages()
if self.mode == ProcessingMode.INFO_MODE:
# save the messages for later
self.old_info_messages.extend(messages)
return
elif self.mode == ProcessingMode.GRADIENT_MODE:
messages.extend(self.old_info_messages)
self.old_info_messages.clear()
# only keep newest message per Neighbor and ObjectId
# filter for known neighbors as well
most_current_messages: dict[tuple[NodeId, ObjectId], InfoMessage] = dict()
for msg in messages:
# filter out messages from unknown nodes
if msg.sender not in self.neighbors:
continue
# no message for this key yet
if (msg.sender, msg.object_id) not in most_current_messages:
most_current_messages[(msg.sender, msg.object_id)] = msg
# counter is higher -> message is newer -> overwrite
elif msg.message_counter > most_current_messages[(msg.sender, msg.object_id)].message_counter:
most_current_messages[(msg.sender, msg.object_id)] = msg
for msg in most_current_messages.values():
assert msg.sender in self.neighbors, "we filtered for known neighbors before"
neighbor = self.neighbors[msg.sender]
exchange_vec = msg.state
for loss_channel, cost_fn_grad in neighbor.cost_fn_gradient.items():
if cost_fn_grad is None:
continue
gradient = cost_fn_grad(exchange_vec)
self.communicator.send_message(
GradientMessage(
self.node_id, neighbor.node_id, msg.object_id, np.atleast_1d(gradient), loss_channel
)
)
[docs]
def add_neighbor(
self,
node_id: NodeId,
loss_channel: LossChannel,
cost_fn_grad: Callable[[ExchangeVector], np.ndarray] | None = None,
) -> None:
"""Add this node as a neighbor to the lossnode or add a new cost_fn_grad.
you will **get gradients from the lossnode iff you provide a cost_fn_grad**
Args:
node_id (NodeId): the id of the (new) neighbor
loss_channel (LossChannel): loss channel this cost function gradient should be associated with
cost_fn_grad (Callable[[ExchangeVector], np.ndarray] | None):
supply a cost function gradient as in a function that calculates the gradient.
Raises:
ValueError: in case `cost_fn_grad is None` and `loss_channel.mode == Probe`
"""
if loss_channel.mode == FuturePastProbe.PROBE and cost_fn_grad is None:
raise ValueError("cannot have Probe and no cost_fn_gradient")
if node_id not in self.neighbors:
self.neighbors[node_id] = Neighbor(node_id, dict())
self.neighbors[node_id].cost_fn_gradient[loss_channel] = cost_fn_grad
def _get_exchange_vector_mean(self, messages: list[InfoMessage]) -> dict[NodeId, ExchangeVector]:
nodes = set(msg.sender for msg in messages)
results = dict()
for n in nodes:
exchange_vectors = [msg.state for msg in messages if msg.sender == n]
mean = np.mean(np.array([ev.mean for ev in exchange_vectors]), axis=0)
cov = np.mean(np.array([ev.covariance for ev in exchange_vectors]), axis=0)
results[n] = ExchangeVector(mean, cov)
return results
def _switch_modes(self) -> None:
control_command = ControlCommand.GetNodeInfo
match self.mode:
case ProcessingMode.INFO_MODE:
self.mode = ProcessingMode.GRADIENT_MODE
control_command = ControlCommand.FZB
case ProcessingMode.GRADIENT_MODE:
self.mode = ProcessingMode.INFO_MODE
control_command = ControlCommand.UUI
case _:
logger.error(f"invalid mode: {self.mode}")
logger.info(f"switch to mode {self.mode}")
for neighbor in self.neighbors.values():
self.communicator.send_message(ControlMessage(self.node_id, neighbor.node_id, UUID(int=0), control_command))
def _receive_info_messages(self) -> list[InfoMessage]:
all_messages = self.communicator.get_messages()
info_messages: list[InfoMessage] = [im for im in all_messages if isinstance(im, InfoMessage)]
other_messages = [msg for msg in all_messages if not isinstance(msg, InfoMessage)]
self.__other_message_queue.extend(other_messages)
return info_messages