"""ANT Object Reference."""
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import numpy as np
from loguru import logger
from numpy.linalg import LinAlgError
from scipy import optimize
from scipy.linalg import block_diag
from .ant_link.ant_link import AntLink, LossChannel
from .ant_link.external import AntLinkExternal
from .ant_link.internal import AntLinkInternalObjectID
from .ant_link.local import AntLinkLocal
from .ant_link.loss_channel import FuturePastProbe
from .ddf import Id, NodeId, ObjectId
from .information import ExchangeVector, InternalState
from .measurement.gradient_storage import GradientStorage
from .measurement.measurement import Measurement
from .message import GradientMessage, InfoMessage
from .operator import Operator
if TYPE_CHECKING:
from .node import Node
[docs]
def kl_div(mu_post: np.ndarray, gamma_post: np.ndarray, mu_prior: np.ndarray, gamma_prior: np.ndarray) -> float:
"""Calculate the KL divergence in bits.
Args:
mu_post (1D ndarray): Posterior mean vector
gamma_post (2D ndarray): Posterior covariance
mu_prior (1D ndarray): Prior mean
gamma_prior (2D ndarray): Prior covariance
Raises:
Exception: _description_
Returns:
float: KL divergence in bits
"""
n_dim = len(mu_post.flatten())
gamma_post = np.atleast_2d(gamma_post)
gamma_prior = np.atleast_2d(gamma_prior)
if (
(n_dim != gamma_post.shape[0])
or (gamma_post.shape[0] != gamma_post.shape[1])
or (gamma_post.shape[1] != len(mu_prior.flatten()))
or (len(mu_prior.flatten()) != gamma_prior.shape[0])
or (gamma_prior.shape[0] != gamma_prior.shape[1])
):
raise ValueError("Input dimension mismatch.")
return (
np.trace(np.linalg.inv(gamma_prior) @ gamma_post)
+ (np.reshape((mu_prior - mu_post), (1, -1)) @ np.linalg.inv(gamma_prior))
[docs]
@ np.reshape((mu_prior - mu_post), (-1, 1))
- n_dim
+ np.linalg.slogdet(gamma_prior)[1]
- np.linalg.slogdet(gamma_post)[1]
) / (2 * np.log(2))
@dataclass
class InfoSubscriber:
"""Information Subscriber Info.
Information about another node that needs to be updated when this node has new information.
"""
operator: Operator
source: AntLink
last_message: InfoMessage | None
communication_threshold: float = 1e-6
[docs]
def update(self, state: InternalState) -> None:
raise NotImplementedError # TODO what should update do??
[docs]
def generate_info_message(self, node: "Node", state: InternalState) -> None | InfoMessage:
update_mean = self.operator.evaluate(state)
update_cov = self.operator.get_cov(state)
if isinstance(self.source, AntLinkExternal):
new_msg = self.source.generate_info_message(node, ExchangeVector(update_mean, update_cov))
# if no prior message or difference is big enough, send new message
send_message = False
if self.last_message is None:
send_message = True
else:
try:
kldiv = self.kl_div(
new_msg.state,
self.last_message.state,
)
if kldiv > self.communication_threshold:
send_message = True
except LinAlgError:
logger.warning("Error calculating KL divergence, sending message.")
send_message = True
if send_message:
if self.last_message is not None:
new_msg.message_counter = self.last_message.message_counter + 1
self.last_message = new_msg
return new_msg
return None
[docs]
def kl_div(self, new_state: ExchangeVector, old_state: ExchangeVector) -> float:
"""Check if the info message is different from the current message."""
# compare with self.old_message
return kl_div(new_state.mean, new_state.covariance, old_state.mean, old_state.covariance)
[docs]
def cov2whitening(cov: np.ndarray, method: str = "pca", fudge: float = 1e-18) -> np.ndarray:
"""Function for determining whitening matrices from covariance matrices.
Args:
cov (2D ndarray symmetric): covariance matrix
method (str, optional): Method. 'pca' Eigenvalue decomposition nearest semi-positive definite
approximation (if needed). 'chol' cholesky decomposition, only for non singular matrices. Defaults to 'pca'.
fudge (float, optional): Small fudge parameter to avoid infinities. Defaults to 1e-18.
Returns:
2D ndarray: whitening matrix
"""
assert method == "pca" or method == "chol" # noqa: PLR1714 python linter does not understand it if merged
if method == "pca":
w, v = np.linalg.eigh(cov)
if (w < 0).sum() > 0:
w, v = np.linalg.eigh(0.5 * cov + 0.5 * cov.transpose())
w[w < 0] = 0
whitening_op = np.diag(1 / np.sqrt(w + fudge)) @ v.transpose()
return whitening_op
if method == "chol":
whitening_op = np.linalg.cholesky(np.linalg.inv(cov)).transpose()
return whitening_op
raise ValueError("method unkown")
[docs]
@dataclass
class AntObjectReference:
"""A single point in time, or a single batch."""
parent_node: "Node"
object_id: ObjectId
initial_state: InternalState
"""initial state estimate, keep for covarianve limiter"""
state: InternalState
"""current best guess for state estimate"""
measurements: list[Measurement] = field(default_factory=list)
info_subscribers: list[InfoSubscriber] = field(default_factory=list)
up_to_date: bool = False
"""this AntObjectReference just finished DataFusion and there is no new information in measurements"""
optimizer_lsqs_method: str = "trf"
lower_bound_x_value: float = -np.inf
"""value used to populate AntObjectReference.lower_bound_x"""
upper_bound_x_value: float = np.inf
"""value used to populate AntObjectReference.upper_bound_x"""
lower_bound_x: np.ndarray = field(init=False)
"""lower bounds for least squares optimizer"""
upper_bound_x: np.ndarray = field(init=False)
"""upper bounds for least squares optimizer"""
probe_storage: GradientStorage = field(default_factory=lambda: GradientStorage(dict()))
def __post_init__(self) -> None:
self.lower_bound_x = self.lower_bound_x_value * np.ones(len(self.initial_state.mean))
self.upper_bound_x = self.upper_bound_x_value * np.ones(len(self.initial_state.mean))
[docs]
def softmax(self, x: np.ndarray) -> np.ndarray:
"""The softmax function of vector x.
Args:
x (1D ndarray): Input vector
Returns:
1D ndarray: softmax(x)
"""
return np.exp(x) / np.sum(np.exp(x))
[docs]
def get_weight_projection_mat(self) -> np.ndarray:
"""Resize the weights vector.
Resizing is needed to perform Covariance Intersection with a vector of weights
which only has one weight for all the measurements whose origin is the same node.
Returns:
np.ndarray: "Weight Projection Matrix"
"""
internal_idxs = [idx for idx, m in enumerate(self.measurements) if isinstance(m.source, AntLinkLocal)]
external_idxs = [idx for idx, m in enumerate(self.measurements) if not isinstance(m.source, AntLinkLocal)]
index = 0
rows = len(self.measurements)
cols = len(external_idxs)
if len(internal_idxs) > 0:
cols += 1
mat = np.zeros((rows, cols))
for row, measurement in enumerate(self.measurements):
if isinstance(measurement.source, AntLinkLocal):
mat[row, -1] = 1
else:
mat[row, index] = 1
index += 1
return mat
[docs]
def solve(self) -> None:
"""Solve the data fusion problem using Covariance Intersection.
This function determines the weights fot the diferent external + a single weight for the internal
measurements. Then a weighted sum of the sources of information is performed.
Returns:
None.
"""
if len(self.measurements) == 0:
raise ValueError("No measurements available for optimization.")
# Initial state (working_point) infered with all weights equal to 1
r: InternalState | None = self.weighted_inference(measmt_weights=np.ones(len(self.measurements)))
if r is None:
working_point: InternalState = self.initial_state
else:
working_point = r
w_full = self._calculate_weights(working_point)
# Update the state with the result from data fusion.
result = self.weighted_inference(measmt_weights=w_full, save_results=True)
if result is None:
logger.debug("solve did not return a value")
else:
self.up_to_date = True
def _calculate_weights(self, working_point: InternalState) -> np.ndarray:
info_mats = [m.get_state_info_mat(working_point) for m in self.measurements]
weight_proj_mat = self.get_weight_projection_mat()
if weight_proj_mat.shape[1] > 1:
# Cost function for CI: we want to have an unbound problem from the perspective of the optimizer,
# but we have a bound problem with x in [0,1]^n and sum(x)=1
# that is why remove one of the weights, add a (constant!) 1,
# and then apply softmax which turns RR^n -> [0,1]^n for us
# after optimizing we have to repeat this to get our x in [0,1]^n
def ci_det_costfunction(weights_minus1: np.ndarray) -> np.ndarray:
weights = np.concatenate((weights_minus1, np.ones(1)), axis=0)
# Softmax returns a combination of values [0-1] that sum up to one.
soft_max_w = self.softmax(weights)
projected_weights = weight_proj_mat @ soft_max_w
return -np.linalg.slogdet(
# We sum up the information matrices weighted by the softmax weights,
# and determinate the logaritmic determinate of the result.
np.add.reduce([softmax * info for (softmax, info) in zip(projected_weights.tolist(), info_mats)])
)[1]
# Reduce thee number of weight by one, so that we constraint one of the weight to be equal to 1.
num_source_info = weight_proj_mat.shape[1] - 1
w0 = np.ones(num_source_info)
w_ret = optimize.minimize(
ci_det_costfunction, # type: ignore
w0,
bounds=optimize.Bounds(ub=np.ones(num_source_info) * 20, lb=-np.ones(num_source_info) * 20), # type: ignore
method="Nelder-Mead",
)
# Here we introduce the constrained weight to be equal to 1.
weights = np.concatenate((w_ret.x, np.ones(1)), axis=0)
soft_max_w = self.softmax(weights)
w_full = weight_proj_mat @ soft_max_w
else:
# If we only have one source of information, we just set the weight to 1.
w_full = np.ones(len(self.measurements))
return w_full
[docs]
def evaluate_operators(self, state: InternalState) -> list[np.ndarray]:
"""Composite model for the AntObjectReference.
It combines the outputs of all models
Args:
state (np.ndarray): State vector.
Returns:
np.ndarray: The composite model output.
"""
return [m.operator.evaluate(state) for m in self.measurements]
[docs]
def operator_jacobians(self, state: InternalState) -> list[np.ndarray]:
"""Jacobian of the composite model for the AntObjectReference.
It combines the jacobians of all models
"""
return [m.operator.get_jacobian(state) for m in self.measurements]
[docs]
def weighted_inference(self, measmt_weights: np.ndarray, save_results: bool = False) -> InternalState | None:
"""Weighted inference for the AntObjectReference.
Calculation of the weights that determine the influence of each measurement in the
data fusion. With these weights the measurements are combined as information matrices
that are used.
Args:
measmt_weights (np.ndarray): Weights for the measurements.
save_results: save state and state jacobians for AntObjectReferences and Measurements?
Returns:
Information: The infered state.
"""
if len(measmt_weights) != len(self.measurements):
raise ValueError("Length of measmt_weights must be equal to the number of measurements")
if len(self.measurements) == 0:
raise ValueError("no measurements available")
# Aglomerated whitening operators and exchange states.
whitening_operators = []
exchange_states = []
for idx, m in enumerate(self.measurements):
whitening_operators.append(np.sqrt(measmt_weights[idx]) * m.exchange_vector.get_whitening_op())
exchange_states.append(m.exchange_vector.mean)
# cost function to determine the weights.
# state_mean: n, output: m -> jac: (m,n)
def cost_function(state_mean: np.ndarray) -> np.ndarray:
operator_results = self.evaluate_operators(InternalState(state_mean, np.array([])))
results = []
for whitening_operator, operator_result, exchange_state in zip(
whitening_operators, operator_results, exchange_states
):
results.append(
(
whitening_operator
@ (
operator_result.reshape(operator_result.shape[0], 1)
- exchange_state.reshape(exchange_state.shape[0], 1)
)
).flatten()
)
result = np.concat(results)
return result
# Cost function jacobian to determine the covariance.
def cost_function_jac(state_mean: np.ndarray) -> np.ndarray:
operator_jacobians = self.operator_jacobians(InternalState(state_mean, np.array([])))
results = []
for whitening_operator, operator_jacobian in zip(whitening_operators, operator_jacobians):
results.append(whitening_operator @ operator_jacobian)
result = np.concat(results, axis=0)
return result
if self.optimizer_lsqs_method == "trf": # TODO: in the c++ version well use "fides" method
# (Trust Region Reflective for boundary costrained optimization) that performs better.
opt_result = optimize.least_squares(
cost_function,
self.state.mean,
jac=cost_function_jac, # type: ignore
verbose=0,
ftol=1e-12,
xtol=1e-12,
gtol=1e-12,
max_nfev=100,
method=self.optimizer_lsqs_method,
bounds=(self.lower_bound_x, self.upper_bound_x),
)
else:
opt_result = optimize.least_squares(
cost_function,
self.state.mean,
jac=cost_function_jac, # type: ignore
verbose=0,
ftol=1e-10,
xtol=1e-10,
max_nfev=100,
method=self.optimizer_lsqs_method, # type: ignore
)
inferred_state = opt_result["x"]
if not opt_result.success:
logger.error(f"Optimization failed: {opt_result.message}")
return None
# Calculation of the covariance of the the inferred state. This can just be the inverse of the
# information matrices, because those are calculated with the initial guess of the weights, but once we
# have the state the jacobian of the loss can be evaluated on this point and in terms of that cov = (L.T@L)^-1.
cost_fn_jac = cost_function_jac(inferred_state)
inferred_state_cov = np.linalg.inv(cost_fn_jac.T @ cost_fn_jac)
if save_results:
l_mat = block_diag(*whitening_operators)
jac_b = -l_mat
datafusion_jacobian = -np.linalg.solve(
cost_fn_jac.transpose() @ cost_fn_jac, cost_fn_jac.transpose() @ jac_b
)
col_index = 0
for m in self.measurements:
cols = len(m.exchange_vector.mean)
m.state_jacobian = datafusion_jacobian[:, col_index : col_index + cols]
col_index += cols
self.state = InternalState(inferred_state, inferred_state_cov)
self.update()
return InternalState(inferred_state, inferred_state_cov)
[docs]
def backpropagate_from_neighbor(
self, gradient_wrt_exchange_vector: np.ndarray, loss_channel: LossChannel, node_id: NodeId
) -> None:
"""Backpropagate the gradient from a source (neighbor, another AntObjReference) to the internal state."""
op = None
for sub in self.info_subscribers:
if sub.source.id.node_id == node_id:
op = sub.operator
break
assert op is not None, "Operator not found in info subscribers."
jacobian = op.get_jacobian(self.state)
# loss_gradient: Gradient w.r.t. the state of this Node.
loss_gradient = jacobian.T @ gradient_wrt_exchange_vector
# if mode is Probe: save it in the probe gradient storage
if loss_channel.mode == FuturePastProbe.PROBE:
if loss_channel in self.probe_storage.storage:
self.probe_storage.storage[loss_channel] += loss_gradient
else:
self.probe_storage.storage[loss_channel] = loss_gradient
self.backpropagate(loss_gradient, loss_channel)
[docs]
def backpropagate(self, loss_gradient: np.ndarray, loss_channel: LossChannel) -> None:
"""Backpropagate the translated loss gradients through the measurements."""
in_past = self.parent_node.object_ref_in_past(self)
future_gradient = loss_channel.mode == FuturePastProbe.FUTURE
if future_gradient and in_past:
return
for m in self.measurements:
m.backpropagate(self.parent_node, loss_gradient, loss_channel)
[docs]
def generate_gradient_messages(self) -> list[GradientMessage]:
"""Generate gradient messages from measurements."""
ret: list[GradientMessage] = []
for m in self.measurements:
ret.extend(m.source.generate_gradient_messages(self.parent_node, m.gradient_storage))
return ret
[docs]
def generate_info_messages(self) -> list[InfoMessage]:
"""Generate info messages for info subscribers."""
ret: list[InfoMessage] = []
for sub in self.info_subscribers:
msg = sub.generate_info_message(self.parent_node, self.state)
if msg is not None:
ret.append(msg)
return ret
[docs]
def add_new_measurement(self, operator: Operator, exchange_state: ExchangeVector, source: AntLink) -> None:
"""Add or overwrite a measurement.
1. locals: always just add (they come from the default_info_packets)
2. external: always overwrite if there is something to overwrite
3. internal: always overwrite if there is something to overwrite
"""
old_id = None
for i, m in enumerate(self.measurements):
if m.source == source:
old_id = i
break
# this is only a default so we have the right dimensions here
default_state_jacobian = np.zeros((self.initial_state.mean.shape[0], exchange_state.mean.shape[0]))
measurement = Measurement(self.object_id, operator, exchange_state, default_state_jacobian, source)
if old_id is None or isinstance(source, AntLinkLocal):
# create new measurement
self.measurements.append(measurement)
else:
# replace old measurement
self.measurements[old_id] = measurement
self.up_to_date = False
[docs]
def update(self) -> None:
"""Like an info message, but only through history.
we do not need to send it over the network
this is only done for AntLinkInternalObjectID - the others have different Nodes as target
"""
for sub in self.info_subscribers:
if isinstance(sub.source, AntLinkInternalObjectID):
assert sub.source.id.node_id == self.parent_node.node_id, (
"you added an AntLinkInternalObjectID with a different Node?!"
)
if sub.source.id.object_id is not None:
ref = self.parent_node.find_object_reference(sub.source.id.object_id)
if ref is not None:
exch_vec: ExchangeVector = sub.operator.get_exchange_vector(self.state)
source = AntLinkInternalObjectID(Id(self.parent_node.node_id, self.object_id))
ref.add_new_measurement(self.parent_node.internal_operator, exch_vec, source)
[docs]
def optimize(self) -> None:
for m in self.measurements:
m.optimize_step()