Source code for ptwt.sparse_math

"""Efficiently construct fwt operations using sparse matrices."""

from itertools import product
from typing import List

import torch

from .constants import PaddingMode


def _dense_kron(
    sparse_tensor_a: torch.Tensor, sparse_tensor_b: torch.Tensor
) -> torch.Tensor:
    """Faster than sparse_kron.

    Limited to resolutions of approximately 128x128 pixels
    by memory on my machine.

    Args:
        sparse_tensor_a (torch.Tensor): Sparse 2d-Tensor a of shape [m, n].
        sparse_tensor_b (torch.Tensor): Sparse 2d-Tensor b of shape [p, q].

    Returns:
        torch.Tensor: The resulting [mp, nq] tensor.

    """
    return torch.kron(
        sparse_tensor_a.to_dense(), sparse_tensor_b.to_dense()
    ).to_sparse()


[docs] def sparse_kron( sparse_tensor_a: torch.Tensor, sparse_tensor_b: torch.Tensor ) -> torch.Tensor: """Compute a sparse Kronecker product. As defined at: https://en.wikipedia.org/wiki/Kronecker_product Adapted from: https://github.com/scipy/scipy/blob/v1.7.1/scipy/sparse/construct.py#L274-L357 Args: sparse_tensor_a (torch.Tensor): Sparse 2d-Tensor a of shape [m, n]. sparse_tensor_b (torch.Tensor): Sparse 2d-Tensor b of shape [p, q]. Returns: torch.Tensor: The resulting [mp, nq] tensor. """ assert sparse_tensor_a.device == sparse_tensor_b.device sparse_tensor_a = sparse_tensor_a.coalesce() sparse_tensor_b = sparse_tensor_b.coalesce() output_shape = ( sparse_tensor_a.shape[0] * sparse_tensor_b.shape[0], sparse_tensor_a.shape[1] * sparse_tensor_b.shape[1], ) nzz_a = len(sparse_tensor_a.values()) nzz_b = len(sparse_tensor_b.values()) # take care of the zero case. if nzz_a == 0 or nzz_b == 0: return torch.sparse_coo_tensor( torch.zeros([2, 1]), torch.zeros([1]), size=output_shape, device=sparse_tensor_a.device, ) # expand A's entries into blocks row = sparse_tensor_a.indices()[0, :].repeat_interleave(nzz_b) col = sparse_tensor_a.indices()[1, :].repeat_interleave(nzz_b) data = sparse_tensor_a.values().repeat_interleave(nzz_b) row *= sparse_tensor_b.shape[0] col *= sparse_tensor_b.shape[1] # increment block indices row, col = row.reshape(-1, nzz_b), col.reshape(-1, nzz_b) row += sparse_tensor_b.indices()[0, :] col += sparse_tensor_b.indices()[1, :] row, col = row.reshape(-1), col.reshape(-1) # compute block entries data = data.reshape(-1, nzz_b) * sparse_tensor_b.values() data = data.reshape(-1) result = torch.sparse_coo_tensor( torch.stack([row, col], 0), data, size=output_shape, device=sparse_tensor_a.device, ) return result
[docs] def cat_sparse_identity_matrix( sparse_matrix: torch.Tensor, new_length: int ) -> torch.Tensor: """Concatenate a sparse input matrix and a sparse identity matrix. Args: sparse_matrix (torch.Tensor): The input matrix. new_length (int): The length up to which the diagonal should be elongated. Returns: torch.Tensor: Square [input, eye] matrix of size [new_length, new_length] """ # assert square matrix. assert ( sparse_matrix.shape[0] == sparse_matrix.shape[1] ), "Matrices must be square. Odd inputs can cause non-square matrices." assert new_length > sparse_matrix.shape[0], "can't add negatively many entries." x = torch.arange( sparse_matrix.shape[0], new_length, dtype=sparse_matrix.dtype, device=sparse_matrix.device, ) y = torch.arange( sparse_matrix.shape[0], new_length, dtype=sparse_matrix.dtype, device=sparse_matrix.device, ) extra_indices = torch.stack([x, y]) extra_values = torch.ones( [new_length - sparse_matrix.shape[0]], dtype=sparse_matrix.dtype, device=sparse_matrix.device, ) new_indices = torch.cat([sparse_matrix.coalesce().indices(), extra_indices], -1) new_values = torch.cat([sparse_matrix.coalesce().values(), extra_values], -1) new_matrix = torch.sparse_coo_tensor( new_indices, new_values, device=sparse_matrix.device ) return new_matrix
[docs] def sparse_diag( diagonal: torch.Tensor, col_offset: int, rows: int, cols: int ) -> torch.Tensor: """Create a rows-by-cols sparse diagonal-matrix. The matrix is constructed by taking the columns of the input and placing them along the diagonal specified by col_offset. Args: diagonal (torch.Tensor): The values for the diagonal. col_offset (int): Move the diagonal to the right by offset columns. rows (int): The number of rows in the final matrix. cols (int): The number of columns in the final matrix. Returns: torch.Tensor: A sparse matrix with a shifted diagonal. """ diag_indices = torch.stack( [ torch.arange(len(diagonal), device=diagonal.device), torch.arange(len(diagonal), device=diagonal.device), ] ) if col_offset > 0: diag_indices[1] += col_offset if col_offset < 0: diag_indices[0] += abs(col_offset) if torch.max(diag_indices[0]) >= rows: mask = diag_indices[0] < rows diag_indices = diag_indices[:, mask] diagonal = diagonal[mask] if torch.max(diag_indices[1]) >= cols: mask = diag_indices[1] < cols diag_indices = diag_indices[:, mask] diagonal = diagonal[mask] diag = torch.sparse_coo_tensor( diag_indices, diagonal, size=(rows, cols), dtype=diagonal.dtype, device=diagonal.device, ) return diag
[docs] def sparse_replace_row( matrix: torch.Tensor, row_index: int, row: torch.Tensor ) -> torch.Tensor: """Replace a row within a sparse [rows, cols] matrix. I.e. matrix[row_no, :] = row. Args: matrix (torch.Tensor): A sparse two-dimensional matrix. row_index (int): The row to replace. row (torch.Tensor): The row to insert into the sparse matrix. Returns: torch.Tensor: A sparse matrix, with the new row inserted at row_index. """ matrix = matrix.coalesce() assert ( matrix.shape[-1] == row.shape[-1] ), "matrix and replacement row must share the same column number." # delete existing indices we dont want diag_indices = torch.arange(matrix.shape[0]) diag = torch.ones_like(diag_indices) diag[row_index] = 0 removal_matrix = torch.sparse_coo_tensor( torch.stack([diag_indices] * 2, 0), diag, size=matrix.shape, device=matrix.device, dtype=matrix.dtype, ) row = row.coalesce() addition_matrix = torch.sparse_coo_tensor( torch.stack((row.indices()[0, :] + row_index, row.indices()[1, :]), 0), row.values(), size=matrix.shape, device=matrix.device, dtype=matrix.dtype, ) result = torch.sparse.mm(removal_matrix, matrix) + addition_matrix return result
def _orth_by_qr( matrix: torch.Tensor, rows_to_orthogonalize: torch.Tensor ) -> torch.Tensor: """Orthogonalize a wavelet matrix through qr decomposition. A dense qr-decomposition is used for GPU-efficiency reasons. If memory becomes a constraint choose _orth_by_gram_schmidt instead, which is implemented using only sparse function calls. Args: matrix (torch.Tensor): The matrix to orthogonalize. rows_to_orthogonalize (torch.Tensor): The matrix rows, which need work. Returns: torch.Tensor: The corrected sparse matrix. """ selection_indices = torch.stack( [ torch.arange(len(rows_to_orthogonalize), device=matrix.device), rows_to_orthogonalize, ], 0, ) selection_matrix = torch.sparse_coo_tensor( selection_indices, torch.ones_like(rows_to_orthogonalize), dtype=matrix.dtype, device=matrix.device, ) sel = torch.sparse.mm(selection_matrix, matrix) q, _ = torch.linalg.qr(sel.to_dense().T) q_rows = q.T.to_sparse() diag_indices = torch.arange(matrix.shape[0]) diag = torch.ones_like(diag_indices) diag[rows_to_orthogonalize] = 0 removal_matrix = torch.sparse_coo_tensor( torch.stack([diag_indices] * 2, 0), diag, size=matrix.shape, device=matrix.device, dtype=matrix.dtype, ) result = torch.sparse.mm(removal_matrix, matrix) for pos, row in enumerate(q_rows): row = row.unsqueeze(0).coalesce() addition_matrix = torch.sparse_coo_tensor( torch.stack( (row.indices()[0, :] + rows_to_orthogonalize[pos], row.indices()[1, :]), 0, ), row.values(), size=matrix.shape, device=matrix.device, dtype=matrix.dtype, ) result += addition_matrix return result.coalesce() def _orth_by_gram_schmidt( matrix: torch.Tensor, to_orthogonalize: torch.Tensor ) -> torch.Tensor: """Orthogonalize by using sparse Gram-Schmidt. This function is very memory efficient and very slow. Args: matrix (torch.Tensor): The sparse matrix to work on. to_orthogonalize (torch.Tensor): The matrix rows, which need work. Returns: torch.Tensor: The orthogonalized sparse matrix. """ done: List[int] = [] # loop over the rows we want to orthogonalize for row_no_to_ortho in to_orthogonalize: current_row = matrix.select(0, row_no_to_ortho).unsqueeze(0) sum = torch.zeros_like(current_row) for done_row_no in done: done_row = matrix.select(0, done_row_no).unsqueeze(0) non_orthogonal = torch.sparse.mm(current_row, done_row.transpose(1, 0)) non_orthogonal_values = non_orthogonal.coalesce().values() if len(non_orthogonal_values) == 0: non_orthogonal_item: float = 0 else: non_orthogonal_item = non_orthogonal_values.item() sum += non_orthogonal_item * done_row orthogonal_row = current_row - sum length = torch.native_norm(orthogonal_row) orthonormal_row = orthogonal_row / length matrix = sparse_replace_row(matrix, row_no_to_ortho, orthonormal_row) done.append(row_no_to_ortho) return matrix
[docs] def construct_conv_matrix( filter: torch.Tensor, input_rows: int, *, mode: PaddingMode = "valid" ) -> torch.Tensor: """Construct a convolution matrix. Full, valid and same, padding are supported. For reference see: https://github.com/RoyiAvital/StackExchangeCodes/blob/\ master/StackOverflow/Q2080835/CreateConvMtxSparse.m Args: filter (torch.tensor): The 1D-filter to convolve with. input_rows (int): The number of columns in the input. mode : String identifier for the desired padding. Choose 'full', 'valid' or 'same'. Defaults to valid. Returns: torch.Tensor: The sparse convolution tensor. Raises: ValueError: If the padding is not 'full', 'same' or 'valid'. """ filter_length = len(filter) if mode == "full": start_row = 0 stop_row = input_rows + filter_length - 1 elif mode == "same" or mode == "sameshift": filter_offset = filter_length % 2 # signal_offset = input_rows % 2 start_row = filter_length // 2 - 1 + filter_offset stop_row = start_row + input_rows - 1 elif mode == "valid": start_row = filter_length - 1 stop_row = input_rows - 1 else: raise ValueError("unkown padding type.") product_lst = [ (row, col) for col, row in product(range(input_rows), range(filter_length)) if row + col in range(start_row, stop_row + 1) ] row_indices = torch.tensor( [row + col - start_row for row, col in product_lst], device=filter.device ) col_indices = torch.tensor([col for row, col in product_lst], device=filter.device) indices = torch.stack([row_indices, col_indices]) values = torch.stack([filter[row] for row, col in product_lst]) return torch.sparse_coo_tensor( indices, values, device=filter.device, dtype=filter.dtype )
[docs] def construct_conv2d_matrix( filter: torch.Tensor, input_rows: int, input_columns: int, mode: PaddingMode = "valid", fully_sparse: bool = True, ) -> torch.Tensor: """Create a two-dimensional sparse convolution matrix. Convolving with this matrix should be equivalent to a call to scipy.signal.convolve2d and a reshape. Args: filter (torch.tensor): A filter of shape [height, width] to convolve with. input_rows (int): The number of rows in the input matrix. input_columns (int): The number of columns in the input matrix. mode: (str) = The desired padding method. Options are full, same and valid. Defaults to 'valid' or no padding. fully_sparse (bool): Use a sparse implementation of the Kronecker to save memory. Defaults to True. Returns: torch.Tensor: A sparse convolution matrix. Raises: ValueError: If the padding mode is neither full, same or valid. """ if fully_sparse: kron = sparse_kron else: kron = _dense_kron kernel_column_number = filter.shape[-1] matrix_block_number = kernel_column_number if mode == "full": diag_index = 0 kronecker_rows = input_columns + kernel_column_number - 1 elif mode == "same" or mode == "sameshift": filter_offset = kernel_column_number % 2 diag_index = kernel_column_number // 2 - 1 + filter_offset kronecker_rows = input_columns elif mode == "valid": diag_index = kernel_column_number - 1 kronecker_rows = input_columns - kernel_column_number + 1 else: raise ValueError("unknown conv mode.") block_matrix_list = [] for i in range(matrix_block_number): block_matrix_list.append( construct_conv_matrix(filter[:, i], input_rows, mode=mode) ) diag_values = torch.ones( min(kronecker_rows, input_columns), dtype=filter.dtype, device=filter.device, ) diag = sparse_diag(diag_values, diag_index, kronecker_rows, input_columns) sparse_conv_matrix = kron(diag, block_matrix_list[0]) for block_matrix in block_matrix_list[1:]: diag_index -= 1 diag = sparse_diag(diag_values, diag_index, kronecker_rows, input_columns) sparse_conv_matrix += kron(diag, block_matrix) return sparse_conv_matrix
[docs] def construct_strided_conv_matrix( filter: torch.Tensor, input_rows: int, stride: int = 2, *, mode: PaddingMode = "valid" ) -> torch.Tensor: """Construct a strided convolution matrix. Args: filter (torch.Tensor): The filter coefficients to convolve with. input_rows (int): The number of rows in the input vector. stride (int): The step size of the convolution. Defaults to two. mode : Choose 'valid', 'same' or 'sameshift'. Defaults to 'valid'. Returns: torch.Tensor: The strided sparse convolution matrix. """ conv_matrix = construct_conv_matrix(filter, input_rows, mode=mode) if mode == "sameshift": # find conv_matrix[1:stride, :] sparsely select_rows = torch.arange(1, conv_matrix.shape[0], stride) else: # find conv_matrix[:stride, :] sparsely select_rows = torch.arange(0, conv_matrix.shape[0], stride) selection_matrix = torch.sparse_coo_tensor( torch.stack([torch.arange(0, len(select_rows)), select_rows]), torch.ones_like(select_rows), size=[len(select_rows), conv_matrix.shape[0]], dtype=conv_matrix.dtype, device=conv_matrix.device, ) return torch.sparse.mm(selection_matrix, conv_matrix)
[docs] def construct_strided_conv2d_matrix( filter: torch.Tensor, input_rows: int, input_columns: int, stride: int = 2, mode: PaddingMode = "full", ) -> torch.Tensor: """Create a strided sparse two-dimensional convolution matrix. Args: filter (torch.Tensor): The two-dimensional convolution filter. input_rows (int): The number of rows in the 2d-input matrix. input_columns (int): The number of columns in the 2d- input matrix. stride (int): The stride between the filter positions. Defaults to 2. mode : The convolution type. Defaults to 'full'. Sameshift starts at 1 instead of 0. Raises: ValueError: Raised if an unknown convolution string is provided. Returns: torch.Tensor: The sparse convolution tensor. """ filter_shape = filter.shape if mode == "full": output_rows = filter_shape[0] + input_rows - 1 output_columns = filter_shape[1] + input_columns - 1 elif mode == "valid": output_rows = input_rows - filter_shape[0] + 1 output_columns = input_columns - filter_shape[1] + 1 elif mode == "same" or mode == "sameshift": output_rows = input_rows output_columns = input_columns else: raise ValueError("Padding mode not accepted.") convolution_matrix = construct_conv2d_matrix( filter, input_rows, input_columns, mode=mode ) output_elements = output_rows * output_columns element_numbers = torch.arange(output_elements, device=filter.device).reshape( output_columns, output_rows ) start = 0 if mode == "sameshift": start += 1 strided_rows = element_numbers[start::stride, start::stride] strided_rows = strided_rows.flatten() selection_eye = torch.sparse_coo_tensor( torch.stack( [ torch.arange(len(strided_rows), device=convolution_matrix.device), strided_rows, ], 0, ), torch.ones(len(strided_rows)), dtype=convolution_matrix.dtype, device=convolution_matrix.device, size=[len(strided_rows), convolution_matrix.shape[0]], ) # return convolution_matrix.index_select(0, strided_rows) return torch.sparse.mm(selection_eye, convolution_matrix)
[docs] def batch_mm(matrix: torch.Tensor, matrix_batch: torch.Tensor) -> torch.Tensor: """Calculate a batched matrix-matrix product using torch tensors. This calculates the product of a matrix with a batch of dense matrices. The former can be dense or sparse. Args: matrix (torch.Tensor): Sparse or dense matrix, size (m, n). matrix_batch (torch.Tensor): Batched dense matrices, size (b, n, k). Returns torch.Tensor: The batched matrix-matrix product, size (b, m, k). Raises: ValueError: If the matrices cannot be multiplied due to incompatible matrix shapes. """ batch_size = matrix_batch.shape[0] if matrix.shape[1] != matrix_batch.shape[1]: raise ValueError("Matrix shapes not compatible.") # Stack the vector batch into columns. (b, n, k) -> (n, b, k) -> (n, b*k) vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1) return matrix.mm(vectors).reshape(matrix.shape[0], batch_size, -1).transpose(1, 0)
def _batch_dim_mm( matrix: torch.Tensor, batch_tensor: torch.Tensor, dim: int ) -> torch.Tensor: """Multiply batch_tensor with matrix along the dimensions specified in dim. Args: matrix (torch.Tensor): A matrix of shape [m, n] batch_tensor (torch.Tensor): A tensor with a selected dim of length n. dim (int): The position of the desired dimension. Returns: torch.Tensor: The multiplication result. """ dim_length = batch_tensor.shape[dim] permuted_tensor = batch_tensor.transpose(dim, -1) permuted_shape = permuted_tensor.shape res = torch.sparse.mm(matrix, permuted_tensor.reshape(-1, dim_length).T).T return res.reshape(list(permuted_shape[:-1]) + [matrix.shape[0]]).transpose(-1, dim)