Source code for ptwt.stationary_transform

"""This module implements stationary wavelet transforms."""

from collections.abc import Sequence
from typing import Optional, Union

import pywt
import torch
import torch.nn.functional as F  # noqa:N812

from ._util import Wavelet, _as_wavelet, _unfold_axes
from .conv_transform import (
    _get_filter_tensors,
    _postprocess_result_list_dec1d,
    _preprocess_result_list_rec1d,
    _preprocess_tensor_dec1d,
)


def _circular_pad(x: torch.Tensor, padding_dimensions: Sequence[int]) -> torch.Tensor:
    """Pad a tensor in circular mode, more than once if needed."""
    trailing_dimension = x.shape[-1]

    # if every padding dimension is smaller than or equal the trailing dimension,
    # we do not need to manually wrap
    if not any(
        padding_dimension > trailing_dimension
        for padding_dimension in padding_dimensions
    ):
        return F.pad(x, padding_dimensions, mode="circular")

    # repeat to pad at maximum trailing dimensions until all padding dimensions are zero
    while any(padding_dimension > 0 for padding_dimension in padding_dimensions):
        # reduce every padding dimension to at maximum trailing dimension width
        reduced_padding_dimensions = [
            min(trailing_dimension, padding_dimension)
            for padding_dimension in padding_dimensions
        ]
        # pad using reduced dimensions,
        # which will never throw the circular wrap error
        x = F.pad(x, reduced_padding_dimensions, mode="circular")
        # remove the pad width that was just padded, and repeat
        # if any pad width is greater than zero
        padding_dimensions = [
            max(padding_dimension - trailing_dimension, 0)
            for padding_dimension in padding_dimensions
        ]

    return x


[docs] def swt( data: torch.Tensor, wavelet: Union[Wavelet, str], level: Optional[int] = None, axis: int = -1, ) -> list[torch.Tensor]: """Compute a multilevel 1d stationary wavelet transform. This fuctions is equivalent to pywt's swt with `trim_approx=True` and `norm=False`. Args: data (torch.Tensor): The input data of shape ``[batch_size, time]``. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. level (int, optional): The number of levels to compute. axis (int): The axis to transform along. Defaults to the last axis. Returns: Same as wavedec. Equivalent to pywt.swt with trim_approx=True. Raises: ValueError: Is the axis argument is not an integer. """ if axis != -1: if isinstance(axis, int): data = data.swapaxes(axis, -1) else: raise ValueError("swt transforms a single axis only.") data, ds = _preprocess_tensor_dec1d(data) dec_lo, dec_hi, _, _ = _get_filter_tensors( wavelet, flip=True, device=data.device, dtype=data.dtype ) filt_len = dec_lo.shape[-1] filt = torch.stack([dec_lo, dec_hi], 0) if level is None: level = pywt.swt_max_level(data.shape[-1]) result_list = [] res_lo = data for current_level in range(level): dilation = 2**current_level padl, padr = dilation * (filt_len // 2 - 1), dilation * (filt_len // 2) res_lo = _circular_pad(res_lo, [padl, padr]) res = torch.nn.functional.conv1d(res_lo, filt, stride=1, dilation=dilation) res_lo, res_hi = torch.split(res, 1, 1) # Trim_approx == False # result_list.append((res_lo.squeeze(1), res_hi.squeeze(1))) result_list.append(res_hi.squeeze(1)) result_list.append(res_lo.squeeze(1)) result_list = _postprocess_result_list_dec1d(result_list, ds, axis) return result_list[::-1]
[docs] def iswt( coeffs: Sequence[torch.Tensor], wavelet: Union[pywt.Wavelet, str], axis: Optional[int] = -1, ) -> torch.Tensor: """Invert a 1d stationary wavelet transform. Args: coeffs (Sequence[torch.Tensor]): The coefficients as computed by the swt function. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet, as used in the forward transform. axis (int, optional): The axis the forward trasform was computed over. Defaults to -1. Returns: A reconstruction of the original swt input. Raises: ValueError: If the axis argument is not an integer. """ if axis != -1: swap = [] if isinstance(axis, int): for coeff in coeffs: swap.append(coeff.swapaxes(axis, -1)) coeffs = swap else: raise ValueError("iswt transforms a single axis only.") coeffs, ds = _preprocess_result_list_rec1d(coeffs) wavelet = _as_wavelet(wavelet) _, _, rec_lo, rec_hi = _get_filter_tensors( wavelet, flip=False, dtype=coeffs[0].dtype, device=coeffs[0].device ) filt_len = rec_lo.shape[-1] rec_filt = torch.stack([rec_lo, rec_hi], 0) res_lo = coeffs[0] for c_pos, res_hi in enumerate(coeffs[1:]): dilation = 2 ** (len(coeffs[1:]) - c_pos - 1) res_lo = torch.stack([res_lo, res_hi], 1) padl, padr = dilation * (filt_len // 2), dilation * (filt_len // 2 - 1) # res_lo = torch.nn.functional.pad(res_lo, (padl, padr), mode="circular") res_lo_pad = _circular_pad(res_lo, (padl, padr)) res_lo = torch.mean( torch.nn.functional.conv_transpose1d( res_lo_pad, rec_filt, dilation=dilation, groups=2, padding=(padl + padr) ), 1, ) if len(ds) == 1: res_lo = res_lo.squeeze(0) elif len(ds) > 2: res_lo = _unfold_axes(res_lo, ds, 1) if axis != -1: res_lo = res_lo.swapaxes(axis, -1) return res_lo