#!/usr/bin/env python3
from __future__ import annotations
from typing import Optional, Tuple, Union
import torch
from torch import Tensor
from ..utils.getitem import _compute_getitem_size, _is_noop_index
from ..utils.memoize import cached
from ._linear_operator import LinearOperator
from .diag_linear_operator import ConstantDiagLinearOperator
from .triangular_linear_operator import TriangularLinearOperator
from .zero_linear_operator import ZeroLinearOperator
[docs]class IdentityLinearOperator(ConstantDiagLinearOperator):
"""
Identity linear operator. Supports arbitrary batch sizes.
:param diag_shape: The size of the identity matrix (i.e. :math:`N`).
:param batch_shape: The size of the batch dimensions. It may useful to set these dimensions for broadcasting.
:param dtype: Dtype that the LinearOperator will be operating on. (Default: :meth:`torch.get_default_dtype()`).
:param device: Device that the LinearOperator will be operating on. (Default: CPU).
"""
def __init__(
self,
diag_shape: int,
batch_shape: Optional[torch.Size] = torch.Size([]),
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
one = torch.tensor(1.0, dtype=dtype, device=device)
LinearOperator.__init__(self, diag_shape=diag_shape, batch_shape=batch_shape, dtype=dtype, device=device)
self.diag_values = one.expand(torch.Size([*batch_shape, 1]))
self.diag_shape = diag_shape
self._batch_shape = batch_shape
self._dtype = dtype
self._device = device
@property
def batch_shape(self) -> torch.Size:
return self._batch_shape
@property
def dtype(self) -> torch.dtype:
return self._dtype
@property
def device(self) -> torch.device:
return self._device
def _maybe_reshape_rhs(self, rhs: torch.Tensor) -> torch.Tensor:
if self._batch_shape != rhs.shape[:-2]:
batch_shape = torch.broadcast_shapes(rhs.shape[:-2], self._batch_shape)
return rhs.expand(*batch_shape, *rhs.shape[-2:])
else:
return rhs
@cached(name="cholesky", ignore_args=True)
def _cholesky(self, upper: Optional[bool] = False) -> TriangularLinearOperator:
return self
def _cholesky_solve(self, rhs: torch.Tensor) -> torch.Tensor:
return self._maybe_reshape_rhs(rhs)
def _expand_batch(self, batch_shape: torch.Size) -> LinearOperator:
return IdentityLinearOperator(
diag_shape=self.diag_shape, batch_shape=batch_shape, dtype=self.dtype, device=self.device
)
def _getitem(
self,
row_index: Union[slice, torch.LongTensor],
col_index: Union[slice, torch.LongTensor],
*batch_indices: Tuple[Union[int, slice, torch.LongTensor], ...],
) -> LinearOperator:
# Special case: if both row and col are not indexed, then we are done
if _is_noop_index(row_index) and _is_noop_index(col_index):
if len(batch_indices):
new_batch_shape = _compute_getitem_size(self, (*batch_indices, row_index, col_index))[:-2]
res = IdentityLinearOperator(
diag_shape=self.diag_shape, batch_shape=new_batch_shape, dtype=self._dtype, device=self._device
)
return res
else:
return self
return super()._getitem(row_index, col_index, *batch_indices)
def _matmul(self, rhs: torch.Tensor) -> torch.Tensor:
return self._maybe_reshape_rhs(rhs)
def _mul_constant(self, other: Union[float, torch.Tensor]) -> LinearOperator:
return ConstantDiagLinearOperator(self.diag_values * other, diag_shape=self.diag_shape)
def _mul_matrix(self, other: Union[torch.Tensor, LinearOperator]) -> LinearOperator:
return other
def _permute_batch(self, *dims: Tuple[int, ...]) -> LinearOperator:
batch_shape = self.diag_values.permute(*dims, -1).shape[:-1]
return IdentityLinearOperator(
diag_shape=self.diag_shape, batch_shape=batch_shape, dtype=self._dtype, device=self._device
)
def _prod_batch(self, dim: int) -> LinearOperator:
batch_shape = list(self.batch_shape)
del batch_shape[dim]
return IdentityLinearOperator(
diag_shape=self.diag_shape, batch_shape=torch.Size(batch_shape), dtype=self.dtype, device=self.device
)
def _root_decomposition(self) -> LinearOperator:
return self.sqrt()
def _root_inv_decomposition(
self,
initial_vectors: Optional[torch.Tensor] = None,
test_vectors: Optional[torch.Tensor] = None,
) -> LinearOperator:
return self.inverse().sqrt()
def _size(self) -> torch.Size:
return torch.Size([*self._batch_shape, self.diag_shape, self.diag_shape])
@cached(name="svd")
def _svd(self) -> Tuple[LinearOperator, Tensor, LinearOperator]:
return self, self._diag, self
def _symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional[LinearOperator]]:
return self._diag, self
def _t_matmul(self, rhs: torch.Tensor) -> LinearOperator:
return self._maybe_reshape_rhs(rhs)
def _transpose_nonbatch(self) -> LinearOperator:
return self
def _unsqueeze_batch(self, dim: int) -> IdentityLinearOperator:
batch_shape = list(self._batch_shape)
batch_shape.insert(dim, 1)
batch_shape = torch.Size(batch_shape)
return IdentityLinearOperator(
diag_shape=self.diag_shape, batch_shape=batch_shape, dtype=self.dtype, device=self.device
)
def abs(self) -> LinearOperator:
return self
def exp(self) -> LinearOperator:
return self
def inverse(self) -> LinearOperator:
return self
def inv_quad_logdet(
self, inv_quad_rhs: Optional[torch.Tensor] = None, logdet: bool = False, reduce_inv_quad: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO: Use proper batching for inv_quad_rhs (prepand to shape rathern than append)
if inv_quad_rhs is None:
inv_quad_term = torch.empty(0, dtype=self.dtype, device=self.device)
else:
rhs_batch_shape = inv_quad_rhs.shape[1 + self.batch_dim :]
inv_quad_term = inv_quad_rhs.mul(inv_quad_rhs).sum(-(1 + len(rhs_batch_shape)))
if reduce_inv_quad:
inv_quad_term = inv_quad_term.sum(-1)
if logdet:
logdet_term = torch.zeros(self.batch_shape, dtype=self.dtype, device=self.device)
else:
logdet_term = torch.empty(0, dtype=self.dtype, device=self.device)
return inv_quad_term, logdet_term
def log(self) -> LinearOperator:
return ZeroLinearOperator(
*self._batch_shape, self.diag_shape, self.diag_shape, dtype=self._dtype, device=self._device
)
def matmul(self, other: Union[torch.Tensor, LinearOperator]) -> Union[torch.Tensor, LinearOperator]:
is_vec = False
if other.dim() == 1:
is_vec = True
other = other.unsqueeze(-1)
res = self._maybe_reshape_rhs(other)
if is_vec:
res = res.squeeze(-1)
return res
def solve(self, right_tensor: torch.Tensor, left_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
res = self._maybe_reshape_rhs(right_tensor)
if left_tensor is not None:
res = left_tensor @ res
return res
def sqrt(self) -> LinearOperator:
return self
def sqrt_inv_matmul(self, rhs: torch.Tensor, lhs: Optional[torch.Tensor] = None) -> torch.Tensor:
if lhs is None:
return self._maybe_reshape_rhs(rhs)
else:
sqrt_inv_matmul = lhs @ rhs
inv_quad = lhs.pow(2).sum(dim=-1)
return sqrt_inv_matmul, inv_quad
def type(self, dtype: torch.dtype) -> LinearOperator:
return IdentityLinearOperator(
diag_shape=self.diag_shape, batch_shape=self.batch_shape, dtype=dtype, device=self.device
)
def zero_mean_mvn_samples(self, num_samples: int) -> torch.Tensor:
base_samples = torch.randn(num_samples, *self.shape[:-1], dtype=self.dtype, device=self.device)
return base_samples