Using LinearOperator Objects

LinearOperator objects share (mostly) the same API as torch.Tensor objects. Under the hood, these objects use __torch_function__ to dispatch all efficient linear algebra operations to the torch and torch.linalg namespaces. This includes

  • torch.add

  • torch.cat

  • torch.clone

  • torch.diagonal

  • torch.dim

  • torch.div

  • torch.expand

  • torch.logdet

  • torch.matmul

  • torch.numel

  • torch.permute

  • torch.prod

  • torch.squeeze

  • torch.sub

  • torch.sum

  • torch.transpose

  • torch.unsqueeze

  • torch.linalg.cholesky

  • torch.linalg.eigh

  • torch.linalg.eigvalsh

  • torch.linalg.solve

  • torch.linalg.svd

Each of these functions will either return a torch.Tensor, or a new LinearOperator object, depending on the function. For example:

# A = RootLinearOperator(...)
# B = ToeplitzLinearOperator(...)
# d = vec

C = torch.matmul(A, B)  # A new LienearOperator representing the product of A and B
torch.linalg.solve(C, d)  # A torch.Tensor

For more examples, see the examples folder.

Batch Support and Broadcasting

LinearOperator objects operate naturally in batch mode. For example, to represent a batch of 3 100 x 100 diagonal matrices:

# d = torch.randn(3, 100)
D = DiagLinearOperator(d)  # Reprents an operator of size 3 x 100 x 100

These objects fully support broadcasted operations:

D @ torch.randn(100, 2)  # Returns a tensor of size 3 x 100 x 2

D2 = DiagLinearOperator(torch.randn([2, 1, 100]))  # Represents an operator of size 2 x 1 x 100 x 100
D2 + D  # Represents an operator of size 2 x 3 x 100 x 100

Indexing

LinearOperator objects can be indexed in ways similar to torch Tensors. This includes:

  • Integer indexing (get a row, column, or batch)

  • Slice indexing (get a subset of rows, columns, or batches)

  • LongTensor indexing (get a set of individual entries by index)

  • Ellipses (support indexing operations with arbitrary batch dimensions)

D = DiagLinearOperator(torch.randn(2, 3, 100))  # Represents an operator of size 2 x 3 x 100 x 100
D[-1]  # Returns a 3 x 100 x 100 operator
D[..., :10, -5:]  # Returns a 2 x 3 x 10 x 5 operator
D[..., torch.LongTensor([0, 1, 2, 3]), torch.LongTensor([0, 1, 2, 3])]  # Returns a 2 x 3 x 4 tensor

Composition and Decoration

LinearOperators can be composed with one another in various ways. This includes

  • Addition (LinearOpA + LinearOpB)

  • Matrix multiplication (LinearOpA @ LinearOpB)

  • Concatenation (torch.cat([LinearOpA, LinearOpB], dim=-2))

  • Kronecker product (torch.kron(LinearOpA, LinearOpB))

In addition, there are many ways to “decorate” LinearOperator objects. This includes:

  • Elementwise multiplying by constants (torch.mul(2., LinearOpA))

  • Summing over batches (torch.sum(LinearOpA, dim=-3))

  • Elementwise multiplying over batches (torch.prod(LinearOpA, dim=-3))

See the documentation for a full list of supported composition and decoration operations.