"""MVN estimates.
Include two subclasses to have type information about the type of MVN estimate
"""
import copy
from dataclasses import dataclass
import numpy as np
[docs]
@dataclass
class MVNEstimate:
"""Information class for storing information about the state of the system."""
mean: np.ndarray
"""Mean vector of the state"""
covariance: np.ndarray
"""Covariance matrix of the state."""
def __eq__(self, other: object) -> bool:
if not isinstance(other, MVNEstimate):
return False
mean_eq = np.all(self.mean == other.mean)
cov_eq = np.all(self.covariance == other.covariance)
# because numpy.bool[builtins.bool] != bool...
if mean_eq and cov_eq:
return True
return False
def __hash__(self) -> int:
return hash((self.mean, self.covariance))
[docs]
def copy(self) -> "MVNEstimate":
return copy.deepcopy(self)
[docs]
def get_whitening_op(self, method: str = "pca", fudge: float = 1e-18) -> np.ndarray: # type: ignore
"""Function for determining whitening matrices from covariance matrices.
Args:
cov (2D ndarray symmetric): covariance matrix
method (str, optional): Method. 'pca' Eigenvalue decomposition nearest semi-positive definite
approximation (if needed). 'chol' cholesky decomposition, only for non singular matrices.
Defaults to 'pca'.
fudge (float, optional): Small fudge parameter to avoid infinities. Defaults to 1e-18.
Returns:
2D ndarray: whitening matrix
"""
if method == "pca":
w, v = np.linalg.eigh(self.covariance)
if (w < 0).sum() > 0:
w, v = np.linalg.eigh(0.5 * self.covariance + 0.5 * self.covariance.transpose())
w[w < 0] = 0
l_mat = np.diag(1 / np.sqrt(w + fudge)) @ v.transpose()
return l_mat
if method == "chol":
l_mat = np.linalg.cholesky(np.linalg.inv(self.covariance)).transpose()
return l_mat
[docs]
class InternalState(MVNEstimate):
"""MVNEstimate representing internal state."""
pass
[docs]
class ExchangeVector(MVNEstimate):
"""MVNEstimate representing an exchange vector."""
pass