Source code for ddf.ant_link.local

from dataclasses import dataclass

from ..ddf import ObjectId
from ..information import ExchangeVector
from ..measurement.gradient_storage import GradientStorage
from .ant_link import AntLink, GradientSubscription


[docs] @dataclass class AntLinkLocal(AntLink): """Refer to the same ObjectId on the same Node.""" gradient_subscriptions: list[GradientSubscription]
[docs] def optimize_step( self, gradient_storage: GradientStorage, object_id: ObjectId, exchange_vector: ExchangeVector, ) -> None | ExchangeVector: states: list[ExchangeVector] = [] for sub in self.gradient_subscriptions: if sub.loss_channel in gradient_storage.storage: grad = gradient_storage.storage[sub.loss_channel] solver = sub.solver state_copy: ExchangeVector = ExchangeVector( exchange_vector.mean.copy(), exchange_vector.covariance.copy() ) new_state, sub.solver_state = solver(grad, state_copy, sub.solver_state) if new_state is not None: states.append(new_state) if len(states) == 0: return None else: return states[0]