Source code for ptwt.separable_conv_transform

"""Implement separable convolution based transforms.

Under the hood code in this module transforms all dimensions
individually using torch.nn.functional.conv1d and it's
transpose.
"""
from typing import Dict, List, Optional, Union

import numpy as np
import pywt
import torch

from ._util import _as_wavelet
from .conv_transform import wavedec, waverec


def _separable_conv_dwtn_(
    rec_dict: Dict[str, torch.Tensor],
    input: torch.Tensor,
    wavelet: Union[str, pywt.Wavelet],
    mode: str = "reflect",
    key: str = "",
) -> None:
    """Compute a single level separable fast wavelet transform.

    All but the first axes are transformed.

    Args:
        input (torch.Tensor): Tensor of shape [batch, data_1, ... data_n].
        wavelet (Union[str, pywt.Wavelet]): The Wavelet to work with.
        mode (str): The padding mode. The following methods are supported::

                "reflect", "zero", "constant", "periodic".

            Defaults to "reflect".
        key (str): The filter application path. Defaults to "".
        dict (Dict[str, torch.Tensor]): The result will be stored here
            in place. Defaults to {}.
    """
    axis_total = len(input.shape) - 1
    if len(key) == axis_total:
        rec_dict[key] = input
    if len(key) < axis_total:
        current_axis = len(key) + 1
        transposed = input.transpose(-current_axis, -1)
        flat = transposed.reshape(-1, transposed.shape[-1])
        res_a, res_d = wavedec(flat, wavelet, level=1, mode=mode)
        res_a = res_a.reshape(list(transposed.shape[:-1]) + [res_a.shape[-1]])
        res_d = res_d.reshape(list(transposed.shape[:-1]) + [res_d.shape[-1]])
        res_a = res_a.transpose(-1, -current_axis)
        res_d = res_d.transpose(-1, -current_axis)
        _separable_conv_dwtn_(rec_dict, res_a, wavelet, mode, "a" + key)
        _separable_conv_dwtn_(rec_dict, res_d, wavelet, mode, "d" + key)


def _separable_conv_idwtn(
    in_dict: Dict[str, torch.Tensor], wavelet: Union[str, pywt.Wavelet]
) -> torch.Tensor:
    """Separable single level inverse fast wavelet transform.

    Args:
        in_dict (Dict[str, torch.Tensor]): The dictionary produced
            by _separable_conv_dwtn_ .
        wavelet (Union[str, pywt.Wavelet]): The wavelet used by
            _separable_conv_dwtn_ .

    Returns:
        torch.Tensor: A reconstruction of the original signal.
    """
    done_dict = {}
    a_initial_keys = list(filter(lambda x: x[0] == "a", in_dict.keys()))
    for a_key in a_initial_keys:
        current_axis = len(a_key)
        d_key = "d" + a_key[1:]
        coeff_d = in_dict[d_key]
        d_shape = coeff_d.shape
        # undo any analysis padding.
        coeff_a = in_dict[a_key][tuple(slice(0, ds) for ds in d_shape)]
        trans_a, trans_d = (
            coeff.transpose(-1, -current_axis) for coeff in (coeff_a, coeff_d)
        )
        flat_a, flat_d = (
            coeff.reshape(-1, coeff.shape[-1]) for coeff in (trans_a, trans_d)
        )
        rec_ad = waverec([flat_a, flat_d], wavelet)
        rec_ad = rec_ad.reshape(list(trans_a.shape[:-1]) + [rec_ad.shape[-1]])
        rec_ad = rec_ad.transpose(-current_axis, -1)
        if a_key[1:]:
            done_dict[a_key[1:]] = rec_ad
        else:
            return rec_ad
    return _separable_conv_idwtn(done_dict, wavelet)


def _separable_conv_wavedecn(
    input: torch.Tensor,
    wavelet: Union[str, pywt.Wavelet],
    mode: str = "reflect",
    level: Optional[int] = None,
) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]:
    """Compute a multilevel separable padded wavelet analysis transform.

    Args:
        input (torch.Tensor): A tensor of shape [batch, axis_1, ... axis_n].
            Everything but the batch axis will be transformed.
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet.
            Please consider the output from ``pywt.wavelist(kind='discrete')``
            for possible choices.
        mode (str): The desired padding mode. Padding extends the signal along
            the edges. Supported methods are::

                "reflect", "zero", "constant", "periodic".

            Defaults to "reflect".
        level (int): The desired decomposition level. If None the
            largest possible decomposition value is used.

    Returns:
        List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: _description_
    """
    result: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = []
    approx = input

    if level is None:
        wlen = len(_as_wavelet(wavelet))
        level = int(
            min([np.log2(axis_len / (wlen - 1)) for axis_len in input.shape[1:]])
        )

    for _ in range(level):
        level_dict: Dict[str, torch.Tensor] = {}
        _separable_conv_dwtn_(level_dict, approx, wavelet, mode, "")
        approx_key = "a" * (len(input.shape) - 1)
        approx = level_dict.pop(approx_key)
        result.append(level_dict)
    result.append(approx)
    return result[::-1]


def _separable_conv_waverecn(
    coeff_list: List[Union[torch.Tensor, Dict[str, torch.Tensor]]],
    wavelet: Union[str, pywt.Wavelet],
) -> torch.Tensor:
    """Separable n-dimensional wavelet synthesis transform.

    Args:
        coeff_list (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]):
            The output as produced by `_separable_conv_wavedecn`.
        wavelet (Union[str, pywt.Wavelet]):
            The wavelet used by `_separable_conv_wavedecn`.

    Returns:
        torch.Tensor: The reconstruction of the original signal.

    Raises:
        ValueError: If the coeff_list is no not structured as expected.
    """
    if not isinstance(coeff_list[0], torch.Tensor):
        raise ValueError("approximation tensor must be first in coefficient list.")
    if not all(map(lambda x: isinstance(x, dict), coeff_list[1:])):
        raise ValueError("All entries after approximation tensor must be dicts.")

    approx: torch.Tensor = coeff_list[0]
    for level_dict in coeff_list[1:]:
        keys = list(level_dict.keys())  # type: ignore
        level_dict["a" * max(map(len, keys))] = approx  # type: ignore
        approx = _separable_conv_idwtn(level_dict, wavelet)  # type: ignore
    return approx


def _fswavedec(
    input: torch.Tensor,
    wavelet: Union[str, pywt.Wavelet],
    mode: str = "reflect",
    level: Optional[int] = None,
) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]:
    """Compute a fully separable 1D-padded analysis wavelet transform.

       This function is private. Results are the identical to wavedec.
       Use wavedec instead.

    Args:
        input (torch.Tensor): An input signal of shape [batch, length].
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet. Refer to the output of
            ``pywt.wavelist(kind="discrete")`` for a list of possible choices.
        mode (str): The padding mode. Options are::

                "reflect", "zero", "constant", "periodic".

            This function defaults to "reflect".
        level (int): The number of desired scales.
            Defaults to None.


    Raises:
        ValueError: If the input is not a batched 1d-signal.

    Returns:
        List[Union[torch.Tensor, Dict[str, torch.Tensor]]]:
            The transformed signal.

    Example:
        >>> import torch
        >>> import ptwt
        >>> data = torch.randn(5,10)
        >>> coeff = ptwt.fswavedec(data, "haar", level=2)

    """
    if len(input.shape) == 1:
        input = input.unsqueeze(0)
    if len(input.shape) != 2:
        raise ValueError("Batched 1d inputs required for a 1d transform.")
    return _separable_conv_wavedecn(input, wavelet, mode, level)


[docs]def fswavedec2( input: torch.Tensor, wavelet: Union[str, pywt.Wavelet], mode: str = "reflect", level: Optional[int] = None, ) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a fully separable 2D-padded analysis wavelet transform. Args: input (torch.Tensor): An input signal of shape [batch, height, width]. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of ``pywt.wavelist(kind="discrete")`` for a list of possible choices. mode (str): The padding mode. Options are:: "reflect", "zero", "constant", "periodic". This function defaults to "reflect". level (int): The number of desired scales. Defaults to None. Raises: ValueError: If the input is not a batched 2d-signal. Returns: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: A list with the lll coefficients and dictionaries with the filter order strings:: ("ad", "da", "dd") as keys. With a for the low pass or approximation filter and d for the high-pass or detail filter. Example: >>> import torch >>> import ptwt >>> data = torch.randn(5, 10, 10) >>> coeff = ptwt.fswavedec2(data, "haar", level=2) """ if len(input.shape) == 2: input = input.unsqueeze(0) if len(input.shape) != 3: raise ValueError("Batched 2d inputs required for a 2d transform.") return _separable_conv_wavedecn(input, wavelet, mode, level)
[docs]def fswavedec3( input: torch.Tensor, wavelet: Union[str, pywt.Wavelet], mode: str = "reflect", level: Optional[int] = None, ) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: """Compute a fully separable 3D-padded analysis wavelet transform. Args: input (torch.Tensor): An input signal of shape [batch, depth, height, width]. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of ``pywt.wavelist(kind="discrete")`` for a list of possible choices. mode (str): The padding mode. Options are:: "reflect", "zero", "constant", "periodic". This function defaults to "reflect". level (int): The number of desired scales. Defaults to None. Raises: ValueError: If the input is not a batched 3d-signal. Returns: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]: A list with the lll coefficients and dictionaries with the filter order strings:: ("aad", "ada", "add", "daa", "dad", "dda", "ddd") as keys. With a for the low pass or approximation filter and d for the high-pass or detail filter. Example: >>> import torch >>> import ptwt >>> data = torch.randn(5, 10, 10, 10) >>> coeff = ptwt.fswavedec3(data, "haar", level=2) """ if len(input.shape) == 3: input = input.unsqueeze(0) if len(input.shape) != 4: raise ValueError("Batched 3d inputs required for a 3d transform.") return _separable_conv_wavedecn(input, wavelet, mode, level)
def _fswaverec( coeff_list: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[str, pywt.Wavelet], ) -> torch.Tensor: """Compute a fully separable 1D-padded synthesis wavelet transform. Args: coeff_list (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): The wavelet coefficients as computed by `fswavedec`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. Returns: torch.Tensor: A reconstruction of the signal encoded in the wavelet coefficients. Example: >>> import torch >>> import ptwt >>> data = torch.randn(5,10) >>> coeff = ptwt.fswavedec(data, "haar", level=2) >>> rec = ptwt.fswaverec(coeff, "haar") """ return _separable_conv_waverecn(coeff_list, wavelet)
[docs]def fswaverec2( coeff_list: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[str, pywt.Wavelet], ) -> torch.Tensor: """Compute a fully separable 2D-padded synthesis wavelet transform. Args: coeff_list (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): The wavelet coefficients as computed by `fswavedec2`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. Returns: torch.Tensor: A reconstruction of the signal encoded in the wavelet coefficients. Example: >>> import torch >>> import ptwt >>> data = torch.randn(5, 10, 10) >>> coeff = ptwt.fswavedec2(data, "haar", level=2) >>> rec = ptwt.fswaverec2(coeff, "haar") """ return _separable_conv_waverecn(coeff_list, wavelet)
[docs]def fswaverec3( coeff_list: List[Union[torch.Tensor, Dict[str, torch.Tensor]]], wavelet: Union[str, pywt.Wavelet], ) -> torch.Tensor: """Compute a fully separable 3D-padded synthesis wavelet transform. Args: coeff_list (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]): The wavelet coefficients as computed by `fswavedec3`. wavelet (Union[str, pywt.Wavelet]): The wavelet to use for the synthesis transform. Returns: torch.Tensor: A reconstruction of the signal encoded in the wavelet coefficients. Example: >>> import torch >>> import ptwt >>> data = torch.randn(5, 10, 10, 10) >>> coeff = ptwt.fswavedec3(data, "haar", level=2) >>> rec = ptwt.fswaverec3(coeff, "haar") """ return _separable_conv_waverecn(coeff_list, wavelet)