Source code for ptwt.continuous_transform

"""PyTorch compatible cwt code.

This module is based on pywt's cwt implementation.
"""

from typing import Any, Tuple, Union

import numpy as np
import torch
from pywt import ContinuousWavelet, DiscreteContinuousWavelet, Wavelet
from pywt._functions import scale2frequency
from torch.fft import fft, ifft


def _next_fast_len(n: int) -> int:
    """Round up size to the nearest power of two.

    Given a number of samples `n`, returns the next power of two
    following this number to take advantage of FFT speedup.
    This fallback is less efficient than `scipy.fftpack.next_fast_len`
    """
    return int(2 ** np.ceil(np.log2(n)))


[docs] def cwt( data: torch.Tensor, scales: Union[np.ndarray, torch.Tensor], # type: ignore wavelet: Union[ContinuousWavelet, str], sampling_period: float = 1.0, ) -> Tuple[torch.Tensor, np.ndarray]: # type: ignore """Compute the single-dimensional continuous wavelet transform. This function is a PyTorch port of pywt.cwt as found at: https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py Args: data (torch.Tensor): The input tensor of shape [batch_size, time]. scales (torch.Tensor or np.array): The wavelet scales to use. One can use ``f = pywt.scale2frequency(wavelet, scale)/sampling_period`` to determine what physical frequency, ``f``. Here, ``f`` is in hertz when the ``sampling_period`` is given in seconds. wavelet (ContinuousWavelet or str): The continuous wavelet to work with. sampling_period (float): Sampling period for the frequencies output (optional). The values computed for ``coefs`` are independent of the choice of ``sampling_period`` (i.e. ``scales`` is not scaled by the sampling period). Raises: ValueError: If a scale is too small for the input signal. Returns: Tuple[torch.Tensor, np.ndarray]: The first tuple-element contains the transformation matrix of shape [scales, batch, time]. The second element contains an array with frequency information. Example: >>> import torch, ptwt >>> import numpy as np >>> import scipy.signal as signal >>> t = np.linspace(-2, 2, 800, endpoint=False) >>> sig = signal.chirp(t, f0=1, f1=12, t1=2, method="linear") >>> widths = np.arange(1, 31) >>> cwtmatr, freqs = ptwt.cwt( >>> torch.from_numpy(sig), widths, "mexh", sampling_period=(4 / 800) * np.pi >>> ) """ # accept array_like input; make a copy to ensure a contiguous array if not isinstance( wavelet, (ContinuousWavelet, Wavelet, _DifferentiableContinuousWavelet) ): wavelet = DiscreteContinuousWavelet(wavelet) if type(scales) is torch.Tensor: scales = scales.numpy() elif np.isscalar(scales): scales = np.array([scales]) if isinstance(wavelet, torch.nn.Module): if data.is_cuda: wavelet.cuda() precision = 10 int_psi, x = _integrate_wavelet(wavelet, precision=precision) if type(wavelet) is ContinuousWavelet: int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi int_psi = torch.tensor(int_psi, device=data.device) elif isinstance(wavelet, torch.nn.Module): int_psi = torch.conj(int_psi) if wavelet.complex_cwt else int_psi else: int_psi = torch.tensor(int_psi, device=data.device) x = torch.tensor(x, device=data.device) # convert int_psi, x to the same precision as the data # x = np.asarray(x, dtype=data.cpu().numpy().real.dtype) size_scale0 = -1 fft_data = None out = [] for scale in scales: step = x[1] - x[0] j = torch.arange( scale * (x[-1] - x[0]) + 1, device=data.device, dtype=data.dtype ) / (scale * step) j = torch.floor(j).type(torch.long) if j[-1] >= len(int_psi): # j = np.extract(j < len(int_psi), j) j = torch.masked_select(j, j < len(int_psi)) int_psi_scale = int_psi[j].flip(0) # The padding is selected for: # - optimal FFT complexity # - to be larger than the two signals length to avoid circular # convolution size_scale = _next_fast_len(data.shape[-1] + len(int_psi_scale) - 1) if size_scale != size_scale0: # Must recompute fft_data when the padding size changes. fft_data = fft(data, size_scale, dim=-1) size_scale0 = size_scale fft_wav = fft(int_psi_scale, size_scale, dim=-1) conv = ifft(fft_wav * fft_data, dim=-1) conv = conv[..., : data.shape[-1] + len(int_psi_scale) - 1] coef = -np.sqrt(scale) * torch.diff(conv, dim=-1) # transform axis is always -1 d = (coef.shape[-1] - data.shape[-1]) / 2.0 if d > 0: coef = coef[..., int(np.floor(d)) : -int(np.ceil(d))] elif d < 0: raise ValueError("Selected scale of {} too small.".format(scale)) out.append(coef) out_tensor = torch.stack(out) if type(wavelet) is Wavelet: out_tensor = out_tensor.real elif isinstance(wavelet, _DifferentiableContinuousWavelet): out_tensor = out_tensor if wavelet.complex_cwt else out_tensor.real wavelet.cpu() else: out_tensor = out_tensor if wavelet.complex_cwt else out_tensor.real with torch.no_grad(): frequencies = scale2frequency(wavelet, scales, precision) if np.isscalar(frequencies): frequencies = np.array([frequencies]) frequencies /= sampling_period if isinstance(wavelet, _DifferentiableContinuousWavelet): if data.is_cuda: wavelet.cuda() return out_tensor, frequencies
def _integrate_wavelet( wavelet: Union[ContinuousWavelet, str], precision: int = 8 ) -> Any: """ Integrate `psi` wavelet function from -Inf to x using rectangle integration. Modified to enable gradient flow through the cwt. Ported from: https://github.com/PyWavelets/pywt/blob/cef09e7f419aaf4c39b9f778bdc2d54b32fd7337/pywt/_functions.py#L60 Parameters ---------- wavelet: Wavelet instance or str Wavelet to integrate. If a string, should be the name of a wavelet. precision : int, optional Precision that will be used for wavelet function approximation computed with the wavefun(level=precision) Wavelet's method (default: 8). Returns ------- [int_psi, x] : for orthogonal wavelets [int_psi_d, int_psi_r, x] : for other wavelets Examples -------- >>> from pywt import Wavelet, _integrate_wavelet >>> wavelet1 = Wavelet('db2') >>> [int_psi, x] = _integrate_wavelet(wavelet1, precision=5) >>> wavelet2 = Wavelet('bior1.3') >>> [int_psi_d, int_psi_r, x] = _integrate_wavelet(wavelet2, precision=5) """ def _integrate( arr: Union[np.ndarray, torch.Tensor], # type: ignore step: Union[np.ndarray, torch.Tensor], # type: ignore ) -> Union[np.ndarray, torch.Tensor]: # type: ignore if type(arr) is np.ndarray: integral = np.cumsum(arr) elif type(arr) is torch.Tensor: integral = torch.cumsum(arr, -1) else: raise TypeError("Only ndarrays or tensors are integratable.") integral *= step return integral if type(wavelet) is str: wavelet = DiscreteContinuousWavelet(wavelet) elif not isinstance( wavelet, (Wavelet, ContinuousWavelet, _DifferentiableContinuousWavelet) ): wavelet = DiscreteContinuousWavelet(wavelet) functions_approximations = wavelet.wavefun(precision) if len(functions_approximations) == 2: # continuous wavelet psi, x = functions_approximations step = x[1] - x[0] return _integrate(psi, step), x elif len(functions_approximations) == 3: # orthogonal wavelet _, psi, x = functions_approximations step = x[1] - x[0] return _integrate(psi, step), x else: # biorthogonal wavelet _, psi_d, _, psi_r, x = functions_approximations step = x[1] - x[0] return _integrate(psi_d, step), _integrate(psi_r, step), x class _WaveletParameter(torch.nn.Parameter): pass class _DifferentiableContinuousWavelet( torch.nn.Module, ContinuousWavelet # type: ignore ): """A base class for learnable Continuous Wavelets.""" def __init__(self, name: str): """Create a trainable shannon wavelet.""" super().__init__() super(ContinuousWavelet, self).__init__() self.dtype = torch.float64 # Use torch nn parameter self.bandwidth_par = _WaveletParameter( torch.sqrt(torch.tensor(self.bandwidth_frequency, dtype=self.dtype)), requires_grad=True, ) self.center_par = _WaveletParameter( torch.sqrt(torch.tensor(self.center_frequency, dtype=self.dtype)), requires_grad=True, ) def __call__(self, grid_values: torch.Tensor) -> torch.Tensor: """Return numerical values for the wavelet on a grid.""" raise NotImplementedError @property def bandwidth(self) -> torch.Tensor: """Square the bandwith parameter to ensure positive values.""" return self.bandwidth_par * self.bandwidth_par @property def center(self) -> torch.Tensor: """Square the bandwith parameter to ensure positive values.""" return self.center_par * self.center_par def wavefun( self, precision: int, dtype: torch.dtype = torch.float64 ) -> Tuple[torch.Tensor, torch.Tensor]: """Define a grid and evaluate the wavelet on it.""" length = 2**precision # load the bounds from untyped pywt code. lower_bound: float = float(self.lower_bound) upper_bound: float = float(self.upper_bound) grid = torch.linspace( lower_bound, upper_bound, length, dtype=dtype, device=self.bandwidth_par.device, ) return self(grid), grid class _ShannonWavelet(_DifferentiableContinuousWavelet): """A differentiable Shannon wavelet.""" def __call__(self, grid_values: torch.Tensor) -> torch.Tensor: """Return numerical values for the wavelet on a grid.""" shannon = ( torch.sqrt(self.bandwidth) * ( torch.sin(torch.pi * self.bandwidth * grid_values) / (torch.pi * self.bandwidth * grid_values) ) * torch.exp(1j * 2 * torch.pi * self.center * grid_values) ) return shannon class _ComplexMorletWavelet(_DifferentiableContinuousWavelet): """A differentiable Shannon wavelet.""" def __call__(self, grid_values: torch.Tensor) -> torch.Tensor: """Return numerical values for the wavelet on a grid.""" morlet = ( 1.0 / torch.sqrt(torch.pi * self.bandwidth) * torch.exp(-(grid_values**2) / self.bandwidth) * torch.exp(1j * 2 * torch.pi * self.center * grid_values) ) return morlet