Source code for ddf.node

"""ANT Node Module."""

from dataclasses import dataclass, field
from enum import Enum
from time import sleep
from uuid import UUID

import uuid_utils
from loguru import logger

from ddf.ant_link.loss_channel import FuturePastProbe

from .ant_link.ant_link import GradientSubscription
from .ant_link.local import AntLinkLocal
from .ant_link.external import AntLinkExternal
from .ant_object_reference import AntObjectReference, InfoSubscriber
from .communicator import Communicator
from .ddf import Id, NodeId, ObjectId
from .information import ExchangeVector, InternalState
from .message import (
    ControlCommand,
    ControlMessage,
    GradientMessage,
    InfoMessage,
    Message,
    NodeContextMessage,
    NodeInfoMessage,
    ProbeResponseMessage,
)
from .operator import LinearOperator, Operator


[docs] @dataclass class Neighbor: node_id: NodeId # only to _send_ data! (receive data as it is) operator: Operator # This translate from internalState to exchange_vector (The shape is standard)
[docs] class ProcessingMode(Enum): INFO_MODE = 0 GRADIENT_MODE = 1
# i.e. DiffDataFusor
[docs] @dataclass class Node: """ANT Node class.""" node_id: NodeId history: list[AntObjectReference] # map for our own internal state -> exchange vec we accepted from another node # (used in the datafusion problem "receive" info messages) neighbors: dict[NodeId, Neighbor] # map for our own internal state -> exchange vec accepted by some neighbor (used to *send* info messages) info_subscribers: dict[NodeId, Neighbor] # like a neighbor with node_id == self.node_id internal_operator: Operator alias: str initial_state: InternalState communicator: Communicator # to every new AntObjectReference created, there will be measurements added with these default_information_packets: list[tuple[Operator, ExchangeVector, list[GradientSubscription]]] fixed_context: str frozen: bool = False mode: ProcessingMode = ProcessingMode.INFO_MODE calculate_residual_loss: bool = False user_context: str = "" available_labels: list[str] = field(default_factory=list)
[docs] def run(self) -> None: """Run the `step` function forever.""" while True: # maybe stop at some point? self.step() sleep(3)
[docs] def step(self) -> None: """Process and send messages once.""" messages: list[Message] = self.communicator.get_messages() self._process_messages(messages) # send our messages match self.mode: case ProcessingMode.INFO_MODE: # self.solve() for ref in self.history: info_msgs = ref.generate_info_messages() for info_msg in info_msgs: self.communicator.send_message(info_msg) case ProcessingMode.GRADIENT_MODE: for ref in self.history: gradient_messages = ref.generate_gradient_messages() for gradient_message in gradient_messages: self.communicator.send_message(gradient_message) case _: raise ValueError("unkown mode")
def _message_filter(self, msg: Message) -> bool: """Filter messages. We only want messages that are for this node. Used in `process_messages`. """ if msg.recipient != self.node_id: return False # for InfoMessages and GradientMessages we only want to accept from known neighbors if isinstance(msg, InfoMessage) or isinstance(msg, GradientMessage): if msg.sender not in self.neighbors: logger.warning(f"Message sender not found in Neighbors. Ignoring message. Message={msg}") return False return True def _process_messages(self, messages: list[Message]) -> None: """Read and process all kinds of messages. Info messages to get new information, control messages to switch modes and other utility things, gradient messages to receive updated gradient information. """ messages = list(filter(self._message_filter, messages)) control_msgs: list[ControlMessage] = [message for message in messages if isinstance(message, ControlMessage)] info_msgs: list[InfoMessage] = [message for message in messages if isinstance(message, InfoMessage)] gradient_msgs: list[GradientMessage] = [message for message in messages if isinstance(message, GradientMessage)] node_info_msgs: list[NodeInfoMessage] = [ message for message in messages if isinstance(message, NodeInfoMessage) ] probe_response_msgs: list[ProbeResponseMessage] = [ message for message in messages if isinstance(message, ProbeResponseMessage) ] node_context_msgs: list[NodeContextMessage] = [ message for message in messages if isinstance(message, NodeContextMessage) ] for info_msg in info_msgs: self._process_info_message(info_msg) for control_msg in control_msgs: self._process_control_message(control_msg) for grad_msg in gradient_msgs: self._process_gradient_message(grad_msg) for msg_info in node_info_msgs: logger.debug(f"Node {self.alias}: Received NodeInfoMessage: {msg_info}") for msg_probe_response in probe_response_msgs: logger.debug(f"Node {self.alias}: Received ProbeResponseMessage: {msg_probe_response}") for msg_node_context in node_context_msgs: logger.debug(f"Node {self.alias}: Received NodeContextMessage: {msg_node_context}") def _process_info_message(self, msg: InfoMessage) -> None: ref = self.poll_history_from_object_id(msg.object_id) if ref is None: ref = self.create_ant_object_reference(msg.object_id, msg.message_counter) self.insert_measurement_from_info_message(ref, msg) else: # try to overwrite existing measurement: for m in ref.measurements: id_ = m.source.id if id_.node_id == msg.sender and ref.object_id == msg.object_id: assert isinstance(m.source, AntLinkExternal), "source is not AntLinkExternal" if msg.message_counter > m.source.update_counter: if m.exchange_vector.mean.shape != msg.state.mean.shape: logger.warning( f"Mean vector shape mismatch: was {m.exchange_vector.mean.shape}, \ got {msg.state.mean.shape}" ) m.exchange_vector.mean = msg.state.mean m.exchange_vector.covariance = msg.state.covariance m.source.update_counter = msg.message_counter return # else: no old measurement to overwrite, insert a new one self.insert_measurement_from_info_message(ref, msg) def _process_control_message(self, msg: ControlMessage) -> None: # early return if the mode did not change if (msg.command == ControlCommand.FZB and self.mode == ProcessingMode.GRADIENT_MODE) or ( msg.command == ControlCommand.UUI and self.mode == ProcessingMode.INFO_MODE ): return match msg.command: case ControlCommand.FZB: # case "freeze+zerograd+gradmode": logger.info("freezing,zeroing") self.solve() self.mode = ProcessingMode.GRADIENT_MODE self._send_probe_responses() self.zerogradients() if self.calculate_residual_loss: self.residual_loss() case ControlCommand.UUI: # case "sgdstep+unfreeze+infomode": logger.info("stepping,unfreezing") self.mode = ProcessingMode.INFO_MODE self.optimize() case ControlCommand.GetNodeInfo: available: list[str] = self.available_labels exchange: list[str] = [] if msg.sender in self.neighbors: # try to get list of exchange labels (human readable) neighbor = self.neighbors[msg.sender] if isinstance(neighbor.operator, LinearOperator): m = neighbor.operator.matrix assert m.shape[1] == len(available), ( "m @ available does not work if m does not have len(available) cols" ) # create m @ available_labels as texts for row in m: exchange.append(" + ".join(f"{factor}{name}" for (factor, name) in zip(row, available))) else: # do nothing, exchange is empty pass self.communicator.send_message( NodeInfoMessage(self.node_id, msg.sender, available, exchange, self.alias, msg.request_id) ) case ControlCommand.GetNodeContext: self.communicator.send_message( NodeContextMessage(self.node_id, msg.sender, self.fixed_context, self.user_context, msg.request_id) ) case _: logger.error(f"unknown ControlCommand {msg.command}") def _process_gradient_message(self, msg: GradientMessage) -> None: ref = self.poll_history_from_object_id(msg.object_id) assert ref is not None, "we should not get gradient messages for unknown object ids" ref.backpropagate_from_neighbor(msg.gradient, msg.loss_channel, msg.sender)
[docs] def poll_history_from_object_id(self, object_id: ObjectId) -> AntObjectReference | None: """Search for specific Object Id in history. Args: object_id (ObjectId): the object id to search for in this nodes history. Returns: reference (AntObjectReference | None): if the object id is found, return the AntObjectReference with this object id, otherwise `None` is returned. """ for ref in self.history: if ref.object_id == object_id: return ref return None
[docs] def add_neighbor(self, node_id: NodeId, op: Operator, eval_vec_len: int) -> None: assert node_id != self.node_id, "dont add yourself as a neighbor!" logger.debug(f"add neighbor {node_id.uuid}") if node_id in self.neighbors: logger.warning(f"Node {node_id.uuid} already exists in neighbors. Neighbor not replaced.") return # Mab@Xa = Mba@Xb = exchange_vec_ab => Mab is (m,n1), Mba is (m, n2), exchange_vec_ab is of length m if op.get_eval_vec_len() != eval_vec_len: logger.error("Neighbor operator not compatible in shape") return neighbor = Neighbor(node_id, op) self.neighbors[node_id] = neighbor
[docs] def add_info_subscriber(self, node_id: NodeId, op: Operator, eval_vec_len: int) -> None: assert node_id != self.node_id, "dont add yourself as a neighbor!" logger.debug(f"add neighbor {node_id.uuid}") if node_id in self.info_subscribers: logger.warning(f"Node {node_id.uuid} already exists in info subscribers. Subscriber not replaced.") return # Mab@Xa = Mba@Xb = exchange_vec_ab => Mab is (m,n1), Mba is (m, n2), exchange_vec_ab is of length m if op.get_eval_vec_len() != eval_vec_len: logger.error("Neighbor operator not compatible in shape") return neighbor = Neighbor(node_id, op) self.info_subscribers[node_id] = neighbor
[docs] def solve(self, force: bool = False) -> None: logger.debug(f"solve (force:{force})") # TODO: set some restrictions on what to solve if not (force or self.mode == ProcessingMode.INFO_MODE): return def should_solve(ref: AntObjectReference) -> bool: if ref.up_to_date: # return early return False if self.object_ref_in_past(ref): ref.up_to_date = True return not ref.up_to_date while any(should_solve(ref) for ref in self.history): for ref in self.history: if should_solve(ref): ref.solve()
[docs] def create_ant_object_reference(self, object_id: ObjectId | None, update_counter: int = 0) -> AntObjectReference: if object_id is None: # create a new uuid with current timestamp # see https://github.com/aminalaee/uuid-utils/blob/9ddd132c46278ac8aeb70474e688acec3465ce30/src/lib.rs#L402 uuid = uuid_utils.uuid7() object_id = ObjectId(UUID(bytes=uuid.bytes)) initial_state = self.initial_state state = InternalState(initial_state.mean.copy(), initial_state.covariance.copy()) subscribers: list[InfoSubscriber] = [] # external stuff for n in self.info_subscribers.values(): # if you want to control whether to send to this subscriber or not, put a check here if n.node_id != self.node_id: src = AntLinkExternal(Id(n.node_id, object_id), update_counter) info_sub = InfoSubscriber(n.operator, src, None) subscribers.append(info_sub) ref = AntObjectReference(self, object_id, initial_state, state, info_subscribers=subscribers) # internal stuff for op, info, grad_subscriptions in self.default_information_packets: ref.add_new_measurement(op, info, AntLinkLocal(Id(self.node_id, object_id), grad_subscriptions)) self.history.append(ref) return ref
[docs] def insert_measurement_from_info_message(self, ref: AntObjectReference, msg: InfoMessage) -> None: neighbor = self.neighbors[msg.sender] op = neighbor.operator update_counter = msg.message_counter source = AntLinkExternal(Id(msg.sender, msg.object_id), update_counter) # if you want to control whether you want to receive from this neighbor: put a check here ref.add_new_measurement(op, msg.state, source)
[docs] def insert_measurement_from_sensor(self, object_id: ObjectId, op: Operator, sensor_data: ExchangeVector) -> None: ref = self.poll_history_from_object_id(object_id) if ref is None: ref = self.create_ant_object_reference(object_id) source_id = Id(self.node_id, object_id) # we did not find any old message that we could update -> create a new one source = AntLinkLocal(source_id, []) ref.add_new_measurement(op, sensor_data, source)
[docs] def zerogradients(self) -> None: for ref in self.history: ref.probe_storage.storage.clear() for measurement in ref.measurements: measurement.set_gradients_to_zero()
[docs] def residual_loss(self) -> None: # TODO implement residual loss calculation, store in place pass
[docs] def optimize(self) -> None: for ref in self.history: ref.optimize()
[docs] def find_object_reference(self, object_id: ObjectId) -> AntObjectReference | None: """Search for specific AntObjectReference in this Node.""" for ref in self.history: if ref.object_id == object_id: return ref return None
[docs] def set_user_context(self, user_context: str) -> None: """Set the `user_context` of this node. This will be used when sending NodeContext Messages. """ self.user_context = user_context
[docs] def set_available_labels(self, labels: list[str]) -> None: """Set all available label texts. This will be used when asking for GetNodeInfo. """ self.available_labels = labels
def _send_probe_responses(self) -> None: """Send probe responses for all AntObjectReferences in history.""" for ref in self.history: for loss_channel, gradient in ref.probe_storage.storage.items(): if loss_channel.mode != FuturePastProbe.PROBE or loss_channel.lossref is None: continue self.communicator.send_message( ProbeResponseMessage( self.node_id, NodeId(loss_channel.lossref), ref.object_id, gradient, loss_channel, ) )
[docs] def visualize(self) -> str: """Visualize the node as a string.""" from ddf.ant_link.ant_link import AntLink # noqa: PLC0415 from ddf.ant_link.internal import AntLinkInternalObjectID # noqa: PLC0415 def vis_antlink(ant_link: AntLink) -> str: target = "" match ant_link: case AntLinkLocal(): link_type = "local" target = "self" case AntLinkExternal(): link_type = "external" target = f"node {ant_link.id.node_id.uuid} -\ {ant_link.id.object_id.uuid if ant_link.id.object_id else 'None'}" case AntLinkInternalObjectID(): link_type = "internal" for o in self.history: if o.object_id == ant_link.id.object_id: target = f"ref_{self.history.index(o)}" break case _: link_type = "unknown" target = "unknown" return f"--[{link_type}]-> {target}" tab = 4 * " " lines = [self.alias] for i, ref in enumerate(self.history): lines.append(f"{tab}ref_{i}") for infosub in ref.info_subscribers: lines.append(f"{tab * 2}subscriber {vis_antlink(infosub.source)}") lines.append(f"{tab * 2}-") for measurement in ref.measurements: lines.append(f"{tab * 2}measurement {vis_antlink(measurement.source)}") return "\n".join(lines)
[docs] def object_ref_in_past(self, ref: AntObjectReference) -> bool: return False