Writing Your Own LinearOpeators
In order to define a new LinearOperator class, a user must define at a minimum the following methods (in each example, \(\mathbf A\) denotes the matrix that the LinearOperator represents)
_matmul()
, which performs a matrix multiplication \(\mathbf {AB}\)_size()
, which returns atorch.Size
containing the dimensions of \(\mathbf A\)._transpose_nonbatch()
, which returns a transposed version of the LinearOperator
In addition to these, the following methods should be implemented for maximum efficiency
_bilinear_derivative()
, which computes the derivative of a quadratic form with the LinearOperator’s representation (e.g. \(\partial (\mathbf b^T \mathbf A(\boldsymbol \theta) \mathbf c) / \partial \boldsymbol \theta\))._get_indices()
, which returns atorch.Tensor
containing elements that are given by various tensor indices._expand_batch()
, which expands the batch dimensions of LinearOperators._check_args()
, which performs error checking on the arguments supplied to the LinearOperator constructor.
In addition to these, a LinearOperator may need to define the following functions if it does anything interesting
with the batch dimensions (e.g. sums along them, adds additional ones, etc):
_unsqueeze_batch()
,
_getitem()
, and
_permute_batch()
.
See the documentation for these methods for details.
Note
The base LinearOperator class provides default implementations of many
other operations in order to mimic the behavior of a standard tensor as
closely as possible. For example, we provide default implementations of
__getitem__()
,
__add__()
, etc that either
make use of other linear operators or exploit the functions that must
be defined above.
Rather than overriding the public methods, we recommend that you override
the private versions associated with these methods (e.g. - write a custom
_getitem()
verses a custom __getitem__()
). This is because the public
methods do quite a bit of error checking and casing that doesn’t need to be
repeated.