from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import numpy as np
from ddf.ant_link.loss_channel import LossChannel
from ..information import ExchangeVector
from ..measurement.gradient_storage import GradientStorage
from ..message import GradientMessage, InfoMessage
from .ant_link import AntLink
if TYPE_CHECKING:
from ..node import Node
[docs]
@dataclass
class AntLinkExternal(AntLink):
"""Refer to a different ObjectId on a different Node."""
update_counter: int
gradient_threshold: float = 1e-6
# keep track of already sent gradients and their elementwise maximum absolute value
max_gradients: dict[LossChannel, np.ndarray] = field(default_factory=dict)
[docs]
def generate_gradient_messages(self, node: "Node", gradient_storage: GradientStorage) -> list[GradientMessage]:
sender = node.node_id
recipient = self.id.node_id
object_id = self.id.object_id
result: list[GradientMessage] = []
if object_id is None:
return result
for loss_channel, gradient in gradient_storage.storage.items():
grad_current = gradient.flatten()
if loss_channel not in self.max_gradients:
self.max_gradients[loss_channel] = np.zeros_like(grad_current)
grad_max = self.max_gradients[loss_channel].flatten()
if (np.abs(grad_current) > self.gradient_threshold * grad_max).any():
result.append(GradientMessage(sender, recipient, object_id, grad_current, loss_channel))
# update the max elementwise
self.max_gradients[loss_channel] = np.max(np.vstack((np.abs(grad_current), grad_max)), 0)
return result
[docs]
def generate_info_message(self, node: "Node", state: ExchangeVector) -> InfoMessage:
"""Generate an info message."""
sender = node.node_id
recipient = self.id.node_id
object_id = self.id.object_id
if object_id is None:
raise ValueError()
message_counter = self.update_counter + 1
return InfoMessage(sender, recipient, object_id, state, message_counter)