from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import numpy as np
from ddf.ant_link.external import AntLinkExternal
from ..ant_link.ant_link import AntLink, LossChannel
from ..ddf import ObjectId
from ..information import ExchangeVector, InternalState
from ..operator import Operator
from .gradient_storage import GradientStorage
if TYPE_CHECKING:
from ..node import Node
[docs]
@dataclass
class Measurement:
parent_object_id: ObjectId
"""object id of the parent AntObjectReference.
This could be replaced with a direct pointer to the parent."""
operator: Operator
exchange_vector: ExchangeVector
"""contains b and cov_b"""
state_jacobian: np.ndarray
"""Jacobian of the AntObjectReference state w.r.t. the exchange_state
of this measurement."""
source: AntLink
gradient_storage: GradientStorage = field(default_factory=lambda: GradientStorage({}))
[docs]
def backpropagate(self, node: "Node", loss_gradient_wrt_state: np.ndarray, loss_channel: LossChannel) -> None:
# parameter_gradient: Gradient w.r.t. the exchange vector of this measurement.
parameter_gradient = self.state_jacobian.T @ loss_gradient_wrt_state
if loss_channel not in self.gradient_storage.storage:
self.gradient_storage.storage[loss_channel] = np.zeros_like(parameter_gradient)
self.gradient_storage.storage[loss_channel] += parameter_gradient
self.source.backpropagate(node, parameter_gradient, loss_channel)
[docs]
def optimize_step(self) -> None:
mean_update = self.source.optimize_step(self.gradient_storage, self.parent_object_id, self.exchange_vector)
if mean_update is not None:
self.exchange_vector = mean_update
[docs]
def get_state_info_mat(self, state: InternalState) -> np.ndarray:
"""Returns information matrix in state space (not exchange state space)."""
jacobian = self.operator.get_jacobian(state)
return jacobian.T @ (np.linalg.inv(self.exchange_vector.covariance) @ jacobian)
[docs]
def set_gradients_to_zero(self) -> None:
self.gradient_storage = GradientStorage({})
if isinstance(self.source, AntLinkExternal):
self.source.max_gradients.clear()