# Source code for linear_operator.operators.sum_batch_linear_operator

```#!/usr/bin/env python3

import torch
from jaxtyping import Float
from torch import Tensor

from ..utils.getitem import _noop_index
from ._linear_operator import IndexType, LinearOperator
from .block_linear_operator import BlockLinearOperator

[docs]class SumBatchLinearOperator(BlockLinearOperator):
"""
Represents a lazy tensor that is actually the sum of several lazy tensors blocks.
The :attr:`block_dim` attribute specifies which dimension of the base LinearOperator
specifies the blocks.
For example, (with `block_dim=-3` a `k x n x n` tensor represents `k` `n x n` blocks (a `n x n` matrix).
A `b x k x n x n` tensor represents `k` `b x n x n` blocks (a `b x n x n` batch matrix).

Args:
:attr:`base_linear_op` (LinearOperator):
A `k x n x n` LinearOperator, or a `b x k x n x n` LinearOperator.
:attr:`block_dim` (int):
The dimension that specifies the blocks.
"""

shape = list(other.shape)
expand_shape = list(other.shape)
shape.insert(-2, 1)
expand_shape.insert(-2, self.base_linear_op.size(-3))
other = other.reshape(*shape).expand(*expand_shape)
return other

def _diagonal(self: Float[LinearOperator, "... M N"]) -> Float[torch.Tensor, "... N"]:
diag = self.base_linear_op._diagonal().sum(-2)
return diag

def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor:
# Create an extra index for the summed dimension
sum_index = torch.arange(0, self.base_linear_op.size(-3), device=self.device)
row_index = row_index.unsqueeze(-1)
col_index = col_index.unsqueeze(-1)
batch_indices = [index.unsqueeze(-1) for index in batch_indices]

res = self.base_linear_op._get_indices(row_index, col_index, *batch_indices, sum_index)
return res.sum(-1)

def _getitem(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> LinearOperator:
res = self.base_linear_op._getitem(row_index, col_index, *batch_indices, _noop_index)
return self.__class__(res, **self._kwargs)

def _remove_batch_dim(self, other):
return other.sum(-3)

def _size(self) -> torch.Size:
shape = list(self.base_linear_op.shape)
del shape[-3]