LinearOperator
LinearOperator is a PyTorch package for abstracting away the linear algebra routines needed for structured matrices (or operators).
This package is in beta. Currently, most of the functionality only supports positive semi-definite and triangular matrices. Package development TODOs:
Support PSD operators
Support triangular operators
Interface to specify structure (i.e. symmetric, triangular, PSD, etc.)
Add algebraic routines for symmetric operators
Add algebraic routines for generic square operators
Add algebraic routines for generic rectangular operators
Add sparse operators
Why LinearOperator
Why LinearOperator
Before describing what linear operators are and why they make a useful abstraction, it’s easiest to see an example. Let’s say you wanted to compute a matrix solve:
If you didn’t know anything about the matrix \(\boldsymbol A\), the simplest (and best) way to accomplish this in code is:
# A = torch.randn(1000, 1000)
# b = torch.randn(1000)
torch.linalg.solve(A, b) # computes A^{-1} b
While this is easy, the solve
routine is \(\mathcal O(N^3)\), which gets very slow as \(N\) grows large.
However, let’s imagine that we knew that \(\boldsymbol A\) was equal to a low rank matrix plus a diagonal (i.e. \(\boldsymbol A = \boldsymbol C \boldsymbol C^\top + \boldsymbol D\) for some skinny matrix \(\boldsymbol C\) and some diagonal matrix \(\boldsymbol D\).) There’s now a very efficient \(\boldsymbol O(N)\) routine to compute \(\boldsymbol A^{-1}\) (the Woodbury formula). In general, if we know that \(\boldsymbol A\) has structure, we want to use efficient linear algebra routines - rather than the general routines - that exploit this structure.
Without LinearOperator
Implementing the efficient solve that exploits \(\boldsymbol A\)’s low-rank-plus-diagonal structure would look something like this:
def low_rank_plus_diagonal_solve(C, d, b):
# A = C C^T + diag(d)
# A^{-1} b = D^{-1} b - D^{-1} C (I + C^T D^{-1} C)^{-1} C^T D^{-1} b
# where D = diag(d)
D_inv_b = b / d
D_inv_C = C / d.unsqueeze(-1)
eye = torch.eye(C.size(-2))
return (
D_inv_b - D_inv_C @ torch.cholesky_solve(
C.mT @ D_inv_b,
torch.linalg.cholesky(eye + C.mT @ D_inv_C, upper=False),
upper=False
)
)
# C = torch.randn(1000, 20)
# d = torch.randn(1000)
# b = torch.randn(1000)
low_rank_plus_diagonal_solve(C, d, b) # computes A^{-1} b in O(N) time, instead of O(N^3)
While this is efficient code, it’s not ideal for a number of reasons:
It’s a lot more complicated than
torch.linalg.solve(A, b)
.There’s no object that represents \(\boldsymbol A\). To perform any math with \(\boldsymbol A\), we have to pass around the matrix
C
and the vectord
.
With LinearOperator
The LinearOperator package offers the best of both worlds:
from linear_operator.operators import DiagLinearOperator, LowRankRootLinearOperator
# C = torch.randn(1000, 20)
# d = torch.randn(1000)
# b = torch.randn(1000)
A = LowRankRootLinearOperator(C) + DiagLinearOperator(d) # represents C C^T + diag(d)
it provides an interface that lets us treat \(\boldsymbol A\) as if it were a generic tensor, using the standard PyTorch API:
torch.linalg.solve(A, b) # computes A^{-1} b efficiently!
Under-the-hood, the LinearOperator
object keeps track of the algebraic structure of \(\boldsymbol A\) (low rank plus diagonal)
and determines the most efficient routine to use (the Woodbury formula).
This way, we can get a efficient \(\mathcal O(N)\) solve while abstracting away all of the details.
Crucially, \(\boldsymbol A\) is never explicitly instantiated as a matrix, which makes it possible to scale to very large operators without running out of memory:
# C = torch.randn(10000000, 20)
# d = torch.randn(10000000)
# b = torch.randn(10000000)
A = LowRankRootLinearOperator(C) + DiagLinearOperator(d) # represents a 10M x 10M matrix!
torch.linalg.solve(A, b) # computes A^{-1} b efficiently!
Use Cases
There are several use cases for the LinearOperator package. Here we highlight two general themes:
Modular Code for Structured Matrices
For example, let’s say that you have a generative model that involves sampling from a high-dimensional multivariate Gaussian. This sampling operation will require storing and manipulating a large covariance matrix, so to speed things up you might want to experiment with different structured approximations of that covariance matrix. This is easy with the LinearOperator package.
from gpytorch.distributions import MultivariateNormal
# variance = torch.randn(10000)
cov = DiagLinearOperator(variance)
# or
# cov = LowRankRootLinearOperator(...) + DiagLinearOperator(...)
# or
# cov = KroneckerProductLinearOperator(...)
# or
# cov = ToeplitzLinearOperator(...)
# or
# ...
mvn = MultivariateNormal(torch.zeros(cov.size(-1), cov) # 10000-dimensional MVN
mvn.rsample() # returns a 10000-dimensional vector
Efficient Routines for Complex Operators
Many of the efficient linear algebra routines in LinearOperator are iterative algorithms based on matrix-vector multiplication. Since matrix-vector multiplication obeys many nice compositional properties it is possible to obtain efficient routines for extremely complex compositional LienarOperators:
from linear_operator.operators import KroneckerProductLinearOperator, RootLinearOperator, ToeplitzLinearOperator
# mat1 = 200 x 200 PSD matrix
# mat2 = 100 x 100 PSD matrix
# vec3 = 20000 vector
A = KroneckerProductLinearOperator(mat1, mat2) + RootLinearOperator(ToeplitzLinearOperator(vec3))
# represents a 20000 x 20000 matrix
torch.linalg.solve(A, torch.randn(20000)) # Sub O(N^3) routine!
Getting Started
Basic Concepts
Linear Operator Objects