Source code for linear_operator.operators.added_diag_linear_operator

#!/usr/bin/env python3

from __future__ import annotations

import warnings
from typing import Callable, List, Optional, Tuple, Union

import torch
from jaxtyping import Float
from torch import Tensor

from .. import settings
from ..utils.memoize import cached
from ..utils.warnings import NumericalWarning
from ._linear_operator import LinearOperator
from .diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator
from .psd_sum_linear_operator import PsdSumLinearOperator
from .root_linear_operator import RootLinearOperator
from .sum_linear_operator import SumLinearOperator


[docs]class AddedDiagLinearOperator(SumLinearOperator): """ A :class:`~linear_operator.operators.SumLinearOperator`, but of only two linear operators, the second of which must be a :class:`~linear_operator.operators.DiagLinearOperator`. :param linear_ops: The LinearOperator, and the DiagLinearOperator to add to it. :param preconditioner_override: A preconditioning method to be used with conjugate gradients. If not provided, the default preconditioner (based on the partial pivoted Cholesky factorization) will be used (see `Gardner et al., NeurIPS 2018`_ for details). .. _Gardner et al., NeurIPS 2018: https://arxiv.org/abs/1809.11165 """ def __init__( self, *linear_ops: Union[Tuple[LinearOperator, DiagLinearOperator], Tuple[DiagLinearOperator, LinearOperator]], preconditioner_override: Optional[Callable] = None, ): linear_ops = list(linear_ops) super(AddedDiagLinearOperator, self).__init__(*linear_ops, preconditioner_override=preconditioner_override) if len(linear_ops) > 2: raise RuntimeError("An AddedDiagLinearOperator can only have two components") if isinstance(linear_ops[0], DiagLinearOperator) and isinstance(linear_ops[1], DiagLinearOperator): raise RuntimeError( "Trying to lazily add two DiagLinearOperators. Create a single DiagLinearOperator instead." ) elif isinstance(linear_ops[0], DiagLinearOperator): self._diag_tensor = linear_ops[0] self._linear_op = linear_ops[1] elif isinstance(linear_ops[1], DiagLinearOperator): self._diag_tensor = linear_ops[1] self._linear_op = linear_ops[0] else: raise RuntimeError( "One of the LinearOperators input to AddedDiagLinearOperator must be a DiagLinearOperator!" ) self.preconditioner_override = preconditioner_override # Placeholders self._constant_diag = None self._noise = None self._piv_chol_self = None # <- Doesn't need to be an attribute, but used for testing purposes self._precond_lt = None self._precond_logdet_cache = None self._q_cache = None self._r_cache = None def _matmul( self: Float[LinearOperator, "*batch M N"], rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: return torch.addcmul(self._linear_op._matmul(rhs), self._diag_tensor._diag.unsqueeze(-1), rhs) def add_diagonal( self: Float[LinearOperator, "*batch N N"], diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], ) -> Float[LinearOperator, "*batch N N"]: return self.__class__(self._linear_op, self._diag_tensor.add_diagonal(diag)) def __add__( self: Float[LinearOperator, "... #M #N"], other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: from .diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): return self.__class__(self._linear_op, self._diag_tensor + other) else: return self.__class__(self._linear_op + other, self._diag_tensor) def _preconditioner(self) -> Tuple[Optional[Callable], Optional[LinearOperator], Optional[torch.Tensor]]: r""" Here we use a partial pivoted Cholesky preconditioner: K \approx L L^T + D where L L^T is a low rank approximation, and D is a diagonal. We can compute the preconditioner's inverse using Woodbury (L L^T + D)^{-1} = D^{-1} - D^{-1} L (I + L D^{-1} L^T)^{-1} L^T D^{-1} :return: - A function `precondition_closure` that computes the solve (L L^T + D)^{-1} x - A LinearOperator `precondition_lt` that represents (L L^T + D) - The log determinant of (L L^T + D) """ if self.preconditioner_override is not None: return self.preconditioner_override(self) if settings.max_preconditioner_size.value() == 0 or self.size(-1) < settings.min_preconditioning_size.value(): return None, None, None # Cache a QR decomposition [Q; Q'] R = [D^{-1/2}; L] # This makes it fast to compute solves and log determinants with it # # Through woodbury, (L L^T + D)^{-1} reduces down to (D^{-1} - D^{-1/2} Q Q^T D^{-1/2}) # Through matrix determinant lemma, log |L L^T + D| reduces down to 2 log |R| if self._q_cache is None: max_iter = settings.max_preconditioner_size.value() self._piv_chol_self = self._linear_op.pivoted_cholesky(rank=max_iter) if torch.any(torch.isnan(self._piv_chol_self)).item(): warnings.warn( "NaNs encountered in preconditioner computation. Attempting to continue without preconditioning.", NumericalWarning, ) return None, None, None self._init_cache() # NOTE: We cannot memoize this precondition closure as it causes a memory leak def precondition_closure(tensor): # This makes it fast to compute solves with it qqt = self._q_cache.matmul(self._q_cache.mT.matmul(tensor)) if self._constant_diag: return (1 / self._noise) * (tensor - qqt) return (tensor / self._noise) - qqt return (precondition_closure, self._precond_lt, self._precond_logdet_cache) def _init_cache(self): *batch_shape, n, k = self._piv_chol_self.shape self._noise = self._diag_tensor._diagonal().unsqueeze(-1) # the check for constant diag needs to be done carefully for batches. noise_first_element = self._noise[..., :1, :] self._constant_diag = torch.equal(self._noise, noise_first_element * torch.ones_like(self._noise)) eye = torch.eye(k, dtype=self._piv_chol_self.dtype, device=self._piv_chol_self.device) eye = eye.expand(*batch_shape, k, k) if self._constant_diag: self._init_cache_for_constant_diag(eye, batch_shape, n, k) else: self._init_cache_for_non_constant_diag(eye, batch_shape, n) self._precond_lt = PsdSumLinearOperator(RootLinearOperator(self._piv_chol_self), self._diag_tensor) def _init_cache_for_constant_diag(self, eye: Tensor, batch_shape: Union[torch.Size, List[int]], n: int, k: int): # We can factor out the noise for for both QR and solves. self._noise = self._noise.narrow(-2, 0, 1) self._q_cache, self._r_cache = torch.linalg.qr( torch.cat((self._piv_chol_self, self._noise.sqrt() * eye), dim=-2) ) self._q_cache = self._q_cache[..., :n, :] # Use the matrix determinant lemma for the logdet, using the fact that R'R = L_k'L_k + s*I logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2) logdet = logdet + (n - k) * self._noise.squeeze(-2).squeeze(-1).log() self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze() def _init_cache_for_non_constant_diag(self, eye: Tensor, batch_shape: Union[torch.Size, List[int]], n: int): # With non-constant diagonals, we cant factor out the noise as easily self._q_cache, self._r_cache = torch.linalg.qr( torch.cat((self._piv_chol_self / self._noise.sqrt(), eye), dim=-2) ) self._q_cache = self._q_cache[..., :n, :] / self._noise.sqrt() # Use the matrix determinant lemma for the logdet, using the fact that R'R = L_k'L_k + s*I logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2) logdet -= (1.0 / self._noise).log().sum([-1, -2]) self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze() @cached(name="svd") def _svd( self: Float[LinearOperator, "*batch N N"] ) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]: if isinstance(self._diag_tensor, ConstantDiagLinearOperator): U, S_, V = self._linear_op.svd() S = S_ + self._diag_tensor._diagonal() return U, S, V return super()._svd() def _symeig( self: Float[LinearOperator, "*batch N N"], eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: if isinstance(self._diag_tensor, ConstantDiagLinearOperator): evals_, evecs = self._linear_op._symeig(eigenvectors=eigenvectors) evals = evals_ + self._diag_tensor._diagonal() return evals, evecs return super()._symeig(eigenvectors=eigenvectors) def evaluate_kernel(self): added_diag_linear_op = self.representation_tree()(*self.representation()) return added_diag_linear_op._linear_op + added_diag_linear_op._diag_tensor