from abc import ABC
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
from ..ddf import Id, ObjectId
from ..information import ExchangeVector
from ..measurement.gradient_storage import GradientStorage
from ..message import GradientMessage
from .loss_channel import LossChannel
if TYPE_CHECKING:
from ..node import Node
[docs]
@dataclass
class GradientSubscription:
loss_channel: LossChannel
solver: Callable[[np.ndarray, ExchangeVector, dict], tuple[ExchangeVector, dict]]
solver_state: dict
[docs]
@dataclass
class AntLink(ABC):
id: Id
[docs]
def backpropagate(self, node: "Node", gradient: np.ndarray, losstype: LossChannel) -> None:
return # nop
[docs]
def optimize_step(
self,
gradient_storage: GradientStorage,
object_id: ObjectId,
exchange_vector: ExchangeVector,
) -> None | ExchangeVector:
"""Return updated values for state_vec and state_cov."""
return exchange_vector
[docs]
def generate_gradient_messages(self, node: "Node", gradient_storage: GradientStorage) -> list[GradientMessage]:
return []
[docs]
def Rprop_step( # noqa: PLR0913
grad: np.ndarray,
exchange_vector: ExchangeVector,
delta_0: np.ndarray,
delta_min: np.ndarray,
delta_max: np.ndarray,
solver_dict: dict,
lb: np.ndarray = np.array([]),
ub: np.ndarray = np.array([]),
) -> tuple[ExchangeVector, dict]:
"""Calculate a gradient descent step.
Calculate a gradient descent step in the device parameters while controlling
the learning rate with the iRProp-algorithm.
Respects the provided parameter bounds.
See C. Igel and Michael H¨usken. Improving the rprop learning algorithm. In Proceedings of the Second
International Symposium on Neural Computation, NC2000, 2000.
Args:
grad (1D nd array): loss gradient
exchange_vector (ExchangeVector): current parameter vector
delta_0 (1D nd array): initial step size
delta_min (1D nd array): minimum step size
delta_max (1D nd array): maximum step size
solver_dict (dict): dict for transfering temporary variables form iteration to iteration
lb (1D ndarray): parameterlower bound vector
ub (1D ndarray): parameter upper bound vector
Returns:
1D ndarray: new parameter vector after update
dict: the solver_dict to pass in the next iteration
"""
grad_current = grad.flatten()
# Get iteration variables either from previous step or initialize.
if "Rprop_Delta" in solver_dict.keys():
delta = solver_dict["Rprop_Delta"]
grad_last = solver_dict["Rprop_grad_last"]
else:
delta = delta_0
grad_last = np.zeros(np.array(exchange_vector.mean).size)
delta[grad_current * grad_last > 0] = np.min(
np.stack(
(
delta[grad_current * grad_last > 0] * 1.2,
delta_max[grad_current * grad_last > 0],
)
),
axis=0,
)
delta[grad_current * grad_last < 0] = np.max(
np.stack(
(
delta[grad_current * grad_last < 0] * 0.5,
delta_min[grad_current * grad_last < 0],
)
),
axis=0,
)
newval = exchange_vector.mean.flatten() - delta * np.sign(grad_current.flatten()) # make the step
grad_current[grad_current * grad_last < 0] = 0 # This is what makes it iRProp instead of normal RProp
# ensure bounds are obeyed
if lb.size == 0:
lb = np.full_like(newval, -np.inf)
if ub.size == 0:
ub = np.full_like(newval, np.inf)
newval[newval < lb] = lb[newval < lb]
newval[newval > ub] = ub[newval > ub]
curr_step = newval - exchange_vector.mean.flatten()
grad_current[curr_step == 0.0] = (
0.0 # ensure that we don't increase Delta to Delta_max when we are stuck at a boundary
)
solver_dict["Rprop_Delta"] = delta
solver_dict["Rprop_grad_last"] = grad_current
return ExchangeVector(newval, exchange_vector.covariance), solver_dict