Source code for ptwt.separable_conv_transform

"""Compute separable convolution-based transforms.

This module takes multi-dimensional convolutions apart.
It uses single-dimensional convolutions to transform
axes individually.
Under the hood, code in this module transforms all dimensions
using torch.nn.functional.conv1d and it's transpose.
"""

from __future__ import annotations

from typing import Optional, Union

import numpy as np
import torch

from ._util import (
    AxisHint,
    _as_wavelet,
    _check_same_device_dtype,
    _postprocess_coeffs,
    _postprocess_tensor,
    _preprocess_coeffs,
    _preprocess_tensor,
)
from .constants import (
    BoundaryMode,
    Wavelet,
    WaveletCoeff2dSeparable,
    WaveletCoeffNd,
    WaveletDetailDict,
)
from .conv_transform import wavedec, waverec

__all__ = ["fswavedec2", "fswavedec3", "fswaverec2", "fswaverec3"]


def _separable_conv_dwtn_(
    rec_dict: WaveletDetailDict,
    input_arg: torch.Tensor,
    wavelet: Union[Wavelet, str],
    *,
    mode: BoundaryMode = "reflect",
    key: str = "",
) -> None:
    """Compute a single-level separable fast wavelet transform.

    All but the first axes are transformed.

    Args:
        rec_dict (WaveletDetailDict): The result will be stored here
            in place.
        input_arg (torch.Tensor): Tensor of shape ``[batch, data_1, ... data_n]``.
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet.
        mode : The padding mode. The following methods are supported::

                "reflect", "zero", "constant", "periodic".

            Defaults to "reflect".
        key (str): The filter application path. Defaults to "".
    """
    axis_total = len(input_arg.shape) - 1
    if len(key) == axis_total:
        rec_dict[key] = input_arg
    if len(key) < axis_total:
        current_axis = len(key) + 1
        res_a, res_d = wavedec(
            input_arg, wavelet, level=1, mode=mode, axis=-current_axis
        )
        _separable_conv_dwtn_(rec_dict, res_a, wavelet, mode=mode, key="a" + key)
        _separable_conv_dwtn_(rec_dict, res_d, wavelet, mode=mode, key="d" + key)


def _separable_conv_idwtn(
    in_dict: WaveletDetailDict, wavelet: Union[Wavelet, str]
) -> torch.Tensor:
    """Separable single level inverse fast wavelet transform.

    Args:
        in_dict (WaveletDetailDict): The dictionary produced
            by _separable_conv_dwtn_ .
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet, as used by :func:`_separable_conv_dwtn_`.

    Returns:
        A reconstruction of the original signal.
    """
    done_dict = {}
    a_initial_keys = list(filter(lambda x: x[0] == "a", in_dict.keys()))
    for a_key in a_initial_keys:
        current_axis = len(a_key)
        d_key = "d" + a_key[1:]
        coeff_d = in_dict[d_key]
        d_shape = coeff_d.shape
        # undo any analysis padding.
        coeff_a = in_dict[a_key][tuple(slice(0, ds) for ds in d_shape)]
        trans_a, trans_d = (
            coeff.transpose(-1, -current_axis) for coeff in (coeff_a, coeff_d)
        )
        flat_a, flat_d = (
            coeff.reshape(-1, coeff.shape[-1]) for coeff in (trans_a, trans_d)
        )
        rec_ad = waverec([flat_a, flat_d], wavelet)
        rec_ad = rec_ad.reshape(list(trans_a.shape[:-1]) + [rec_ad.shape[-1]])
        rec_ad = rec_ad.transpose(-current_axis, -1)
        if a_key[1:]:
            done_dict[a_key[1:]] = rec_ad
        else:
            return rec_ad
    return _separable_conv_idwtn(done_dict, wavelet)


def _separable_conv_wavedecn(
    input: torch.Tensor,
    wavelet: Union[Wavelet, str],
    *,
    mode: BoundaryMode = "reflect",
    level: Optional[int] = None,
) -> WaveletCoeffNd:
    """Compute a multilevel separable padded wavelet analysis transform.

    Args:
        input (torch.Tensor): A tensor i.e. of shape ``[batch,axis_1, ... axis_n]``.
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet.
        mode : The desired padding mode.
        level (int): The desired decomposition level.

    Returns:
        A tuple with the approximation coefficients and
        for each scale a dictionary containing the detail coefficients.
        The dictionaries use a string of length n as keys with
        'a' denoting the low pass or approximation filter and
        'd' the high-pass or detail filter.
    """
    result: list[WaveletDetailDict] = []
    approx = input

    if level is None:
        wlen = len(_as_wavelet(wavelet))
        level = int(
            min([np.log2(axis_len / (wlen - 1)) for axis_len in input.shape[1:]])
        )

    for _ in range(level):
        level_dict: WaveletDetailDict = {}
        _separable_conv_dwtn_(level_dict, approx, wavelet, mode=mode, key="")
        approx_key = "a" * (len(input.shape) - 1)
        approx = level_dict.pop(approx_key)
        result.append(level_dict)
    result.reverse()
    return approx, *result


def _separable_conv_waverecn(
    coeffs: WaveletCoeffNd,
    wavelet: Union[Wavelet, str],
) -> torch.Tensor:
    """Separable n-dimensional wavelet synthesis transform.

    Args:
        coeffs (WaveletCoeffNd):
            The output as produced by `_separable_conv_wavedecn`.
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet, as used by :func:`_separable_conv_wavedecn`.

    Returns:
        The reconstruction of the original signal.

    Raises:
        ValueError: If the coeffs is not structured as expected.
    """
    if not isinstance(coeffs[0], torch.Tensor):
        raise ValueError("approximation tensor must be first in coefficient list.")
    if not all(map(lambda x: isinstance(x, dict), coeffs[1:])):
        raise ValueError("All entries after approximation tensor must be dicts.")

    approx: torch.Tensor = coeffs[0]
    for level_dict in coeffs[1:]:
        keys = list(level_dict.keys())
        level_dict["a" * max(map(len, keys))] = approx
        approx = _separable_conv_idwtn(level_dict, wavelet)
    return approx


[docs] def fswavedec2( data: torch.Tensor, wavelet: Union[Wavelet, str], *, mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: AxisHint = None, ) -> WaveletCoeff2dSeparable: """Compute a fully separable 2D-padded analysis wavelet transform. Single-dimensional convolutions are used to transform each axis individually. Under the hood, all dimensions are transformed using :func:`torch.nn.functional.conv1d`. Args: data (torch.Tensor): The input data tensor with at least two dimensions. By default, the last two axes are transformed. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of ``pywt.wavelist(kind="discrete")`` for a list of possible choices. mode: The desired padding mode for extending the signal along the edges. See :data:`ptwt.constants.BoundaryMode`. Defaults to ``reflect``. level (int, optional): The maximum decomposition level. If None, the level is computed based on the signal shape. Defaults to None. axes : Compute the transform over these axes. If none, the last 2 are used. Returns: A tuple starting with the approximation coefficient tensor followed by a dictionary of detail coefficients for each scale, see :data:`ptwt.constants.WaveletCoeff2dSeparable`. The dictionaries use the filter order strings:: "ad", "da", "dd" as keys. ``a`` denotes the low pass or approximation filter and ``d`` the high-pass or detail filter. Example: >>> import ptwt, torch >>> data = torch.randn(5, 10, 10) >>> coeff = ptwt.fswavedec2(data, "haar", level=2) """ return _fswavedecn(data, wavelet, ndim=2, mode=mode, level=level, axes=axes)
[docs] def fswavedec3( data: torch.Tensor, wavelet: Union[Wavelet, str], *, mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: AxisHint = None, ) -> WaveletCoeffNd: """Compute a fully separable 3D-padded analysis wavelet transform. Single-dimensional convolutions are used to transform each axis individually. Under the hood, all dimensions are transformed using :func:`torch.nn.functional.conv1d`. Args: data (torch.Tensor): The input data tensor with at least three dimensions. By default, the last three axes are transformed. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of ``pywt.wavelist(kind="discrete")`` for possible choices. mode: The desired padding mode for extending the signal along the edges. See :data:`ptwt.constants.BoundaryMode`. Defaults to ``reflect``. level (int, optional): The maximum decomposition level. If None, the level is computed based on the signal shape. Defaults to None. axes : Compute the transform over these axes. If none, the last 3 are used. Returns: A tuple starting with the approximation coefficient tensor followed by a dictionary of detail coefficients for each scale, see :data:`ptwt.constants.WaveletCoeffNd`. The dictionaries use the filter order strings:: "aad", "ada", "add", "daa", "dad", "dda", "ddd" as keys. ``a`` denotes the low pass or approximation filter and ``d`` the high-pass or detail filter. Example: >>> import ptwt, torch >>> data = torch.randn(5, 10, 10, 10) >>> coeff = ptwt.fswavedec3(data, "haar", level=2) """ return _fswavedecn(data, wavelet, ndim=3, mode=mode, level=level, axes=axes)
[docs] def fswaverec2( coeffs: WaveletCoeff2dSeparable, wavelet: Union[Wavelet, str], *, axes: AxisHint = None, ) -> torch.Tensor: """Compute a fully separable 2D-padded synthesis wavelet transform. Single-dimensional convolutions are used to transform each axis individually. Under the hood, all dimensions are transformed using :func:`torch.nn.functional.conv1d`. Args: coeffs: The wavelet coefficients as computed by :func:`ptwt.fswavedec2`, see :data:`ptwt.constants.WaveletCoeff2dSeparable`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. axes : Compute the transform over these axes. If none, the last 2 are used. Returns: A reconstruction of the signal encoded in the wavelet coefficients. Its shape depends on the shape of the input to :func:`ptwt.fswavedec2`. Example: >>> import ptwt, torch >>> data = torch.randn(5, 10, 10) >>> coeff = ptwt.fswavedec2(data, "haar", level=2) >>> rec = ptwt.fswaverec2(coeff, "haar") """ return _fswaverecn(coeffs, wavelet, ndim=2, axes=axes)
[docs] def fswaverec3( coeffs: WaveletCoeffNd, wavelet: Union[Wavelet, str], *, axes: AxisHint = None, ) -> torch.Tensor: """Compute a fully separable 3D-padded synthesis wavelet transform. Single-dimensional convolutions are used to transform each axis individually. Under the hood, all dimensions are transformed using :func:`torch.nn.functional.conv1d`. Args: coeffs: The wavelet coefficients as computed by :func:`ptwt.fswavedec3`, see :data:`ptwt.constants.WaveletCoeffNd`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. axes : Compute the transform over these axes. If none, the last 3 are used. Returns: A reconstruction of the signal encoded in the wavelet coefficients. Its shape depends on the shape of the input to :func:`ptwt.fswavedec3`. Example: >>> import ptwt, torch >>> data = torch.randn(5, 10, 10, 10) >>> coeff = ptwt.fswavedec3(data, "haar", level=2) >>> rec = ptwt.fswaverec3(coeff, "haar") """ return _fswaverecn(coeffs, wavelet, ndim=3, axes=axes)
def _fswavedecn( data: torch.Tensor, wavelet: Union[Wavelet, str], ndim: int, *, mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: AxisHint = None, ) -> WaveletCoeffNd: """Compute a fully separable :math:`N`-dimensional padded FWT. Single-dimensional convolutions are used to transform each axis individually. Under the hood, all dimensions are transformed using :func:`torch.nn.functional.conv1d`. Args: data (torch.Tensor): An input signal with at least :math:`N` dimensions. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of ``pywt.wavelist(kind="discrete")`` for possible choices. ndim (int): The number of dimentsions :math:`N`. mode: The desired padding mode for extending the signal along the edges. See :data:`ptwt.constants.BoundaryMode`. Defaults to ``reflect``. level (int, optional): The number of desired scales. Defaults to None. axes : Compute the transform over these axes. If none, the last :math:`N` are used. Returns: A tuple starting with the approximation coefficient tensor followed by a dictionary of detail coefficients for each scale, see :data:`ptwt.constants.WaveletCoeffNd`. Example: >>> import torch >>> from ptwt.separable_conv_transform import _fswavedecn >>> data = torch.randn(5, 10, 10, 10) >>> coeff = _fswavedecn(data, "haar", ndim=3, level=2) Note: ND-Transforms are generally out of this project's scope. """ if axes is None: axes = tuple(range(-ndim, 0)) data, ds = _preprocess_tensor(data, ndim=ndim, axes=axes, add_channel_dim=False) coeffs = _separable_conv_wavedecn(data, wavelet, mode=mode, level=level) return _postprocess_coeffs(coeffs, ndim=ndim, ds=ds, axes=axes) def _fswaverecn( coeffs: WaveletCoeffNd, wavelet: Union[Wavelet, str], ndim: int, *, axes: AxisHint = None, ) -> torch.Tensor: """Invert a fully separable :math:`N`-dimensional padded FWT. Single-dimensional convolutions are used to transform each axis individually. Under the hood, all dimensions are transformed using :func:`torch.nn.functional.conv1d`. Args: coeffs (WaveletCoeffNd): The wavelet coefficients as computed by :func:`fswavedecn`, see :data:`ptwt.constants.WaveletCoeffNd`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. ndim (int): The number of dimentsions :math:`N`. axes : Compute the transform over these axes. If none, the last :math:`N` are used. Returns: A reconstruction of the signal encoded in the wavelet coefficients. Its shape depends on the shape of the input to :func:`ptwt.fswaverecn`. Example: >>> import torch >>> from ptwt.separable_conv_transform import _fswavedecn, _fswaverecn >>> data = torch.randn(5, 10, 10, 10) >>> coeff = _fswavedecn(data, "haar", ndim=3, level=2) >>> rec = _fswaverecn(coeff, "haar", ndim=3) Note: ND-Transforms are generally out of this project's scope. """ coeffs, ds = _preprocess_coeffs(coeffs, ndim=ndim, axes=axes) _check_same_device_dtype(coeffs) res_ll = _separable_conv_waverecn(coeffs, wavelet) return _postprocess_tensor(res_ll, ndim=ndim, ds=ds, axes=axes)