Source code for linear_operator.utils.sparse

#!/usr/bin/env python3

import torch

from linear_operator.utils.broadcasting import _matmul_broadcast_shape


[docs]def make_sparse_from_indices_and_values(interp_indices, interp_values, num_rows): """ This produces a sparse tensor with a fixed number of non-zero entries in each column. Args: - interp_indices - Tensor (batch_size) x num_cols x n_nonzero_entries A matrix which has the indices of the nonzero_entries for each column - interp_values - Tensor (batch_size) x num_cols x n_nonzero_entries The corresponding values - num_rows - the number of rows in the result matrix Returns: - SparseTensor - (batch_size) x num_cols x num_rows """ if not torch.is_tensor(interp_indices): raise RuntimeError("interp_indices and interp_values should be tensors") # Is it batch mode? batch_shape = interp_values.shape[:-2] n_target_points, n_coefficients = interp_values.shape[-2:] # Index tensor batch_tensors = [] for i, batch_size in enumerate(batch_shape): batch_tensor = torch.arange(0, batch_size, dtype=torch.long, device=interp_values.device) batch_tensor = ( batch_tensor.unsqueeze_(1) .repeat( batch_shape[:i].numel(), batch_shape[i + 1 :].numel() * n_target_points * n_coefficients, ) .view(-1) ) batch_tensors.append(batch_tensor) row_tensor = torch.arange(0, n_target_points, dtype=torch.long, device=interp_values.device) row_tensor = row_tensor.unsqueeze_(1).repeat(batch_shape.numel(), n_coefficients).view(-1) index_tensor = torch.stack([*batch_tensors, interp_indices.reshape(-1), row_tensor], 0) # Value tensor value_tensor = interp_values.reshape(-1) nonzero_indices = value_tensor.nonzero(as_tuple=False) if nonzero_indices.storage(): nonzero_indices.squeeze_() index_tensor = index_tensor.index_select(1, nonzero_indices) value_tensor = value_tensor.index_select(0, nonzero_indices) else: index_tensor = index_tensor.resize_(interp_indices.dim(), 1).zero_() value_tensor = value_tensor.resize_(1).zero_() # Make the sparse tensor type_name = value_tensor.type().split(".")[-1] # e.g. FloatTensor interp_size = torch.Size((*batch_shape, num_rows, n_target_points)) if index_tensor.is_cuda: cls = getattr(torch.cuda.sparse, type_name) else: cls = getattr(torch.sparse, type_name) res = cls(index_tensor, value_tensor, interp_size) # Wrap things as a variable, if necessary return res
[docs]def bdsmm(sparse, dense): """ Batch dense-sparse matrix multiply """ # Make the batch sparse matrix into a block-diagonal matrix if sparse.ndimension() > 2: # Expand the tensors to account for broadcasting output_shape = _matmul_broadcast_shape(sparse.shape, dense.shape) expanded_sparse_shape = output_shape[:-2] + sparse.shape[-2:] unsqueezed_sparse_shape = [1 for _ in range(len(output_shape) - sparse.dim())] + list(sparse.shape) repeat_sizes = tuple( output_size // sparse_size for output_size, sparse_size in zip(expanded_sparse_shape, unsqueezed_sparse_shape) ) sparse = sparse_repeat(sparse, *repeat_sizes) dense = dense.expand(*output_shape[:-2], dense.size(-2), dense.size(-1)) # Figure out how much need to be added to the row/column indices to create # a block-diagonal matrix *batch_shape, num_rows, num_cols = sparse.shape batch_size = torch.Size(batch_shape).numel() batch_multiplication_factor = torch.tensor( [torch.Size(batch_shape[i + 1 :]).numel() for i in range(len(batch_shape))], dtype=torch.long, device=sparse.device, ) if batch_multiplication_factor.is_cuda: batch_assignment = (sparse._indices()[:-2].float().t() @ batch_multiplication_factor.float()).long() else: batch_assignment = sparse._indices()[:-2].t() @ batch_multiplication_factor # Create block-diagonal sparse tensor indices = sparse._indices()[-2:].clone() indices[0].add_(batch_assignment, alpha=num_rows) indices[1].add_(batch_assignment, alpha=num_cols) sparse_2d = torch.sparse_coo_tensor( indices, sparse._values(), torch.Size((batch_size * num_rows, batch_size * num_cols)), dtype=sparse._values().dtype, device=sparse._values().device, ) dense_2d = dense.reshape(batch_size * num_cols, -1) res = torch.dsmm(sparse_2d, dense_2d) res = res.view(*batch_shape, num_rows, -1) return res elif dense.dim() > 2: *batch_shape, num_rows, num_cols = dense.size() batch_size = torch.Size(batch_shape).numel() dense = dense.view(batch_size, num_rows, num_cols) res = torch.dsmm(sparse, dense.transpose(0, 1).reshape(-1, batch_size * num_cols)) res = res.view(-1, batch_size, num_cols) res = res.transpose(0, 1).reshape(*batch_shape, -1, num_cols) return res else: return torch.dsmm(sparse, dense)
[docs]def sparse_eye(size): """ Returns the identity matrix as a sparse matrix """ indices = torch.arange(0, size).long().unsqueeze(0).expand(2, size) values = torch.tensor(1.0).expand(size) cls = getattr(torch.sparse, values.type().split(".")[-1]) return cls(indices, values, torch.Size([size, size]))
def sparse_getitem(sparse, idxs): """ """ if not isinstance(idxs, tuple): idxs = (idxs,) if not sparse.ndimension() <= 2: raise RuntimeError("Must be a 1d or 2d sparse tensor") if len(idxs) > sparse.ndimension(): raise RuntimeError("Invalid index for %d-order tensor" % sparse.ndimension()) indices = sparse._indices() values = sparse._values() size = list(sparse.size()) for i, idx in list(enumerate(idxs))[::-1]: if isinstance(idx, int): del size[i] mask = indices[i].eq(idx) if torch.any(mask): new_indices = torch.zeros( indices.size(0) - 1, torch.sum(mask), dtype=indices.dtype, device=indices.device, ) for j in range(indices.size(0)): if i > j: new_indices[j].copy_(indices[j][mask]) elif i < j: new_indices[j - 1].copy_(indices[j][mask]) indices = new_indices values = values[mask] else: indices.resize_(indices.size(0) - 1, 1).zero_() values.resize_(1).zero_() if not len(size): return sum(values) elif isinstance(idx, slice): start, stop, step = idx.indices(size[i]) size = list(size[:i]) + [stop - start] + list(size[i + 1 :]) if step != 1: raise RuntimeError("Slicing with step is not supported") mask = indices[i].lt(stop) & indices[i].ge(start) if torch.any(mask): new_indices = torch.zeros( indices.size(0), torch.sum(mask), dtype=indices.dtype, device=indices.device, ) for j in range(indices.size(0)): new_indices[j].copy_(indices[j][mask]) new_indices[i].sub_(start) indices = new_indices values = values[mask] else: indices.resize_(indices.size(0), 1).zero_() values.resize_(1).zero_() else: raise RuntimeError("Unknown index type") return torch.sparse_coo_tensor(indices, values, torch.Size(size), dtype=values.dtype, device=values.device) def sparse_repeat(sparse, *repeat_sizes): """ """ if len(repeat_sizes) == 1 and isinstance(repeat_sizes, tuple): repeat_sizes = repeat_sizes[0] if len(repeat_sizes) > len(sparse.shape): num_new_dims = len(repeat_sizes) - len(sparse.shape) new_indices = sparse._indices() new_indices = torch.cat( [ torch.zeros( num_new_dims, new_indices.size(1), dtype=new_indices.dtype, device=new_indices.device, ), new_indices, ], 0, ) sparse = torch.sparse_coo_tensor( new_indices, sparse._values(), torch.Size((*[1 for _ in range(num_new_dims)], *sparse.shape)), dtype=sparse.dtype, device=sparse.device, ) for i, repeat_size in enumerate(repeat_sizes): if repeat_size > 1: new_indices = sparse._indices().repeat(1, repeat_size) adding_factor = torch.arange(0, repeat_size, dtype=new_indices.dtype, device=new_indices.device).unsqueeze_( 1 ) new_indices[i].view(repeat_size, -1).add_(adding_factor) sparse = torch.sparse_coo_tensor( new_indices, sparse._values().repeat(repeat_size), torch.Size( ( *sparse.shape[:i], repeat_size * sparse.size(i), *sparse.shape[i + 1 :], ) ), dtype=sparse.dtype, device=sparse.device, ) return sparse def to_sparse(dense): """ """ mask = dense.ne(0) indices = mask.nonzero(as_tuple=False) if indices.storage(): values = dense[mask] else: indices = indices.resize_(1, dense.ndimension()).zero_() values = torch.tensor(0, dtype=dense.dtype, device=dense.device) # Construct sparse tensor klass = getattr(torch.sparse, dense.type().split(".")[-1]) res = klass(indices.t(), values, dense.size()) if dense.is_cuda: res = res.cuda() return res