import math
import warnings
import torch
from linear_operator import settings
from linear_operator.utils.linear_cg import linear_cg
from linear_operator.utils.minres import minres
from linear_operator.utils.warnings import NumericalWarning
[docs]
def contour_integral_quad(
linear_op,
rhs,
inverse=False,
weights=None,
shifts=None,
max_lanczos_iter=20,
num_contour_quadrature=None,
shift_offset=0,
):
r"""
Performs :math:`\mathbf K^{1/2} \mathbf b` or :math:`\mathbf K^{-1/2} \mathbf b`
using contour integral quadrature.
:param linear_operator.operators.LinearOperator linear_op: LinearOperator representing :math:`\mathbf K`
:param torch.Tensor rhs: Right hand side tensor :math:`\mathbf b`
:param bool inverse: (default False) whether to compute :math:`\mathbf K^{1/2} \mathbf b` (if False)
or `\mathbf K^{-1/2} \mathbf b` (if True)
:param int max_lanczos_iter: (default 10) Number of Lanczos iterations to run (to estimate eigenvalues)
:param int num_contour_quadrature: How many quadrature samples to use for approximation. Default is in settings.
:rtype: torch.Tensor
:return: Approximation to :math:`\mathbf K^{1/2} \mathbf b` or :math:`\mathbf K^{-1/2} \mathbf b`.
"""
import numpy as np
from scipy.special import ellipj, ellipk
if num_contour_quadrature is None:
num_contour_quadrature = settings.num_contour_quadrature.value()
output_batch_shape = torch.broadcast_shapes(linear_op.batch_shape, rhs.shape[:-2])
preconditioner, preconditioner_lt, _ = linear_op._preconditioner()
def sqrt_precond_matmul(rhs):
if preconditioner_lt is not None:
solves, weights, _, _ = contour_integral_quad(preconditioner_lt, rhs, inverse=False)
return (solves * weights).sum(0)
else:
return rhs
# if not inverse:
rhs = sqrt_precond_matmul(rhs)
if shifts is None:
# Determine if init_vecs has extra_dimensions
num_extra_dims = max(0, rhs.dim() - linear_op.dim())
lanczos_init = rhs.__getitem__(
(
*([0] * num_extra_dims),
Ellipsis,
slice(None, None, None),
slice(None, 1, None),
)
).expand(*linear_op.shape[:-1], 1)
with warnings.catch_warnings(), torch.no_grad():
warnings.simplefilter("ignore", NumericalWarning) # Supress CG stopping warning
_, lanczos_mat = linear_cg(
lambda v: linear_op._matmul(v),
rhs=lanczos_init,
n_tridiag=1,
max_iter=max_lanczos_iter,
tolerance=1e-5,
max_tridiag_iter=max_lanczos_iter,
preconditioner=preconditioner,
)
lanczos_mat = lanczos_mat.squeeze(0) # We have an extra singleton batch dimension from the Lanczos init
r"""
K^{-1/2} b = 2/pi \int_0^\infty (K - t^2 I)^{-1} dt
We'll approximate this integral as a sum using quadrature
We'll determine the appropriate values of t, as well as their weights using elliptical integrals
"""
# Compute an approximate condition number
# We'll do this with Lanczos
try:
if settings.verbose_linalg.on():
settings.verbose_linalg.logger.debug(
f"Running torch.linalg.eigvalsh on a matrix of size {lanczos_mat.shape}."
)
approx_eigs = torch.linalg.eigvalsh(lanczos_mat)
if approx_eigs.min() <= 0:
raise RuntimeError
except RuntimeError:
approx_eigs = linear_op._diagonal()
max_eig = approx_eigs.max(dim=-1)[0]
min_eig = approx_eigs.min(dim=-1)[0]
k2 = min_eig / max_eig
# Compute the shifts needed for the contour
flat_shifts = torch.zeros(num_contour_quadrature + 1, k2.numel(), dtype=k2.dtype, device=k2.device)
flat_weights = torch.zeros(num_contour_quadrature, k2.numel(), dtype=k2.dtype, device=k2.device)
# For loop because numpy
for i, (sub_k2, sub_min_eig) in enumerate(zip(k2.flatten().tolist(), min_eig.flatten().tolist())):
# Compute shifts
Kp = ellipk(1 - sub_k2) # Elliptical integral of the first kind
N = num_contour_quadrature
t = 1j * (np.arange(1, N + 1) - 0.5) * Kp / N
sn, cn, dn, _ = ellipj(np.imag(t), 1 - sub_k2) # Jacobi elliptic functions
cn = 1.0 / cn
dn = dn * cn
sn = 1j * sn * cn
w = np.sqrt(sub_min_eig) * sn
w_pow2 = np.real(np.power(w, 2))
sub_shifts = torch.tensor(w_pow2, dtype=rhs.dtype, device=rhs.device)
# Compute weights
constant = -2 * Kp * np.sqrt(sub_min_eig) / (math.pi * N)
dzdt = torch.tensor(cn * dn, dtype=rhs.dtype, device=rhs.device)
dzdt.mul_(constant)
sub_weights = dzdt
# Store results
flat_shifts[1:, i].copy_(sub_shifts)
flat_weights[:, i].copy_(sub_weights)
weights = flat_weights.view(num_contour_quadrature, *k2.shape, 1, 1)
shifts = flat_shifts.view(num_contour_quadrature + 1, *k2.shape)
shifts.sub_(shift_offset)
# Make sure we have the right shape
if k2.shape != output_batch_shape:
weights = torch.stack([w.expand(*output_batch_shape, 1, 1) for w in weights], 0)
shifts = torch.stack([s.expand(output_batch_shape) for s in shifts], 0)
# Compute the solves at the given shifts
# Do one more matmul if we don't want to include the inverse
with torch.no_grad():
solves = minres(
lambda v: linear_op._matmul(v),
rhs,
value=-1,
shifts=shifts,
preconditioner=preconditioner,
)
no_shift_solves = solves[0]
solves = solves[1:]
if not inverse:
solves = linear_op._matmul(solves)
return solves, weights, no_shift_solves, shifts