Source code for ptwt.matmul_transform_2

"""Two-dimensional matrix based fast wavelet transform implementations.

This module uses boundary filters to minimize padding.
"""

import sys
from functools import partial
from typing import List, Optional, Tuple, Union, cast

import numpy as np
import torch

from ._util import (
    Wavelet,
    _as_wavelet,
    _check_axes_argument,
    _check_if_tensor,
    _is_boundary_mode_supported,
    _is_dtype_supported,
    _map_result,
    _swap_axes,
    _undo_swap_axes,
    _unfold_axes,
)
from .constants import OrthogonalizeMethod, PaddingMode
from .conv_transform import _get_filter_tensors
from .conv_transform_2 import (
    _construct_2d_filt,
    _preprocess_tensor_dec2d,
    _waverec2d_fold_channels_2d_list,
)
from .matmul_transform import (
    BaseMatrixWaveDec,
    construct_boundary_a,
    construct_boundary_s,
    orthogonalize,
)
from .sparse_math import (
    batch_mm,
    cat_sparse_identity_matrix,
    construct_strided_conv2d_matrix,
)


def _construct_a_2(
    wavelet: Union[Wavelet, str],
    height: int,
    width: int,
    device: Union[torch.device, str],
    dtype: torch.dtype = torch.float64,
    mode: PaddingMode = "sameshift",
) -> torch.Tensor:
    """Construct a raw two-dimensional analysis wavelet transformation matrix.

    Args:
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet.
        height (int): The height of the input image.
        width (int): The width of the input image.
        device (torch.device or str): Where to place the matrix.
        dtype (torch.dtype, optional): Desired matrix data type.
            Defaults to torch.float64.
        mode : The convolution type.
            Options are 'full', 'valid', 'same' and 'sameshift'.
            Defaults to 'sameshift'.

    Returns:
        torch.Tensor: A sparse fwt analysis matrix.
            The matrices are ordered a,h,v,d or
            ll, lh, hl, hh.

    Note:
        The constructed matrix is NOT necessarily orthogonal.
        In most cases, construct_boundary_a2d should be used instead.

    """
    dec_lo, dec_hi, _, _ = _get_filter_tensors(
        wavelet, flip=False, device=device, dtype=dtype
    )
    dec_filt = _construct_2d_filt(lo=dec_lo, hi=dec_hi)
    ll, lh, hl, hh = dec_filt.squeeze(1)
    analysis_ll = construct_strided_conv2d_matrix(ll, height, width, mode=mode)
    analysis_lh = construct_strided_conv2d_matrix(lh, height, width, mode=mode)
    analysis_hl = construct_strided_conv2d_matrix(hl, height, width, mode=mode)
    analysis_hh = construct_strided_conv2d_matrix(hh, height, width, mode=mode)
    analysis = torch.cat([analysis_ll, analysis_lh, analysis_hl, analysis_hh], 0)
    return analysis


def _construct_s_2(
    wavelet: Union[Wavelet, str],
    height: int,
    width: int,
    device: Union[torch.device, str],
    dtype: torch.dtype = torch.float64,
    mode: PaddingMode = "sameshift",
) -> torch.Tensor:
    """Construct a raw fast wavelet transformation synthesis matrix.

    Note:
        The constructed matrix is NOT necessarily orthogonal.
        In most cases, construct_boundary_s2d should be used instead.

    Args:
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet.
        height (int): The height of the input image, which was originally
            transformed.
        width (int): The width of the input image, which was originally
            transformed.
        device (torch.device): Where to place the synthesis matrix,
            usually CPU or GPU.
        dtype (torch.dtype, optional): The data type the matrix should have.
            Defaults to torch.float64.
        mode : The convolution type.
            Options are 'full', 'valid', 'same' and 'sameshift'.
            Defaults to 'sameshift'.

    Returns:
        [torch.Tensor]: The generated fast wavelet synthesis matrix.
    """
    wavelet = _as_wavelet(wavelet)
    _, _, rec_lo, rec_hi = _get_filter_tensors(
        wavelet, flip=True, device=device, dtype=dtype
    )
    dec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi)
    ll, lh, hl, hh = dec_filt.squeeze(1)
    synthesis_ll = construct_strided_conv2d_matrix(ll, height, width, mode=mode)
    synthesis_lh = construct_strided_conv2d_matrix(lh, height, width, mode=mode)
    synthesis_hl = construct_strided_conv2d_matrix(hl, height, width, mode=mode)
    synthesis_hh = construct_strided_conv2d_matrix(hh, height, width, mode=mode)
    synthesis = torch.cat(
        [synthesis_ll, synthesis_lh, synthesis_hl, synthesis_hh], 0
    ).coalesce()
    indices = synthesis.indices()
    shape = synthesis.shape
    transpose_indices = torch.stack([indices[1, :], indices[0, :]])
    transpose_synthesis = torch.sparse_coo_tensor(
        transpose_indices, synthesis.values(), size=(shape[1], shape[0]), device=device
    )
    return transpose_synthesis


[docs] def construct_boundary_a2( wavelet: Union[Wavelet, str], height: int, width: int, device: Union[torch.device, str], boundary: OrthogonalizeMethod = "qr", dtype: torch.dtype = torch.float64, ) -> torch.Tensor: """Construct a boundary fwt matrix for the input wavelet. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. height (int): The height of the input matrix. Should be divisible by two. width (int): The width of the input matrix. Should be divisible by two. device (torch.device): Where to place the matrix. Either on the CPU or GPU. boundary : The method to use for matrix orthogonalization. Choose "qr" or "gramschmidt". Defaults to "qr". dtype (torch.dtype, optional): The desired data type for the matrix. Defaults to torch.float64. Returns: torch.Tensor: A sparse fwt matrix, with orthogonalized boundary wavelets. """ wavelet = _as_wavelet(wavelet) a = _construct_a_2(wavelet, height, width, device, dtype=dtype, mode="sameshift") orth_a = orthogonalize(a, wavelet.dec_len**2, method=boundary) # noqa: BLK100 return orth_a
[docs] def construct_boundary_s2( wavelet: Union[Wavelet, str], height: int, width: int, device: Union[torch.device, str], *, boundary: OrthogonalizeMethod = "qr", dtype: torch.dtype = torch.float64, ) -> torch.Tensor: """Construct a 2d-fwt matrix, with boundary wavelets. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. height (int): The original height of the input matrix. width (int): The width of the original input matrix. device (torch.device): Choose CPU or GPU. boundary : The method to use for matrix orthogonalization. Choose qr or gramschmidt. Defaults to qr. dtype (torch.dtype, optional): The data type of the sparse matrix, choose float32 or 64. Defaults to torch.float64. Returns: torch.Tensor: The synthesis matrix, used to compute the inverse fast wavelet transform. """ wavelet = _as_wavelet(wavelet) s = _construct_s_2(wavelet, height, width, device, dtype=dtype) orth_s = orthogonalize( s.transpose(1, 0), wavelet.rec_len**2, method=boundary # noqa: BLK100 ).transpose(1, 0) return orth_s
def _matrix_pad_2(height: int, width: int) -> Tuple[int, int, Tuple[bool, bool]]: pad_tuple = (False, False) if height % 2 != 0: height += 1 pad_tuple = (pad_tuple[0], True) if width % 2 != 0: width += 1 pad_tuple = (True, pad_tuple[1]) return height, width, pad_tuple
[docs] class MatrixWavedec2(BaseMatrixWaveDec): """Experimental sparse matrix 2d wavelet transform. For a completely pad-free transform, input images are expected to be divisible by two. For multiscale transforms all intermediate scale dimensions should be divisible by two, i.e. 128, 128 -> 64, 64 -> 32, 32 would work well for a level three transform. In this case multiplication with the `sparse_fwt_operator` property is equivalent. Note: Constructing the sparse fwt-matrix is expensive. For longer wavelets, high-level transforms, and large input images this may take a while. The matrix is therefore constructed only once. In the non-separable case, it can be accessed via the sparse_fwt_operator property. Example: >>> import ptwt, torch, pywt >>> import numpy as np >>> from scipy import datasets >>> face = datasets.face()[:256, :256, :].astype(np.float32) >>> pt_face = torch.tensor(face).permute([2, 0, 1]) >>> matrixfwt = ptwt.MatrixWavedec2(pywt.Wavelet("haar"), level=2) >>> mat_coeff = matrixfwt(pt_face) """ def __init__( self, wavelet: Union[Wavelet, str], level: Optional[int] = None, axes: Tuple[int, int] = (-2, -1), boundary: OrthogonalizeMethod = "qr", separable: bool = True, ): """Create a new matrix fwt object. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. level (int, optional): The level up to which to compute the fwt. If None, the maximum level based on the signal length is chosen. Defaults to None. axes (int, int): A tuple with the axes to transform. Defaults to (-2, -1). boundary : The method used for boundary filter treatment. Choose 'qr' or 'gramschmidt'. 'qr' relies on Pytorch's dense qr implementation, it is fast but memory hungry. The 'gramschmidt' option is sparse, memory efficient, and slow. Choose 'gramschmidt' if 'qr' runs out of memory. Defaults to 'qr'. separable (bool): If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. Matrix construction is significantly faster for separable transformations since only a small constant-size part of the matrices must be orthogonalized. Defaults to True. Raises: NotImplementedError: If the selected `boundary` mode is not supported. ValueError: If the wavelet filters have different lengths. """ self.wavelet = _as_wavelet(wavelet) if len(axes) != 2: raise ValueError("2D transforms work with two axes.") else: _check_axes_argument(list(axes)) self.axes = tuple(axes) self.level = level self.boundary = boundary self.separable = separable self.input_signal_shape: Optional[Tuple[int, int]] = None self.fwt_matrix_list: List[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ] = [] self.pad_list: List[Tuple[bool, bool]] = [] self.padded = False if not _is_boundary_mode_supported(self.boundary): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: raise ValueError("All filters must have the same length") @property def sparse_fwt_operator(self) -> torch.Tensor: """Compute the operator matrix for padding-free cases. This property exists to make the transformation matrix available. To benefit from code handling odd-length levels call the object. Returns: torch.Tensor: The sparse 2d-fwt operator matrix. Raises: NotImplementedError: if a separable transformation was used or if padding had to be used in the creation of the transformation matrices. ValueError: If no level transformation matrices are stored (most likely since the object was not called yet). """ if self.separable: raise NotImplementedError # in the non-separable case the list entries are tensors fwt_matrix_list = cast(List[torch.Tensor], self.fwt_matrix_list) if len(fwt_matrix_list) == 1: return fwt_matrix_list[0] elif len(fwt_matrix_list) > 1: if self.padded: raise NotImplementedError fwt_matrix = fwt_matrix_list[0] for scale_mat in fwt_matrix_list[1:]: scale_mat = cat_sparse_identity_matrix(scale_mat, fwt_matrix.shape[0]) fwt_matrix = torch.sparse.mm(scale_mat, fwt_matrix) return fwt_matrix else: raise ValueError( "Call this object first to create the transformation matrices for each " "level." ) def _construct_analysis_matrices( self, device: Union[torch.device, str], dtype: torch.dtype, ) -> None: if self.level is None or self.input_signal_shape is None: raise AssertionError self.fwt_matrix_list = [] self.size_list = [] self.pad_list = [] self.padded = False filt_len = self.wavelet.dec_len current_height, current_width = self.input_signal_shape for curr_level in range(1, self.level + 1): if current_height < filt_len or current_width < filt_len: # we have reached the max decomposition depth. sys.stderr.write( f"Warning: The selected number of decomposition levels {self.level}" f" is too large for the given input shape {self.input_signal_shape}" f". At level {curr_level}, at least one of the current signal " f"height and width ({current_height}, {current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " f"is only computed up to the decomposition level {curr_level-1}.\n" ) break # the conv matrices require even length inputs. current_height, current_width, pad_tuple = _matrix_pad_2( current_height, current_width ) if any(pad_tuple): self.padded = True self.pad_list.append(pad_tuple) self.size_list.append((current_height, current_width)) if self.separable: analysis_matrix_rows = construct_boundary_a( wavelet=self.wavelet, length=current_height, boundary=self.boundary, device=device, dtype=dtype, ) analysis_matrix_cols = construct_boundary_a( wavelet=self.wavelet, length=current_width, boundary=self.boundary, device=device, dtype=dtype, ) self.fwt_matrix_list.append( (analysis_matrix_rows, analysis_matrix_cols) ) else: analysis_matrix_2d = construct_boundary_a2( wavelet=self.wavelet, height=current_height, width=current_width, boundary=self.boundary, device=device, dtype=dtype, ) self.fwt_matrix_list.append(analysis_matrix_2d) current_height = current_height // 2 current_width = current_width // 2 self.size_list.append((current_height, current_width)) def __call__( self, input_signal: torch.Tensor ) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: """Compute the fwt for the given input signal. The fwt matrix is set up during the first call and stored for future use. Args: input_signal (torch.Tensor): An input signal of shape ``[batch_size, height, width]``. 2d inputs are interpreted as ``[height, width]``. 4d inputs as ``[batch_size, channels, height, width]``. This transform affects the last two dimensions. Returns: (list): The resulting coefficients per level are stored in a pywt style list. The list is ordered as:: (ll, (lh, hl, hh), ...) with 'l' for low-pass and 'h' for high-pass filters. Raises: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ if self.axes != (-2, -1): input_signal = _swap_axes(input_signal, list(self.axes)) input_signal, ds = _preprocess_tensor_dec2d(input_signal) input_signal = input_signal.squeeze(1) batch_size, height, width = input_signal.shape if not _is_dtype_supported(input_signal.dtype): raise ValueError(f"Input dtype {input_signal.dtype} not supported") re_build = False if ( self.input_signal_shape is None or self.input_signal_shape[0] != height or self.input_signal_shape[1] != width ): self.input_signal_shape = height, width re_build = True if self.level is None: wlen = len(self.wavelet) self.level = int( np.min([np.log2(height / (wlen - 1)), np.log2(width / (wlen - 1))]) ) re_build = True elif self.level <= 0: raise ValueError("level must be a positive integer.") if not self.fwt_matrix_list or re_build: self._construct_analysis_matrices( device=input_signal.device, dtype=input_signal.dtype ) split_list: List[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ] = [] if self.separable: ll = input_signal for scale, fwt_mats in enumerate(self.fwt_matrix_list): fwt_row_matrix, fwt_col_matrix = fwt_mats pad = self.pad_list[scale] current_height, current_width = self.size_list[scale] if pad[0] or pad[1]: if pad[0] and not pad[1]: ll = torch.nn.functional.pad(ll, [0, 1]) elif pad[1] and not pad[0]: ll = torch.nn.functional.pad(ll, [0, 0, 0, 1]) elif pad[0] and pad[1]: ll = torch.nn.functional.pad(ll, [0, 1, 0, 1]) ll = batch_mm(fwt_col_matrix, ll.transpose(-2, -1)).transpose(-2, -1) ll = batch_mm(fwt_row_matrix, ll) a_coeffs, d_coeffs = torch.split(ll, current_height // 2, dim=-2) ll, lh = torch.split(a_coeffs, current_width // 2, dim=-1) hl, hh = torch.split(d_coeffs, current_width // 2, dim=-1) split_list.append((lh, hl, hh)) split_list.append(ll) else: ll = input_signal.transpose(-2, -1).reshape([batch_size, -1]).T for scale, fwt_matrix in enumerate(self.fwt_matrix_list): fwt_matrix = cast(torch.Tensor, fwt_matrix) pad = self.pad_list[scale] size = self.size_list[scale] if pad[0] or pad[1]: if pad[0] and not pad[1]: ll_reshape = ll.T.reshape( batch_size, size[1] - 1, size[0] ).transpose(2, 1) ll = torch.nn.functional.pad(ll_reshape, [0, 1]) elif pad[1] and not pad[0]: ll_reshape = ll.T.reshape( batch_size, size[1], size[0] - 1 ).transpose(2, 1) ll = torch.nn.functional.pad(ll_reshape, [0, 0, 0, 1]) elif pad[0] and pad[1]: ll_reshape = ll.T.reshape( batch_size, size[1] - 1, size[0] - 1 ).transpose(2, 1) ll = torch.nn.functional.pad(ll_reshape, [0, 1, 0, 1]) ll = ll.transpose(2, 1).reshape([batch_size, -1]).T coefficients = torch.sparse.mm(fwt_matrix, ll) # get the ll, four_split = torch.split( coefficients, int(np.prod((size[0] // 2, size[1] // 2))) ) reshaped = cast( Tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple( ( el.T.reshape( batch_size, size[1] // 2, size[0] // 2 ).transpose(2, 1) ) for el in four_split[1:] ), ) split_list.append(reshaped) ll = four_split[0] split_list.append( ll.T.reshape(batch_size, size[1] // 2, size[0] // 2).transpose(2, 1) ) if ds: _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) split_list = _map_result(split_list, _unfold_axes2) if self.axes != (-2, -1): undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) split_list = _map_result(split_list, undo_swap_fn) return split_list[::-1]
[docs] class MatrixWaverec2(object): """Synthesis or inverse matrix based-wavelet transformation object. Example: >>> import ptwt, torch, pywt >>> import numpy as np >>> from scipy import datasets >>> face = datasets.face()[:256, :256, :].astype(np.float32) >>> pt_face = torch.tensor(face).permute([2, 0, 1]) >>> matrixfwt = ptwt.MatrixWavedec2(pywt.Wavelet("haar"), level=2) >>> mat_coeff = matrixfwt(pt_face) >>> matrixifwt = ptwt.MatrixWaverec2(pywt.Wavelet("haar")) >>> reconstruction = matrixifwt(mat_coeff) """ def __init__( self, wavelet: Union[Wavelet, str], axes: Tuple[int, int] = (-2, -1), boundary: OrthogonalizeMethod = "qr", separable: bool = True, ): """Create the inverse matrix-based fast wavelet transformation. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. axes (int, int): The axes transformed by waverec2. Defaults to (-2, -1). boundary : The method used for boundary filter treatment. Choose 'qr' or 'gramschmidt'. 'qr' relies on pytorch's dense qr implementation, it is fast but memory hungry. The 'gramschmidt' option is sparse, memory efficient, and slow. Choose 'gramschmidt' if 'qr' runs out of memory. Defaults to 'qr'. separable (bool): If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. This is significantly faster than a non-separable transformation since only a small constant- size part of the matrices must be orthogonalized. For invertibility, the analysis and synthesis values must be identical! Defaults to True. Raises: NotImplementedError: If the selected `boundary` mode is not supported. ValueError: If the wavelet filters have different lengths. """ self.wavelet = _as_wavelet(wavelet) self.boundary = boundary self.separable = separable if len(axes) != 2: raise ValueError("2D transforms work with two axes.") else: _check_axes_argument(list(axes)) self.axes = axes self.ifwt_matrix_list: List[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ] = [] self.level: Optional[int] = None self.input_signal_shape: Optional[Tuple[int, int]] = None self.padded = False if not _is_boundary_mode_supported(self.boundary): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: raise ValueError("All filters must have the same length") @property def sparse_ifwt_operator(self) -> torch.Tensor: """Compute the ifwt operator matrix for pad-free cases. Returns: torch.Tensor: The sparse 2d ifwt operator matrix. Raises: NotImplementedError: if a separable transformation was used or if padding had to be used in the creation of the transformation matrices. ValueError: If no level transformation matrices are stored (most likely since the object was not called yet). """ if self.separable: raise NotImplementedError # in the non-separable case the list entries are tensors ifwt_matrix_list = cast(List[torch.Tensor], self.ifwt_matrix_list) if len(ifwt_matrix_list) == 1: return ifwt_matrix_list[0] elif len(ifwt_matrix_list) > 1: if self.padded: raise NotImplementedError ifwt_matrix = ifwt_matrix_list[-1] for scale_mat in ifwt_matrix_list[:-1][::-1]: ifwt_matrix = cat_sparse_identity_matrix( ifwt_matrix, scale_mat.shape[0] ) ifwt_matrix = torch.sparse.mm(scale_mat, ifwt_matrix) return ifwt_matrix else: raise ValueError( "Call this object first to create the transformation matrices for each " "level." ) def _construct_synthesis_matrices( self, device: Union[torch.device, str], dtype: torch.dtype, ) -> None: self.ifwt_matrix_list = [] self.padded = False if self.level is None or self.input_signal_shape is None: raise AssertionError current_height, current_width = self.input_signal_shape filt_len = self.wavelet.rec_len for curr_level in range(1, self.level + 1): if current_height < filt_len or current_width < filt_len: sys.stderr.write( f"Warning: The selected number of decomposition levels {self.level}" f" is too large for the given input shape {self.input_signal_shape}" f". At level {curr_level}, at least one of the current signal " f"height and width ({current_height}, {current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " f"is only computed up to the decomposition level {curr_level-1}.\n" ) break current_height, current_width, pad_tuple = _matrix_pad_2( current_height, current_width ) if any(pad_tuple): self.padded = True if self.separable: synthesis_matrix_rows = construct_boundary_s( wavelet=self.wavelet, length=current_height, boundary=self.boundary, device=device, dtype=dtype, ) synthesis_matrix_cols = construct_boundary_s( wavelet=self.wavelet, length=current_width, boundary=self.boundary, device=device, dtype=dtype, ) self.ifwt_matrix_list.append( (synthesis_matrix_rows, synthesis_matrix_cols) ) else: synthesis_matrix_2d = construct_boundary_s2( self.wavelet, current_height, current_width, boundary=self.boundary, device=device, dtype=dtype, ) self.ifwt_matrix_list.append(synthesis_matrix_2d) current_height = current_height // 2 current_width = current_width // 2 def __call__( self, coefficients: List[ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] ], ) -> torch.Tensor: """Compute the inverse matrix 2d fast wavelet transform. Args: coefficients (list): The coefficient list as returned by the `MatrixWavedec2`-Object. Returns: torch.Tensor: The original signal reconstruction. For example of shape ``[batch_size, height, width]`` or ``[batch_size, channels, height, width]`` depending on the input to the forward transform. and the value of the `axis` argument. Raises: ValueError: If the decomposition level is not a positive integer or if the coefficients are not in the shape as it is returned from a `MatrixWavedec2` object. """ ll = _check_if_tensor(coefficients[0]) if tuple(self.axes) != (-2, -1): swap_fn = partial(_swap_axes, axes=list(self.axes)) coefficients = _map_result(coefficients, swap_fn) ll = _check_if_tensor(coefficients[0]) ds = None if ll.dim() == 1: raise ValueError("2d transforms require more than a single input dim.") elif ll.dim() == 2: # add batch dim to unbatched input ll = ll.unsqueeze(0) elif ll.dim() >= 4: # avoid the channel sum, fold the channels into batches. coefficients, ds = _waverec2d_fold_channels_2d_list(coefficients) ll = _check_if_tensor(coefficients[0]) level = len(coefficients) - 1 height, width = tuple(c * 2 for c in coefficients[-1][0].shape[-2:]) re_build = False if ( self.input_signal_shape is None or self.input_signal_shape[0] != height or self.input_signal_shape[1] != width ): self.input_signal_shape = height, width re_build = True if self.level != level: self.level = level re_build = True batch_size = ll.shape[0] torch_device = ll.device torch_dtype = ll.dtype if not _is_dtype_supported(torch_dtype): raise ValueError(f"Input dtype {torch_dtype} not supported") if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, dtype=torch_dtype, ) for c_pos, coeff_tuple in enumerate(coefficients[1:]): if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3: raise ValueError( f"Unexpected detail coefficient type: {type(coeff_tuple)}. Detail " "coefficients must be a 3-tuple of tensors as returned by " "MatrixWavedec2." ) curr_shape = ll.shape for coeff in coeff_tuple: if torch_device != coeff.device: raise ValueError("coefficients must be on the same device") elif torch_dtype != coeff.dtype: raise ValueError("coefficients must have the same dtype") elif coeff.shape != curr_shape: raise ValueError( "All coefficients on each level must have the same shape" ) lh, hl, hh = coeff_tuple if self.separable: synthesis_matrix_rows, synthesis_matrix_cols = self.ifwt_matrix_list[ ::-1 ][c_pos] a_coeffs = torch.cat((ll, lh), -1) d_coeffs = torch.cat((hl, hh), -1) coeff_tensor = torch.cat((a_coeffs, d_coeffs), -2) if len(curr_shape) == 2: coeff_tensor = coeff_tensor.unsqueeze(0) ll = batch_mm( synthesis_matrix_cols, coeff_tensor.transpose(-2, -1) ).transpose(-2, -1) ll = batch_mm(synthesis_matrix_rows, ll) else: ll = torch.cat( [ ll.transpose(2, 1).reshape([batch_size, -1]), lh.transpose(2, 1).reshape([batch_size, -1]), hl.transpose(2, 1).reshape([batch_size, -1]), hh.transpose(2, 1).reshape([batch_size, -1]), ], -1, ) ifwt_mat = cast(torch.Tensor, self.ifwt_matrix_list[::-1][c_pos]) ll = cast(torch.Tensor, torch.sparse.mm(ifwt_mat, ll.T)) if not self.separable: pred_len = [s * 2 for s in curr_shape[-2:]][::-1] ll = ll.T.reshape([batch_size] + pred_len).transpose(2, 1) pred_len = list(ll.shape[1:]) else: pred_len = [s * 2 for s in curr_shape[-2:]] # remove the padding if c_pos < len(coefficients) - 2: next_len = list(coefficients[c_pos + 2][0].shape[-2:]) if pred_len != next_len: if pred_len[0] != next_len[0]: ll = ll[:, :-1, :] if pred_len[1] != next_len[1]: ll = ll[:, :, :-1] if ds: ll = _unfold_axes(ll, list(ds), 2) if self.axes != (-2, -1): ll = _undo_swap_axes(ll, list(self.axes)) return ll