Source code for ptwt.conv_transform_3

"""Code for three dimensional padded transforms.

The functions here are based on torch.nn.functional.conv3d and it's transpose.
"""

from __future__ import annotations

from typing import Optional, Union, cast

import pywt
import torch

from ._util import (
    AxisHint,
    _adjust_padding_at_reconstruction,
    _as_wavelet,
    _check_same_device_dtype,
    _construct_3d_filt,
    _get_filter_tensors,
    _get_padding_n,
    _group_for_symmetric,
    _pad_symmetric,
    _postprocess_coeffs,
    _postprocess_tensor,
    _preprocess_coeffs,
    _preprocess_deconstruction,
    _translate_boundary_strings,
)
from .constants import BoundaryMode, Wavelet, WaveletCoeffNd, WaveletDetailDict

__all__ = ["wavedec3", "waverec3"]


def _fwt_pad3(
    data: torch.Tensor,
    wavelet: Union[Wavelet, str],
    *,
    mode: BoundaryMode,
    padding: Optional[tuple[int, int, int, int, int, int]] = None,
) -> torch.Tensor:
    """Pad data for the 3d-FWT.

    This function pads the last three axes.

    Args:
        data (torch.Tensor): Input data with 4 dimensions.
        wavelet (Wavelet or str): A pywt wavelet compatible object or
            the name of a pywt wavelet.
            Refer to the output from ``pywt.wavelist(kind='discrete')``
            for possible choices.
        mode: The desired padding mode for extending the signal along the edges.
            See :data:`ptwt.constants.BoundaryMode`.
        padding (tuple[int, int, int, int, int, int], optional): A tuple
            (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back)
            with the number of padded values on the respective side of the
            last three axes of `data`.
            If None, the padding values are computed based
            on the signal shape and the wavelet length. Defaults to None.

    Returns:
        The padded output tensor.
    """
    pytorch_mode = _translate_boundary_strings(mode)

    if padding is None:
        padding = cast(
            tuple[int, int, int, int, int, int], _get_padding_n(data, wavelet, n=3)
        )
    if pytorch_mode == "symmetric":
        data_pad = _pad_symmetric(data, _group_for_symmetric(padding))
    else:
        data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode)
    return data_pad


[docs] def wavedec3( data: torch.Tensor, wavelet: Union[Wavelet, str], *, mode: BoundaryMode = "zero", level: Optional[int] = None, axes: tuple[int, int, int] = (-3, -2, -1), ) -> WaveletCoeffNd: """Compute the three-dimensional fast wavelet transformation. Args: data (torch.Tensor): The input data tensor with at least three dimensions. By default, the last three axes are transformed. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. mode: The desired padding mode for extending the signal along the edges. See :data:`ptwt.constants.BoundaryMode`. Defaults to ``zero``. level (int, optional): The maximum decomposition level. If None, the level is computed based on the signal shape. Defaults to None. axes (tuple[int, int, int]): Compute the transform over these axes of the `data` tensor. Defaults to (-3, -2, -1). Returns: A tuple containing the wavelet coefficients, see :data:`ptwt.constants.WaveletCoeffNd`. Example: >>> import ptwt, torch >>> data = torch.randn(5, 16, 16, 16) >>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect") """ wavelet = _as_wavelet(wavelet) data, ds, dec_lo, dec_hi, dec_filt = _preprocess_deconstruction( data, wavelet, axes=axes, ndim=3 ) if level is None: level = pywt.dwtn_max_level( [data.shape[-1], data.shape[-2], data.shape[-3]], wavelet ) result_lst: list[WaveletDetailDict] = [] res_lll = data for _ in range(level): if len(res_lll.shape) == 4: res_lll = res_lll.unsqueeze(1) res_lll = _fwt_pad3(res_lll, wavelet, mode=mode) res = torch.nn.functional.conv3d(res_lll, dec_filt, stride=2) res_lll, res_llh, res_lhl, res_lhh, res_hll, res_hlh, res_hhl, res_hhh = [ sr.squeeze(1) for sr in torch.split(res, 1, 1) ] result_lst.append( { "aad": res_llh, "ada": res_lhl, "add": res_lhh, "daa": res_hll, "dad": res_hlh, "dda": res_hhl, "ddd": res_hhh, } ) result_lst.reverse() coeffs: WaveletCoeffNd = res_lll, *result_lst return _postprocess_coeffs(coeffs, ndim=3, ds=ds, axes=axes)
[docs] def waverec3( coeffs: WaveletCoeffNd, wavelet: Union[Wavelet, str], *, axes: AxisHint = None, ) -> torch.Tensor: """Reconstruct a 3d signal from wavelet coefficients. Args: coeffs: The wavelet coefficient tuple produced by :func:`ptwt.wavedec3`, see :data:`ptwt.constants.WaveletCoeffNd`. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. axes : Compute the transform over these axes. If none, the last 3 are used. Returns: The reconstructed signal tensor. Its shape depends on the shape of the input to :func:`ptwt.wavedec3`. Raises: ValueError: If `coeffs` is not in a shape as returned from :func:`ptwt.wavedec3` or if the dtype is not supported or if the provided axes input has length other than three or if the same axes it repeated three. Example: >>> import ptwt, torch >>> data = torch.randn(5, 16, 16, 16) >>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect") >>> reconstruction = ptwt.waverec3(transformed, "haar") """ coeffs, ds = _preprocess_coeffs(coeffs, ndim=3, axes=axes) torch_device, torch_dtype = _check_same_device_dtype(coeffs) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, device=torch_device, dtype=torch_dtype ) filt_len = rec_lo.shape[-1] rec_filt = _construct_3d_filt(lo=rec_lo, hi=rec_hi) res_lll = coeffs[0] coeff_dicts = coeffs[1:] for c_pos, coeff_dict in enumerate(coeff_dicts): if not isinstance(coeff_dict, dict) or len(coeff_dict) != 7: raise ValueError( f"Unexpected detail coefficient type: {type(coeff_dict)}. Detail " "coefficients must be a dict containing 7 tensors as returned by " "wavedec3." ) for coeff in coeff_dict.values(): if res_lll.shape != coeff.shape: raise ValueError( "All coefficients on each level must have the same shape" ) res_lll = torch.stack( [ res_lll, coeff_dict["aad"], coeff_dict["ada"], coeff_dict["add"], coeff_dict["daa"], coeff_dict["dad"], coeff_dict["dda"], coeff_dict["ddd"], ], 1, ) res_lll = torch.nn.functional.conv_transpose3d(res_lll, rec_filt, stride=2) res_lll = res_lll.squeeze(1) # remove the padding padfr = (2 * filt_len - 3) // 2 padba = (2 * filt_len - 3) // 2 padl = (2 * filt_len - 3) // 2 padr = (2 * filt_len - 3) // 2 padt = (2 * filt_len - 3) // 2 padb = (2 * filt_len - 3) // 2 if c_pos + 1 < len(coeff_dicts): padr, padl = _adjust_padding_at_reconstruction( res_lll.shape[-1], coeff_dicts[c_pos + 1]["aad"].shape[-1], padr, padl ) padb, padt = _adjust_padding_at_reconstruction( res_lll.shape[-2], coeff_dicts[c_pos + 1]["aad"].shape[-2], padb, padt ) padba, padfr = _adjust_padding_at_reconstruction( res_lll.shape[-3], coeff_dicts[c_pos + 1]["aad"].shape[-3], padba, padfr ) if padt > 0: res_lll = res_lll[..., padt:, :] if padb > 0: res_lll = res_lll[..., :-padb, :] if padl > 0: res_lll = res_lll[..., padl:] if padr > 0: res_lll = res_lll[..., :-padr] if padfr > 0: res_lll = res_lll[..., padfr:, :, :] if padba > 0: res_lll = res_lll[..., :-padba, :, :] return _postprocess_tensor(res_lll, ndim=3, ds=ds, axes=axes)