Source code for ddf.measurement.measurement

from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import numpy as np

from ddf.ant_link.external import AntLinkExternal

from ..ant_link.ant_link import AntLink, LossChannel
from ..ddf import ObjectId
from ..information import ExchangeVector, InternalState
from ..operator import Operator
from .gradient_storage import GradientStorage

if TYPE_CHECKING:
    from ..node import Node


[docs] @dataclass class Measurement: parent_object_id: ObjectId """object id of the parent AntObjectReference. This could be replaced with a direct pointer to the parent.""" operator: Operator exchange_vector: ExchangeVector """contains b and cov_b""" state_jacobian: np.ndarray """Jacobian of the AntObjectReference state w.r.t. the exchange_state of this measurement.""" source: AntLink gradient_storage: GradientStorage = field(default_factory=lambda: GradientStorage({}))
[docs] def backpropagate(self, node: "Node", loss_gradient_wrt_state: np.ndarray, loss_channel: LossChannel) -> None: # parameter_gradient: Gradient w.r.t. the exchange vector of this measurement. parameter_gradient = self.state_jacobian.T @ loss_gradient_wrt_state if loss_channel not in self.gradient_storage.storage: self.gradient_storage.storage[loss_channel] = np.zeros_like(parameter_gradient) self.gradient_storage.storage[loss_channel] += parameter_gradient self.source.backpropagate(node, parameter_gradient, loss_channel)
[docs] def optimize_step(self) -> None: mean_update = self.source.optimize_step(self.gradient_storage, self.parent_object_id, self.exchange_vector) if mean_update is not None: self.exchange_vector = mean_update
[docs] def get_state_info_mat(self, state: InternalState) -> np.ndarray: """Returns information matrix in state space (not exchange state space).""" jacobian = self.operator.get_jacobian(state) return jacobian.T @ (np.linalg.inv(self.exchange_vector.covariance) @ jacobian)
[docs] def set_gradients_to_zero(self) -> None: self.gradient_storage = GradientStorage({}) if isinstance(self.source, AntLinkExternal): self.source.max_gradients.clear()