Source code for ddf.ant_object_reference

"""ANT Object Reference."""

from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import numpy as np
from loguru import logger
from numpy.linalg import LinAlgError
from scipy import optimize
from scipy.linalg import block_diag

from .ant_link.ant_link import AntLink, LossChannel
from .ant_link.external import AntLinkExternal
from .ant_link.internal import AntLinkInternalObjectID
from .ant_link.local import AntLinkLocal
from .ant_link.loss_channel import FuturePastProbe
from .ddf import Id, NodeId, ObjectId
from .information import ExchangeVector, InternalState
from .measurement.gradient_storage import GradientStorage
from .measurement.measurement import Measurement
from .message import GradientMessage, InfoMessage
from .operator import Operator

if TYPE_CHECKING:
    from .node import Node


[docs] def kl_div(mu_post: np.ndarray, gamma_post: np.ndarray, mu_prior: np.ndarray, gamma_prior: np.ndarray) -> float: """Calculate the KL divergence in bits. Args: mu_post (1D ndarray): Posterior mean vector gamma_post (2D ndarray): Posterior covariance mu_prior (1D ndarray): Prior mean gamma_prior (2D ndarray): Prior covariance Raises: Exception: _description_ Returns: float: KL divergence in bits """ n_dim = len(mu_post.flatten()) gamma_post = np.atleast_2d(gamma_post) gamma_prior = np.atleast_2d(gamma_prior) if ( (n_dim != gamma_post.shape[0]) or (gamma_post.shape[0] != gamma_post.shape[1]) or (gamma_post.shape[1] != len(mu_prior.flatten())) or (len(mu_prior.flatten()) != gamma_prior.shape[0]) or (gamma_prior.shape[0] != gamma_prior.shape[1]) ): raise ValueError("Input dimension mismatch.") return ( np.trace(np.linalg.inv(gamma_prior) @ gamma_post) + (np.reshape((mu_prior - mu_post), (1, -1)) @ np.linalg.inv(gamma_prior))
[docs] @ np.reshape((mu_prior - mu_post), (-1, 1)) - n_dim + np.linalg.slogdet(gamma_prior)[1] - np.linalg.slogdet(gamma_post)[1] ) / (2 * np.log(2))
@dataclass class InfoSubscriber: """Information Subscriber Info. Information about another node that needs to be updated when this node has new information. """ operator: Operator source: AntLink last_message: InfoMessage | None communication_threshold: float = 1e-6
[docs] def update(self, state: InternalState) -> None: raise NotImplementedError # TODO what should update do??
[docs] def generate_info_message(self, node: "Node", state: InternalState) -> None | InfoMessage: update_mean = self.operator.evaluate(state) update_cov = self.operator.get_cov(state) if isinstance(self.source, AntLinkExternal): new_msg = self.source.generate_info_message(node, ExchangeVector(update_mean, update_cov)) # if no prior message or difference is big enough, send new message send_message = False if self.last_message is None: send_message = True else: try: kldiv = self.kl_div( new_msg.state, self.last_message.state, ) if kldiv > self.communication_threshold: send_message = True except LinAlgError: logger.warning("Error calculating KL divergence, sending message.") send_message = True if send_message: if self.last_message is not None: new_msg.message_counter = self.last_message.message_counter + 1 self.last_message = new_msg return new_msg return None
[docs] def kl_div(self, new_state: ExchangeVector, old_state: ExchangeVector) -> float: """Check if the info message is different from the current message.""" # compare with self.old_message return kl_div(new_state.mean, new_state.covariance, old_state.mean, old_state.covariance)
[docs] def cov2whitening(cov: np.ndarray, method: str = "pca", fudge: float = 1e-18) -> np.ndarray: """Function for determining whitening matrices from covariance matrices. Args: cov (2D ndarray symmetric): covariance matrix method (str, optional): Method. 'pca' Eigenvalue decomposition nearest semi-positive definite approximation (if needed). 'chol' cholesky decomposition, only for non singular matrices. Defaults to 'pca'. fudge (float, optional): Small fudge parameter to avoid infinities. Defaults to 1e-18. Returns: 2D ndarray: whitening matrix """ assert method == "pca" or method == "chol" # noqa: PLR1714 python linter does not understand it if merged if method == "pca": w, v = np.linalg.eigh(cov) if (w < 0).sum() > 0: w, v = np.linalg.eigh(0.5 * cov + 0.5 * cov.transpose()) w[w < 0] = 0 whitening_op = np.diag(1 / np.sqrt(w + fudge)) @ v.transpose() return whitening_op if method == "chol": whitening_op = np.linalg.cholesky(np.linalg.inv(cov)).transpose() return whitening_op raise ValueError("method unkown")
[docs] @dataclass class AntObjectReference: """A single point in time, or a single batch.""" parent_node: "Node" object_id: ObjectId initial_state: InternalState """initial state estimate, keep for covarianve limiter""" state: InternalState """current best guess for state estimate""" measurements: list[Measurement] = field(default_factory=list) info_subscribers: list[InfoSubscriber] = field(default_factory=list) up_to_date: bool = False """this AntObjectReference just finished DataFusion and there is no new information in measurements""" optimizer_lsqs_method: str = "trf" lower_bound_x_value: float = -np.inf """value used to populate AntObjectReference.lower_bound_x""" upper_bound_x_value: float = np.inf """value used to populate AntObjectReference.upper_bound_x""" lower_bound_x: np.ndarray = field(init=False) """lower bounds for least squares optimizer""" upper_bound_x: np.ndarray = field(init=False) """upper bounds for least squares optimizer""" probe_storage: GradientStorage = field(default_factory=lambda: GradientStorage(dict())) def __post_init__(self) -> None: self.lower_bound_x = self.lower_bound_x_value * np.ones(len(self.initial_state.mean)) self.upper_bound_x = self.upper_bound_x_value * np.ones(len(self.initial_state.mean))
[docs] def softmax(self, x: np.ndarray) -> np.ndarray: """The softmax function of vector x. Args: x (1D ndarray): Input vector Returns: 1D ndarray: softmax(x) """ return np.exp(x) / np.sum(np.exp(x))
[docs] def get_weight_projection_mat(self) -> np.ndarray: """Resize the weights vector. Resizing is needed to perform Covariance Intersection with a vector of weights which only has one weight for all the measurements whose origin is the same node. Returns: np.ndarray: "Weight Projection Matrix" """ internal_idxs = [idx for idx, m in enumerate(self.measurements) if isinstance(m.source, AntLinkLocal)] external_idxs = [idx for idx, m in enumerate(self.measurements) if not isinstance(m.source, AntLinkLocal)] index = 0 rows = len(self.measurements) cols = len(external_idxs) if len(internal_idxs) > 0: cols += 1 mat = np.zeros((rows, cols)) for row, measurement in enumerate(self.measurements): if isinstance(measurement.source, AntLinkLocal): mat[row, -1] = 1 else: mat[row, index] = 1 index += 1 return mat
[docs] def solve(self) -> None: """Solve the data fusion problem using Covariance Intersection. This function determines the weights fot the diferent external + a single weight for the internal measurements. Then a weighted sum of the sources of information is performed. Returns: None. """ if len(self.measurements) == 0: raise ValueError("No measurements available for optimization.") # Initial state (working_point) infered with all weights equal to 1 r: InternalState | None = self.weighted_inference(measmt_weights=np.ones(len(self.measurements))) if r is None: working_point: InternalState = self.initial_state else: working_point = r w_full = self._calculate_weights(working_point) # Update the state with the result from data fusion. result = self.weighted_inference(measmt_weights=w_full, save_results=True) if result is None: logger.debug("solve did not return a value") else: self.up_to_date = True
def _calculate_weights(self, working_point: InternalState) -> np.ndarray: info_mats = [m.get_state_info_mat(working_point) for m in self.measurements] weight_proj_mat = self.get_weight_projection_mat() if weight_proj_mat.shape[1] > 1: # Cost function for CI: we want to have an unbound problem from the perspective of the optimizer, # but we have a bound problem with x in [0,1]^n and sum(x)=1 # that is why remove one of the weights, add a (constant!) 1, # and then apply softmax which turns RR^n -> [0,1]^n for us # after optimizing we have to repeat this to get our x in [0,1]^n def ci_det_costfunction(weights_minus1: np.ndarray) -> np.ndarray: weights = np.concatenate((weights_minus1, np.ones(1)), axis=0) # Softmax returns a combination of values [0-1] that sum up to one. soft_max_w = self.softmax(weights) projected_weights = weight_proj_mat @ soft_max_w return -np.linalg.slogdet( # We sum up the information matrices weighted by the softmax weights, # and determinate the logaritmic determinate of the result. np.add.reduce([softmax * info for (softmax, info) in zip(projected_weights.tolist(), info_mats)]) )[1] # Reduce thee number of weight by one, so that we constraint one of the weight to be equal to 1. num_source_info = weight_proj_mat.shape[1] - 1 w0 = np.ones(num_source_info) w_ret = optimize.minimize( ci_det_costfunction, # type: ignore w0, bounds=optimize.Bounds(ub=np.ones(num_source_info) * 20, lb=-np.ones(num_source_info) * 20), # type: ignore method="Nelder-Mead", ) # Here we introduce the constrained weight to be equal to 1. weights = np.concatenate((w_ret.x, np.ones(1)), axis=0) soft_max_w = self.softmax(weights) w_full = weight_proj_mat @ soft_max_w else: # If we only have one source of information, we just set the weight to 1. w_full = np.ones(len(self.measurements)) return w_full
[docs] def evaluate_operators(self, state: InternalState) -> list[np.ndarray]: """Composite model for the AntObjectReference. It combines the outputs of all models Args: state (np.ndarray): State vector. Returns: np.ndarray: The composite model output. """ return [m.operator.evaluate(state) for m in self.measurements]
[docs] def operator_jacobians(self, state: InternalState) -> list[np.ndarray]: """Jacobian of the composite model for the AntObjectReference. It combines the jacobians of all models """ return [m.operator.get_jacobian(state) for m in self.measurements]
[docs] def weighted_inference(self, measmt_weights: np.ndarray, save_results: bool = False) -> InternalState | None: """Weighted inference for the AntObjectReference. Calculation of the weights that determine the influence of each measurement in the data fusion. With these weights the measurements are combined as information matrices that are used. Args: measmt_weights (np.ndarray): Weights for the measurements. save_results: save state and state jacobians for AntObjectReferences and Measurements? Returns: Information: The infered state. """ if len(measmt_weights) != len(self.measurements): raise ValueError("Length of measmt_weights must be equal to the number of measurements") if len(self.measurements) == 0: raise ValueError("no measurements available") # Aglomerated whitening operators and exchange states. whitening_operators = [] exchange_states = [] for idx, m in enumerate(self.measurements): whitening_operators.append(np.sqrt(measmt_weights[idx]) * m.exchange_vector.get_whitening_op()) exchange_states.append(m.exchange_vector.mean) # cost function to determine the weights. # state_mean: n, output: m -> jac: (m,n) def cost_function(state_mean: np.ndarray) -> np.ndarray: operator_results = self.evaluate_operators(InternalState(state_mean, np.array([]))) results = [] for whitening_operator, operator_result, exchange_state in zip( whitening_operators, operator_results, exchange_states ): results.append( ( whitening_operator @ ( operator_result.reshape(operator_result.shape[0], 1) - exchange_state.reshape(exchange_state.shape[0], 1) ) ).flatten() ) result = np.concat(results) return result # Cost function jacobian to determine the covariance. def cost_function_jac(state_mean: np.ndarray) -> np.ndarray: operator_jacobians = self.operator_jacobians(InternalState(state_mean, np.array([]))) results = [] for whitening_operator, operator_jacobian in zip(whitening_operators, operator_jacobians): results.append(whitening_operator @ operator_jacobian) result = np.concat(results, axis=0) return result if self.optimizer_lsqs_method == "trf": # TODO: in the c++ version well use "fides" method # (Trust Region Reflective for boundary costrained optimization) that performs better. opt_result = optimize.least_squares( cost_function, self.state.mean, jac=cost_function_jac, # type: ignore verbose=0, ftol=1e-12, xtol=1e-12, gtol=1e-12, max_nfev=100, method=self.optimizer_lsqs_method, bounds=(self.lower_bound_x, self.upper_bound_x), ) else: opt_result = optimize.least_squares( cost_function, self.state.mean, jac=cost_function_jac, # type: ignore verbose=0, ftol=1e-10, xtol=1e-10, max_nfev=100, method=self.optimizer_lsqs_method, # type: ignore ) inferred_state = opt_result["x"] if not opt_result.success: logger.error(f"Optimization failed: {opt_result.message}") return None # Calculation of the covariance of the the inferred state. This can just be the inverse of the # information matrices, because those are calculated with the initial guess of the weights, but once we # have the state the jacobian of the loss can be evaluated on this point and in terms of that cov = (L.T@L)^-1. cost_fn_jac = cost_function_jac(inferred_state) inferred_state_cov = np.linalg.inv(cost_fn_jac.T @ cost_fn_jac) if save_results: l_mat = block_diag(*whitening_operators) jac_b = -l_mat datafusion_jacobian = -np.linalg.solve( cost_fn_jac.transpose() @ cost_fn_jac, cost_fn_jac.transpose() @ jac_b ) col_index = 0 for m in self.measurements: cols = len(m.exchange_vector.mean) m.state_jacobian = datafusion_jacobian[:, col_index : col_index + cols] col_index += cols self.state = InternalState(inferred_state, inferred_state_cov) self.update() return InternalState(inferred_state, inferred_state_cov)
[docs] def backpropagate_from_neighbor( self, gradient_wrt_exchange_vector: np.ndarray, loss_channel: LossChannel, node_id: NodeId ) -> None: """Backpropagate the gradient from a source (neighbor, another AntObjReference) to the internal state.""" op = None for sub in self.info_subscribers: if sub.source.id.node_id == node_id: op = sub.operator break assert op is not None, "Operator not found in info subscribers." jacobian = op.get_jacobian(self.state) # loss_gradient: Gradient w.r.t. the state of this Node. loss_gradient = jacobian.T @ gradient_wrt_exchange_vector # if mode is Probe: save it in the probe gradient storage if loss_channel.mode == FuturePastProbe.PROBE: if loss_channel in self.probe_storage.storage: self.probe_storage.storage[loss_channel] += loss_gradient else: self.probe_storage.storage[loss_channel] = loss_gradient self.backpropagate(loss_gradient, loss_channel)
[docs] def backpropagate(self, loss_gradient: np.ndarray, loss_channel: LossChannel) -> None: """Backpropagate the translated loss gradients through the measurements.""" in_past = self.parent_node.object_ref_in_past(self) future_gradient = loss_channel.mode == FuturePastProbe.FUTURE if future_gradient and in_past: return for m in self.measurements: m.backpropagate(self.parent_node, loss_gradient, loss_channel)
[docs] def generate_gradient_messages(self) -> list[GradientMessage]: """Generate gradient messages from measurements.""" ret: list[GradientMessage] = [] for m in self.measurements: ret.extend(m.source.generate_gradient_messages(self.parent_node, m.gradient_storage)) return ret
[docs] def generate_info_messages(self) -> list[InfoMessage]: """Generate info messages for info subscribers.""" ret: list[InfoMessage] = [] for sub in self.info_subscribers: msg = sub.generate_info_message(self.parent_node, self.state) if msg is not None: ret.append(msg) return ret
[docs] def add_new_measurement(self, operator: Operator, exchange_state: ExchangeVector, source: AntLink) -> None: """Add or overwrite a measurement. 1. locals: always just add (they come from the default_info_packets) 2. external: always overwrite if there is something to overwrite 3. internal: always overwrite if there is something to overwrite """ old_id = None for i, m in enumerate(self.measurements): if m.source == source: old_id = i break # this is only a default so we have the right dimensions here default_state_jacobian = np.zeros((self.initial_state.mean.shape[0], exchange_state.mean.shape[0])) measurement = Measurement(self.object_id, operator, exchange_state, default_state_jacobian, source) if old_id is None or isinstance(source, AntLinkLocal): # create new measurement self.measurements.append(measurement) else: # replace old measurement self.measurements[old_id] = measurement self.up_to_date = False
[docs] def update(self) -> None: """Like an info message, but only through history. we do not need to send it over the network this is only done for AntLinkInternalObjectID - the others have different Nodes as target """ for sub in self.info_subscribers: if isinstance(sub.source, AntLinkInternalObjectID): assert sub.source.id.node_id == self.parent_node.node_id, ( "you added an AntLinkInternalObjectID with a different Node?!" ) if sub.source.id.object_id is not None: ref = self.parent_node.find_object_reference(sub.source.id.object_id) if ref is not None: exch_vec: ExchangeVector = sub.operator.get_exchange_vector(self.state) source = AntLinkInternalObjectID(Id(self.parent_node.node_id, self.object_id)) ref.add_new_measurement(self.parent_node.internal_operator, exch_vec, source)
[docs] def optimize(self) -> None: for m in self.measurements: m.optimize_step()