Source code for ptwt.matmul_transform_2

"""Two-dimensional matrix based fast wavelet transform implementations.

This module uses boundary filters to minimize padding.
"""

from __future__ import annotations

import sys
from functools import partial
from typing import Optional, Union, cast

import numpy as np
import torch

from ._util import (
    Wavelet,
    _as_wavelet,
    _check_axes_argument,
    _check_if_tensor,
    _is_boundary_mode_supported,
    _is_dtype_supported,
    _map_result,
    _swap_axes,
    _undo_swap_axes,
    _unfold_axes,
)
from .constants import (
    OrthogonalizeMethod,
    PaddingMode,
    WaveletCoeff2d,
    WaveletDetailTuple2d,
)
from .conv_transform import _get_filter_tensors
from .conv_transform_2 import (
    _construct_2d_filt,
    _preprocess_tensor_dec2d,
    _waverec2d_fold_channels_2d_list,
)
from .matmul_transform import (
    BaseMatrixWaveDec,
    construct_boundary_a,
    construct_boundary_s,
    orthogonalize,
)
from .sparse_math import (
    batch_mm,
    cat_sparse_identity_matrix,
    construct_strided_conv2d_matrix,
)


def _construct_a_2(
    wavelet: Union[Wavelet, str],
    height: int,
    width: int,
    device: Union[torch.device, str],
    dtype: torch.dtype = torch.float64,
    mode: PaddingMode = "sameshift",
) -> torch.Tensor:
    """Construct a raw two-dimensional analysis wavelet transformation matrix.

    Args:
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet.
        height (int): The height of the input image.
        width (int): The width of the input image.
        device (torch.device or str): Where to place the matrix.
        dtype (torch.dtype, optional): Desired matrix data type.
            Defaults to torch.float64.
        mode : The convolution type.
            Options are 'full', 'valid', 'same' and 'sameshift'.
            Defaults to 'sameshift'.

    Returns:
        A sparse fwt analysis matrix.
        The matrices are ordered a, h, v, d or ll, lh, hl, hh.

    Note:
        The constructed matrix is NOT necessarily orthogonal.
        In most cases, construct_boundary_a2d should be used instead.
    """
    dec_lo, dec_hi, _, _ = _get_filter_tensors(
        wavelet, flip=False, device=device, dtype=dtype
    )
    dec_filt = _construct_2d_filt(lo=dec_lo, hi=dec_hi)
    ll, lh, hl, hh = dec_filt.squeeze(1)
    analysis_ll = construct_strided_conv2d_matrix(ll, height, width, mode=mode)
    analysis_lh = construct_strided_conv2d_matrix(lh, height, width, mode=mode)
    analysis_hl = construct_strided_conv2d_matrix(hl, height, width, mode=mode)
    analysis_hh = construct_strided_conv2d_matrix(hh, height, width, mode=mode)
    analysis = torch.cat([analysis_ll, analysis_lh, analysis_hl, analysis_hh], 0)
    return analysis


def _construct_s_2(
    wavelet: Union[Wavelet, str],
    height: int,
    width: int,
    device: Union[torch.device, str],
    dtype: torch.dtype = torch.float64,
    mode: PaddingMode = "sameshift",
) -> torch.Tensor:
    """Construct a raw fast wavelet transformation synthesis matrix.

    Note:
        The constructed matrix is NOT necessarily orthogonal.
        In most cases, construct_boundary_s2d should be used instead.

    Args:
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet.
        height (int): The height of the input image, which was originally
            transformed.
        width (int): The width of the input image, which was originally
            transformed.
        device (torch.device): Where to place the synthesis matrix,
            usually CPU or GPU.
        dtype (torch.dtype, optional): The data type the matrix should have.
            Defaults to torch.float64.
        mode : The convolution type.
            Options are 'full', 'valid', 'same' and 'sameshift'.
            Defaults to 'sameshift'.

    Returns:
        The generated fast wavelet synthesis matrix.
    """
    wavelet = _as_wavelet(wavelet)
    _, _, rec_lo, rec_hi = _get_filter_tensors(
        wavelet, flip=True, device=device, dtype=dtype
    )
    dec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi)
    ll, lh, hl, hh = dec_filt.squeeze(1)
    synthesis_ll = construct_strided_conv2d_matrix(ll, height, width, mode=mode)
    synthesis_lh = construct_strided_conv2d_matrix(lh, height, width, mode=mode)
    synthesis_hl = construct_strided_conv2d_matrix(hl, height, width, mode=mode)
    synthesis_hh = construct_strided_conv2d_matrix(hh, height, width, mode=mode)
    synthesis = torch.cat(
        [synthesis_ll, synthesis_lh, synthesis_hl, synthesis_hh], 0
    ).coalesce()
    indices = synthesis.indices()
    shape = synthesis.shape
    transpose_indices = torch.stack([indices[1, :], indices[0, :]])
    transpose_synthesis = torch.sparse_coo_tensor(
        transpose_indices, synthesis.values(), size=(shape[1], shape[0]), device=device
    )
    return transpose_synthesis


[docs] def construct_boundary_a2( wavelet: Union[Wavelet, str], height: int, width: int, device: Union[torch.device, str], boundary: OrthogonalizeMethod = "qr", dtype: torch.dtype = torch.float64, ) -> torch.Tensor: """Construct a boundary fwt matrix for the input wavelet. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. height (int): The height of the input matrix. Should be divisible by two. width (int): The width of the input matrix. Should be divisible by two. device (torch.device): Where to place the matrix. Either on the CPU or GPU. boundary : The method used for boundary filter treatment, see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. dtype (torch.dtype, optional): The desired data type for the matrix. Defaults to torch.float64. Returns: A sparse fwt matrix, with orthogonalized boundary wavelets. """ wavelet = _as_wavelet(wavelet) a = _construct_a_2(wavelet, height, width, device, dtype=dtype, mode="sameshift") orth_a = orthogonalize(a, wavelet.dec_len**2, method=boundary) # noqa: BLK100 return orth_a
[docs] def construct_boundary_s2( wavelet: Union[Wavelet, str], height: int, width: int, device: Union[torch.device, str], *, boundary: OrthogonalizeMethod = "qr", dtype: torch.dtype = torch.float64, ) -> torch.Tensor: """Construct a 2d-fwt matrix, with boundary wavelets. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. height (int): The original height of the input matrix. width (int): The width of the original input matrix. device (torch.device): Choose CPU or GPU. boundary : The method used for boundary filter treatment, see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. dtype (torch.dtype, optional): The data type of the sparse matrix, choose float32 or 64. Defaults to torch.float64. Returns: The synthesis matrix, used to compute the inverse fast wavelet transform. """ wavelet = _as_wavelet(wavelet) s = _construct_s_2(wavelet, height, width, device, dtype=dtype) orth_s = orthogonalize( s.transpose(1, 0), wavelet.rec_len**2, method=boundary # noqa: BLK100 ).transpose(1, 0) return orth_s
def _matrix_pad_2(height: int, width: int) -> tuple[int, int, tuple[bool, bool]]: pad_tuple = (False, False) if height % 2 != 0: height += 1 pad_tuple = (pad_tuple[0], True) if width % 2 != 0: width += 1 pad_tuple = (True, pad_tuple[1]) return height, width, pad_tuple
[docs] class MatrixWavedec2(BaseMatrixWaveDec): """Experimental sparse matrix 2d wavelet transform. For a completely pad-free transform, input images are expected to be divisible by two. For multiscale transforms all intermediate scale dimensions should be divisible by two, i.e. ``128, 128 -> 64, 64 -> 32, 32`` would work well for a level three transform. In this case multiplication with the `sparse_fwt_operator` property is equivalent. Note: Constructing the sparse fwt-matrix is expensive. For longer wavelets, high-level transforms, and large input images this may take a while. The matrix is therefore constructed only once. In the non-separable case, it can be accessed via the sparse_fwt_operator property. Example: >>> import ptwt, torch, pywt >>> import numpy as np >>> from scipy import datasets >>> face = datasets.face()[:256, :256, :].astype(np.float32) >>> pt_face = torch.tensor(face).permute([2, 0, 1]) >>> matrixfwt = ptwt.MatrixWavedec2(pywt.Wavelet("haar"), level=2) >>> mat_coeff = matrixfwt(pt_face) """ def __init__( self, wavelet: Union[Wavelet, str], level: Optional[int] = None, axes: tuple[int, int] = (-2, -1), boundary: OrthogonalizeMethod = "qr", separable: bool = True, ): """Create a new matrix fwt object. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. level (int, optional): The level up to which to compute the fwt. If None, the maximum level based on the signal length is chosen. Defaults to None. axes (int, int): A tuple with the axes to transform. Defaults to (-2, -1). boundary : The method used for boundary filter treatment, see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. separable (bool): If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. Matrix construction is significantly faster for separable transformations since only a small constant-size part of the matrices must be orthogonalized. Defaults to True. Raises: NotImplementedError: If the selected `boundary` mode is not supported. ValueError: If the wavelet filters have different lengths. """ self.wavelet = _as_wavelet(wavelet) if len(axes) != 2: raise ValueError("2D transforms work with two axes.") else: _check_axes_argument(list(axes)) self.axes = tuple(axes) self.level = level self.boundary = boundary self.separable = separable self.input_signal_shape: Optional[tuple[int, int]] = None self.fwt_matrix_list: list[ Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] ] = [] self.pad_list: list[tuple[bool, bool]] = [] self.padded = False if not _is_boundary_mode_supported(self.boundary): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: raise ValueError("All filters must have the same length") @property def sparse_fwt_operator(self) -> torch.Tensor: """Compute the operator matrix for padding-free cases. This property exists to make the transformation matrix available. To benefit from code handling odd-length levels call the object. Returns: The sparse 2d-fwt operator matrix. Raises: NotImplementedError: if a separable transformation was used or if padding had to be used in the creation of the transformation matrices. ValueError: If no level transformation matrices are stored (most likely since the object was not called yet). """ if self.separable: raise NotImplementedError # in the non-separable case the list entries are tensors fwt_matrix_list = cast(list[torch.Tensor], self.fwt_matrix_list) if len(fwt_matrix_list) == 1: return fwt_matrix_list[0] elif len(fwt_matrix_list) > 1: if self.padded: raise NotImplementedError fwt_matrix = fwt_matrix_list[0] for scale_mat in fwt_matrix_list[1:]: scale_mat = cat_sparse_identity_matrix(scale_mat, fwt_matrix.shape[0]) fwt_matrix = torch.sparse.mm(scale_mat, fwt_matrix) return fwt_matrix else: raise ValueError( "Call this object first to create the transformation matrices for each " "level." ) def _construct_analysis_matrices( self, device: Union[torch.device, str], dtype: torch.dtype, ) -> None: if self.level is None or self.input_signal_shape is None: raise AssertionError self.fwt_matrix_list = [] self.size_list = [] self.pad_list = [] self.padded = False filt_len = self.wavelet.dec_len current_height, current_width = self.input_signal_shape for curr_level in range(1, self.level + 1): if current_height < filt_len or current_width < filt_len: # we have reached the max decomposition depth. sys.stderr.write( f"Warning: The selected number of decomposition levels {self.level}" f" is too large for the given input shape {self.input_signal_shape}" f". At level {curr_level}, at least one of the current signal " f"height and width ({current_height}, {current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " f"is only computed up to the decomposition level {curr_level-1}.\n" ) break # the conv matrices require even length inputs. current_height, current_width, pad_tuple = _matrix_pad_2( current_height, current_width ) if any(pad_tuple): self.padded = True self.pad_list.append(pad_tuple) self.size_list.append((current_height, current_width)) if self.separable: analysis_matrix_rows = construct_boundary_a( wavelet=self.wavelet, length=current_height, boundary=self.boundary, device=device, dtype=dtype, ) analysis_matrix_cols = construct_boundary_a( wavelet=self.wavelet, length=current_width, boundary=self.boundary, device=device, dtype=dtype, ) self.fwt_matrix_list.append( (analysis_matrix_rows, analysis_matrix_cols) ) else: analysis_matrix_2d = construct_boundary_a2( wavelet=self.wavelet, height=current_height, width=current_width, boundary=self.boundary, device=device, dtype=dtype, ) self.fwt_matrix_list.append(analysis_matrix_2d) current_height = current_height // 2 current_width = current_width // 2 self.size_list.append((current_height, current_width))
[docs] def __call__(self, input_signal: torch.Tensor) -> WaveletCoeff2d: """Compute the fwt for the given input signal. The fwt matrix is set up during the first call and stored for future use. Args: input_signal (torch.Tensor): An input signal of shape ``[batch_size, height, width]``. 2d inputs are interpreted as ``[height, width]``. 4d inputs as ``[batch_size, channels, height, width]``. This transform affects the last two dimensions. Returns: The resulting coefficients per level are stored in a pywt style tuple, see :data:`ptwt.constants.WaveletCoeff2d`. Raises: ValueError: If the decomposition level is not a positive integer or if the input signal has not the expected shape. """ if self.axes != (-2, -1): input_signal = _swap_axes(input_signal, list(self.axes)) input_signal, ds = _preprocess_tensor_dec2d(input_signal) input_signal = input_signal.squeeze(1) batch_size, height, width = input_signal.shape if not _is_dtype_supported(input_signal.dtype): raise ValueError(f"Input dtype {input_signal.dtype} not supported") re_build = False if ( self.input_signal_shape is None or self.input_signal_shape[0] != height or self.input_signal_shape[1] != width ): self.input_signal_shape = height, width re_build = True if self.level is None: wlen = len(self.wavelet) self.level = int( np.min([np.log2(height / (wlen - 1)), np.log2(width / (wlen - 1))]) ) re_build = True elif self.level <= 0: raise ValueError("level must be a positive integer.") if not self.fwt_matrix_list or re_build: self._construct_analysis_matrices( device=input_signal.device, dtype=input_signal.dtype ) split_list: list[WaveletDetailTuple2d] = [] if self.separable: ll = input_signal for scale, fwt_mats in enumerate(self.fwt_matrix_list): fwt_row_matrix, fwt_col_matrix = fwt_mats pad = self.pad_list[scale] current_height, current_width = self.size_list[scale] if pad[0] or pad[1]: if pad[0] and not pad[1]: ll = torch.nn.functional.pad(ll, [0, 1]) elif pad[1] and not pad[0]: ll = torch.nn.functional.pad(ll, [0, 0, 0, 1]) elif pad[0] and pad[1]: ll = torch.nn.functional.pad(ll, [0, 1, 0, 1]) ll = batch_mm(fwt_col_matrix, ll.transpose(-2, -1)).transpose(-2, -1) ll = batch_mm(fwt_row_matrix, ll) a_coeffs, d_coeffs = torch.split(ll, current_height // 2, dim=-2) ll, lh = torch.split(a_coeffs, current_width // 2, dim=-1) hl, hh = torch.split(d_coeffs, current_width // 2, dim=-1) split_list.append(WaveletDetailTuple2d(lh, hl, hh)) else: ll = input_signal.transpose(-2, -1).reshape([batch_size, -1]).T for scale, fwt_matrix in enumerate(self.fwt_matrix_list): fwt_matrix = cast(torch.Tensor, fwt_matrix) pad = self.pad_list[scale] size = self.size_list[scale] if pad[0] or pad[1]: if pad[0] and not pad[1]: ll_reshape = ll.T.reshape( batch_size, size[1] - 1, size[0] ).transpose(2, 1) ll = torch.nn.functional.pad(ll_reshape, [0, 1]) elif pad[1] and not pad[0]: ll_reshape = ll.T.reshape( batch_size, size[1], size[0] - 1 ).transpose(2, 1) ll = torch.nn.functional.pad(ll_reshape, [0, 0, 0, 1]) elif pad[0] and pad[1]: ll_reshape = ll.T.reshape( batch_size, size[1] - 1, size[0] - 1 ).transpose(2, 1) ll = torch.nn.functional.pad(ll_reshape, [0, 1, 0, 1]) ll = ll.transpose(2, 1).reshape([batch_size, -1]).T coefficients = torch.sparse.mm(fwt_matrix, ll) # get the ll, four_split = torch.split( coefficients, int(np.prod((size[0] // 2, size[1] // 2))) ) reshaped = cast( tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple( ( el.T.reshape( batch_size, size[1] // 2, size[0] // 2 ).transpose(2, 1) ) for el in four_split[1:] ), ) split_list.append(WaveletDetailTuple2d(*reshaped)) ll = four_split[0] ll = ll.T.reshape(batch_size, size[1] // 2, size[0] // 2).transpose(2, 1) split_list.reverse() result: WaveletCoeff2d = ll, *split_list if ds: _unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2) result = _map_result(result, _unfold_axes2) if self.axes != (-2, -1): undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) result = _map_result(result, undo_swap_fn) return result
[docs] class MatrixWaverec2(object): """Synthesis or inverse matrix based-wavelet transformation object. Example: >>> import ptwt, torch, pywt >>> import numpy as np >>> from scipy import datasets >>> face = datasets.face()[:256, :256, :].astype(np.float32) >>> pt_face = torch.tensor(face).permute([2, 0, 1]) >>> matrixfwt = ptwt.MatrixWavedec2(pywt.Wavelet("haar"), level=2) >>> mat_coeff = matrixfwt(pt_face) >>> matrixifwt = ptwt.MatrixWaverec2(pywt.Wavelet("haar")) >>> reconstruction = matrixifwt(mat_coeff) """ def __init__( self, wavelet: Union[Wavelet, str], axes: tuple[int, int] = (-2, -1), boundary: OrthogonalizeMethod = "qr", separable: bool = True, ): """Create the inverse matrix-based fast wavelet transformation. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. axes (int, int): The axes transformed by waverec2. Defaults to (-2, -1). boundary : The method used for boundary filter treatment, see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. separable (bool): If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. This is significantly faster than a non-separable transformation since only a small constant- size part of the matrices must be orthogonalized. For invertibility, the analysis and synthesis values must be identical! Defaults to True. Raises: NotImplementedError: If the selected `boundary` mode is not supported. ValueError: If the wavelet filters have different lengths. """ self.wavelet = _as_wavelet(wavelet) self.boundary = boundary self.separable = separable if len(axes) != 2: raise ValueError("2D transforms work with two axes.") else: _check_axes_argument(list(axes)) self.axes = axes self.ifwt_matrix_list: list[ Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] ] = [] self.level: Optional[int] = None self.input_signal_shape: Optional[tuple[int, int]] = None self.padded = False if not _is_boundary_mode_supported(self.boundary): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: raise ValueError("All filters must have the same length") @property def sparse_ifwt_operator(self) -> torch.Tensor: """Compute the ifwt operator matrix for pad-free cases. Returns: The sparse 2d ifwt operator matrix. Raises: NotImplementedError: if a separable transformation was used or if padding had to be used in the creation of the transformation matrices. ValueError: If no level transformation matrices are stored (most likely since the object was not called yet). """ if self.separable: raise NotImplementedError # in the non-separable case the list entries are tensors ifwt_matrix_list = cast(list[torch.Tensor], self.ifwt_matrix_list) if len(ifwt_matrix_list) == 1: return ifwt_matrix_list[0] elif len(ifwt_matrix_list) > 1: if self.padded: raise NotImplementedError ifwt_matrix = ifwt_matrix_list[-1] for scale_mat in ifwt_matrix_list[:-1][::-1]: ifwt_matrix = cat_sparse_identity_matrix( ifwt_matrix, scale_mat.shape[0] ) ifwt_matrix = torch.sparse.mm(scale_mat, ifwt_matrix) return ifwt_matrix else: raise ValueError( "Call this object first to create the transformation matrices for each " "level." ) def _construct_synthesis_matrices( self, device: Union[torch.device, str], dtype: torch.dtype, ) -> None: self.ifwt_matrix_list = [] self.padded = False if self.level is None or self.input_signal_shape is None: raise AssertionError current_height, current_width = self.input_signal_shape filt_len = self.wavelet.rec_len for curr_level in range(1, self.level + 1): if current_height < filt_len or current_width < filt_len: sys.stderr.write( f"Warning: The selected number of decomposition levels {self.level}" f" is too large for the given input shape {self.input_signal_shape}" f". At level {curr_level}, at least one of the current signal " f"height and width ({current_height}, {current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " f"is only computed up to the decomposition level {curr_level-1}.\n" ) break current_height, current_width, pad_tuple = _matrix_pad_2( current_height, current_width ) if any(pad_tuple): self.padded = True if self.separable: synthesis_matrix_rows = construct_boundary_s( wavelet=self.wavelet, length=current_height, boundary=self.boundary, device=device, dtype=dtype, ) synthesis_matrix_cols = construct_boundary_s( wavelet=self.wavelet, length=current_width, boundary=self.boundary, device=device, dtype=dtype, ) self.ifwt_matrix_list.append( (synthesis_matrix_rows, synthesis_matrix_cols) ) else: synthesis_matrix_2d = construct_boundary_s2( self.wavelet, current_height, current_width, boundary=self.boundary, device=device, dtype=dtype, ) self.ifwt_matrix_list.append(synthesis_matrix_2d) current_height = current_height // 2 current_width = current_width // 2
[docs] def __call__( self, coefficients: WaveletCoeff2d, ) -> torch.Tensor: """Compute the inverse matrix 2d fast wavelet transform. Args: coefficients (WaveletCoeff2d): The coefficient tuple as returned by the `MatrixWavedec2` object, see :data:`ptwt.constants.WaveletCoeff2d`. Returns: The original signal reconstruction. For example of shape ``[batch_size, height, width]`` or ``[batch_size, channels, height, width]`` depending on the input to the forward transform and the value of the `axis` argument. Raises: ValueError: If the decomposition level is not a positive integer or if the coefficients are not in the shape as it is returned from a `MatrixWavedec2` object. """ ll = _check_if_tensor(coefficients[0]) if tuple(self.axes) != (-2, -1): swap_fn = partial(_swap_axes, axes=list(self.axes)) coefficients = _map_result(coefficients, swap_fn) ll = _check_if_tensor(coefficients[0]) ds = None if ll.dim() == 1: raise ValueError("2d transforms require more than a single input dim.") elif ll.dim() == 2: # add batch dim to unbatched input ll = ll.unsqueeze(0) elif ll.dim() >= 4: # avoid the channel sum, fold the channels into batches. coefficients, ds = _waverec2d_fold_channels_2d_list(coefficients) ll = _check_if_tensor(coefficients[0]) level = len(coefficients) - 1 height, width = tuple(c * 2 for c in coefficients[-1][0].shape[-2:]) re_build = False if ( self.input_signal_shape is None or self.input_signal_shape[0] != height or self.input_signal_shape[1] != width ): self.input_signal_shape = height, width re_build = True if self.level != level: self.level = level re_build = True batch_size = ll.shape[0] torch_device = ll.device torch_dtype = ll.dtype if not _is_dtype_supported(torch_dtype): raise ValueError(f"Input dtype {torch_dtype} not supported") if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, dtype=torch_dtype, ) for c_pos, coeff_tuple in enumerate(coefficients[1:]): if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3: raise ValueError( f"Unexpected detail coefficient type: {type(coeff_tuple)}. Detail " "coefficients must be a 3-tuple of tensors as returned by " "MatrixWavedec2." ) curr_shape = ll.shape for coeff in coeff_tuple: if torch_device != coeff.device: raise ValueError("coefficients must be on the same device") elif torch_dtype != coeff.dtype: raise ValueError("coefficients must have the same dtype") elif coeff.shape != curr_shape: raise ValueError( "All coefficients on each level must have the same shape" ) lh, hl, hh = coeff_tuple if self.separable: synthesis_matrix_rows, synthesis_matrix_cols = self.ifwt_matrix_list[ ::-1 ][c_pos] a_coeffs = torch.cat((ll, lh), -1) d_coeffs = torch.cat((hl, hh), -1) coeff_tensor = torch.cat((a_coeffs, d_coeffs), -2) if len(curr_shape) == 2: coeff_tensor = coeff_tensor.unsqueeze(0) ll = batch_mm( synthesis_matrix_cols, coeff_tensor.transpose(-2, -1) ).transpose(-2, -1) ll = batch_mm(synthesis_matrix_rows, ll) else: ll = torch.cat( [ ll.transpose(2, 1).reshape([batch_size, -1]), lh.transpose(2, 1).reshape([batch_size, -1]), hl.transpose(2, 1).reshape([batch_size, -1]), hh.transpose(2, 1).reshape([batch_size, -1]), ], -1, ) ifwt_mat = cast(torch.Tensor, self.ifwt_matrix_list[::-1][c_pos]) ll = cast(torch.Tensor, torch.sparse.mm(ifwt_mat, ll.T)) if not self.separable: pred_len = [s * 2 for s in curr_shape[-2:]][::-1] ll = ll.T.reshape([batch_size] + pred_len).transpose(2, 1) pred_len = list(ll.shape[1:]) else: pred_len = [s * 2 for s in curr_shape[-2:]] # remove the padding if c_pos < len(coefficients) - 2: next_len = list(coefficients[c_pos + 2][0].shape[-2:]) if pred_len != next_len: if pred_len[0] != next_len[0]: ll = ll[:, :-1, :] if pred_len[1] != next_len[1]: ll = ll[:, :, :-1] if ds: ll = _unfold_axes(ll, list(ds), 2) if self.axes != (-2, -1): ll = _undo_swap_axes(ll, list(self.axes)) return ll