Source code for ddf.ant_link.external

from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import numpy as np

from ddf.ant_link.loss_channel import LossChannel

from ..information import ExchangeVector
from ..measurement.gradient_storage import GradientStorage
from ..message import GradientMessage, InfoMessage
from .ant_link import AntLink

if TYPE_CHECKING:
    from ..node import Node


[docs] @dataclass class AntLinkExternal(AntLink): """Refer to a different ObjectId on a different Node.""" update_counter: int gradient_threshold: float = 1e-6 # keep track of already sent gradients and their elementwise maximum absolute value max_gradients: dict[LossChannel, np.ndarray] = field(default_factory=dict)
[docs] def generate_gradient_messages(self, node: "Node", gradient_storage: GradientStorage) -> list[GradientMessage]: sender = node.node_id recipient = self.id.node_id object_id = self.id.object_id result: list[GradientMessage] = [] if object_id is None: return result for loss_channel, gradient in gradient_storage.storage.items(): grad_current = gradient.flatten() if loss_channel not in self.max_gradients: self.max_gradients[loss_channel] = np.zeros_like(grad_current) grad_max = self.max_gradients[loss_channel].flatten() if (np.abs(grad_current) > self.gradient_threshold * grad_max).any(): result.append(GradientMessage(sender, recipient, object_id, grad_current, loss_channel)) # update the max elementwise self.max_gradients[loss_channel] = np.max(np.vstack((np.abs(grad_current), grad_max)), 0) return result
[docs] def generate_info_message(self, node: "Node", state: ExchangeVector) -> InfoMessage: """Generate an info message.""" sender = node.node_id recipient = self.id.node_id object_id = self.id.object_id if object_id is None: raise ValueError() message_counter = self.update_counter + 1 return InfoMessage(sender, recipient, object_id, state, message_counter)