#!/usr/bin/env python3
from typing import List, Optional, Tuple, Union
import torch
from jaxtyping import Float
from torch import Tensor
from linear_operator.operators._linear_operator import IndexType, LinearOperator
from linear_operator.operators.linear_operator_representation_tree import LinearOperatorRepresentationTree
from linear_operator.operators.root_linear_operator import RootLinearOperator
from linear_operator.utils.broadcasting import _matmul_broadcast_shape
from linear_operator.utils.memoize import cached
[docs]
class MulLinearOperator(LinearOperator):
def _check_args(self, left_linear_op, right_linear_op):
if not isinstance(left_linear_op, LinearOperator) or not isinstance(right_linear_op, LinearOperator):
return "MulLinearOperator expects two LinearOperators."
if left_linear_op.shape != right_linear_op.shape:
return "MulLinearOperator expects two LinearOperators of the same size: got {} and {}.".format(
left_linear_op, right_linear_op
)
def __init__(self, left_linear_op, right_linear_op):
"""
Args:
- linear_ops (A list of LinearOperator) - A list of LinearOperator to multiplicate with.
"""
if left_linear_op._root_decomposition_size() < right_linear_op._root_decomposition_size():
left_linear_op, right_linear_op = right_linear_op, left_linear_op
if not isinstance(left_linear_op, RootLinearOperator):
left_linear_op = left_linear_op.root_decomposition()
if not isinstance(right_linear_op, RootLinearOperator):
right_linear_op = right_linear_op.root_decomposition()
super().__init__(left_linear_op, right_linear_op)
self.left_linear_op = left_linear_op
self.right_linear_op = right_linear_op
def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
res = self.left_linear_op._diagonal() * self.right_linear_op._diagonal()
return res
def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor:
left_res = self.left_linear_op._get_indices(row_index, col_index, *batch_indices)
right_res = self.right_linear_op._get_indices(row_index, col_index, *batch_indices)
return left_res * right_res
def _matmul(
self: Float[LinearOperator, "*batch M N"],
rhs: Union[Float[torch.Tensor, "*batch2 N C"], Float[torch.Tensor, "*batch2 N"]],
) -> Union[Float[torch.Tensor, "... M C"], Float[torch.Tensor, "... M"]]:
output_shape = _matmul_broadcast_shape(self.shape, rhs.shape)
output_batch_shape = output_shape[:-2]
is_vector = False
if rhs.ndimension() == 1:
rhs = rhs.unsqueeze(1)
is_vector = True
# Here we have a root decomposition
if isinstance(self.left_linear_op, RootLinearOperator):
left_root = self.left_linear_op.root.to_dense()
left_res = rhs.unsqueeze(-2) * left_root.unsqueeze(-1)
rank = left_root.size(-1)
n = self.size(-1)
m = rhs.size(-1)
# Now implement the formula (A . B) v = diag(A D_v B)
left_res = left_res.view(*output_batch_shape, n, rank * m)
left_res = self.right_linear_op._matmul(left_res)
left_res = left_res.view(*output_batch_shape, n, rank, m)
res = left_res.mul_(left_root.unsqueeze(-1)).sum(-2)
# This is the case where we're not doing a root decomposition, because the matrix is too small
else: # Dead?
res = (self.left_linear_op.to_dense() * self.right_linear_op.to_dense()).matmul(rhs)
res = res.squeeze(-1) if is_vector else res
return res
def _mul_constant(
self: Float[LinearOperator, "*batch M N"], other: Union[float, torch.Tensor]
) -> Float[LinearOperator, "*batch M N"]:
if other > 0:
res = self.__class__(self.left_linear_op._mul_constant(other), self.right_linear_op)
else:
# Negative constants can screw up the root_decomposition
# So we'll do a standard _mul_constant
res = super()._mul_constant(other)
return res
def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> Tuple[Optional[Tensor], ...]:
if left_vecs.ndimension() == 1:
left_vecs = left_vecs.unsqueeze(1)
right_vecs = right_vecs.unsqueeze(1)
*batch_shape, n, num_vecs = left_vecs.size()
if isinstance(self.right_linear_op, RootLinearOperator):
right_root = self.right_linear_op.root.to_dense()
left_factor = left_vecs.unsqueeze(-2) * right_root.unsqueeze(-1)
right_factor = right_vecs.unsqueeze(-2) * right_root.unsqueeze(-1)
right_rank = right_root.size(-1)
else:
right_rank = n
eye = torch.eye(n, dtype=self.right_linear_op.dtype, device=self.right_linear_op.device)
left_factor = left_vecs.unsqueeze(-2) * self.right_linear_op.to_dense().unsqueeze(-1)
right_factor = right_vecs.unsqueeze(-2) * eye.unsqueeze(-1)
left_factor = left_factor.view(*batch_shape, n, num_vecs * right_rank)
right_factor = right_factor.view(*batch_shape, n, num_vecs * right_rank)
left_deriv_args = self.left_linear_op._bilinear_derivative(left_factor, right_factor)
if isinstance(self.left_linear_op, RootLinearOperator):
left_root = self.left_linear_op.root.to_dense()
left_factor = left_vecs.unsqueeze(-2) * left_root.unsqueeze(-1)
right_factor = right_vecs.unsqueeze(-2) * left_root.unsqueeze(-1)
left_rank = left_root.size(-1)
else:
left_rank = n
eye = torch.eye(n, dtype=self.left_linear_op.dtype, device=self.left_linear_op.device)
left_factor = left_vecs.unsqueeze(-2) * self.left_linear_op.to_dense().unsqueeze(-1)
right_factor = right_vecs.unsqueeze(-2) * eye.unsqueeze(-1)
left_factor = left_factor.view(*batch_shape, n, num_vecs * left_rank)
right_factor = right_factor.view(*batch_shape, n, num_vecs * left_rank)
right_deriv_args = self.right_linear_op._bilinear_derivative(left_factor, right_factor)
return tuple(list(left_deriv_args) + list(right_deriv_args))
def _expand_batch(
self: Float[LinearOperator, "... M N"], batch_shape: Union[torch.Size, List[int]]
) -> Float[LinearOperator, "... M N"]:
return self.__class__(
self.left_linear_op._expand_batch(batch_shape), self.right_linear_op._expand_batch(batch_shape)
)
@cached
def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]:
return self.left_linear_op.to_dense() * self.right_linear_op.to_dense()
def _size(self) -> torch.Size:
return self.left_linear_op.size()
def _transpose_nonbatch(self: Float[LinearOperator, "*batch M N"]) -> Float[LinearOperator, "*batch N M"]:
# mul.linear_op only works with symmetric matrices
return self
[docs]
def representation(self) -> Tuple[torch.Tensor, ...]:
"""
Returns the Tensors that are used to define the LinearOperator
"""
res = super(MulLinearOperator, self).representation()
return res
def representation_tree(self) -> LinearOperatorRepresentationTree:
return super(MulLinearOperator, self).representation_tree()