"""ANT Node Module."""
from dataclasses import dataclass, field
from enum import Enum
from time import sleep
from uuid import UUID
import uuid_utils
from loguru import logger
from ddf.ant_link.loss_channel import FuturePastProbe
from .ant_link.ant_link import GradientSubscription
from .ant_link.local import AntLinkLocal
from .ant_link.external import AntLinkExternal
from .ant_object_reference import AntObjectReference, InfoSubscriber
from .communicator import Communicator
from .ddf import Id, NodeId, ObjectId
from .information import ExchangeVector, InternalState
from .message import (
ControlCommand,
ControlMessage,
GradientMessage,
InfoMessage,
Message,
NodeContextMessage,
NodeInfoMessage,
ProbeResponseMessage,
)
from .operator import LinearOperator, Operator
[docs]
@dataclass
class Neighbor:
node_id: NodeId
# only to _send_ data! (receive data as it is)
operator: Operator # This translate from internalState to exchange_vector (The shape is standard)
[docs]
class ProcessingMode(Enum):
INFO_MODE = 0
GRADIENT_MODE = 1
# i.e. DiffDataFusor
[docs]
@dataclass
class Node:
"""ANT Node class."""
node_id: NodeId
history: list[AntObjectReference]
# map for our own internal state -> exchange vec we accepted from another node
# (used in the datafusion problem "receive" info messages)
neighbors: dict[NodeId, Neighbor]
# map for our own internal state -> exchange vec accepted by some neighbor (used to *send* info messages)
info_subscribers: dict[NodeId, Neighbor]
# like a neighbor with node_id == self.node_id
internal_operator: Operator
alias: str
initial_state: InternalState
communicator: Communicator
# to every new AntObjectReference created, there will be measurements added with these
default_information_packets: list[tuple[Operator, ExchangeVector, list[GradientSubscription]]]
fixed_context: str
frozen: bool = False
mode: ProcessingMode = ProcessingMode.INFO_MODE
calculate_residual_loss: bool = False
user_context: str = ""
available_labels: list[str] = field(default_factory=list)
[docs]
def run(self) -> None:
"""Run the `step` function forever."""
while True: # maybe stop at some point?
self.step()
sleep(3)
[docs]
def step(self) -> None:
"""Process and send messages once."""
messages: list[Message] = self.communicator.get_messages()
self._process_messages(messages)
# send our messages
match self.mode:
case ProcessingMode.INFO_MODE:
# self.solve()
for ref in self.history:
info_msgs = ref.generate_info_messages()
for info_msg in info_msgs:
self.communicator.send_message(info_msg)
case ProcessingMode.GRADIENT_MODE:
for ref in self.history:
gradient_messages = ref.generate_gradient_messages()
for gradient_message in gradient_messages:
self.communicator.send_message(gradient_message)
case _:
raise ValueError("unkown mode")
def _message_filter(self, msg: Message) -> bool:
"""Filter messages.
We only want messages that are for this node.
Used in `process_messages`.
"""
if msg.recipient != self.node_id:
return False
# for InfoMessages and GradientMessages we only want to accept from known neighbors
if isinstance(msg, InfoMessage) or isinstance(msg, GradientMessage):
if msg.sender not in self.neighbors:
logger.warning(f"Message sender not found in Neighbors. Ignoring message. Message={msg}")
return False
return True
def _process_messages(self, messages: list[Message]) -> None:
"""Read and process all kinds of messages.
Info messages to get new information, control messages to switch modes and other utility things,
gradient messages to receive updated gradient information.
"""
messages = list(filter(self._message_filter, messages))
control_msgs: list[ControlMessage] = [message for message in messages if isinstance(message, ControlMessage)]
info_msgs: list[InfoMessage] = [message for message in messages if isinstance(message, InfoMessage)]
gradient_msgs: list[GradientMessage] = [message for message in messages if isinstance(message, GradientMessage)]
node_info_msgs: list[NodeInfoMessage] = [
message for message in messages if isinstance(message, NodeInfoMessage)
]
probe_response_msgs: list[ProbeResponseMessage] = [
message for message in messages if isinstance(message, ProbeResponseMessage)
]
node_context_msgs: list[NodeContextMessage] = [
message for message in messages if isinstance(message, NodeContextMessage)
]
for info_msg in info_msgs:
self._process_info_message(info_msg)
for control_msg in control_msgs:
self._process_control_message(control_msg)
for grad_msg in gradient_msgs:
self._process_gradient_message(grad_msg)
for msg_info in node_info_msgs:
logger.debug(f"Node {self.alias}: Received NodeInfoMessage: {msg_info}")
for msg_probe_response in probe_response_msgs:
logger.debug(f"Node {self.alias}: Received ProbeResponseMessage: {msg_probe_response}")
for msg_node_context in node_context_msgs:
logger.debug(f"Node {self.alias}: Received NodeContextMessage: {msg_node_context}")
def _process_info_message(self, msg: InfoMessage) -> None:
ref = self.poll_history_from_object_id(msg.object_id)
if ref is None:
ref = self.create_ant_object_reference(msg.object_id, msg.message_counter)
self.insert_measurement_from_info_message(ref, msg)
else:
# try to overwrite existing measurement:
for m in ref.measurements:
id_ = m.source.id
if id_.node_id == msg.sender and ref.object_id == msg.object_id:
assert isinstance(m.source, AntLinkExternal), "source is not AntLinkExternal"
if msg.message_counter > m.source.update_counter:
if m.exchange_vector.mean.shape != msg.state.mean.shape:
logger.warning(
f"Mean vector shape mismatch: was {m.exchange_vector.mean.shape}, \
got {msg.state.mean.shape}"
)
m.exchange_vector.mean = msg.state.mean
m.exchange_vector.covariance = msg.state.covariance
m.source.update_counter = msg.message_counter
return
# else: no old measurement to overwrite, insert a new one
self.insert_measurement_from_info_message(ref, msg)
def _process_control_message(self, msg: ControlMessage) -> None:
# early return if the mode did not change
if (msg.command == ControlCommand.FZB and self.mode == ProcessingMode.GRADIENT_MODE) or (
msg.command == ControlCommand.UUI and self.mode == ProcessingMode.INFO_MODE
):
return
match msg.command:
case ControlCommand.FZB:
# case "freeze+zerograd+gradmode":
logger.info("freezing,zeroing")
self.solve()
self.mode = ProcessingMode.GRADIENT_MODE
self._send_probe_responses()
self.zerogradients()
if self.calculate_residual_loss:
self.residual_loss()
case ControlCommand.UUI:
# case "sgdstep+unfreeze+infomode":
logger.info("stepping,unfreezing")
self.mode = ProcessingMode.INFO_MODE
self.optimize()
case ControlCommand.GetNodeInfo:
available: list[str] = self.available_labels
exchange: list[str] = []
if msg.sender in self.neighbors:
# try to get list of exchange labels (human readable)
neighbor = self.neighbors[msg.sender]
if isinstance(neighbor.operator, LinearOperator):
m = neighbor.operator.matrix
assert m.shape[1] == len(available), (
"m @ available does not work if m does not have len(available) cols"
)
# create m @ available_labels as texts
for row in m:
exchange.append(" + ".join(f"{factor}⋅{name}" for (factor, name) in zip(row, available)))
else:
# do nothing, exchange is empty
pass
self.communicator.send_message(
NodeInfoMessage(self.node_id, msg.sender, available, exchange, self.alias, msg.request_id)
)
case ControlCommand.GetNodeContext:
self.communicator.send_message(
NodeContextMessage(self.node_id, msg.sender, self.fixed_context, self.user_context, msg.request_id)
)
case _:
logger.error(f"unknown ControlCommand {msg.command}")
def _process_gradient_message(self, msg: GradientMessage) -> None:
ref = self.poll_history_from_object_id(msg.object_id)
assert ref is not None, "we should not get gradient messages for unknown object ids"
ref.backpropagate_from_neighbor(msg.gradient, msg.loss_channel, msg.sender)
[docs]
def poll_history_from_object_id(self, object_id: ObjectId) -> AntObjectReference | None:
"""Search for specific Object Id in history.
Args:
object_id (ObjectId): the object id to search for in this nodes history.
Returns:
reference (AntObjectReference | None): if the object id is found,
return the AntObjectReference with this object id,
otherwise `None` is returned.
"""
for ref in self.history:
if ref.object_id == object_id:
return ref
return None
[docs]
def add_neighbor(self, node_id: NodeId, op: Operator, eval_vec_len: int) -> None:
assert node_id != self.node_id, "dont add yourself as a neighbor!"
logger.debug(f"add neighbor {node_id.uuid}")
if node_id in self.neighbors:
logger.warning(f"Node {node_id.uuid} already exists in neighbors. Neighbor not replaced.")
return
# Mab@Xa = Mba@Xb = exchange_vec_ab => Mab is (m,n1), Mba is (m, n2), exchange_vec_ab is of length m
if op.get_eval_vec_len() != eval_vec_len:
logger.error("Neighbor operator not compatible in shape")
return
neighbor = Neighbor(node_id, op)
self.neighbors[node_id] = neighbor
[docs]
def add_info_subscriber(self, node_id: NodeId, op: Operator, eval_vec_len: int) -> None:
assert node_id != self.node_id, "dont add yourself as a neighbor!"
logger.debug(f"add neighbor {node_id.uuid}")
if node_id in self.info_subscribers:
logger.warning(f"Node {node_id.uuid} already exists in info subscribers. Subscriber not replaced.")
return
# Mab@Xa = Mba@Xb = exchange_vec_ab => Mab is (m,n1), Mba is (m, n2), exchange_vec_ab is of length m
if op.get_eval_vec_len() != eval_vec_len:
logger.error("Neighbor operator not compatible in shape")
return
neighbor = Neighbor(node_id, op)
self.info_subscribers[node_id] = neighbor
[docs]
def solve(self, force: bool = False) -> None:
logger.debug(f"solve (force:{force})")
# TODO: set some restrictions on what to solve
if not (force or self.mode == ProcessingMode.INFO_MODE):
return
def should_solve(ref: AntObjectReference) -> bool:
if ref.up_to_date:
# return early
return False
if self.object_ref_in_past(ref):
ref.up_to_date = True
return not ref.up_to_date
while any(should_solve(ref) for ref in self.history):
for ref in self.history:
if should_solve(ref):
ref.solve()
[docs]
def create_ant_object_reference(self, object_id: ObjectId | None, update_counter: int = 0) -> AntObjectReference:
if object_id is None:
# create a new uuid with current timestamp
# see https://github.com/aminalaee/uuid-utils/blob/9ddd132c46278ac8aeb70474e688acec3465ce30/src/lib.rs#L402
uuid = uuid_utils.uuid7()
object_id = ObjectId(UUID(bytes=uuid.bytes))
initial_state = self.initial_state
state = InternalState(initial_state.mean.copy(), initial_state.covariance.copy())
subscribers: list[InfoSubscriber] = []
# external stuff
for n in self.info_subscribers.values():
# if you want to control whether to send to this subscriber or not, put a check here
if n.node_id != self.node_id:
src = AntLinkExternal(Id(n.node_id, object_id), update_counter)
info_sub = InfoSubscriber(n.operator, src, None)
subscribers.append(info_sub)
ref = AntObjectReference(self, object_id, initial_state, state, info_subscribers=subscribers)
# internal stuff
for op, info, grad_subscriptions in self.default_information_packets:
ref.add_new_measurement(op, info, AntLinkLocal(Id(self.node_id, object_id), grad_subscriptions))
self.history.append(ref)
return ref
[docs]
def insert_measurement_from_info_message(self, ref: AntObjectReference, msg: InfoMessage) -> None:
neighbor = self.neighbors[msg.sender]
op = neighbor.operator
update_counter = msg.message_counter
source = AntLinkExternal(Id(msg.sender, msg.object_id), update_counter)
# if you want to control whether you want to receive from this neighbor: put a check here
ref.add_new_measurement(op, msg.state, source)
[docs]
def insert_measurement_from_sensor(self, object_id: ObjectId, op: Operator, sensor_data: ExchangeVector) -> None:
ref = self.poll_history_from_object_id(object_id)
if ref is None:
ref = self.create_ant_object_reference(object_id)
source_id = Id(self.node_id, object_id)
# we did not find any old message that we could update -> create a new one
source = AntLinkLocal(source_id, [])
ref.add_new_measurement(op, sensor_data, source)
[docs]
def zerogradients(self) -> None:
for ref in self.history:
ref.probe_storage.storage.clear()
for measurement in ref.measurements:
measurement.set_gradients_to_zero()
[docs]
def residual_loss(self) -> None:
# TODO implement residual loss calculation, store in place
pass
[docs]
def optimize(self) -> None:
for ref in self.history:
ref.optimize()
[docs]
def find_object_reference(self, object_id: ObjectId) -> AntObjectReference | None:
"""Search for specific AntObjectReference in this Node."""
for ref in self.history:
if ref.object_id == object_id:
return ref
return None
[docs]
def set_user_context(self, user_context: str) -> None:
"""Set the `user_context` of this node.
This will be used when sending NodeContext Messages.
"""
self.user_context = user_context
[docs]
def set_available_labels(self, labels: list[str]) -> None:
"""Set all available label texts.
This will be used when asking for GetNodeInfo.
"""
self.available_labels = labels
def _send_probe_responses(self) -> None:
"""Send probe responses for all AntObjectReferences in history."""
for ref in self.history:
for loss_channel, gradient in ref.probe_storage.storage.items():
if loss_channel.mode != FuturePastProbe.PROBE or loss_channel.lossref is None:
continue
self.communicator.send_message(
ProbeResponseMessage(
self.node_id,
NodeId(loss_channel.lossref),
ref.object_id,
gradient,
loss_channel,
)
)
[docs]
def visualize(self) -> str:
"""Visualize the node as a string."""
from ddf.ant_link.ant_link import AntLink # noqa: PLC0415
from ddf.ant_link.internal import AntLinkInternalObjectID # noqa: PLC0415
def vis_antlink(ant_link: AntLink) -> str:
target = ""
match ant_link:
case AntLinkLocal():
link_type = "local"
target = "self"
case AntLinkExternal():
link_type = "external"
target = f"node {ant_link.id.node_id.uuid} -\
{ant_link.id.object_id.uuid if ant_link.id.object_id else 'None'}"
case AntLinkInternalObjectID():
link_type = "internal"
for o in self.history:
if o.object_id == ant_link.id.object_id:
target = f"ref_{self.history.index(o)}"
break
case _:
link_type = "unknown"
target = "unknown"
return f"--[{link_type}]-> {target}"
tab = 4 * " "
lines = [self.alias]
for i, ref in enumerate(self.history):
lines.append(f"{tab}ref_{i}")
for infosub in ref.info_subscribers:
lines.append(f"{tab * 2}subscriber {vis_antlink(infosub.source)}")
lines.append(f"{tab * 2}-")
for measurement in ref.measurements:
lines.append(f"{tab * 2}measurement {vis_antlink(measurement.source)}")
return "\n".join(lines)
[docs]
def object_ref_in_past(self, ref: AntObjectReference) -> bool:
return False