Source code for linear_operator.operators.added_diag_linear_operator

#!/usr/bin/env python3

from __future__ import annotations

import warnings
from typing import Callable, List, Optional, Tuple, Union

import torch
from jaxtyping import Float
from torch import Tensor

from linear_operator import settings
from linear_operator.operators._linear_operator import LinearOperator
from linear_operator.operators.diag_linear_operator import ConstantDiagLinearOperator, DiagLinearOperator
from linear_operator.operators.psd_sum_linear_operator import PsdSumLinearOperator
from linear_operator.operators.root_linear_operator import RootLinearOperator
from linear_operator.operators.sum_linear_operator import SumLinearOperator
from linear_operator.utils.memoize import cached
from linear_operator.utils.warnings import NumericalWarning


[docs]class AddedDiagLinearOperator(SumLinearOperator): """ A :class:`~linear_operator.operators.SumLinearOperator`, but of only two linear operators, the second of which must be a :class:`~linear_operator.operators.DiagLinearOperator`. :param linear_ops: The LinearOperator, and the DiagLinearOperator to add to it. :param preconditioner_override: A preconditioning method to be used with conjugate gradients. If not provided, the default preconditioner (based on the partial pivoted Cholesky factorization) will be used (see `Gardner et al., NeurIPS 2018`_ for details). .. _Gardner et al., NeurIPS 2018: https://arxiv.org/abs/1809.11165 """ def __init__( self, *linear_ops: Union[Tuple[LinearOperator, DiagLinearOperator], Tuple[DiagLinearOperator, LinearOperator]], preconditioner_override: Optional[Callable] = None, ): linear_ops = list(linear_ops) super(AddedDiagLinearOperator, self).__init__(*linear_ops, preconditioner_override=preconditioner_override) if len(linear_ops) > 2: raise RuntimeError("An AddedDiagLinearOperator can only have two components") if isinstance(linear_ops[0], DiagLinearOperator) and isinstance(linear_ops[1], DiagLinearOperator): raise RuntimeError( "Trying to lazily add two DiagLinearOperators. Create a single DiagLinearOperator instead." ) elif isinstance(linear_ops[0], DiagLinearOperator): self._diag_tensor = linear_ops[0] self._linear_op = linear_ops[1] elif isinstance(linear_ops[1], DiagLinearOperator): self._diag_tensor = linear_ops[1] self._linear_op = linear_ops[0] else: raise RuntimeError( "One of the LinearOperators input to AddedDiagLinearOperator must be a DiagLinearOperator!" ) self.preconditioner_override = preconditioner_override # Placeholders self._constant_diag = None self._noise = None self._piv_chol_self = None # <- Doesn't need to be an attribute, but used for testing purposes self._precond_lt = None self._precond_logdet_cache = None self._q_cache = None self._r_cache = None def _matmul( self: Float[LinearOperator, "*batch M N"], rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: return torch.addcmul(self._linear_op._matmul(rhs), self._diag_tensor._diag.unsqueeze(-1), rhs) def add_diagonal( self: Float[LinearOperator, "*batch N N"], diag: Union[Float[torch.Tensor, "... N"], Float[torch.Tensor, "... 1"], Float[torch.Tensor, ""]], ) -> Float[LinearOperator, "*batch N N"]: return self.__class__(self._linear_op, self._diag_tensor.add_diagonal(diag)) def __add__( self: Float[LinearOperator, "... #M #N"], other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float], ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]: from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): return self.__class__(self._linear_op, self._diag_tensor + other) else: return self.__class__(self._linear_op + other, self._diag_tensor) def _preconditioner(self) -> Tuple[Optional[Callable], Optional[LinearOperator], Optional[torch.Tensor]]: r""" Here we use a partial pivoted Cholesky preconditioner: K \approx L L^T + D where L L^T is a low rank approximation, and D is a diagonal. We can compute the preconditioner's inverse using Woodbury (L L^T + D)^{-1} = D^{-1} - D^{-1} L (I + L D^{-1} L^T)^{-1} L^T D^{-1} :return: - A function `precondition_closure` that computes the solve (L L^T + D)^{-1} x - A LinearOperator `precondition_lt` that represents (L L^T + D) - The log determinant of (L L^T + D) """ if self.preconditioner_override is not None: return self.preconditioner_override(self) if settings.max_preconditioner_size.value() == 0 or self.size(-1) < settings.min_preconditioning_size.value(): return None, None, None # Cache a QR decomposition [Q; Q'] R = [D^{-1/2}; L] # This makes it fast to compute solves and log determinants with it # # Through woodbury, (L L^T + D)^{-1} reduces down to (D^{-1} - D^{-1/2} Q Q^T D^{-1/2}) # Through matrix determinant lemma, log |L L^T + D| reduces down to 2 log |R| if self._q_cache is None: max_iter = settings.max_preconditioner_size.value() self._piv_chol_self = self._linear_op.pivoted_cholesky(rank=max_iter) if torch.any(torch.isnan(self._piv_chol_self)).item(): warnings.warn( "NaNs encountered in preconditioner computation. Attempting to continue without preconditioning.", NumericalWarning, ) return None, None, None self._init_cache() # NOTE: We cannot memoize this precondition closure as it causes a memory leak def precondition_closure(tensor): # This makes it fast to compute solves with it qqt = self._q_cache.matmul(self._q_cache.mT.matmul(tensor)) if self._constant_diag: return (1 / self._noise) * (tensor - qqt) return (tensor / self._noise) - qqt return (precondition_closure, self._precond_lt, self._precond_logdet_cache) def _init_cache(self): *batch_shape, n, k = self._piv_chol_self.shape self._noise = self._diag_tensor._diagonal().unsqueeze(-1) # the check for constant diag needs to be done carefully for batches. noise_first_element = self._noise[..., :1, :] self._constant_diag = torch.equal(self._noise, noise_first_element * torch.ones_like(self._noise)) eye = torch.eye(k, dtype=self._piv_chol_self.dtype, device=self._piv_chol_self.device) eye = eye.expand(*batch_shape, k, k) if self._constant_diag: self._init_cache_for_constant_diag(eye, batch_shape, n, k) else: self._init_cache_for_non_constant_diag(eye, batch_shape, n) self._precond_lt = PsdSumLinearOperator(RootLinearOperator(self._piv_chol_self), self._diag_tensor) def _init_cache_for_constant_diag(self, eye: Tensor, batch_shape: Union[torch.Size, List[int]], n: int, k: int): # We can factor out the noise for for both QR and solves. self._noise = self._noise.narrow(-2, 0, 1) self._q_cache, self._r_cache = torch.linalg.qr( torch.cat((self._piv_chol_self, self._noise.sqrt() * eye), dim=-2) ) self._q_cache = self._q_cache[..., :n, :] # Use the matrix determinant lemma for the logdet, using the fact that R'R = L_k'L_k + s*I logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2) logdet = logdet + (n - k) * self._noise.squeeze(-2).squeeze(-1).log() self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze() def _init_cache_for_non_constant_diag(self, eye: Tensor, batch_shape: Union[torch.Size, List[int]], n: int): # With non-constant diagonals, we cant factor out the noise as easily self._q_cache, self._r_cache = torch.linalg.qr( torch.cat((self._piv_chol_self / self._noise.sqrt(), eye), dim=-2) ) self._q_cache = self._q_cache[..., :n, :] / self._noise.sqrt() # Use the matrix determinant lemma for the logdet, using the fact that R'R = L_k'L_k + s*I logdet = self._r_cache.diagonal(dim1=-1, dim2=-2).abs().log().sum(-1).mul(2) logdet -= (1.0 / self._noise).log().sum([-1, -2]) self._precond_logdet_cache = logdet.view(*batch_shape) if len(batch_shape) else logdet.squeeze() @cached(name="svd") def _svd( self: Float[LinearOperator, "*batch N N"] ) -> Tuple[Float[LinearOperator, "*batch N N"], Float[Tensor, "... N"], Float[LinearOperator, "*batch N N"]]: if isinstance(self._diag_tensor, ConstantDiagLinearOperator): U, S_, V = self._linear_op.svd() S = S_ + self._diag_tensor._diagonal() return U, S, V return super()._svd() def _symeig( self: Float[LinearOperator, "*batch N N"], eigenvectors: bool = False, return_evals_as_lazy: Optional[bool] = False, ) -> Tuple[Float[Tensor, "*batch M"], Optional[Float[LinearOperator, "*batch N M"]]]: if isinstance(self._diag_tensor, ConstantDiagLinearOperator): evals_, evecs = self._linear_op._symeig(eigenvectors=eigenvectors) evals = evals_ + self._diag_tensor._diagonal() return evals, evecs return super()._symeig(eigenvectors=eigenvectors) def evaluate_kernel(self): added_diag_linear_op = self.representation_tree()(*self.representation()) return added_diag_linear_op._linear_op + added_diag_linear_op._diag_tensor