Source code for ptwt.matmul_transform_3

"""Implement 3D separable boundary transforms."""

from __future__ import annotations

import sys
from functools import partial
from typing import NamedTuple, Optional, Union

import numpy as np
import torch

from ._util import (
    Wavelet,
    _as_wavelet,
    _check_axes_argument,
    _check_if_tensor,
    _fold_axes,
    _is_boundary_mode_supported,
    _is_dtype_supported,
    _map_result,
    _swap_axes,
    _undo_swap_axes,
    _unfold_axes,
)
from .constants import OrthogonalizeMethod, WaveletCoeffNd
from .conv_transform_3 import _waverec3d_fold_channels_3d_list
from .matmul_transform import construct_boundary_a, construct_boundary_s
from .sparse_math import _batch_dim_mm


class _PadTuple(NamedTuple):
    """Replaces _PadTuple = namedtuple("_PadTuple", ("depth", "height", "width"))."""

    depth: bool
    height: bool
    width: bool


def _matrix_pad_3(
    depth: int, height: int, width: int
) -> tuple[int, int, int, _PadTuple]:
    pad_depth, pad_height, pad_width = (False, False, False)
    if height % 2 != 0:
        height += 1
        pad_height = True
    if width % 2 != 0:
        width += 1
        pad_width = True
    if depth % 2 != 0:
        depth += 1
        pad_depth = True
    return depth, height, width, _PadTuple(pad_depth, pad_height, pad_width)


[docs] class MatrixWavedec3(object): """Compute 3d separable transforms.""" def __init__( self, wavelet: Union[Wavelet, str], level: Optional[int] = None, axes: tuple[int, int, int] = (-3, -2, -1), boundary: OrthogonalizeMethod = "qr", ): """Create a *separable* three-dimensional fast boundary wavelet transform. Input signals should have the shape [batch_size, depth, height, width], this object transforms the last three dimensions. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. level (int, optional): The desired decomposition level. Defaults to None. boundary : The method used for boundary filter treatment, see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. Raises: NotImplementedError: If the chosen orthogonalization method is not implemented. ValueError: If the analysis and synthesis filters do not have the same length. """ self.wavelet = _as_wavelet(wavelet) self.level = level self.boundary = boundary if len(axes) != 3: raise ValueError("3D transforms work with three axes.") else: _check_axes_argument(list(axes)) self.axes = axes self.input_signal_shape: Optional[tuple[int, int, int]] = None self.fwt_matrix_list: list[list[torch.Tensor]] = [] if not _is_boundary_mode_supported(self.boundary): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: raise ValueError("All filters must have the same length") def _construct_analysis_matrices( self, device: Union[torch.device, str], dtype: torch.dtype, ) -> None: if self.level is None or self.input_signal_shape is None: raise AssertionError self.fwt_matrix_list = [] self.size_list = [] self.pad_list = [] self.padded = False filt_len = self.wavelet.dec_len current_depth, current_height, current_width = self.input_signal_shape for curr_level in range(1, self.level + 1): if ( current_height < filt_len or current_width < filt_len or current_depth < filt_len ): # we have reached the max decomposition depth. sys.stderr.write( f"Warning: The selected number of decomposition levels {self.level}" f" is too large for the given input shape {self.input_signal_shape}" f". At level {curr_level}, at least one of the current signal " f"depth, height, and width ({current_depth}, {current_height}," f"{current_width}) is smaller " f"then the filter length {filt_len}. Therefore, the transformation " f"is only computed up to the decomposition level {curr_level-1}.\n" ) break # the conv matrices require even length inputs. current_depth, current_height, current_width, pad_tuple = _matrix_pad_3( depth=current_depth, height=current_height, width=current_width ) if any(pad_tuple): self.padded = True self.pad_list.append(pad_tuple) self.size_list.append((current_depth, current_height, current_width)) matrix_construction_fun = partial( construct_boundary_a, wavelet=self.wavelet, boundary=self.boundary, device=device, dtype=dtype, ) analysis_matrics = [ matrix_construction_fun(length=dimension_length) for dimension_length in (current_depth, current_height, current_width) ] self.fwt_matrix_list.append(analysis_matrics) current_depth, current_height, current_width = ( current_depth // 2, current_height // 2, current_width // 2, ) self.size_list.append((current_depth, current_height, current_width))
[docs] def __call__(self, input_signal: torch.Tensor) -> WaveletCoeffNd: """Compute a separable 3d-boundary wavelet transform. Args: input_signal (torch.Tensor): An input signal. For example of shape ``[batch_size, depth, height, width]``. Returns: The resulting coefficients for each level are stored in a tuple, see :data:`ptwt.constants.WaveletCoeffNd`. Raises: ValueError: If the input dimensions don't work. """ if self.axes != (-3, -2, -1): input_signal = _swap_axes(input_signal, list(self.axes)) ds = None if input_signal.dim() < 3: raise ValueError("At least three dimensions are required for 3d wavedec.") elif len(input_signal.shape) == 3: input_signal = input_signal.unsqueeze(1) else: input_signal, ds = _fold_axes(input_signal, 3) _, depth, height, width = input_signal.shape if not _is_dtype_supported(input_signal.dtype): raise ValueError(f"Input dtype {input_signal.dtype} not supported") re_build = False if ( self.input_signal_shape is None or self.input_signal_shape[0] != depth or self.input_signal_shape[1] != height or self.input_signal_shape[2] != width ): self.input_signal_shape = depth, height, width re_build = True if self.level is None: wlen = len(self.wavelet) self.level = int( np.min( [ np.log2(depth / (wlen - 1)), np.log2(height / (wlen - 1)), np.log2(width / (wlen - 1)), ] ) ) re_build = True elif self.level <= 0: raise ValueError("level must be a positive integer.") if not self.fwt_matrix_list or re_build: self._construct_analysis_matrices( device=input_signal.device, dtype=input_signal.dtype ) split_list: list[dict[str, torch.Tensor]] = [] lll = input_signal for scale, fwt_mats in enumerate(self.fwt_matrix_list): # fwt_depth_matrix, fwt_row_matrix, fwt_col_matrix = fwt_mats pad_tuple = self.pad_list[scale] # current_depth, current_height, current_width = self.size_list[scale] if pad_tuple.width: lll = torch.nn.functional.pad(lll, [0, 1, 0, 0, 0, 0]) if pad_tuple.height: lll = torch.nn.functional.pad(lll, [0, 0, 0, 1, 0, 0]) if pad_tuple.depth: lll = torch.nn.functional.pad(lll, [0, 0, 0, 0, 0, 1]) for dim, mat in enumerate(fwt_mats[::-1]): lll = _batch_dim_mm(mat, lll, dim=(-1) * (dim + 1)) def _split_rec( tensor: torch.Tensor, key: str, depth: int, dict: dict[str, torch.Tensor], ) -> None: if key: dict[key] = tensor if len(key) < depth: dim = len(key) + 1 ca, cd = torch.split(tensor, tensor.shape[-dim] // 2, dim=-dim) _split_rec(ca, "a" + key, depth, dict) _split_rec(cd, "d" + key, depth, dict) coeff_dict: dict[str, torch.Tensor] = {} _split_rec(lll, "", 3, coeff_dict) lll = coeff_dict["aaa"] result_keys = list( filter(lambda x: len(x) == 3 and not x == "aaa", coeff_dict.keys()) ) coeff_dict = { key: tensor for key, tensor in coeff_dict.items() if key in result_keys } split_list.append(coeff_dict) split_list.reverse() result: WaveletCoeffNd = lll, *split_list if ds: _unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3) result = _map_result(result, _unfold_axes_fn) if self.axes != (-3, -2, -1): undo_swap_fn = partial(_undo_swap_axes, axes=self.axes) result = _map_result(result, undo_swap_fn) return result
[docs] class MatrixWaverec3(object): """Reconstruct a signal from 3d-separable-fwt coefficients.""" def __init__( self, wavelet: Union[Wavelet, str], axes: tuple[int, int, int] = (-3, -2, -1), boundary: OrthogonalizeMethod = "qr", ): """Compute a three-dimensional separable boundary wavelet synthesis transform. Args: wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. axes (tuple[int, int, int]): Transform these axes instead of the last three. Defaults to (-3, -2, -1). boundary : The method used for boundary filter treatment, see :data:`ptwt.constants.OrthogonalizeMethod`. Defaults to 'qr'. Raises: NotImplementedError: If the selected `boundary` mode is not supported. ValueError: If the wavelet filters have different lengths. """ self.wavelet = _as_wavelet(wavelet) if len(axes) != 3: raise ValueError("3D transforms work with three axes") else: _check_axes_argument(list(axes)) self.axes = axes self.boundary = boundary self.ifwt_matrix_list: list[list[torch.Tensor]] = [] self.input_signal_shape: Optional[tuple[int, int, int]] = None self.level: Optional[int] = None if not _is_boundary_mode_supported(self.boundary): raise NotImplementedError if self.wavelet.dec_len != self.wavelet.rec_len: raise ValueError("All filters must have the same length") def _construct_synthesis_matrices( self, device: Union[torch.device, str], dtype: torch.dtype, ) -> None: self.ifwt_matrix_list = [] self.padded = False if self.level is None or self.input_signal_shape is None: raise AssertionError current_depth, current_height, current_width = self.input_signal_shape filt_len = self.wavelet.rec_len for curr_level in range(1, self.level + 1): if ( current_depth < filt_len or current_height < filt_len or current_width < filt_len ): sys.stderr.write( f"Warning: The selected number of decomposition levels {self.level}" f" is too large for the given input shape {self.input_signal_shape}" f". At level {curr_level}, at least one of the current signal " f" depth, height and width ({current_depth}, {current_height}, " f"{current_width}) is smaller than the filter length {filt_len}." f" Therefore, the transformation " f"is only computed up to the decomposition level {curr_level-1}.\n" ) break # the conv matrices require even length inputs. current_depth, current_height, current_width, pad_tuple = _matrix_pad_3( depth=current_depth, height=current_height, width=current_width ) if any(pad_tuple): self.padded = True matrix_construction_fun = partial( construct_boundary_s, wavelet=self.wavelet, boundary=self.boundary, device=device, dtype=dtype, ) synthesis_matrices = [ matrix_construction_fun(length=dimension_length) for dimension_length in (current_depth, current_height, current_width) ] self.ifwt_matrix_list.append(synthesis_matrices) current_depth, current_height, current_width = ( current_depth // 2, current_height // 2, current_width // 2, ) def _cat_coeff_recursive(self, input_dict: dict[str, torch.Tensor]) -> torch.Tensor: done_dict = {} a_initial_keys = list(filter(lambda x: x[0] == "a", input_dict.keys())) for a_key in a_initial_keys: d_key = "d" + a_key[1:] cat_d = input_dict[d_key] d_shape = cat_d.shape # undo any analysis padding. cat_a = input_dict[a_key][:, : d_shape[1], : d_shape[2], : d_shape[3]] cat_tensor = torch.cat([cat_a, cat_d], dim=-len(a_key)) if a_key[1:]: done_dict[a_key[1:]] = cat_tensor else: return cat_tensor return self._cat_coeff_recursive(done_dict)
[docs] def __call__(self, coefficients: WaveletCoeffNd) -> torch.Tensor: """Reconstruct a batched 3d-signal from its coefficients. Args: coefficients (WaveletCoeffNd): The output from the `MatrixWavedec3` object, see :data:`ptwt.constants.WaveletCoeffNd`. Returns: torch.Tensor: A reconstruction of the original signal. Raises: ValueError: If the data structure is inconsistent. """ if self.axes != (-3, -2, -1): swap_axes_fn = partial(_swap_axes, axes=list(self.axes)) coefficients = _map_result(coefficients, swap_axes_fn) ds = None # the Union[tensor, dict] idea is coming from pywt. We don't change it here. res_lll = _check_if_tensor(coefficients[0]) if res_lll.dim() < 3: raise ValueError( "Three dimensional transforms require at least three dimensions." ) elif res_lll.dim() >= 5: coefficients, ds = _waverec3d_fold_channels_3d_list(coefficients) res_lll = _check_if_tensor(coefficients[0]) level = len(coefficients) - 1 if type(coefficients[-1]) is dict: depth, height, width = tuple( c * 2 for c in coefficients[-1]["ddd"].shape[-3:] ) else: raise ValueError("Waverec3 expects dicts of tensors.") re_build = False if ( self.input_signal_shape is None or self.input_signal_shape[0] != depth or self.input_signal_shape[1] != height or self.input_signal_shape[2] != width ): self.input_signal_shape = depth, height, width re_build = True if self.level != level: self.level = level re_build = True lll = coefficients[0] if not isinstance(lll, torch.Tensor): raise ValueError( "First element of coeffs must be the approximation coefficient tensor." ) torch_device = lll.device torch_dtype = lll.dtype if not _is_dtype_supported(torch_dtype): if not _is_dtype_supported(torch_dtype): raise ValueError(f"Input dtype {torch_dtype} not supported") if not self.ifwt_matrix_list or re_build: self._construct_synthesis_matrices( device=torch_device, dtype=torch_dtype, ) for c_pos, coeff_dict in enumerate(coefficients[1:]): if not isinstance(coeff_dict, dict) or len(coeff_dict) != 7: raise ValueError( f"Unexpected detail coefficient type: {type(coeff_dict)}. Detail " "coefficients must be a dict containing 7 tensors as returned by " "MatrixWavedec3." ) test_shape = None for coeff in coeff_dict.values(): if test_shape is None: test_shape = coeff.shape if torch_device != coeff.device: raise ValueError("coefficients must be on the same device") elif torch_dtype != coeff.dtype: raise ValueError("coefficients must have the same dtype") elif test_shape != coeff.shape: raise ValueError( "All coefficients on each level must have the same shape" ) coeff_dict["a" * len(list(coeff_dict.keys())[-1])] = lll lll = self._cat_coeff_recursive(coeff_dict) for dim, mat in enumerate(self.ifwt_matrix_list[level - 1 - c_pos][::-1]): lll = _batch_dim_mm(mat, lll, dim=(-1) * (dim + 1)) if ds: lll = _unfold_axes(lll, ds, 3) if self.axes != (-3, -2, -1): lll = _undo_swap_axes(lll, list(self.axes)) return lll