Source code for linear_operator.operators.triangular_linear_operator

#!/usr/bin/env python3
from __future__ import annotations

from typing import Callable, TypeAlias

import torch
from torch import Tensor

from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.batch_repeat_linear_operator import BatchRepeatLinearOperator
from linear_operator.operators.dense_linear_operator import DenseLinearOperator

from linear_operator.utils.errors import NotPSDError
from linear_operator.utils.memoize import cached

Allsor: TypeAlias = Tensor | LinearOperator


class _TriangularLinearOperatorBase:
    """Base class that all triangular lazy tensors are derived from."""

    pass


[docs] class TriangularLinearOperator(LinearOperator, _TriangularLinearOperatorBase): r""" A wrapper for LinearOperators when we have additional knowledge that it represents a lower- or upper-triangular matrix (or batch of matrices). :param tensor: A `... x N x N` Tensor, representing a (batch of) `N x N` triangular matrix. :param upper: If True, the tensor is considered to be upper-triangular, otherwise lower-triangular. """ def __init__(self, tensor: Allsor, upper: bool = False) -> None: if isinstance(tensor, TriangularLinearOperator): # this is a null-op, we can just use underlying tensor directly. tensor = tensor._tensor # TODO: Use a metaclass to create a DiagLinearOperator if tensor is diagonal elif isinstance(tensor, BatchRepeatLinearOperator): # things get kind of messy when interleaving repeats and triangualrisms if not isinstance(tensor.base_linear_op, TriangularLinearOperator): tensor = tensor.__class__( TriangularLinearOperator(tensor.base_linear_op, upper=upper), batch_repeat=tensor.batch_repeat, ) if torch.is_tensor(tensor): tensor = DenseLinearOperator(tensor) super().__init__(tensor, upper=upper) self.upper = upper self._tensor = tensor def __add__( self: LinearOperator, # shape: (..., #M, #N) other: Tensor | LinearOperator | float, # shape: (..., #M, #N) ) -> LinearOperator | Tensor: # shape: (..., M, N) from linear_operator.operators.diag_linear_operator import DiagLinearOperator if isinstance(other, DiagLinearOperator): from linear_operator.operators.added_diag_linear_operator import AddedDiagLinearOperator return self.__class__(AddedDiagLinearOperator(self._tensor, other), upper=self.upper) if isinstance(other, TriangularLinearOperator) and not self.upper ^ other.upper: return self.__class__(self._tensor + other._tensor, upper=self.upper) return self._tensor + other def _cholesky( self: LinearOperator, upper: bool | None = False # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) raise NotPSDError("TriangularLinearOperator does not allow a Cholesky decomposition") def _cholesky_solve( self: LinearOperator, # shape: (*batch, N, N) rhs: LinearOperator | Tensor, # shape: (*batch2, N, M) upper: bool | None = False, ) -> LinearOperator | Tensor: # shape: (..., N, M) # use custom method if implemented try: res = self._tensor._cholesky_solve(rhs=rhs, upper=upper) except NotImplementedError: if upper: # res = (U.T @ U)^-1 @ v = U^-1 @ U^-T @ v w = self._transpose_nonbatch().solve(rhs) res = self.solve(w) else: # res = (L @ L.T)^-1 @ v = L^-T @ L^-1 @ v w = self.solve(rhs) res = self._transpose_nonbatch().solve(w) return res def _diagonal( self: LinearOperator, # shape: (..., M, N) ) -> torch.Tensor: # shape: (..., N) return self._tensor._diagonal() def _expand_batch( self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) if len(batch_shape) == 0: return self return self.__class__(tensor=self._tensor._expand_batch(batch_shape), upper=self.upper) def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: return self._tensor._get_indices(row_index, col_index, *batch_indices) def _matmul( self: LinearOperator, # shape: (*batch, M, N) rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) ) -> torch.Tensor: # shape: (..., M, C) or (..., M) return self._tensor.matmul(rhs) def _mul_constant( self: LinearOperator, other: float | torch.Tensor # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) return self.__class__(self._tensor * other.unsqueeze(-1), upper=self.upper) def _root_decomposition( self: LinearOperator, # shape: (..., N, N) ) -> torch.Tensor | LinearOperator: # shape: (..., N, N) raise NotPSDError("TriangularLinearOperator does not allow a root decomposition") def _root_inv_decomposition( self: LinearOperator, # shape: (*batch, N, N) initial_vectors: torch.Tensor | None = None, test_vectors: torch.Tensor | None = None, ) -> LinearOperator | Tensor: # shape: (..., N, N) raise NotPSDError("TriangularLinearOperator does not allow an inverse root decomposition") def _size(self) -> torch.Size: return self._tensor.shape def _solve( self: LinearOperator, # shape: (..., N, N) rhs: torch.Tensor, # shape: (..., N, C) preconditioner: Callable[[torch.Tensor], torch.Tensor] | None = None, # shape: (..., N, C) num_tridiag: int | None = 0, ) -> ( torch.Tensor # shape: (..., N, C) | tuple[ torch.Tensor, # shape: (..., N, C) torch.Tensor, # Note that in case of a tuple the second term size depends on num_tridiag # shape: (...) ] ): # already triangular, can just call solve for the solve return self.solve(rhs) def _sum_batch(self, dim: int) -> LinearOperator: return self.__class__(self._tensor._sum_batch(dim), upper=self.upper) def _transpose_nonbatch( self: LinearOperator, # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, N, M) return self.__class__(self._tensor._transpose_nonbatch(), upper=not self.upper)
[docs] def abs(self) -> LinearOperator: """ Returns a TriangleLinearOperator with the absolute value of all diagonal entries. """ return self.__class__(self._tensor.abs(), upper=self.upper)
def add_diagonal( self: LinearOperator, # shape: (*batch, N, N) diag: torch.Tensor, # shape: (..., N) or (..., 1) or () ) -> LinearOperator: # shape: (*batch, N, N) added_diag_lt = self._tensor.add_diagonal(diag) return self.__class__(added_diag_lt, upper=self.upper) def to_dense( self: LinearOperator, # shape: (*batch, M, N) ) -> Tensor: # shape: (*batch, M, N) return self._tensor.to_dense()
[docs] def exp( self: LinearOperator, # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, M, N) """ Returns a TriangleLinearOperator with all diagonal entries exponentiated. """ return self.__class__(self._tensor.exp(), upper=self.upper)
def inv_quad_logdet( self: LinearOperator, # shape: (*batch, N, N) inv_quad_rhs: Tensor | None = None, # shape: (*batch, N, M) or (*batch, N) logdet: bool | None = False, reduce_inv_quad: bool | None = True, ) -> tuple[ # fmt: off Tensor | None, # shape: (*batch, M) or (*batch) or (0) Tensor | None, # shape: (...) ]: # fmt: on if inv_quad_rhs is None: inv_quad_term = torch.empty(0, dtype=self.dtype, device=self.device) else: # triangular, solve is cheap inv_quad_term = (inv_quad_rhs * self.solve(inv_quad_rhs)).sum(dim=-2) if logdet: diag = self._diagonal() logdet_term = diag.abs().log().sum(-1) if torch.sign(diag).prod(-1) < 0: logdet_term = torch.full_like(logdet_term, float("nan")) else: logdet_term = torch.empty(0, dtype=self.dtype, device=self.device) if inv_quad_term.numel() and reduce_inv_quad: inv_quad_term = inv_quad_term.sum(-1) return inv_quad_term, logdet_term
[docs] @cached def inverse( self: LinearOperator, # shape: (*batch, N, N) ) -> LinearOperator: # shape: (*batch, N, N) """ Returns the inverse of the DiagLinearOperator. """ eye = torch.eye(self._tensor.size(-1), device=self._tensor.device, dtype=self._tensor.dtype) inv = self.solve(eye) return self.__class__(inv, upper=self.upper)
def solve( self: LinearOperator, # shape: (..., N, N) right_tensor: Tensor, # shape: (..., N, P) or (N) left_tensor: Tensor | None = None, # shape: (..., O, N) ) -> Tensor: # shape: (..., N, P) or (..., N) or (..., O, P) or (..., O) squeeze = False if right_tensor.dim() == 1: right_tensor = right_tensor.unsqueeze(-1) squeeze = True if isinstance(self._tensor, DenseLinearOperator): res = torch.linalg.solve_triangular(self.to_dense(), right_tensor, upper=self.upper) elif isinstance(self._tensor, BatchRepeatLinearOperator): res = self._tensor.base_linear_op.solve(right_tensor, left_tensor) # TODO: Proper broadcasting res = res.expand(self._tensor.batch_repeat + res.shape[-2:]) else: # TODO: Can we be smarter here? res = self._tensor.solve(right_tensor=right_tensor, left_tensor=left_tensor) if squeeze: res = res.squeeze(-1) if left_tensor is not None: res = left_tensor @ res return res def solve_triangular( self, rhs: torch.Tensor, upper: bool, left: bool = True, unitriangular: bool = False ) -> torch.Tensor: if upper != self.upper: raise RuntimeError( f"Incompatible argument: {self.__class__.__name__}.solve_triangular called with 'upper={upper}', " f"but 'LinearOperator' has 'upper={self.upper}'." ) if not left: raise NotImplementedError( f"Argument 'left=False' not yet supported for {self.__class__.__name__}.solve_triangular." ) if unitriangular: raise NotImplementedError( f"Argument 'unitriangular=True' not yet supported for {self.__class__.__name__}.solve_triangular." ) return self.solve(right_tensor=rhs)