import time
from typing import List, Union, Optional
import numpy as np
import torch
from dataclasses import dataclass
@dataclass
class ObservationLinearization:
"""Stores an observation model for a LinearRecursiveEstimator.
Model is measurement = matrix * x[state_indices] + bias + eps
with eps ~ N(0,var).
"""
matrix : torch.Tensor # shape #obs x #indices or #obs x n if state_indices is None
var : torch.Tensor # shape #obs or #obs x #obs
bias : Optional[torch.Tensor] = None # shape #obs
state_indices : Optional[torch.Tensor] = None # shape #indices
def __post_init__(self):
assert self.matrix.ndim == 2
nobs = self.matrix.size(0)
if self.state_indices is not None:
assert self.state_indices.ndim == 1
assert self.state_indices.size(0) == self.matrix.size(1)
if self.bias is not None:
assert self.bias.size(0) == nobs
if isinstance(self.var,(float,int)):
self.var = torch.full((nobs,),self.var,dtype=self.matrix.dtype,device=self.matrix.device)
if self.var.ndim == 1:
assert self.var.size(0) == nobs
elif self.var.ndim == 2:
assert self.var.size(0) == nobs
assert self.var.size(1) == nobs
else:
raise ValueError("Invalid var shape")
def diag_var(self):
"""Returns the diagonal of the covariance matrix"""
if self.var.ndim == 1:
return self.var
return torch.diag(self.var)
def covar(self):
"""Returns the full covariance matrix"""
if self.var.ndim == 1:
return torch.diag(self.var)
return self.var
def predict(self, x:torch.Tensor) -> torch.Tensor:
"""Standard mean prediction. x is the state vector"""
if self.state_indices is not None:
pred = self.matrix @ x[self.state_indices]
else:
pred = self.matrix @ x
if self.bias is not None:
return pred + self.bias
return pred
def predict_batch(self, xs:torch.Tensor) -> torch.Tensor:
"""Batch mean prediction. xs is a B x N state tensor batch."""
if self.state_indices is not None:
pred = xs[:,self.state_indices] @ self.matrix.T
else:
pred = xs @ self.matrix.T
if self.bias is not None:
return pred + self.bias
return pred
def to(self, device, dtype=None):
"""
Move the model to the given device and dtype.
NOTE: this will not create a copy of the model, but only
internally change the device and dtype of the tensors.
"""
self.matrix.to(device)
self.var.to(device)
if dtype is not None:
self.matrix = self.matrix.to(dtype)
self.var = self.var.to(dtype)
if self.bias is not None:
self.bias.to(device)
if dtype is not None:
self.bias = self.bias.to(dtype)
if self.state_indices is not None:
self.state_indices.to(device)
if dtype is not None:
self.state_indices = self.state_indices.to(dtype)
return self
@staticmethod
def merge(*obs_models : 'ObservationLinearization') -> 'ObservationLinearization':
"""
Merge multiple sparse linear observation models into one obsernation model.
This method takes a list of `ObservationLinearization` objects, each of
which may observe only a subset of the full state, the return will have
a single merged observation model with the following properties:
matrix: single merged observation matrix of shape
(sum_i n_obs_i, n_unique_state_indices),
bias: concatenated bias vector of shape (sum_i n_obs_i,),
var: concatenated variance vector of shape (sum_i n_obs_i,),
state_indices: 1D array of the sorted unique state indices.
Internally, it uses `np.unique(..., return_inverse=True)` on the
concatenated state index lists to compute the overall state dimension
and to map each submatrix into the correct columns of the merged
matrix.
Args:
obs_models : List[ObservationLinearization]
List of sparse observation models to merge. Each must have
`.matrix` (Tensor), `.var` (Tensor or scalar), optional
`.bias` (Tensor), and optional `.state_indices` (1D Tensor).
If `state_indices is None`, the model is assumed to act on the
full state vector of size `matrix.size(1)`.
Returns:
merged_obs : ObservationLinearization
A new observation model whose
- `merged_obs.matrix` is shape `(total_meas, n_unique)`,
- `merged_obs.bias` is shape `(total_meas,)`,
- `merged_obs.var` is shape `(total_meas,)` (variances),
- `merged_obs.state_indices` is a 1D LongTensor of length `n_unique`.
`merged_obs.state_indices[i]` gives the original state index
corresponding to column `i` of `merged_obs.matrix`.
merged_measurement: np.ndarray
An array of shape `(total_meas,)` containing the concatenated
measurements from all models. The order of the measurements
corresponds to the order of the rows in `merged_obs.matrix`,
which is the order in the measurement_list.
"""
# 0) convert to list if a single model is passed
obs_models = list(obs_models)
# 1) how many rows each model contributes
num_meas_list = [model.matrix.shape[0] for model in obs_models]
total_meas = sum(num_meas_list)
tsr_params = { 'dtype': obs_models[0].matrix.dtype,
'device': obs_models[0].matrix.device }
# 2) gather all state_indices as one long array
state_idx_lists = [
(obs.state_indices.numpy() if obs.state_indices is not None
else np.arange(obs.matrix.size(1)))
for obs in obs_models
]
all_state_idxs = np.concatenate(state_idx_lists)
# compute unique indices + inverse map
unique_indices, inverse = np.unique(all_state_idxs, return_inverse=True)
# split inverse back per-model
splits = np.cumsum([len(idx) for idx in state_idx_lists])[:-1]
per_model_inverse = np.split(inverse, splits)
# 3) build merged observation matrix
merged_matrix = torch.zeros((total_meas, len(unique_indices)), **tsr_params)
# row block boundaries
row_ends = np.cumsum(num_meas_list)
row_starts = np.concatenate([[0], row_ends[:-1]])
for (obs, r0, r1, inv_idx) in zip(
obs_models, row_starts, row_ends, per_model_inverse
):
# obs.matrix is shape (n_obs_i, len(inv_idx))
merged_matrix[r0:r1, inv_idx] = obs.matrix
# 4) merge biases (fill with zeros if None)
merged_bias = torch.cat([
(obs.bias if obs.bias is not None else torch.zeros(n, **tsr_params))
for obs, n in zip(obs_models, num_meas_list) ], dim=0)
# 5) merge variances as vector of diag elements
# TODO: support merging full covariance matrices
merged_var = torch.cat([obs.diag_var() for obs in obs_models], dim=0)
# wrap in a new ObservationLinearization
merged_state_tensor = torch.tensor(unique_indices,
device=tsr_params['device'],
dtype=torch.long)
merged_obs = ObservationLinearization(
matrix=merged_matrix,
var=merged_var,
bias=merged_bias,
state_indices=merged_state_tensor
)
return merged_obs
def diag_AtB(A : torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""Returns the diagonal of A.T @ B with A and B both m x n matrices"""
assert A.shape == B.shape
return (A * B).sum(dim=0)
[docs]
class LinearRecursiveEstimator:
"""Base estimator for a high-dimensional stationary state with observations
as linear-Gaussian functions of state.
The state x is composed of a heterogeneous factor y and latent factor q, both
of which are optional:
.. math::
x = y + A*q \\in \\mathbb{R}^n
where `y` in :math:`\\mathbb{R}^n` is a heterogeneous term and `q` in :math:`\\mathbb{R}^k`
is a latent factor.
If the latent factor is included, A is a known basis matrix. Diagonal priors
should be defined over y and q,
.. math::
y \\sim N(\\mu_y,\\text{diag}(\\Sigma_y)), q \\sim N(\\mu_q,\\text{diag}(\\Sigma_q))
such that `x` has the prior `N(\\mu_y + A*\\mu_q, diag(\\Sig_y) + A * diag(\\Sig_q) * A^T)`.
To mark the heterogeneous factor as optional, set Sig_y = 0. To mark the
latent factor as optional, set Sig_q = 0 or A = None.
Observations are given by vectors z^i, observation matrices W^i, optional
indices ind^i, biases z0^i, and covariances :math:`Sigma_z^i` such that
.. math::
z^i = W^i * x[ind^i] + z0^i + \\epsilon^i, \\epsilon^i \\sim N(0,\\Sigma_z^i).
This class provides shared methods for different instantiations of the
estimator. It gives a replay buffer and
Unified interface:
1. **add_observation(obs_model, measurement)**
Adds an observation model and measurement to the buffer.
2. **update_estimation()**
Reads the most recent measurements. This function is implementation-dependent.
3. **finalize_estimate()**
Updates `y_mu`, `y_var`, `q_mu`, `q_var` so the `x` mean/variance can be extracted.
This function is implementation-dependent.
NOTE: for numerical stability purposes, all torch tensors in this class are double precision.
Attributes:
obs_buffer: the observation buffer, stores (observation model, measurement) pairs
max_buffer_len: the max length of the observation buffer, or None for unlimited buffer.
y_mu, y_var: the mean and variance of the heterogeneous factor y
q_mu, q_var: the mean and variance of the latent factor q
A: the basis matrix for the latent factor q
"""
def __init__(self, max_dim:int,
x_mu:Union[float,torch.Tensor]=0.0, x_var:Union[float,torch.Tensor]=0.0,
latent_mu:Union[float,torch.Tensor]=0.0, latent_var:Union[float,torch.Tensor]=0.0,
latent_basis:Optional[torch.Tensor]=None,
max_buffer_len:Optional[int] = None) -> None:
self.max_buffer_len = max_buffer_len
self.obs_buffer = [] # type : List[Tuple[ObservationLinearization,torch.Tensor]]
if isinstance(x_mu,torch.Tensor):
device = x_mu.device
else:
device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
# NOTE: all tensors are double precision for numerical stability
self.tsr_params = {'dtype': torch.double, 'device': device}
if isinstance(x_mu,torch.Tensor):
assert x_var.ndim == 1,"Mean must be a vector"
self.y_mu = x_mu.to(**self.tsr_params)
else:
self.y_mu : torch.Tensor = torch.full((max_dim,), x_mu, **self.tsr_params)
if isinstance(x_var,torch.Tensor):
assert x_var.ndim == 1,"Variance must be a per-point variance"
self.y_var = x_var.to(**self.tsr_params)
else:
self.y_var : torch.Tensor = torch.full((max_dim,), x_var, **self.tsr_params)
assert len(self.y_mu) == len(self.y_var)
# create the latent basis matrix
if latent_basis is not None:
assert latent_basis.ndim == 2
assert latent_basis.shape[0] == max_dim
self.A : torch.Tensor = latent_basis
if self.A is not None:
self.A = self.A.to(**self.tsr_params)
latent_dim = 0 if latent_basis is None else latent_basis.shape[1]
if isinstance(latent_mu,torch.Tensor):
assert latent_var.ndim == 1,"Latent mean must be a vector"
self.q_mu = latent_mu.to(**self.tsr_params)
else:
self.q_mu : torch.Tensor = torch.full((latent_dim,), latent_mu, **self.tsr_params)
assert self.q_mu.size(0) == latent_dim
if isinstance(latent_var,torch.Tensor):
assert latent_var.ndim == 1,"Latent variance must be a vector"
self.q_var = latent_var.to(**self.tsr_params)
else:
self.q_var : torch.Tensor = torch.full((latent_dim,), latent_var, **self.tsr_params)
assert self.q_var.size(0) == latent_dim
[docs]
def to(self, device):
self.y_mu.to(device)
self.y_var.to(device)
self.q_mu.to(device)
self.q_var.to(device)
if self.A is not None:
self.A.to(device)
return self
[docs]
def num_obs(self):
"""Returns the number of observations in the buffer."""
return len(self.obs_buffer)
[docs]
def clear_obs_buffer(self):
"""Clear the observation buffer, set self.obs_buffer to an empty list."""
self.obs_buffer = []
[docs]
def add_observation(self, model : ObservationLinearization, measurement : torch.Tensor):
"""Add an observation to the buffer."""
assert isinstance(measurement,torch.Tensor)
# estimator only supports 1D measurements
# so we flatten the measurement tensor
if measurement.ndim > 1:
measurement = measurement.flatten()
assert measurement.device == model.matrix.device,'Inconsistent tensor devices in observation'
self.obs_buffer.append((model,measurement))
if self.max_buffer_len is not None and self.num_obs() > self.max_buffer_len: # fifo buffer
self.obs_buffer.pop(0)
[docs]
def get_mean(self, idx=None) -> torch.Tensor:
"""Returns the estimated x mean."""
if idx is None:
if self.A is not None:
return self.y_mu + self.A @ self.q_mu
else:
return self.y_mu.clone()
else:
x = self.y_mu[idx]
if self.A is not None:
return x + self.A[idx, :] @ self.q_mu
else:
return x.clone()
[docs]
def get_var(self, idx=None) -> torch.Tensor:
"""Returns the diagonal variance of x."""
# Return the full covariance matrix if idx is None
if idx is None:
if self.A is not None:
return self.y_var + (self.A**2) @ self.q_var
else:
return self.y_var.clone()
# Return the diagonal variance of the selected indices
else:
var = self.y_var[idx]
if self.A is not None:
A = self.A[idx, :]
return var + (A**2) @ self.q_var
else:
return var.clone()
[docs]
def get_covar(self, idx=None) -> torch.Tensor:
"""Returns the full covariance matrix of x."""
if idx is None:
if self.A is not None:
return torch.diag(self.y_var) + self.A @ torch.diag(self.q_var) @ self.A.T
else:
return torch.diag(self.y_var)
else:
cov = torch.diag(self.y_var[idx])
if self.A is not None:
A = self.A[idx, :]
return cov + A @ torch.diag(self.q_var) @ A.T
else:
return cov
[docs]
def update_estimation(self):
"""Update the estimation based on the current observation buffer"""
raise NotImplementedError
[docs]
def finalize_estimation(self):
"""Finalize the estimation after all observations are added"""
pass
[docs]
def predict_observation(self, obs : ObservationLinearization) -> torch.Tensor:
"""Predicts the mean observation at the current state."""
res = obs.matrix @ self.get_mean(idx=obs.state_indices)
if obs.bias is not None:
return res + obs.bias
return res
[docs]
def get_observed_indices(self) -> torch.LongTensor:
"""Returns all the indices observed across all observations in the buffer """
touch_idx = torch.concatenate([o.state_indices for o,m in self.obs_buffer])
return torch.unique(touch_idx)
[docs]
def get_unobserved_indices(self) -> torch.LongTensor:
"""Complement of get_observed_indices.
NOTE: converts to CPU
"""
touch_idx = self.get_observed_indices()
return torch.tensor(np.setdiff1d(np.arange(len(self.x), touch_idx.cpu().numpy())),device=touch_idx.device,dtype=touch_idx.dtype)
[docs]
def state_dict(self):
"""Subclasses may override this to save additional state."""
res = {'y_mu': self.y_mu, 'y_var': self.y_var, 'q_mu': self.q_mu, 'q_var': self.q_var}
if self.A is not None:
res['A'] = self.A
return res
[docs]
def load_state_dict(self, state_dict : dict):
"""Subclasses may override this to load additional state."""
self.y_mu = state_dict['y_mu']
self.y_var = state_dict['y_var']
self.q_mu = state_dict['q_mu']
self.q_var = state_dict['q_var']
if 'A' in state_dict:
self.A = state_dict['A']
class SGDEstimator(LinearRecursiveEstimator):
"""
Stochastic Gradient Descent (SGD) estimator for linear estimation problem.
Uses a replay buffer with the given batch_size for each update step.
If you don't use finalize_estimation(), you will need to detach the mean
/ variance tensors before using them.
"""
def __init__(self, max_dim:int,
x_mu:Union[float,torch.Tensor]=0.0, x_var:Union[float,torch.Tensor]=0.0,
latent_mu:Union[float,torch.Tensor]=0.0, latent_var:Union[float,torch.Tensor]=0.0,
latent_basis:Optional[torch.Tensor]=None,
max_buffer_len:int = None,
non_negative:bool=True, batch_size:int=100) -> None:
super().__init__(max_dim, x_mu,x_var,latent_mu,latent_var,latent_basis,max_buffer_len)
include_heterogeneous = (isinstance(x_var,torch.Tensor) or x_var != 0.0)
include_latent = self.A is not None and (isinstance(latent_var,torch.Tensor) or latent_var != 0.0)
if include_heterogeneous:
params_lst = [self.y_mu]
else:
params_lst = []
if include_latent:
params_lst += [self.q_mu]
for p in params_lst:
p.requires_grad_()
self.optimizer = torch.optim.Adam(params_lst, lr=1e-3)
self.observation_loss = torch.nn.GaussianNLLLoss(full=True)
self.regularization_loss = torch.nn.GaussianNLLLoss(full=True)
self.batch_size = batch_size
self.non_negative = non_negative
def sgd_step(self, obs_model : ObservationLinearization, measurement : torch.Tensor) -> float:
"""
Rune a single stochastic gradient descent step on the given observation.
"""
tau_hat = self.predict_observation(obs_model)
loss = self.observation_loss(measurement, tau_hat, obs_model.var)
loss += self.regularization_loss(self.y_mu, torch.zeros_like(self.y_mu), self.y_var)
if self.A is not None:
loss += self.regularization_loss(self.q_mu, torch.zeros_like(self.q_mu), self.q_var)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def update_estimation(self, verbose=False):
"""
SGD update step on the observation buffer.
The estimator will sample a batch of observations from the buffer and
run a single SGD step on each observation.
"""
self.q_mu.requires_grad_(True)
self.y_mu.requires_grad_(True)
num_sample = min(self.batch_size, self.num_obs())
step_idx_ary = np.random.randint(0, self.num_obs(), num_sample)
for step_idx in step_idx_ary:
loss = self.sgd_step(*self.obs_buffer[step_idx])
if verbose:
print("loss:", loss)
if self.non_negative:
with torch.no_grad():
self.y_mu.clamp_(min=0)
if self.A is not None:
self.q_mu.clamp_(min=0)
def finalize_estimation(self):
"""
SGD estimator does not change estimation during finalization.
This function will disable gradient computation on the mean tensors
to support evaluation steps.
"""
self.y_mu.requires_grad_(False)
self.q_mu.requires_grad_(False)
class DiagonalEKF(LinearRecursiveEstimator):
"""
Diagonal EKF estimator that maintains the diagonal part of the
covariance matrix for the estimation.
Note: this will use the information filter formulation with *_lam indicating
the inverse of the diagonal of the covariance matrix (diagonal precision),
and *_mu being the product of the precision matrix and the normal mean.
On the state x=(y,q), where y is heterogeneous and q is latent, the
observation model is z = W*Sel[ind]*[I | A]*x + eps
which implies the observation matrix is H = W*[Sel[ind] | A[ind:]].
TODO: math below may be incorrect, need to verify
Hence, the information filter update computes
delta_i = H^T * R^-1 z
= [ Sel[ind]^T W^T * R^-1 * z \\
A[ind:]^T W^T * R^-1 * z ]
and
variance I = H^T R^-1 H
= [Sel[ind]^T] W^T * R^-1 * W [Sel[ind] | A[ind:]] [A[ind:]^T ]
And then updates (y_lam, q_lam) += diag(I), (y_mu, q_mu) += i.
"""
def __init__(self, max_dim:int,
x_mu:Union[float,torch.Tensor]=0.0, x_var:Union[float,torch.Tensor]=0.0,
latent_mu:Union[float,torch.Tensor]=0.0, latent_var:Union[float,torch.Tensor]=0.0,
latent_basis:Optional[torch.Tensor]=None,
max_buffer_len:int = None, num_replay_update:int=10,
non_negative:bool=True, update_format='kalman') -> None:
super().__init__(max_dim, x_mu, x_var, latent_mu, latent_var, latent_basis, max_buffer_len)
self.include_heterogeneous = (isinstance(x_var,torch.Tensor) or x_var != 0.0)
self.include_latent = self.A is not None and (isinstance(latent_var,torch.Tensor) or latent_var != 0.0)
self.num_replay_update = num_replay_update
self.non_negative = non_negative
# NOTE: Diagonal EKF has DIFFERENT update results using Kalman or Information filter formulations
# The Kalman update is more numerically stable,
# but the Information filter update does not need to invert the covariance matrix.
assert update_format in ['kalman','information']
self.update_format = update_format
def diag_ekf_step(self, obs_model : ObservationLinearization,
measurement : torch.Tensor, verbose:bool=False, replay=False):
"""
Run a single Diagonal EKF update step on the given observation.
If replay is True, the estimator will not update the variance terms.
Args:
obs_model: a linear observation model
measurement: the observation value as a torch tensor
verbose: whether to print debug information
replay: whether the update is triggered by a replay
"""
if verbose:
start_time = time.time()
obs_idx = obs_model.state_indices
obs_W_tsr = obs_model.matrix
if verbose:
print("obs tau tsr:", measurement)
print('obs_W_tsr shape:', obs_W_tsr.T.shape)
# compute the residual error in the observation
z = measurement.clone()
if obs_model.bias is not None:
z -= obs_model.bias
z_hat = obs_model.predict(self.get_mean())
res_z = z - z_hat
# Compute inverse of the covariance matrix in observation noise
if obs_model.var.ndim == 1:
# obs_model.var is a scalar
Rinv = 1.0 / obs_model.var
Rinvz = Rinv * res_z
else:
Rinv = torch.linalg.inv(obs_model.var)
Rinvz = Rinv @ res_z
if verbose:
print("Rinv shape:", Rinv.shape)
print("Rinvz shape:", Rinvz.shape)
# NOTE: Very important!!! The covariance update should happen BEFORE
# the mean update, as the mean update uses the covariance matrix
# at the current time step.
# Want the diagonal of these fat (n > m) matrices
# Wt_Rinv_W = (obs_W_tsr.T @ Rinv * obs_W_tsr) if obs_model.var.ndim == 1 else (obs_W_tsr.T @ Rinv) @ obs_W_tsr
# and this self.A[obs_idx, :].T @ Wt_Rinv_W @ self.A[obs_idx, :]
if not replay:
if self.include_heterogeneous:
# Compute information matrix update for the heterogeneous factor
if obs_model.var.ndim == 1:
y_lam_inc = diag_AtB(obs_W_tsr, Rinv[:,None] * obs_W_tsr)
else:
y_lam_inc = diag_AtB(obs_W_tsr, Rinv @ obs_W_tsr)
y_lam_updated = 1.0 / self.y_var[obs_idx] + y_lam_inc
self.y_var[obs_idx] = 1.0 / y_lam_updated
if self.include_latent:
WA = obs_W_tsr @ self.A[obs_idx,:]
self.q_var += Rinv @ diag_AtB(WA,WA) if obs_model.var.ndim == 1 else diag_AtB(WA, Rinv @ WA)
# Update the mean of the state
# first get diagonal component of y_var
y_var = self.get_var(obs_idx)
if verbose:
print('y_var min, mean, max:', y_var.min(), y_var.mean(), y_var.max())
# Information filter update
if self.update_format == 'information':
y_mu_inc = y_var * (obs_W_tsr.T @ Rinvz)
# Kalman filter update
elif self.update_format == 'kalman':
sig_WT = y_var.unsqueeze(1) * obs_W_tsr.T
if obs_model.var.ndim == 1:
R_mat = torch.diag(obs_model.var)
else:
R_mat = obs_model.var
W_sig_WT = obs_W_tsr @ sig_WT
y_mu_inc = sig_WT @ torch.linalg.inv(W_sig_WT + R_mat) @ res_z
if self.include_heterogeneous:
self.y_mu[obs_idx] += y_mu_inc
if self.include_latent:
# TODO: fix latent mean update rules
q_mu_inc = (self.A[obs_idx, :].T @ obs_W_tsr @ Rinvz)
self.q_mu += q_mu_inc
if verbose:
print("number of points:", obs_idx.shape[0])
print("update step time:", time.time()-start_time)
def update_estimation(self, verbose:bool=False):
"""
Update estimation of Diagonal EKF estimator.
If reply is False, the estimator will only update
using the most recent observation.
If replay is True, the estimator will first update
using the most recent observation, then replay a fixed
number of history observations in the replay buffer.
"""
assert self.num_obs() > 0, "Need to have at least one observation"
#do the update
self.diag_ekf_step(*self.obs_buffer[-1], verbose, replay=False)
if self.num_replay_update > 0:
step_idx_ary = np.random.randint(0, self.num_obs(),
(self.num_replay_update,))
for step_idx in step_idx_ary:
self.diag_ekf_step(*self.obs_buffer[step_idx], verbose, replay=True)
#clamp
if self.non_negative:
with torch.no_grad():
self.y_mu.clamp_(min=0)
if self.A is not None:
self.q_mu.clamp_(min=0)
class DenseEKF(LinearRecursiveEstimator):
"""
Dense EKF estimator that maintains the full dense covariance
matrix for the estimation.
This will try to maintain the full dense covariance matrix for only
the observed indices, and will extend the matrix as needed.
finalize_estimation() is essential here to get the updates.
Attributes:
iter_per_update: number of iteration per update
len_in_mem: length of the memory
mem_inc: memory increase per update
idx_raw2obs: mapping from raw index to observed state index
info_mat: subset dense information matrix for (q,y) covariance
update_method: update method for the estimation
non_negative: whether the estimation is non-negative
num_replay_update: how many replays to use in each update
"""
def __init__(self, max_dim:int,
x_mu:Union[float,torch.Tensor]=0.0, x_var:Union[float,torch.Tensor]=0.0,
latent_mu:Union[float,torch.Tensor]=0.0, latent_var:Union[float,torch.Tensor]=0.0,
latent_basis:Optional[torch.Tensor]=None,
max_buffer_len:int = None, init_len_in_mem: int=100,
update_method: str='inv', num_replay_update : int=0, non_negative: bool = True) -> None:
super().__init__(max_dim, x_mu, x_var, latent_mu, latent_var, latent_basis, max_buffer_len)
self.include_heterogeneous = (isinstance(x_var,torch.Tensor) or x_var != 0.0)
self.include_latent = self.A is not None and (isinstance(latent_var,torch.Tensor) or latent_var != 0.0)
self.idx_raw2obs = -1*torch.ones(max_dim, device = self.y_mu.device, dtype=int)
self.iter_per_update = 1
self.mem_inc = 1000
self.num_latent = 0
self.info_mat = torch.eye(init_len_in_mem,device=self.y_mu.device, dtype=self.y_mu.dtype)
if self.include_heterogeneous:
self.y_info_mu = self.y_mu / self.y_var
else:
self.y_info_mu = self.y_mu
if self.include_latent:
self.num_latent = self.A.shape[1]
assert init_len_in_mem >= self.num_latent
self.info_mat.diagonal()[self.num_latent] = 1.0 / self.q_var
self.q_info_mu = self.q_mu / self.q_var
else:
self.q_info_mu = self.q_mu
self.update_method = update_method
self.non_negative = non_negative
self.num_replay_update = num_replay_update
def ekf_step(self, obs_model: ObservationLinearization, measurement: torch.Tensor, verbose=False, replay=False):
"""
Run a single EKF step that updates the full dense covariance matrix.
"""
obs_idx = obs_model.state_indices
obs_W_tsr = obs_model.matrix
z = measurement.clone()
if obs_model.bias is not None:
z -= obs_model.bias
Rinv = 1.0 / obs_model.var if obs_model.var.ndim == 1 else torch.linalg.inv(obs_model.var)
Rinvz = z * Rinv if obs_model.var.ndim == 1 else Rinv @ z
y_mu_inc = (obs_W_tsr.T @ Rinvz)
if self.include_heterogeneous:
self.y_info_mu[obs_idx] += y_mu_inc
if self.include_latent:
q_mu_inc = (self.A[obs_idx, :].T @ obs_W_tsr @ Rinvz)
self.q_info_mu += q_mu_inc
if not replay:
Wt_Rinv_W = (obs_W_tsr.T @ (Rinv[:,None] * obs_W_tsr)) if obs_model.var.ndim == 1 else (obs_W_tsr.T @ Rinv) @ obs_W_tsr
touch_idx = self.get_touched_idx(obs_idx) + self.num_latent
if self.include_latent:
self.info_mat[:self.num_latent,:self.num_latent] += (self.A[obs_idx, :].T @ Wt_Rinv_W @ self.A[obs_idx, :])
if self.include_heterogeneous:
mat_idx = torch.meshgrid(touch_idx, touch_idx, indexing='ij')
self.info_mat[mat_idx] += Wt_Rinv_W
if self.include_latent and self.include_heterogeneous:
mat_idx = torch.meshgrid(torch.arange(self.num_latent), touch_idx, indexing='ij')
self.info_mat[mat_idx] += self.A[obs_idx,:].T @ Wt_Rinv_W
mat_idx = torch.meshgrid(touch_idx, torch.arange(self.num_latent), indexing='ij')
self.info_mat[mat_idx] += Wt_Rinv_W @ self.A[obs_idx,:]
def update_estimation(self, verbose=False):
self.ekf_step(*self.obs_buffer[-1], verbose)
if self.num_replay_update > 0:
step_idx_ary = np.random.randint(0, self.num_obs(),
(self.num_replay_update,))
for step_idx in step_idx_ary:
self.ekf_step(*self.obs_buffer[step_idx], verbose, replay=True)
def finalize_estimation(self, verbose=False):
"""
Dense EKF finalization step requires solving a linear system.
For computational efficiency, the covariance matrix are not saved
during the online estimation stage. During finalization, we need
to compute the covariance matrix from the information matrix by solving
a linear system.
"""
# index conversion
sorted_raw_idx = self.get_sorted_raw_idx()
touch_info_vec = torch.concat([self.q_info_mu,self.y_info_mu[sorted_raw_idx]])
num = self.num_latent + self.num_touched()
touch_info_mat = self.info_mat[:num, :num]
if self.update_method == 'gs':
mu_hat = torch.zeros_like(touch_info_vec)
for _ in range(self.iter_per_update):
U = torch.triu(touch_info_mat, diagonal=1)
res_vec = -U @ mu_hat + touch_info_vec
torch.linalg.solve_triangular(torch.tril(touch_info_mat), res_vec,
upper=False, out=mu_hat)
elif self.update_method == 'inv':
mu_hat = torch.zeros_like(touch_info_vec)
torch.linalg.solve(touch_info_mat, touch_info_vec, out=mu_hat)
else:
raise ValueError("Invalid update_method, must be gs or inv")
if self.non_negative:
mu_hat.clamp_(min=0)
self.q_mu = mu_hat[:self.num_latent]
self.y_mu[sorted_raw_idx] = mu_hat.reshape(-1)[self.num_latent:]
def num_touched(self):
"""Return number of points in contact from the observation buffer"""
return torch.sum(self.idx_raw2obs != -1).item()
def get_touch_info_mat(self):
"""
Return the information matrix for the touched indices
NOTE: the information matrix is not the full dense matrix, we only
need to index the left-top corner of the matrix with size num_touched x num_touched.
"""
num = self.num_touched()
return self.info_mat[:num, :num]
def get_sorted_raw_idx(self):
"""Return the sorted raw index of the touched indices"""
all_raw_idx = torch.where(self.idx_raw2obs != -1)[0]
idx_touch2raw = torch.argsort(self.idx_raw2obs[all_raw_idx])
return all_raw_idx[idx_touch2raw]
def get_touched_idx(self, obs_idx:torch.Tensor):
"""
Return the touched index for the observation index
NOTE: if the observation index is not in the touched index,
this function will add the index to the touched index.
TODO: finish the increment of the information matrix.
"""
touched_idx = self.idx_raw2obs[obs_idx]
none_idx = obs_idx[touched_idx == -1]
if none_idx.size() != 0:
add_idx = torch.arange(none_idx.shape[0], device=self.y_mu.device, dtype=int)
add_idx += self.num_touched()
self.idx_raw2obs[none_idx] = add_idx
if self.num_touched() > self.info_mat.shape[0]:
#extend the information matrix with the new indices
raise NotImplementedError("Dense EKF extension implemented yet")
old_len = len(self.info_mat)
new_len = old_len+self.mem_inc
prior_info_val = self.prior_lam*self.prior_mu
new_info_mat = self.prior_lam*torch.eye(new_len, **self.tsr_params)
new_info_mat[:old_len, :old_len] = self.info_mat
# update info vec and mat
self.info_mat = new_info_mat
# re-read touched index
touched_idx = self.idx_raw2obs[obs_idx]
return touched_idx
def state_dict(self):
"""Get the state dictionary for the dense EKF estimator"""
res = super(self).state_dict()
res['info_mat'] = self.info_mat
res['y_info_mu'] = self.y_info_mu
res['q_info_mu'] = self.q_info_mu
return res
def load_state_dict(self, state_dict : dict):
"""Load the state dictionary for the dense EKF estimator"""
super(self).load_state_dict(state_dict)
self.info_mat = state_dict['info_mat']
self.y_info_mu = state_dict['y_info_mu']
self.q_info_mu = state_dict['q_info_mu']