from dataclasses import dataclass
from ..ddf import ObjectId
from ..information import ExchangeVector
from ..measurement.gradient_storage import GradientStorage
from .ant_link import AntLink, GradientSubscription
[docs]
@dataclass
class AntLinkLocal(AntLink):
"""Refer to the same ObjectId on the same Node."""
gradient_subscriptions: list[GradientSubscription]
[docs]
def optimize_step(
self,
gradient_storage: GradientStorage,
object_id: ObjectId,
exchange_vector: ExchangeVector,
) -> None | ExchangeVector:
states: list[ExchangeVector] = []
for sub in self.gradient_subscriptions:
if sub.loss_channel in gradient_storage.storage:
grad = gradient_storage.storage[sub.loss_channel]
solver = sub.solver
state_copy: ExchangeVector = ExchangeVector(
exchange_vector.mean.copy(), exchange_vector.covariance.copy()
)
new_state, sub.solver_state = solver(grad, state_copy, sub.solver_state)
if new_state is not None:
states.append(new_state)
if len(states) == 0:
return None
else:
return states[0]