Source code for linear_operator.operators.masked_linear_operator

from __future__ import annotations

import torch
from torch import Tensor

from linear_operator.operators._linear_operator import _is_noop_index, IndexType, LinearOperator

from linear_operator.utils.generic import _to_helper


[docs] class MaskedLinearOperator(LinearOperator): r""" A :obj:`~linear_operator.operators.LinearOperator` that applies a mask to the rows and columns of a base :obj:`~linear_operator.operators.LinearOperator`. """ def __init__( self, base: LinearOperator, # shape: (*batch, M0, N0) row_mask: Tensor, # shape: (M0) col_mask: Tensor, # shape: (N0) ): r""" Create a new :obj:`~linear_operator.operators.MaskedLinearOperator` that applies a mask to the rows and columns of a base :obj:`~linear_operator.operators.LinearOperator`. :param base: The base :obj:`~linear_operator.operators.LinearOperator`. :param row_mask: A :obj:`torch.BoolTensor` containing the mask to apply to the rows. :param col_mask: A :obj:`torch.BoolTensor` containing the mask to apply to the columns. """ super().__init__(base, row_mask, col_mask) self.base = base self.row_mask = row_mask self.col_mask = col_mask self.row_eq_col_mask = torch.equal(row_mask, col_mask) @staticmethod def _expand( tensor: Tensor, # shape: (*batch, N, C) mask: Tensor, # shape: (N0) ) -> Tensor: # shape: (*batch, N0, C) res = torch.zeros( *tensor.shape[:-2], mask.size(-1), tensor.size(-1), device=tensor.device, dtype=tensor.dtype, ) res[..., mask, :] = tensor return res def _matmul( self: LinearOperator, # shape: (*batch, M, N) rhs: torch.Tensor, # shape: (*batch2, N, C) or (*batch2, N) ) -> torch.Tensor: # shape: (..., M, C) or (..., M) rhs_expanded = self._expand(rhs, self.col_mask) res_expanded = self.base._matmul(rhs_expanded) res = res_expanded[..., self.row_mask, :] return res def _t_matmul( self: LinearOperator, # shape: (*batch, M, N) rhs: Tensor | LinearOperator, # shape: (*batch2, M, P) ) -> LinearOperator | Tensor: # shape: (..., N, P) rhs_expanded = self._expand(rhs, self.row_mask) res_expanded = self.base._t_matmul(rhs_expanded) res = res_expanded[..., self.col_mask, :] return res def _size(self) -> torch.Size: return torch.Size( (*self.base.size()[:-2], torch.count_nonzero(self.row_mask), torch.count_nonzero(self.col_mask)) ) def _transpose_nonbatch( self: LinearOperator, # shape: (*batch, M, N) ) -> LinearOperator: # shape: (*batch, N, M) return self.__class__(self.base.mT, self.col_mask, self.row_mask) def _diagonal( self: LinearOperator, # shape: (..., M, N) ) -> torch.Tensor: # shape: (..., N) if not self.row_eq_col_mask: raise NotImplementedError() diag = self.base.diagonal() return diag[..., self.row_mask] def to_dense( self: LinearOperator, # shape: (*batch, M, N) ) -> Tensor: # shape: (*batch, M, N) full_dense = self.base.to_dense() return full_dense[..., self.row_mask, :][..., :, self.col_mask] def _bilinear_derivative(self, left_vecs: Tensor, right_vecs: Tensor) -> tuple[Tensor | None, ...]: left_vecs = self._expand(left_vecs, self.row_mask) right_vecs = self._expand(right_vecs, self.col_mask) return self.base._bilinear_derivative(left_vecs, right_vecs) + (None, None) def _expand_batch( self: LinearOperator, batch_shape: torch.Size | list[int] # shape: (..., M, N) ) -> LinearOperator: # shape: (..., M, N) return self.__class__(self.base._expand_batch(batch_shape), self.row_mask, self.col_mask) def _unsqueeze_batch(self, dim: int) -> LinearOperator: return self.__class__(self.base._unsqueeze_batch(dim), self.row_mask, self.col_mask) def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator: if _is_noop_index(row_index) and _is_noop_index(col_index): if len(batch_indices): return self.__class__(self.base[batch_indices], self.row_mask, self.col_mask) else: return self else: return super()._getitem(row_index, col_index, *batch_indices) def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: row_mapping = torch.arange(self.base.size(-2), device=self.base.device)[self.row_mask] col_mapping = torch.arange(self.base.size(-1), device=self.base.device)[self.col_mask] return self.base._get_indices(row_mapping[row_index], col_mapping[col_index], *batch_indices) def _permute_batch(self, *dims: int) -> LinearOperator: return self.__class__(self.base._permute_batch(*dims), self.row_mask, self.col_mask) def to( self: LinearOperator, # shape: (*batch, M, N) *args, **kwargs, ) -> LinearOperator: # shape: (*batch, M, N) # Overwrite the to() method in _linear_operator to avoid converting mask matrices to float. # Will only convert both dtype and device when arg's dtype is not torch.bool. # Otherwise, will only convert device. device, dtype = _to_helper(*args, **kwargs) new_args = [] new_kwargs = {} for arg in self._args: if hasattr(arg, "to"): if hasattr(arg, "dtype") and arg.dtype.is_floating_point == dtype.is_floating_point: new_args.append(arg.to(dtype=dtype, device=device)) else: new_args.append(arg.to(device=device)) else: new_args.append(arg) for name, val in self._kwargs.items(): if hasattr(val, "to"): new_kwargs[name] = val.to(dtype=dtype, device=device) else: new_kwargs[name] = val return self.__class__(*new_args, **new_kwargs)