#!/usr/bin/env python3
from __future__ import annotations
import warnings
from typing import Callable, List, Optional, Tuple, Union
import torch
from jaxtyping import Float
from torch import Tensor
from linear_operator import settings
from linear_operator.operators._linear_operator import LinearOperator
from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator
from linear_operator.operators.psd_sum_linear_operator import PsdSumLinearOperator
from linear_operator.operators.root_linear_operator import RootLinearOperator
from linear_operator.operators.sum_linear_operator import SumLinearOperator
from linear_operator.utils.memoize import cached
from linear_operator.utils.warnings import NumericalWarning
[docs]
class AddedDiagLinearOperator(SumLinearOperator):
"""
A :class:`~linear_operator.operators.SumLinearOperator`, but of only two
linear operators, the second of which must be a
:class:`~linear_operator.operators.DiagLinearOperator`.
:param linear_ops: The LinearOperator, and the DiagLinearOperator to add to it.
:param preconditioner_override: A preconditioning method to be used with conjugate gradients.
If not provided, the default preconditioner (based on the partial pivoted Cholesky factorization) will be used
(see `Gardner et al., NeurIPS 2018`_ for details).
.. _Gardner et al., NeurIPS 2018:
https://arxiv.org/abs/1809.11165
"""
def __init__(
self,
*linear_ops: Union[Tuple[LinearOperator, DiagLinearOperator], Tuple[DiagLinearOperator, LinearOperator]],
preconditioner_override: Optional[Callable] = None,
):
linear_ops = list(linear_ops)
super(AddedDiagLinearOperator, self).__init__(*linear_ops, preconditioner_override=preconditioner_override)
if len(linear_ops) > 2:
raise RuntimeError("An AddedDiagLinearOperator can only have two components")
if isinstance(linear_ops[0], DiagLinearOperator) and isinstance(linear_ops[1], DiagLinearOperator):
raise RuntimeError(
"Trying to lazily add two DiagLinearOperators. Create a single DiagLinearOperator instead."
)
elif isinstance(linear_ops[0], DiagLinearOperator):
self._diag_tensor = linear_ops[0]
self._linear_op = linear_ops[1]
elif isinstance(linear_ops[1], DiagLinearOperator):
self._diag_tensor = linear_ops[1]
self._linear_op = linear_ops[0]
else:
raise RuntimeError(
"One of the LinearOperators input to AddedDiagLinearOperator must be a DiagLinearOperator!"
)
self.preconditioner_override = preconditioner_override
# Placeholders
self._constant_diag = None
self._noise = None
self._piv_chol_self = None # <- Doesn't need to be an attribute, but used for testing purposes
self._precond_lt = None
self._precond_logdet_cache = None
self._q_cache = None
self._r_cache = None
def _matmul(
self: Float[LinearOperator, "*batch M N"],
rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
return torch.addcmul(self._linear_op._matmul(rhs), self._diag_tensor._diag.unsqueeze(-1), rhs)
def add_diagonal(
self: Float[LinearOperator, "*batch N N"],
diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]],
) -> Float[LinearOperator, "*batch N N"]:
return self.__class__(self._linear_op, self._diag_tensor.add_diagonal(diag))
def __add__(
self: Float[LinearOperator, "... #M #N"],
other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float],
) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]:
from linear_operator.operators.diag_linear_operator import DiagLinearOperator
if isinstance(other, DiagLinearOperator):
return self.__class__(self._linear_op, self._diag_tensor + other)
else:
return self.__class__(self._linear_op + other, self._diag_tensor)
def _preconditioner(self) -> Tuple[Optional[Callable], Optional[LinearOperator], Optional[torch.Tensor]]:
r"""
Here we use a partial pivoted Cholesky preconditioner:
K \approx L L^T + D
where L L^T is a low rank approximation, and D is a diagonal.
We can compute the preconditioner's inverse using Woodbury
(L L^T + D)^{-1} = D^{-1} - D^{-1} L (I + L D^{-1} L^T)^{-1} L^T D^{-1}
:return:
- A function `precondition_closure` that computes the solve (L L^T + D)^{-1} x
- A LinearOperator `precondition_lt` that represents (L L^T + D)
- The log determinant of (L L^T + D)
"""
if self.preconditioner_override is not None:
return self.preconditioner_override(self)
if settings.max_preconditioner_size.value() == 0 or self.size(-1) < settings.min_preconditioning_size.value():
return None, None, None
# Cache a QR decomposition [Q; Q'] R = [D^{-1/2}; L]
# This makes it fast to compute solves and log determinants with it
#
# Through woodbury, (L L^T + D)^{-1} reduces down to (D^{-1} - D^{-1/2} Q Q^T D^{-1/2})
# Through matrix determinant lemma, log |L L^T + D| reduces down to 2 log |R|
if self._q_cache is None:
max_iter = settings.max_preconditioner_size.value()
self._piv_chol_self = self._linear_op.pivoted_cholesky(rank=max_iter)
if torch.any(torch.isnan(self._piv_chol_self)).item():
warnings.warn(
"NaNs encountered in preconditioner computation. Attempting to continue without preconditioning.",
NumericalWarning,
)
return None, None, None
self._init_cache()
# NOTE: We cannot memoize this precondition closure as it causes a memory leak
def precondition_closure(tensor):
# This makes it fast to compute solves with it
qqt = self._q_cache.matmul(self._q_cache.mT.matmul(tensor))
if self._constant_diag:
return (1 / self._noise) * (tensor - qqt)
return (tensor / self._noise) - qqt
return (precondition_closure, self._precond_lt, self._precond_logdet_cache)
def _init_cache(self):
*batch_shape, n, k = self._piv_chol_self.shape
self._noise = self._diag_tensor._diagonal().unsqueeze(-1)
# the check for constant diag needs to be done carefully for batches.
noise_first_element = self._noise[..., :1, :]
self._constant_diag = torch.equal(self._noise, noise_first_element * torch.ones_like(self._noise))
eye = torch.eye(k, dtype=self._piv_chol_self.dtype, device=self._piv_chol_self.device)
eye = eye.expand(*batch_shape, k, k)
if self._constant_diag:
self._init_cache_for_constant_diag(eye, batch_shape, n, k)
else:
self._init_cache_for_non_constant_diag(eye, batch_shape, n)
self._precond_lt = PsdSumLinearOperator(RootLinearOperator(self._piv_chol_self), self._diag_tensor)
def _init_cache_for_constant_diag(self, eye: Tensor, batch_shape: Union[torch.Size, List[int]], n: int, k: int):
# We can factor out the noise for for both QR and solves.
self._noise = self._noise.narrow(-2, 0, 1)
self._q_cache, self._r_cache = torch.linalg.qr(
torch.cat((self._piv_chol_self, self._noise.sqrt() * eye), dim=-2)
)
self._q_cache = self._q_cache[..., :n, :]
# Use the matrix determinant lemma for the logdet, using the fact that R'R = L_k'L_k + s*I
logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2)
logdet = logdet + (n - k) * self._noise.squeeze(-2).squeeze(-1).log()
self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze()
def _init_cache_for_non_constant_diag(self, eye: Tensor, batch_shape: Union[torch.Size, List[int]], n: int):
# With non-constant diagonals, we cant factor out the noise as easily
self._q_cache, self._r_cache = torch.linalg.qr(
torch.cat((self._piv_chol_self / self._noise.sqrt(), eye), dim=-2)
)
self._q_cache = self._q_cache[..., :n, :] / self._noise.sqrt()
# Use the matrix determinant lemma for the logdet, using the fact that R'R = L_k'L_k + s*I
logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2)
logdet -= (1.0 / self._noise).log().sum([-1, -2])
self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze()
@cached(name="svd")
def _svd(
self: Float[LinearOperator, "*batch N N"]
) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]:
if isinstance(self._diag_tensor, ConstantDiagLinearOperator):
U, S_, V = self._linear_op.svd()
S = S_ + self._diag_tensor._diagonal()
return U, S, V
return super()._svd()
def _symeig(
self: Float[LinearOperator, "*batch N N"],
eigenvectors: bool = False,
return_evals_as_lazy: Optional[bool] = False,
) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]:
if isinstance(self._diag_tensor, ConstantDiagLinearOperator):
evals_, evecs = self._linear_op._symeig(eigenvectors=eigenvectors)
evals = evals_ + self._diag_tensor._diagonal()
return evals, evecs
return super()._symeig(eigenvectors=eigenvectors)
def evaluate_kernel(self):
added_diag_linear_op = self.representation_tree()(*self.representation())
return added_diag_linear_op._linear_op + added_diag_linear_op._diag_tensor