Source code for linear_operator.operators.mul_linear_operator

#!/usr/bin/env python3
from typing import List, Optional, Tuple, Union

import torch
from jaxtyping import Float
from torch import Tensor

from ..utils.broadcasting import _matmul_broadcast_shape
from ..utils.memoize import cached
from ._linear_operator import IndexType, LinearOperator
from .linear_operator_representation_tree import LinearOperatorRepresentationTree
from .root_linear_operator import RootLinearOperator


[docs]class MulLinearOperator(LinearOperator): def _check_args(self, left_linear_op, right_linear_op): if not isinstance(left_linear_op, LinearOperator) or not isinstance(right_linear_op, LinearOperator): return "MulLinearOperator expects two LinearOperators." if left_linear_op.shape != right_linear_op.shape: return "MulLinearOperator expects two LinearOperators of the same size: got {} and {}.".format( left_linear_op, right_linear_op ) def __init__(self, left_linear_op, right_linear_op): """ Args: - linear_ops (A list of LinearOperator) - A list of LinearOperator to multiplicate with. """ if left_linear_op._root_decomposition_size() < right_linear_op._root_decomposition_size(): left_linear_op, right_linear_op = right_linear_op, left_linear_op if not isinstance(left_linear_op, RootLinearOperator): left_linear_op = left_linear_op.root_decomposition() if not isinstance(right_linear_op, RootLinearOperator): right_linear_op = right_linear_op.root_decomposition() super().__init__(left_linear_op, right_linear_op) self.left_linear_op = left_linear_op self.right_linear_op = right_linear_op def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]: res = self.left_linear_op._diagonal() * self.right_linear_op._diagonal() return res def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: left_res = self.left_linear_op._get_indices(row_index, col_index, *batch_indices) right_res = self.right_linear_op._get_indices(row_index, col_index, *batch_indices) return left_res * right_res def _matmul( self: Float[LinearOperator, "*batch M N"], rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]], ) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]: output_shape = _matmul_broadcast_shape(self.shape, rhs.shape) output_batch_shape = output_shape[:-2] is_vector = False if rhs.ndimension() == 1: rhs = rhs.unsqueeze(1) is_vector = True # Here we have a root decomposition if isinstance(self.left_linear_op, RootLinearOperator): left_root = self.left_linear_op.root.to_dense() left_res = rhs.unsqueeze(-2) * left_root.unsqueeze(-1) rank = left_root.size(-1) n = self.size(-1) m = rhs.size(-1) # Now implement the formula (A . B) v = diag(A D_v B) left_res = left_res.view(*output_batch_shape, n, rank * m) left_res = self.right_linear_op._matmul(left_res) left_res = left_res.view(*output_batch_shape, n, rank, m) res = left_res.mul_(left_root.unsqueeze(-1)).sum(-2) # This is the case where we're not doing a root decomposition, because the matrix is too small else: # Dead? res = (self.left_linear_op.to_dense() * self.right_linear_op.to_dense()).matmul(rhs) res = res.squeeze(-1) if is_vector else res return res def _mul_constant( self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor] ) -> Float[LinearOperator, "*batch M N"]: if other > 0: res = self.__class__(self.left_linear_op._mul_constant(other), self.right_linear_op) else: # Negative constants can screw up the root_decomposition # So we'll do a standard _mul_constant res = super()._mul_constant(other) return res def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]: if left_vecs.ndimension() == 1: left_vecs = left_vecs.unsqueeze(1) right_vecs = right_vecs.unsqueeze(1) *batch_shape, n, num_vecs = left_vecs.size() if isinstance(self.right_linear_op, RootLinearOperator): right_root = self.right_linear_op.root.to_dense() left_factor = left_vecs.unsqueeze(-2) * right_root.unsqueeze(-1) right_factor = right_vecs.unsqueeze(-2) * right_root.unsqueeze(-1) right_rank = right_root.size(-1) else: right_rank = n eye = torch.eye(n, dtype=self.right_linear_op.dtype, device=self.right_linear_op.device) left_factor = left_vecs.unsqueeze(-2) * self.right_linear_op.to_dense().unsqueeze(-1) right_factor = right_vecs.unsqueeze(-2) * eye.unsqueeze(-1) left_factor = left_factor.view(*batch_shape, n, num_vecs * right_rank) right_factor = right_factor.view(*batch_shape, n, num_vecs * right_rank) left_deriv_args = self.left_linear_op._bilinear_derivative(left_factor, right_factor) if isinstance(self.left_linear_op, RootLinearOperator): left_root = self.left_linear_op.root.to_dense() left_factor = left_vecs.unsqueeze(-2) * left_root.unsqueeze(-1) right_factor = right_vecs.unsqueeze(-2) * left_root.unsqueeze(-1) left_rank = left_root.size(-1) else: left_rank = n eye = torch.eye(n, dtype=self.left_linear_op.dtype, device=self.left_linear_op.device) left_factor = left_vecs.unsqueeze(-2) * self.left_linear_op.to_dense().unsqueeze(-1) right_factor = right_vecs.unsqueeze(-2) * eye.unsqueeze(-1) left_factor = left_factor.view(*batch_shape, n, num_vecs * left_rank) right_factor = right_factor.view(*batch_shape, n, num_vecs * left_rank) right_deriv_args = self.right_linear_op._bilinear_derivative(left_factor, right_factor) return tuple(list(left_deriv_args) + list(right_deriv_args)) def _expand_batch( self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]] ) -> Float[LinearOperator, "... M N"]: return self.__class__( self.left_linear_op._expand_batch(batch_shape), self.right_linear_op._expand_batch(batch_shape) ) @cached def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]: return self.left_linear_op.to_dense() * self.right_linear_op.to_dense() def _size(self) -> torch.Size: return self.left_linear_op.size() def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]: # mul.linear_op only works with symmetric matrices return self
[docs] def representation(self) -> Tuple[torch.Tensor, ...]: """ Returns the Tensors that are used to define the LinearOperator """ res = super(MulLinearOperator, self).representation() return res
def representation_tree(self) -> LinearOperatorRepresentationTree: return super(MulLinearOperator, self).representation_tree()