Source code for linear_operator.operators.mul_linear_operator

#!/usr/bin/env python3

import torch

from ..utils.broadcasting import _matmul_broadcast_shape
from ..utils.memoize import cached
from ._linear_operator import LinearOperator
from .root_linear_operator import RootLinearOperator


[docs]class MulLinearOperator(LinearOperator): def _check_args(self, left_linear_op, right_linear_op): if not isinstance(left_linear_op, LinearOperator) or not isinstance(right_linear_op, LinearOperator): return "MulLinearOperator expects two LinearOperators." if left_linear_op.shape != right_linear_op.shape: return "MulLinearOperator expects two LinearOperators of the same size: got {} and {}.".format( left_linear_op, right_linear_op ) def __init__(self, left_linear_op, right_linear_op): """ Args: - linear_ops (A list of LinearOperator) - A list of LinearOperator to multiplicate with. """ if left_linear_op._root_decomposition_size() < right_linear_op._root_decomposition_size(): left_linear_op, right_linear_op = right_linear_op, left_linear_op if not isinstance(left_linear_op, RootLinearOperator): left_linear_op = left_linear_op.root_decomposition() if not isinstance(right_linear_op, RootLinearOperator): right_linear_op = right_linear_op.root_decomposition() super().__init__(left_linear_op, right_linear_op) self.left_linear_op = left_linear_op self.right_linear_op = right_linear_op def _diagonal(self): res = self.left_linear_op._diagonal() * self.right_linear_op._diagonal() return res def _get_indices(self, row_index, col_index, *batch_indices): left_res = self.left_linear_op._get_indices(row_index, col_index, *batch_indices) right_res = self.right_linear_op._get_indices(row_index, col_index, *batch_indices) return left_res * right_res def _matmul(self, rhs): output_shape = _matmul_broadcast_shape(self.shape, rhs.shape) output_batch_shape = output_shape[:-2] is_vector = False if rhs.ndimension() == 1: rhs = rhs.unsqueeze(1) is_vector = True # Here we have a root decomposition if isinstance(self.left_linear_op, RootLinearOperator): left_root = self.left_linear_op.root.to_dense() left_res = rhs.unsqueeze(-2) * left_root.unsqueeze(-1) rank = left_root.size(-1) n = self.size(-1) m = rhs.size(-1) # Now implement the formula (A . B) v = diag(A D_v B) left_res = left_res.view(*output_batch_shape, n, rank * m) left_res = self.right_linear_op._matmul(left_res) left_res = left_res.view(*output_batch_shape, n, rank, m) res = left_res.mul_(left_root.unsqueeze(-1)).sum(-2) # This is the case where we're not doing a root decomposition, because the matrix is too small else: # Dead? res = (self.left_linear_op.to_dense() * self.right_linear_op.to_dense()).matmul(rhs) res = res.squeeze(-1) if is_vector else res return res def _mul_constant(self, constant): if constant > 0: res = self.__class__(self.left_linear_op._mul_constant(constant), self.right_linear_op) else: # Negative constants can screw up the root_decomposition # So we'll do a standard _mul_constant res = super()._mul_constant(constant) return res def _bilinear_derivative(self, left_vecs, right_vecs): if left_vecs.ndimension() == 1: left_vecs = left_vecs.unsqueeze(1) right_vecs = right_vecs.unsqueeze(1) *batch_shape, n, num_vecs = left_vecs.size() if isinstance(self.right_linear_op, RootLinearOperator): right_root = self.right_linear_op.root.to_dense() left_factor = left_vecs.unsqueeze(-2) * right_root.unsqueeze(-1) right_factor = right_vecs.unsqueeze(-2) * right_root.unsqueeze(-1) right_rank = right_root.size(-1) else: right_rank = n eye = torch.eye(n, dtype=self.right_linear_op.dtype, device=self.right_linear_op.device) left_factor = left_vecs.unsqueeze(-2) * self.right_linear_op.to_dense().unsqueeze(-1) right_factor = right_vecs.unsqueeze(-2) * eye.unsqueeze(-1) left_factor = left_factor.view(*batch_shape, n, num_vecs * right_rank) right_factor = right_factor.view(*batch_shape, n, num_vecs * right_rank) left_deriv_args = self.left_linear_op._bilinear_derivative(left_factor, right_factor) if isinstance(self.left_linear_op, RootLinearOperator): left_root = self.left_linear_op.root.to_dense() left_factor = left_vecs.unsqueeze(-2) * left_root.unsqueeze(-1) right_factor = right_vecs.unsqueeze(-2) * left_root.unsqueeze(-1) left_rank = left_root.size(-1) else: left_rank = n eye = torch.eye(n, dtype=self.left_linear_op.dtype, device=self.left_linear_op.device) left_factor = left_vecs.unsqueeze(-2) * self.left_linear_op.to_dense().unsqueeze(-1) right_factor = right_vecs.unsqueeze(-2) * eye.unsqueeze(-1) left_factor = left_factor.view(*batch_shape, n, num_vecs * left_rank) right_factor = right_factor.view(*batch_shape, n, num_vecs * left_rank) right_deriv_args = self.right_linear_op._bilinear_derivative(left_factor, right_factor) return tuple(list(left_deriv_args) + list(right_deriv_args)) def _expand_batch(self, batch_shape): return self.__class__( self.left_linear_op._expand_batch(batch_shape), self.right_linear_op._expand_batch(batch_shape) ) @cached def to_dense(self): return self.left_linear_op.to_dense() * self.right_linear_op.to_dense() def _size(self): return self.left_linear_op.size() def _transpose_nonbatch(self): # mul.linear_op only works with symmetric matrices return self
[docs] def representation(self): """ Returns the Tensors that are used to define the LinearOperator """ res = super(MulLinearOperator, self).representation() return res
def representation_tree(self): return super(MulLinearOperator, self).representation_tree()