Source code for linear_operator.operators.dense_linear_operator

#!/usr/bin/env python3

from __future__ import annotations

from typing import List, Optional, Tuple, Union

import torch
from jaxtyping import Float
from torch import Tensor

from linear_operator.operators._linear_operator import IndexType, LinearOperator, to_dense


class DenseLinearOperator(LinearOperator):
    def _check_args(self, tsr):
        if not torch.is_tensor(tsr):
            return "DenseLinearOperator must take a torch.Tensor; got {}".format(tsr.__class__.__name__)
        if tsr.dim() < 2:
            return "DenseLinearOperator expects a matrix (or batches of matrices) - got a Tensor of size {}.".format(
                tsr.shape
            )

    def __init__(self, tsr):
        """
        Not a lazy tensor

        Args:
        - tsr (Tensor: matrix) a Tensor
        """
        super().__init__(tsr)
        self.tensor = tsr

    def _cholesky_solve(
        self: Float[LinearOperator, "*batch N N"],
        rhs: Union[Float[LinearOperator, "*batch2 N M"], Float[Tensor, "*batch2 N M"]],
        upper: Optional[bool] = False,
    ) -> Union[Float[LinearOperator, "... N M"], Float[Tensor, "... N M"]]:
        return torch.cholesky_solve(rhs, self.to_dense(), upper=upper)

    def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
        return self.tensor.diagonal(dim1=-1, dim2=-2)

    def _expand_batch(
        self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]]
    ) -> Float[LinearOperator, "... M N"]:
        return self.__class__(self.tensor.expand(*batch_shape, *self.matrix_shape))

    def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor:
        # Perform the __getitem__
        res = self.tensor[(*batch_indices, row_index, col_index)]
        return res

    def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator:
        # Perform the __getitem__
        res = self.tensor[(*batch_indices, row_index, col_index)]
        return self.__class__(res)

    def _isclose(self, other, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Tensor:
        return torch.isclose(self.tensor, to_dense(other), rtol=rtol, atol=atol, equal_nan=equal_nan)

    def _matmul(
        self: Float[LinearOperator, "*batch M N"],
        rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
    ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
        return torch.matmul(self.tensor, rhs)

    def _prod_batch(self, dim: int) -> LinearOperator:
        return self.__class__(self.tensor.prod(dim))

    def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]:
        res = left_vecs.matmul(right_vecs.mT)
        return (res,)

    def _size(self) -> torch.Size:
        return self.tensor.size()

    def _sum_batch(self, dim: int) -> LinearOperator:
        return self.__class__(self.tensor.sum(dim))

    def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]:
        return DenseLinearOperator(self.tensor.mT)

    def _t_matmul(
        self: Float[LinearOperator, "*batch M N"],
        rhs: Union[Float[Tensor, "*batch2 M P"], Float[LinearOperator, "*batch2 M P"]],
    ) -> Union[Float[LinearOperator, "... N P"], Float[Tensor, "... N P"]]:
        return torch.matmul(self.tensor.mT, rhs)

    def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]:
        return self.tensor

    def __add__(
        self: Float[LinearOperator, "... #M #N"],
        other: Union[Float[Tensor, "... #M #N"], Float[LinearOperator, "... #M #N"], float],
    ) -> Union[Float[LinearOperator, "... M N"], Float[Tensor, "... M N"]]:
        if isinstance(other, DenseLinearOperator):
            return DenseLinearOperator(self.tensor + other.tensor)
        elif isinstance(other, torch.Tensor):
            return DenseLinearOperator(self.tensor + other)
        else:
            return super().__add__(other)


[docs]def to_linear_operator(obj: Union[torch.Tensor, LinearOperator]) -> LinearOperator: """ A function which ensures that `obj` is a LinearOperator. - If `obj` is a LinearOperator, this function does nothing. - If `obj` is a (normal) Tensor, this function wraps it with a `DenseLinearOperator`. """ if torch.is_tensor(obj): return DenseLinearOperator(obj) elif isinstance(obj, LinearOperator): return obj else: raise TypeError("object of class {} cannot be made into a LinearOperator".format(obj.__class__.__name__))
__all__ = ["DenseLinearOperator", "to_linear_operator"]