Source code for ptwt.conv_transform_2

"""This module implements two-dimensional padded wavelet transforms.

The implementation relies on torch.nn.functional.conv2d and
torch.nn.functional.conv_transpose2d under the hood.
"""

from __future__ import annotations

from typing import Optional, Union

import pywt
import torch

from ._util import (
    _adjust_padding_at_reconstruction,
    _check_same_device_dtype,
    _get_filter_tensors,
    _get_len,
    _get_pad,
    _outer,
    _pad_symmetric,
    _postprocess_coeffs,
    _postprocess_tensor,
    _preprocess_coeffs,
    _preprocess_tensor,
    _translate_boundary_strings,
)
from .constants import BoundaryMode, Wavelet, WaveletCoeff2d, WaveletDetailTuple2d

__all__ = ["wavedec2", "waverec2"]


def _construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
    """Construct two-dimensional filters using outer products.

    Args:
        lo (torch.Tensor): Low-pass input filter.
        hi (torch.Tensor): High-pass input filter

    Returns:
        Stacked 2d-filters of dimension

        [filt_no, 1, height, width].

        The four filters are ordered ll, lh, hl, hh.

    """
    ll = _outer(lo, lo)
    lh = _outer(hi, lo)
    hl = _outer(lo, hi)
    hh = _outer(hi, hi)
    filt = torch.stack([ll, lh, hl, hh], 0)
    filt = filt.unsqueeze(1)
    return filt


def _fwt_pad2(
    data: torch.Tensor,
    wavelet: Union[Wavelet, str],
    *,
    mode: Optional[BoundaryMode] = None,
    padding: Optional[tuple[int, int, int, int]] = None,
) -> torch.Tensor:
    """Pad data for the 2d FWT.

    This function pads along the last two axes.

    Args:
        data (torch.Tensor): Input data with 4 dimensions.
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet.
            Refer to the output from ``pywt.wavelist(kind='discrete')``
            for possible choices.
        mode: The desired padding mode for extending the signal along the edges.
            Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
        padding (tuple[int, int, int, int], optional): A tuple
            (padl, padr, padt, padb) with the number of padded values
            on the left, right, top and bottom side of the last two
            axes of `data`. If None, the padding values are computed based
            on the signal shape and the wavelet length. Defaults to None.

    Returns:
        The padded output tensor.

    """
    if mode is None:
        mode = "reflect"
    pytorch_mode = _translate_boundary_strings(mode)

    if padding is None:
        padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
        padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
    else:
        padl, padr, padt, padb = padding
    if pytorch_mode == "symmetric":
        data_pad = _pad_symmetric(data, [(padt, padb), (padl, padr)])
    else:
        data_pad = torch.nn.functional.pad(
            data, [padl, padr, padt, padb], mode=pytorch_mode
        )
    return data_pad


[docs] def wavedec2( data: torch.Tensor, wavelet: Union[Wavelet, str], *, mode: BoundaryMode = "reflect", level: Optional[int] = None, axes: tuple[int, int] = (-2, -1), ) -> WaveletCoeff2d: r"""Compute the two-dimensional fast wavelet transformation. This function relies on two-dimensional convolutions. Outer products allow the construction of 2d filters :math:`\mathbf{h}_k` for :math:`k\in\{a, h, v, d\}` from the 1d filter pair of the wavelet where :math:`a` denotes approximation, :math:`h` horizontal details, :math:`v` vertical details, and :math:`d` diagonal details. See the :ref:`FWT intro <sec-fwt-2d>`. The coefficients on level :math:`s` are calculated iteratively as .. math:: \mathbf{c}_{k, s} = \mathbf{c}_{a, s-1} *_2 \mathbf{h}_k \quad \text{for $k \in \{a, h, v, d\}$} with :math:`\mathbf{c}_{a, 0} = \mathbf{x}_0` the original input image. :math:`*_2` indicates two dimensional-convolution. Set the `level` argument to choose the largest scale. Args: data (torch.Tensor): The input data tensor with at least two dimensions. By default, the last two axes are transformed. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. mode: The desired padding mode for extending the signal along the edges. See :data:`ptwt.constants.BoundaryMode`. Defaults to ``reflect``. level (int, optional): The maximum decomposition level. If None, the level is computed based on the signal shape. Defaults to None. axes (tuple[int, int]): Compute the transform over these axes of the `data` tensor. Defaults to (-2, -1). Returns: A tuple containing the wavelet coefficients in pywt order, see :data:`ptwt.constants.WaveletCoeff2d`. Example: >>> import ptwt, torch >>> from scipy import datasets >>> data = torch.tensor(datasets.face(), dtype=torch.float64) >>> # permute [H, W, C] -> [C, H, W] >>> data = data.permute(2, 0, 1) >>> # compute the FWT coefficients >>> coefficients = ptwt.wavedec2(data, "haar", level=2, mode="zero") """ data, ds = _preprocess_tensor(data, ndim=2, axes=axes) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype ) dec_filt = _construct_2d_filt(lo=dec_lo, hi=dec_hi) if level is None: level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet) result_lst: list[WaveletDetailTuple2d] = [] res_ll = data for _ in range(level): res_ll = _fwt_pad2(res_ll, wavelet, mode=mode) res = torch.nn.functional.conv2d(res_ll, dec_filt, stride=2) res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1) to_append = WaveletDetailTuple2d( res_lh.squeeze(1), res_hl.squeeze(1), res_hh.squeeze(1) ) result_lst.append(to_append) result_lst.reverse() res_ll = res_ll.squeeze(1) result: WaveletCoeff2d = res_ll, *result_lst result = _postprocess_coeffs(result, ndim=2, ds=ds, axes=axes) return result
[docs] def waverec2( coeffs: WaveletCoeff2d, wavelet: Union[Wavelet, str], axes: tuple[int, int] = (-2, -1), ) -> torch.Tensor: """Reconstruct a 2d signal from wavelet coefficients. Args: coeffs: The wavelet coefficient tuple produced by :func:`ptwt.wavedec2`. See :data:`ptwt.constants.WaveletCoeff2d` wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. axes (tuple[int, int]): Compute the transform over these axes of the `data` tensor. Defaults to (-2, -1). Returns: The reconstructed signal tensor. Its shape depends on the shape of the input to :func:`ptwt.wavedec2`. Raises: ValueError: If `coeffs` is not in a shape as returned from :func:`ptwt.wavedec2` or if the dtype is not supported or if the provided axes input has length other than two or if the same axes it repeated twice. Example: >>> import ptwt, torch >>> from scipy import datasets >>> data = torch.tensor(datasets.face(), dtype=torch.float64) >>> # permute [H, W, C] -> [C, H, W] >>> data = data.permute(2, 0, 1) >>> # compute the forward fwt coefficients and the reconstruction >>> coefficients = ptwt.wavedec2(data, "haar", level=2, mode="constant") >>> reconstruction = ptwt.waverec2(coefficients, "haar") """ coeffs, ds = _preprocess_coeffs(coeffs, ndim=2, axes=axes) torch_device, torch_dtype = _check_same_device_dtype(coeffs) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype ) filt_len = rec_lo.shape[-1] rec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi) res_ll = coeffs[0] for c_pos, coeff_tuple in enumerate(coeffs[1:]): if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3: raise ValueError( f"Unexpected detail coefficient type: {type(coeff_tuple)}. Detail " "coefficients must be a 3-tuple of tensors as returned by " "wavedec2." ) curr_shape = res_ll.shape for coeff in coeff_tuple: if coeff.shape != curr_shape: raise ValueError( "All coefficients on each level must have the same shape" ) res_lh, res_hl, res_hh = coeff_tuple res_ll = torch.stack([res_ll, res_lh, res_hl, res_hh], 1) res_ll = torch.nn.functional.conv_transpose2d( res_ll, rec_filt, stride=2 ).squeeze(1) # remove the padding padl = (2 * filt_len - 3) // 2 padr = (2 * filt_len - 3) // 2 padt = (2 * filt_len - 3) // 2 padb = (2 * filt_len - 3) // 2 if c_pos < len(coeffs) - 2: padr, padl = _adjust_padding_at_reconstruction( res_ll.shape[-1], coeffs[c_pos + 2][0].shape[-1], padr, padl ) padb, padt = _adjust_padding_at_reconstruction( res_ll.shape[-2], coeffs[c_pos + 2][0].shape[-2], padb, padt ) if padt > 0: res_ll = res_ll[..., padt:, :] if padb > 0: res_ll = res_ll[..., :-padb, :] if padl > 0: res_ll = res_ll[..., padl:] if padr > 0: res_ll = res_ll[..., :-padr] res_ll = _postprocess_tensor(res_ll, ndim=2, ds=ds, axes=axes) return res_ll