#!/usr/bin/env python3
import torch
from linear_operator.utils.broadcasting import _matmul_broadcast_shape
[docs]
def make_sparse_from_indices_and_values(interp_indices, interp_values, num_rows):
"""
This produces a sparse tensor with a fixed number of non-zero entries in each column.
Args:
- interp_indices - Tensor (batch_size) x num_cols x n_nonzero_entries
A matrix which has the indices of the nonzero_entries for each column
- interp_values - Tensor (batch_size) x num_cols x n_nonzero_entries
The corresponding values
- num_rows - the number of rows in the result matrix
Returns:
- SparseTensor - (batch_size) x num_cols x num_rows
"""
if not torch.is_tensor(interp_indices):
raise RuntimeError("interp_indices and interp_values should be tensors")
# Is it batch mode?
batch_shape = interp_values.shape[:-2]
n_target_points, n_coefficients = interp_values.shape[-2:]
# Index tensor
batch_tensors = []
for i, batch_size in enumerate(batch_shape):
batch_tensor = torch.arange(0, batch_size, dtype=torch.long, device=interp_values.device)
batch_tensor = (
batch_tensor.unsqueeze_(1)
.repeat(
batch_shape[:i].numel(),
batch_shape[i + 1 :].numel() * n_target_points * n_coefficients,
)
.view(-1)
)
batch_tensors.append(batch_tensor)
row_tensor = torch.arange(0, n_target_points, dtype=torch.long, device=interp_values.device)
row_tensor = row_tensor.unsqueeze_(1).repeat(batch_shape.numel(), n_coefficients).view(-1)
index_tensor = torch.stack([*batch_tensors, interp_indices.reshape(-1), row_tensor], 0)
# Value tensor
value_tensor = interp_values.reshape(-1)
nonzero_indices = value_tensor.nonzero(as_tuple=False)
if nonzero_indices.storage():
nonzero_indices.squeeze_()
index_tensor = index_tensor.index_select(1, nonzero_indices)
value_tensor = value_tensor.index_select(0, nonzero_indices)
else:
index_tensor = index_tensor.resize_(interp_indices.dim(), 1).zero_()
value_tensor = value_tensor.resize_(1).zero_()
# Make the sparse tensor
type_name = value_tensor.type().split(".")[-1] # e.g. FloatTensor
interp_size = torch.Size((*batch_shape, num_rows, n_target_points))
if index_tensor.is_cuda:
cls = getattr(torch.cuda.sparse, type_name)
else:
cls = getattr(torch.sparse, type_name)
res = cls(index_tensor, value_tensor, interp_size)
# Wrap things as a variable, if necessary
return res
[docs]
def bdsmm(sparse, dense):
"""
Batch dense-sparse matrix multiply
"""
# Make the batch sparse matrix into a block-diagonal matrix
if sparse.ndimension() > 2:
# Expand the tensors to account for broadcasting
output_shape = _matmul_broadcast_shape(sparse.shape, dense.shape)
expanded_sparse_shape = output_shape[:-2] + sparse.shape[-2:]
unsqueezed_sparse_shape = [1 for _ in range(len(output_shape) - sparse.dim())] + list(sparse.shape)
repeat_sizes = tuple(
output_size // sparse_size
for output_size, sparse_size in zip(expanded_sparse_shape, unsqueezed_sparse_shape)
)
sparse = sparse_repeat(sparse, *repeat_sizes)
dense = dense.expand(*output_shape[:-2], dense.size(-2), dense.size(-1))
# Figure out how much need to be added to the row/column indices to create
# a block-diagonal matrix
*batch_shape, num_rows, num_cols = sparse.shape
batch_size = torch.Size(batch_shape).numel()
batch_multiplication_factor = torch.tensor(
[torch.Size(batch_shape[i + 1 :]).numel() for i in range(len(batch_shape))],
dtype=torch.long,
device=sparse.device,
)
if batch_multiplication_factor.is_cuda:
batch_assignment = (sparse._indices()[:-2].float().t() @ batch_multiplication_factor.float()).long()
else:
batch_assignment = sparse._indices()[:-2].t() @ batch_multiplication_factor
# Create block-diagonal sparse tensor
indices = sparse._indices()[-2:].clone()
indices[0].add_(batch_assignment, alpha=num_rows)
indices[1].add_(batch_assignment, alpha=num_cols)
sparse_2d = torch.sparse_coo_tensor(
indices,
sparse._values(),
torch.Size((batch_size * num_rows, batch_size * num_cols)),
dtype=sparse._values().dtype,
device=sparse._values().device,
)
dense_2d = dense.reshape(batch_size * num_cols, -1)
res = torch.dsmm(sparse_2d, dense_2d)
res = res.view(*batch_shape, num_rows, -1)
return res
elif dense.dim() > 2:
*batch_shape, num_rows, num_cols = dense.size()
batch_size = torch.Size(batch_shape).numel()
dense = dense.view(batch_size, num_rows, num_cols)
res = torch.dsmm(sparse, dense.transpose(0, 1).reshape(-1, batch_size * num_cols))
res = res.view(-1, batch_size, num_cols)
res = res.transpose(0, 1).reshape(*batch_shape, -1, num_cols)
return res
else:
return torch.dsmm(sparse, dense)
[docs]
def sparse_eye(size):
"""
Returns the identity matrix as a sparse matrix
"""
indices = torch.arange(0, size).long().unsqueeze(0).expand(2, size)
values = torch.tensor(1.0).expand(size)
cls = getattr(torch.sparse, values.type().split(".")[-1])
return cls(indices, values, torch.Size([size, size]))
def sparse_getitem(sparse, idxs):
""" """
if not isinstance(idxs, tuple):
idxs = (idxs,)
if not sparse.ndimension() <= 2:
raise RuntimeError("Must be a 1d or 2d sparse tensor")
if len(idxs) > sparse.ndimension():
raise RuntimeError("Invalid index for %d-order tensor" % sparse.ndimension())
indices = sparse._indices()
values = sparse._values()
size = list(sparse.size())
for i, idx in list(enumerate(idxs))[::-1]:
if isinstance(idx, int):
del size[i]
mask = indices[i].eq(idx)
if torch.any(mask):
new_indices = torch.zeros(
indices.size(0) - 1,
torch.sum(mask),
dtype=indices.dtype,
device=indices.device,
)
for j in range(indices.size(0)):
if i > j:
new_indices[j].copy_(indices[j][mask])
elif i < j:
new_indices[j - 1].copy_(indices[j][mask])
indices = new_indices
values = values[mask]
else:
indices.resize_(indices.size(0) - 1, 1).zero_()
values.resize_(1).zero_()
if not len(size):
return sum(values)
elif isinstance(idx, slice):
start, stop, step = idx.indices(size[i])
size = list(size[:i]) + [stop - start] + list(size[i + 1 :])
if step != 1:
raise RuntimeError("Slicing with step is not supported")
mask = indices[i].lt(stop) & indices[i].ge(start)
if torch.any(mask):
new_indices = torch.zeros(
indices.size(0),
torch.sum(mask),
dtype=indices.dtype,
device=indices.device,
)
for j in range(indices.size(0)):
new_indices[j].copy_(indices[j][mask])
new_indices[i].sub_(start)
indices = new_indices
values = values[mask]
else:
indices.resize_(indices.size(0), 1).zero_()
values.resize_(1).zero_()
else:
raise RuntimeError("Unknown index type")
return torch.sparse_coo_tensor(indices, values, torch.Size(size), dtype=values.dtype, device=values.device)
def sparse_repeat(sparse, *repeat_sizes):
""" """
if len(repeat_sizes) == 1 and isinstance(repeat_sizes, tuple):
repeat_sizes = repeat_sizes[0]
if len(repeat_sizes) > len(sparse.shape):
num_new_dims = len(repeat_sizes) - len(sparse.shape)
new_indices = sparse._indices()
new_indices = torch.cat(
[
torch.zeros(
num_new_dims,
new_indices.size(1),
dtype=new_indices.dtype,
device=new_indices.device,
),
new_indices,
],
0,
)
sparse = torch.sparse_coo_tensor(
new_indices,
sparse._values(),
torch.Size((*[1 for _ in range(num_new_dims)], *sparse.shape)),
dtype=sparse.dtype,
device=sparse.device,
)
for i, repeat_size in enumerate(repeat_sizes):
if repeat_size > 1:
new_indices = sparse._indices().repeat(1, repeat_size)
adding_factor = torch.arange(0, repeat_size, dtype=new_indices.dtype, device=new_indices.device).unsqueeze_(
1
)
new_indices[i].view(repeat_size, -1).add_(adding_factor)
sparse = torch.sparse_coo_tensor(
new_indices,
sparse._values().repeat(repeat_size),
torch.Size(
(
*sparse.shape[:i],
repeat_size * sparse.size(i),
*sparse.shape[i + 1 :],
)
),
dtype=sparse.dtype,
device=sparse.device,
)
return sparse
def to_sparse(dense):
""" """
mask = dense.ne(0)
indices = mask.nonzero(as_tuple=False)
if indices.storage():
values = dense[mask]
else:
indices = indices.resize_(1, dense.ndimension()).zero_()
values = torch.tensor(0, dtype=dense.dtype, device=dense.device)
# Construct sparse tensor
klass = getattr(torch.sparse, dense.type().split(".")[-1])
res = klass(indices.t(), values, dense.size())
if dense.is_cuda:
res = res.cuda()
return res