Writing Your Own LinearOpeators

In order to define a new LinearOperator class, a user must define at a minimum the following methods (in each example, \(\mathbf A\) denotes the matrix that the LinearOperator represents)

  • _matmul(), which performs a matrix multiplication \(\mathbf {AB}\)

  • _size(), which returns a torch.Size containing the dimensions of \(\mathbf A\).

  • _transpose_nonbatch(), which returns a transposed version of the LinearOperator

In addition to these, the following methods should be implemented for maximum efficiency

  • _bilinear_derivative(), which computes the derivative of a quadratic form with the LinearOperator’s representation (e.g. \(\partial (\mathbf b^T \mathbf A(\boldsymbol \theta) \mathbf c) / \partial \boldsymbol \theta\)).

  • _get_indices(), which returns a torch.Tensor containing elements that are given by various tensor indices.

  • _expand_batch(), which expands the batch dimensions of LinearOperators.

  • _check_args(), which performs error checking on the arguments supplied to the LinearOperator constructor.

In addition to these, a LinearOperator may need to define the following functions if it does anything interesting with the batch dimensions (e.g. sums along them, adds additional ones, etc): _unsqueeze_batch(), _getitem(), and _permute_batch(). See the documentation for these methods for details.

Note

The base LinearOperator class provides default implementations of many other operations in order to mimic the behavior of a standard tensor as closely as possible. For example, we provide default implementations of __getitem__(), __add__(), etc that either make use of other linear operators or exploit the functions that must be defined above.

Rather than overriding the public methods, we recommend that you override the private versions associated with these methods (e.g. - write a custom _getitem() verses a custom __getitem__()). This is because the public methods do quite a bit of error checking and casing that doesn’t need to be repeated.