The ptwt package

Contents

The ptwt package#

ptwt.conv_transform module#

Fast wavelet transformations based on torch.nn.functional.conv1d and its transpose.

This module treats boundaries with edge-padding.

ptwt.conv_transform.wavedec(data: Tensor, wavelet: Wavelet | str, *, mode: Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', level: int | None = None, axis: int = -1) list[Tensor][source]#

Compute the analysis (forward) 1d fast wavelet transform.

The transformation relies on convolution operations with filter pairs.

\[x_s * h_k = c_{k,s+1}\]

Where \(x_s\) denotes the input at scale \(s\), with \(x_0\) equal to the original input. \(h_k\) denotes the convolution filter, with \(k \in {A, D}\), where \(A\) for approximation and \(D\) for detail. The processes uses approximation coefficients as inputs for higher scales. Set the level argument to choose the largest scale.

Parameters:
  • data (torch.Tensor) – The input time series, By default the last axis is transformed.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Please consider the output from pywt.wavelist(kind='discrete') for possible choices.

  • mode – The desired padding mode for extending the signal along the edges. Defaults to “reflect”. See ptwt.constants.BoundaryMode.

  • level (int) – The scale level to be computed. Defaults to None.

  • axis (int) – Compute the transform over this axis instead of the last one. Defaults to -1.

Returns:

A list:

[cA_s, cD_s, cD_s-1, …, cD2, cD1]

containing the wavelet coefficients. A denotes approximation and D detail coefficients.

Raises:

ValueError – If the dtype of the input data tensor is unsupported or if more than one axis is provided.

Example

>>> import torch
>>> import ptwt, pywt
>>> import numpy as np
>>> # generate an input of even length.
>>> data = np.array([0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> data_torch = torch.from_numpy(data.astype(np.float32))
>>> # compute the forward fwt coefficients
>>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
>>>              mode='zero', level=2)
ptwt.conv_transform.waverec(coeffs: Sequence[Tensor], wavelet: Wavelet | str, axis: int = -1) Tensor[source]#

Reconstruct a signal from wavelet coefficients.

Parameters:
  • coeffs (Sequence) – The wavelet coefficient sequence produced by wavedec.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • axis (int) – Transform this axis instead of the last one. Defaults to -1.

Returns:

The reconstructed signal tensor.

Raises:

ValueError – If the dtype of the coeffs tensor is unsupported or if the coefficients have incompatible shapes, dtypes or devices or if more than one axis is provided.

Example

>>> import torch
>>> import ptwt, pywt
>>> import numpy as np
>>> # generate an input of even length.
>>> data = np.array([0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> data_torch = torch.from_numpy(data.astype(np.float32))
>>> # invert the fast wavelet transform.
>>> ptwt.waverec(ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
>>>                           mode='zero', level=2),
>>>              pywt.Wavelet('haar'))

ptwt.conv_transform_2 module#

This module implements two-dimensional padded wavelet transforms.

The implementation relies on torch.nn.functional.conv2d and torch.nn.functional.conv_transpose2d under the hood.

ptwt.conv_transform_2.wavedec2(data: Tensor, wavelet: Wavelet | str, *, mode: Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', level: int | None = None, axes: tuple[int, int] = (-2, -1)) ptwt.constants.WaveletCoeff2d[source]#

Run a two-dimensional wavelet transformation.

This function relies on two-dimensional convolutions. Outer products allow the construction of 2D-filters from 1D filter arrays (see fwt-intro). It transforms the last two axes by default. This function computes

\[\mathbf{x}_s *_2 \mathbf{h}_k = \mathbf{c}_{k, s+1}\]

with \(k \in [a, h, v, d]\) and \(s \in \mathbb{N}_0\) the set of natural numbers, where \(\mathbf{x}_0\) is equal to the original input image \(\mathbf{X}\). \(*_2\) indicates two dimensional-convolution. Computations at subsequent scales work exclusively with approximation coefficients \(c_{a, s}\) as inputs. Setting the level argument allows choosing the largest scale.

Parameters:
  • data (torch.Tensor) – The input data tensor with any number of dimensions. By default 2d inputs are interpreted as [height, width], 3d inputs are interpreted as [batch_size, height, width]. 4d inputs are interpreted as [batch_size, channels, height, width]. The axes argument allows other interpretations.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • mode – The desired padding mode for extending the signal along the edges. Defaults to “reflect”. See ptwt.constants.BoundaryMode.

  • level (int) – The number of desired scales. Defaults to None.

  • axes (tuple[int, int]) – Compute the transform over these axes instead of the last two. Defaults to (-2, -1).

Returns:

A tuple containing the wavelet coefficients in pywt order, see ptwt.constants.WaveletCoeff2d.

Raises:

ValueError – If the dimensionality or the dtype of the input data tensor is unsupported or if the provided axes input has a length other than two.

Example

>>> import torch
>>> import ptwt, pywt
>>> import numpy as np
>>> from scipy import datasets
>>> face = np.transpose(datasets.face(),
>>>                     [2, 0, 1]).astype(np.float64)
>>> pytorch_face = torch.tensor(face) # try unsqueeze(0)
>>> coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
>>>                              level=2, mode="zero")
ptwt.conv_transform_2.waverec2(coeffs: ptwt.constants.WaveletCoeff2d, wavelet: Wavelet | str, axes: tuple[int, int] = (-2, -1)) Tensor[source]#

Reconstruct a signal from wavelet coefficients.

This function undoes the effect of the analysis or forward transform by running transposed convolutions.

Parameters:
  • coeffs (WaveletCoeff2d) – The wavelet coefficient tuple produced by wavedec2. See ptwt.constants.WaveletCoeff2d

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • axes (tuple[int, int]) – Compute the transform over these axes instead of the last two. Defaults to (-2, -1).

Returns:

The reconstructed signal tensor of shape [batch, height, width] or [batch, channel, height, width] depending on the input to wavedec2.

Raises:

ValueError – If coeffs is not in a shape as returned from wavedec2 or if the dtype is not supported or if the provided axes input has length other than two or if the same axes it repeated twice.

Example

>>> import ptwt, pywt, torch
>>> import numpy as np
>>> from scipy import datasets
>>> face = np.transpose(datasets.face(),
>>>                     [2, 0, 1]).astype(np.float64)
>>> pytorch_face = torch.tensor(face)
>>> coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
>>>                              level=2, mode="constant")
>>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))

ptwt.conv_transform_3 module#

Code for three dimensional padded transforms.

The functions here are based on torch.nn.functional.conv3d and it’s transpose.

ptwt.conv_transform_3.wavedec3(data: Tensor, wavelet: Wavelet | str, *, mode: Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'zero', level: int | None = None, axes: tuple[int, int, int] = (-3, -2, -1)) ptwt.constants.WaveletCoeffNd[source]#

Compute a three-dimensional wavelet transform.

Parameters:
  • data (torch.Tensor) – The input data. For example of shape [batch_size, length, height, width]

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • mode – The desired padding mode for extending the signal along the edges. Defaults to “zero”. See ptwt.constants.BoundaryMode.

  • level (Optional[int]) – The maximum decomposition level. This argument defaults to None.

  • axes (tuple[int, int, int]) – Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1).

Returns:

A tuple containing the wavelet coefficients, see ptwt.constants.WaveletCoeffNd.

Raises:

ValueError – If the input has fewer than three dimensions or if the dtype is not supported or if the provided axes input has length other than three.

Example

>>> import ptwt, torch
>>> data = torch.randn(5, 16, 16, 16)
>>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect")
ptwt.conv_transform_3.waverec3(coeffs: ptwt.constants.WaveletCoeffNd, wavelet: Wavelet | str, axes: tuple[int, int, int] = (-3, -2, -1)) Tensor[source]#

Reconstruct a signal from wavelet coefficients.

Parameters:
  • coeffs (WaveletCoeffNd) – The wavelet coefficient tuple produced by wavedec3, see ptwt.constants.WaveletCoeffNd.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • axes (tuple[int, int, int]) – Transform these axes instead of the last three. Defaults to (-3, -2, -1).

Returns:

The reconstructed four-dimensional signal tensor of shape [batch, depth, height, width].

Raises:

ValueError – If coeffs is not in a shape as returned from wavedec3 or if the dtype is not supported or if the provided axes input has length other than three or if the same axes it repeated three.

Example

>>> import ptwt, torch
>>> data = torch.randn(5, 16, 16, 16)
>>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect")
>>> reconstruction = ptwt.waverec3(transformed, "haar")

ptwt.packets module#

Compute analysis wavelet packet representations.

class ptwt.packets.WaveletPacket(data: Tensor | None, wavelet: Wavelet | str, mode: Literal['boundary'] | Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', maxlevel: int | None = None, axis: int = -1, boundary_orthogonalization: Literal['qr', 'gramschmidt'] = 'qr')[source]#

Bases: UserDict

Implements a single-dimensional wavelet packets analysis transform.

Create a wavelet packet decomposition object.

The decompositions will rely on padded fast wavelet transforms.

Parameters:
  • data (torch.Tensor, optional) – The input data array of shape [time], [batch_size, time] or [batch_size, channels, time]. If None, the object is initialized without performing a decomposition. The time axis is transformed by default. Use the axis argument to choose another dimension.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • mode – The desired padding method. If you select ‘boundary’, the sparse matrix backend will be used. Defaults to ‘reflect’.

  • maxlevel (int, optional) – Value is passed on to transform. The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

  • axis (int) – The axis to transform. Defaults to -1.

  • boundary_orthogonalization – The orthogonalization method to use in the sparse matrix backend, see ptwt.constants.OrthogonalizeMethod. Only used if mode equals ‘boundary’. Defaults to ‘qr’.

Example

>>> import torch, pywt, ptwt
>>> import numpy as np
>>> import scipy.signal
>>> import matplotlib.pyplot as plt
>>> t = np.linspace(0, 10, 1500)
>>> w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
>>> wp = ptwt.WaveletPacket(data=torch.from_numpy(w.astype(np.float32)),
>>>     wavelet=pywt.Wavelet("db3"), mode="reflect")
>>> np_lst = []
>>> for node in wp.get_level(5):
>>>     np_lst.append(wp[node])
>>> viz = np.stack(np_lst).squeeze()
>>> plt.imshow(np.abs(viz))
>>> plt.show()
__getitem__(key: str) Tensor[source]#

Access the coefficients in the wavelet packets tree.

Parameters:

key (str) – The key of the accessed coefficients. The string may only consist of the chars ‘a’ and ‘d’ where ‘a’ denotes the low pass or approximation filter and ‘d’ the high-pass or detail filter.

Returns:

The accessed wavelet packet coefficients.

Raises:
  • ValueError – If the wavelet packet tree is not initialized.

  • KeyError – If no wavelet coefficients are indexed by the specified key.

get_level(level: int) list[str][source]#

Return the graycode-ordered paths to the filter tree nodes.

Parameters:

level (int) – The depth of the tree.

Returns:

A list with the paths to each node.

reconstruct() WaveletPacket[source]#

Recursively reconstruct the input starting from the leaf nodes.

Reconstruction replaces the input data originally assigned to this object.

Note

Only changes to leaf node data impact the results, since changes in all other nodes will be replaced with a reconstruction from the leaves.

Example

>>> import numpy as np
>>> import ptwt, torch
>>> signal = np.random.randn(1, 16)
>>> ptwp = ptwt.WaveletPacket(torch.from_numpy(signal), "haar",
>>>     mode="boundary", maxlevel=2)
>>> ptwp["aa"].data *= 0
>>> ptwp.reconstruct()
>>> print(ptwp[""])
transform(data: Tensor, maxlevel: int | None = None) WaveletPacket[source]#

Calculate the 1d wavelet packet transform for the input data.

Parameters:
  • data (torch.Tensor) – The input data array of shape [time] or [batch_size, time].

  • maxlevel (int, optional) – The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

class ptwt.packets.WaveletPacket2D(data: Tensor | None, wavelet: Wavelet | str, mode: Literal['boundary'] | Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', maxlevel: int | None = None, axes: tuple[int, int] = (-2, -1), boundary_orthogonalization: Literal['qr', 'gramschmidt'] = 'qr', separable: bool = False)[source]#

Bases: UserDict

Two-dimensional wavelet packets.

Example code illustrating the use of this class is available at: v0lta/PyTorch-Wavelet-Toolbox

Create a 2D-Wavelet packet tree.

Parameters:
  • data (torch.tensor, optional) – The input data tensor. For example of shape [batch_size, height, width] or [batch_size, channels, height, width]. If None, the object is initialized without performing a decomposition.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • mode – A string indicating the desired padding mode. If you select ‘boundary’, the sparse matrix backend is used. Defaults to ‘reflect’

  • maxlevel (int, optional) – Value is passed on to transform. The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

  • axes ([int, int], optional) – The tensor axes that should be transformed. Defaults to (-2, -1).

  • boundary_orthogonalization – The orthogonalization method to use in the sparse matrix backend, see ptwt.constants.OrthogonalizeMethod. Only used if mode equals ‘boundary’. Defaults to ‘qr’.

  • separable (bool) – If true, a separable transform is performed, i.e. each image axis is transformed separately. Defaults to False.

__getitem__(key: str) Tensor[source]#

Access the coefficients in the wavelet packets tree.

Parameters:

key (str) – The key of the accessed coefficients. The string may only consist of the following chars: ‘a’, ‘h’, ‘v’, ‘d’ The chars correspond to the selected coefficients for a level where ‘a’ denotes the approximation coefficients and ‘h’ horizontal, ‘v’ vertical and ‘d’ diagonal details coefficients.

Returns:

The accessed wavelet packet coefficients.

Raises:
  • ValueError – If the wavelet packet tree is not initialized.

  • KeyError – If no wavelet coefficients are indexed by the specified key.

static get_freq_order(level: int) list[list[str]][source]#

Get the frequency order for a given packet decomposition level.

Use this code to create two-dimensional frequency orderings.

Parameters:

level (int) – The number of decomposition scales.

Returns:

A list with the tree nodes in frequency order.

Note

Adapted from: PyWavelets/pywt

The code elements denote the filter application order. The filters are named following the pywt convention as: a - LL, low-low coefficients h - LH, low-high coefficients v - HL, high-low coefficients d - HH, high-high coefficients

static get_natural_order(level: int) list[str][source]#

Get the natural ordering for a given decomposition level.

Parameters:

level (int) – The decomposition level.

Returns:

A list with the filter order strings.

reconstruct() WaveletPacket2D[source]#

Recursively reconstruct the input starting from the leaf nodes.

Note

Only changes to leaf node data impact the results, since changes in all other nodes will be replaced with a reconstruction from the leaves.

transform(data: Tensor, maxlevel: int | None = None) WaveletPacket2D[source]#

Calculate the 2d wavelet packet transform for the input data.

The transform function allows reusing the same object.

Parameters:
  • data (torch.tensor) – The input data tensor of shape [batch_size, height, width].

  • maxlevel (int, optional) – The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

ptwt.continuous_transform module#

PyTorch compatible cwt code.

This module is based on pywt’s cwt implementation.

ptwt.continuous_transform.cwt(data: Tensor, scales: ndarray | Tensor, wavelet: ContinuousWavelet | str, sampling_period: float = 1.0) tuple[Tensor, ndarray][source]#

Compute the single-dimensional continuous wavelet transform.

This function is a PyTorch port of pywt.cwt as found at: PyWavelets/pywt

Parameters:
  • data (torch.Tensor) – The input tensor of shape [batch_size, time].

  • scales (torch.Tensor or np.array) – The wavelet scales to use. One can use f = pywt.scale2frequency(wavelet, scale)/sampling_period to determine what physical frequency, f. Here, f is in hertz when the sampling_period is given in seconds.

  • wavelet (ContinuousWavelet or str) – The continuous wavelet to work with.

  • sampling_period (float) – Sampling period for the frequencies output (optional). The values computed for coefs are independent of the choice of sampling_period (i.e. scales is not scaled by the sampling period).

Raises:

ValueError – If a scale is too small for the input signal.

Returns:

A tuple (out_tensor, frequencies). The first tuple-element contains the transformation matrix of shape [scales, batch, time]. The second element contains an array with frequency information.

Example

>>> import torch, ptwt
>>> import numpy as np
>>> import scipy.signal as signal
>>> t = np.linspace(-2, 2, 800, endpoint=False)
>>> sig = signal.chirp(t, f0=1, f1=12, t1=2, method="linear")
>>> widths = np.arange(1, 31)
>>> cwtmatr, freqs = ptwt.cwt(
>>>     torch.from_numpy(sig), widths, "mexh", sampling_period=(4 / 800) * np.pi
>>> )

ptwt.separable_conv_transform module#

Compute separable convolution-based transforms.

This module takes multi-dimensional convolutions apart. It uses single-dimensional convolutions to transform axes individually. Under the hood, code in this module transforms all dimensions using torch.nn.functional.conv1d and it’s transpose.

ptwt.separable_conv_transform.fswavedec2(data: Tensor, wavelet: Wavelet | str, *, mode: Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', level: int | None = None, axes: tuple[int, int] = (-2, -1)) ptwt.constants.WaveletCoeff2dSeparable[source]#

Compute a fully separable 2D-padded analysis wavelet transform.

Parameters:
  • data (torch.Tensor) – An data signal of shape [batch, height, width] or [batch, channels, height, width].

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of pywt.wavelist(kind="discrete") for a list of possible choices.

  • mode – The desired padding mode for extending the signal along the edges. Defaults to “reflect”. See ptwt.constants.BoundaryMode.

  • level (int) – The number of desired scales. Defaults to None.

  • axes ([int, int]) – The axes we want to transform, defaults to (-2, -1).

Returns:

A tuple with the ll coefficients and for each scale a dictionary containing the detail coefficients, see ptwt.constants.WaveletCoeff2dSeparable. The dictionaries use the filter order strings:

("ad", "da", "dd")

as keys. ‘a’ denotes the low pass or approximation filter and ‘d’ the high-pass or detail filter.

Raises:

ValueError – If the data is not a batched 2D signal.

Example

>>> import torch
>>> import ptwt
>>> data = torch.randn(5, 10, 10)
>>> coeff = ptwt.fswavedec2(data, "haar", level=2)
ptwt.separable_conv_transform.fswavedec3(data: Tensor, wavelet: Wavelet | str, *, mode: Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', level: int | None = None, axes: tuple[int, int, int] = (-3, -2, -1)) ptwt.constants.WaveletCoeffNd[source]#

Compute a fully separable 3D-padded analysis wavelet transform.

Parameters:
  • data (torch.Tensor) – An input signal of shape [batch, depth, height, width].

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of pywt.wavelist(kind="discrete") for possible choices.

  • mode – The desired padding mode for extending the signal along the edges. Defaults to “reflect”. See ptwt.constants.BoundaryMode.

  • level (int) – The number of desired scales. Defaults to None.

  • axes (tuple[int, int, int]) – Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1).

Returns:

A tuple with the lll coefficients and for each scale a dictionary containing the detail coefficients, see ptwt.constants.WaveletCoeffNd. The dictionaries use the filter order strings:

("aad", "ada", "add", "daa", "dad", "dda", "ddd")

as keys. ‘a’ denotes the low pass or approximation filter and ‘d’ the high-pass or detail filter.

Raises:

ValueError – If the input is not a batched 3D signal.

Example

>>> import torch
>>> import ptwt
>>> data = torch.randn(5, 10, 10, 10)
>>> coeff = ptwt.fswavedec3(data, "haar", level=2)
ptwt.separable_conv_transform.fswaverec2(coeffs: ptwt.constants.WaveletCoeff2dSeparable, wavelet: Wavelet | str, axes: tuple[int, int] = (-2, -1)) Tensor[source]#

Compute a fully separable 2D-padded synthesis wavelet transform.

The function uses separate single-dimensional convolutions under the hood.

Parameters:
  • coeffs (WaveletCoeff2dSeparable) – The wavelet coefficients as computed by fswavedec2, see ptwt.constants.WaveletCoeff2dSeparable.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • axes (tuple[int, int]) – Compute the transform over these axes instead of the last two. Defaults to (-2, -1).

Returns:

A reconstruction of the signal encoded in the wavelet coefficients.

Raises:

ValueError – If the axes argument is not a tuple of two integers.

Example

>>> import torch
>>> import ptwt
>>> data = torch.randn(5, 10, 10)
>>> coeff = ptwt.fswavedec2(data, "haar", level=2)
>>> rec = ptwt.fswaverec2(coeff, "haar")
ptwt.separable_conv_transform.fswaverec3(coeffs: ptwt.constants.WaveletCoeffNd, wavelet: Wavelet | str, axes: tuple[int, int, int] = (-3, -2, -1)) Tensor[source]#

Compute a fully separable 3D-padded synthesis wavelet transform.

Parameters:
  • coeffs (WaveletCoeffNd) – The wavelet coefficients as computed by fswavedec3, see ptwt.constants.WaveletCoeffNd.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • axes (tuple[int, int, int]) – Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1).

Returns:

A reconstruction of the signal encoded in the wavelet coefficients.

Raises:

ValueError – If the axes argument is not a tuple with three ints.

Example

>>> import torch
>>> import ptwt
>>> data = torch.randn(5, 10, 10, 10)
>>> coeff = ptwt.fswavedec3(data, "haar", level=2)
>>> rec = ptwt.fswaverec3(coeff, "haar")

ptwt.stationary_transform module#

This module implements stationary wavelet transforms.

ptwt.stationary_transform.iswt(coeffs: Sequence[Tensor], wavelet: Wavelet | str, axis: int | None = -1) Tensor[source]#

Invert a 1d stationary wavelet transform.

Parameters:
  • coeffs (Sequence[torch.Tensor]) – The coefficients as computed by the swt function.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet, as used in the forward transform.

  • axis (int, optional) – The axis the forward trasform was computed over. Defaults to -1.

Returns:

A reconstruction of the original swt input.

Raises:

ValueError – If the axis argument is not an integer.

ptwt.stationary_transform.swt(data: Tensor, wavelet: Wavelet | str, level: int | None = None, axis: int = -1) list[Tensor][source]#

Compute a multilevel 1d stationary wavelet transform.

This fuctions is equivalent to pywt’s swt with trim_approx=True and norm=False.

Parameters:
  • data (torch.Tensor) – The input data of shape [batch_size, time].

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • level (int, optional) – The number of levels to compute.

  • axis (int) – The axis to transform along. Defaults to the last axis.

Returns:

Same as wavedec. Equivalent to pywt.swt with trim_approx=True.

Raises:

ValueError – Is the axis argument is not an integer.

ptwt.matmul_transform module#

Implement matrix-based fwt and ifwt.

This module uses boundary filters instead of padding.

The implementation is based on the description in Strang Nguyen (p. 32), as well as the description of boundary filters in “Ripples in Mathematics” section 10.3 .

class ptwt.matmul_transform.BaseMatrixWaveDec[source]#

Bases: object

A base class for matrix wavedec.

class ptwt.matmul_transform.MatrixWavedec(wavelet: Wavelet | str, level: int | None = None, axis: int | None = -1, boundary: Literal['qr', 'gramschmidt'] = 'qr')[source]#

Bases: BaseMatrixWaveDec

Compute the sparse matrix fast wavelet transform.

Intermediate scale results must be divisible by two. A working third-level transform could use an input length of 128. This would lead to intermediate resolutions of 64 and 32. All are divisible by two.

Example

>>> import ptwt, torch, pywt
>>> import numpy as np
>>> # generate an input of even length.
>>> data = np.array([0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> data_torch = torch.from_numpy(data.astype(np.float32))
>>> matrix_wavedec = ptwt.MatrixWavedec(
>>>     pywt.Wavelet('haar'), level=2)
>>> coefficients = matrix_wavedec(data_torch)

Create a sparse matrix fast wavelet transform object.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • level (int, optional) – The level up to which to compute the fwt. If None, the maximum level based on the signal length is chosen. Defaults to None.

  • axis (int, optional) – The axis we would like to transform. Defaults to -1.

  • boundary – The method used for boundary filter treatment, see ptwt.constants.OrthogonalizeMethod. Defaults to ‘qr’.

Raises:
  • NotImplementedError – If the selected boundary mode is not supported.

  • ValueError – If the wavelet filters have different lengths or if axis is not an integer.

__call__(input_signal: Tensor) list[Tensor][source]#

Compute the matrix fwt for the given input signal.

Matrix FWTs are used to avoid padding.

Parameters:

input_signal (torch.Tensor) – Batched input data. An example shape could be [batch_size, time]. Inputs can have any dimension. This transform affects the last axis by default. Use the axis argument in the constructor to choose another axis.

Returns:

A list with the coefficient tensor for each scale.

Raises:

ValueError – If the decomposition level is not a positive integer or if the input signal has not the expected shape.

property sparse_fwt_operator: Tensor#

The sparse transformation operator.

If the input signal at all levels is divisible by two, the whole operation is padding-free and can be expressed as a single matrix multiply.

The operation torch.sparse.mm(sparse_fwt_operator, data.T) computes a batched fwt.

This property exists to make the operator matrix transparent. Calling the object will handle odd-length inputs properly.

Raises:
  • NotImplementedError – if padding had to be used in the creation of the transformation matrices.

  • ValueError – If no level transformation matrices are stored (most likely since the object was not called yet).

class ptwt.matmul_transform.MatrixWaverec(wavelet: Wavelet | str, axis: int = -1, boundary: Literal['qr', 'gramschmidt'] = 'qr')[source]#

Bases: object

Matrix-based inverse fast wavelet transform.

Example

>>> import ptwt, torch, pywt
>>> import numpy as np
>>> # generate an input of even length.
>>> data = np.array([0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> data_torch = torch.from_numpy(data.astype(np.float32))
>>> matrix_wavedec = ptwt.MatrixWavedec(
>>>     pywt.Wavelet('haar'), level=2)
>>> coefficients = matrix_wavedec(data_torch)
>>> matrix_waverec = ptwt.MatrixWaverec(
>>>     pywt.Wavelet('haar'))
>>> reconstruction = matrix_waverec(coefficients)

Create the inverse matrix-based fast wavelet transformation.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • axis (int) – The axis transformed by the original decomposition defaults to -1 or the last axis.

  • boundary – The method used for boundary filter treatment, see ptwt.constants.OrthogonalizeMethod. Defaults to ‘qr’.

Raises:
  • NotImplementedError – If the selected boundary mode is not supported.

  • ValueError – If the wavelet filters have different lengths or if axis is not an integer.

__call__(coefficients: Sequence[Tensor]) Tensor[source]#

Run the synthesis or inverse matrix fwt.

Parameters:

coefficients (Sequence[torch.Tensor]) – The coefficients produced by the forward transform.

Returns:

The input signal reconstruction.

Raises:

ValueError – If the decomposition level is not a positive integer or if the coefficients are not in the shape as it is returned from a MatrixWavedec object.

property sparse_ifwt_operator: Tensor#

The sparse transformation operator.

If the input signal at all levels is divisible by two, the whole operation is padding-free and can be expressed as a single matrix multiply.

Having concatenated the analysis coefficients, torch.sparse.mm(sparse_ifwt_operator, coefficients.T) to computes a batched iFWT.

This functionality is mainly here to make the operator-matrix transparent. Calling the object handles padding for odd inputs.

Raises:
  • NotImplementedError – if padding had to be used in the creation of the transformation matrices.

  • ValueError – If no level transformation matrices are stored (most likely since the object was not called yet).

ptwt.matmul_transform.construct_boundary_a(wavelet: Wavelet | str, length: int, device: device | str = 'cpu', boundary: Literal['qr', 'gramschmidt'] = 'qr', dtype: dtype = torch.float64) Tensor[source]#

Construct a boundary-wavelet filter 1d-analysis matrix.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • length (int) – The number of entries in the input signal.

  • boundary – The method used for boundary filter treatment, see ptwt.constants.OrthogonalizeMethod. Defaults to ‘qr’.

  • device – Where to place the matrix. Choose cpu or cuda. Defaults to cpu.

  • dtype – Choose float32 or float64.

Returns:

The sparse analysis matrix.

ptwt.matmul_transform.construct_boundary_s(wavelet: Wavelet | str, length: int, device: device | str = 'cpu', boundary: Literal['qr', 'gramschmidt'] = 'qr', dtype: dtype = torch.float64) Tensor[source]#

Construct a boundary-wavelet filter 1d-synthesis matarix.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • length (int) – The number of entries in the input signal.

  • device (torch.device) – Where to place the matrix. Choose cpu or cuda. Defaults to cpu.

  • boundary – The method used for boundary filter treatment, see ptwt.constants.OrthogonalizeMethod. Defaults to ‘qr’.

  • dtype – Choose torch.float32 or torch.float64. Defaults to torch.float64.

Returns:

The sparse synthesis matrix.

ptwt.matmul_transform.orthogonalize(matrix: Tensor, filt_len: int, method: Literal['qr', 'gramschmidt'] = 'qr') Tensor[source]#

Orthogonalization for sparse filter matrices.

Parameters:
  • matrix (torch.Tensor) – The sparse filter matrix to orthogonalize.

  • filt_len (int) – The length of the wavelet filter coefficients.

  • method – The orthogonalization method to use. Choose qr or gramschmidt. The dense qr code will run much faster than sparse gramschidt. Choose gramschmidt if qr fails. Defaults to qr.

Returns:

Orthogonal sparse transformation matrix.

Raises:

ValueError – If an invalid orthogonalization method is given

ptwt.matmul_transform_2 module#

Two-dimensional matrix based fast wavelet transform implementations.

This module uses boundary filters to minimize padding.

class ptwt.matmul_transform_2.MatrixWavedec2(wavelet: Wavelet | str, level: int | None = None, axes: tuple[int, int] = (-2, -1), boundary: Literal['qr', 'gramschmidt'] = 'qr', separable: bool = True)[source]#

Bases: BaseMatrixWaveDec

Experimental sparse matrix 2d wavelet transform.

For a completely pad-free transform, input images are expected to be divisible by two. For multiscale transforms all intermediate scale dimensions should be divisible by two, i.e. 128, 128 -> 64, 64 -> 32, 32 would work well for a level three transform. In this case multiplication with the sparse_fwt_operator property is equivalent.

Note

Constructing the sparse fwt-matrix is expensive. For longer wavelets, high-level transforms, and large input images this may take a while. The matrix is therefore constructed only once. In the non-separable case, it can be accessed via the sparse_fwt_operator property.

Example

>>> import ptwt, torch, pywt
>>> import numpy as np
>>> from scipy import datasets
>>> face = datasets.face()[:256, :256, :].astype(np.float32)
>>> pt_face = torch.tensor(face).permute([2, 0, 1])
>>> matrixfwt = ptwt.MatrixWavedec2(pywt.Wavelet("haar"), level=2)
>>> mat_coeff = matrixfwt(pt_face)

Create a new matrix fwt object.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • level (int, optional) – The level up to which to compute the fwt. If None, the maximum level based on the signal length is chosen. Defaults to None.

  • axes (int, int) – A tuple with the axes to transform. Defaults to (-2, -1).

  • boundary – The method used for boundary filter treatment, see ptwt.constants.OrthogonalizeMethod. Defaults to ‘qr’.

  • separable (bool) – If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. Matrix construction is significantly faster for separable transformations since only a small constant-size part of the matrices must be orthogonalized. Defaults to True.

Raises:
  • NotImplementedError – If the selected boundary mode is not supported.

  • ValueError – If the wavelet filters have different lengths.

__call__(input_signal: Tensor) ptwt.constants.WaveletCoeff2d[source]#

Compute the fwt for the given input signal.

The fwt matrix is set up during the first call and stored for future use.

Parameters:

input_signal (torch.Tensor) – An input signal of shape [batch_size, height, width]. 2d inputs are interpreted as [height, width]. 4d inputs as [batch_size, channels, height, width]. This transform affects the last two dimensions.

Returns:

The resulting coefficients per level are stored in a pywt style tuple, see ptwt.constants.WaveletCoeff2d.

Raises:

ValueError – If the decomposition level is not a positive integer or if the input signal has not the expected shape.

property sparse_fwt_operator: Tensor#

Compute the operator matrix for padding-free cases.

This property exists to make the transformation matrix available. To benefit from code handling odd-length levels call the object.

Returns:

The sparse 2d-fwt operator matrix.

Raises:
  • NotImplementedError – if a separable transformation was used or if padding had to be used in the creation of the transformation matrices.

  • ValueError – If no level transformation matrices are stored (most likely since the object was not called yet).

class ptwt.matmul_transform_2.MatrixWaverec2(wavelet: Wavelet | str, axes: tuple[int, int] = (-2, -1), boundary: Literal['qr', 'gramschmidt'] = 'qr', separable: bool = True)[source]#

Bases: object

Synthesis or inverse matrix based-wavelet transformation object.

Example

>>> import ptwt, torch, pywt
>>> import numpy as np
>>> from scipy import datasets
>>> face = datasets.face()[:256, :256, :].astype(np.float32)
>>> pt_face = torch.tensor(face).permute([2, 0, 1])
>>> matrixfwt = ptwt.MatrixWavedec2(pywt.Wavelet("haar"), level=2)
>>> mat_coeff = matrixfwt(pt_face)
>>> matrixifwt = ptwt.MatrixWaverec2(pywt.Wavelet("haar"))
>>> reconstruction = matrixifwt(mat_coeff)

Create the inverse matrix-based fast wavelet transformation.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • axes (int, int) – The axes transformed by waverec2. Defaults to (-2, -1).

  • boundary – The method used for boundary filter treatment, see ptwt.constants.OrthogonalizeMethod. Defaults to ‘qr’.

  • separable (bool) – If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. This is significantly faster than a non-separable transformation since only a small constant- size part of the matrices must be orthogonalized. For invertibility, the analysis and synthesis values must be identical! Defaults to True.

Raises:
  • NotImplementedError – If the selected boundary mode is not supported.

  • ValueError – If the wavelet filters have different lengths.

__call__(coefficients: ptwt.constants.WaveletCoeff2d) Tensor[source]#

Compute the inverse matrix 2d fast wavelet transform.

Parameters:

coefficients (WaveletCoeff2d) – The coefficient tuple as returned by the MatrixWavedec2 object, see ptwt.constants.WaveletCoeff2d.

Returns:

The original signal reconstruction. For example of shape [batch_size, height, width] or [batch_size, channels, height, width] depending on the input to the forward transform and the value of the axis argument.

Raises:

ValueError – If the decomposition level is not a positive integer or if the coefficients are not in the shape as it is returned from a MatrixWavedec2 object.

property sparse_ifwt_operator: Tensor#

Compute the ifwt operator matrix for pad-free cases.

Returns:

The sparse 2d ifwt operator matrix.

Raises:
  • NotImplementedError – if a separable transformation was used or if padding had to be used in the creation of the transformation matrices.

  • ValueError – If no level transformation matrices are stored (most likely since the object was not called yet).

ptwt.matmul_transform_2.construct_boundary_a2(wavelet: Wavelet | str, height: int, width: int, device: device | str, boundary: Literal['qr', 'gramschmidt'] = 'qr', dtype: dtype = torch.float64) Tensor[source]#

Construct a boundary fwt matrix for the input wavelet.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • height (int) – The height of the input matrix. Should be divisible by two.

  • width (int) – The width of the input matrix. Should be divisible by two.

  • device (torch.device) – Where to place the matrix. Either on the CPU or GPU.

  • boundary – The method used for boundary filter treatment, see ptwt.constants.OrthogonalizeMethod. Defaults to ‘qr’.

  • dtype (torch.dtype, optional) – The desired data type for the matrix. Defaults to torch.float64.

Returns:

A sparse fwt matrix, with orthogonalized boundary wavelets.

ptwt.matmul_transform_2.construct_boundary_s2(wavelet: Wavelet | str, height: int, width: int, device: device | str, *, boundary: Literal['qr', 'gramschmidt'] = 'qr', dtype: dtype = torch.float64) Tensor[source]#

Construct a 2d-fwt matrix, with boundary wavelets.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • height (int) – The original height of the input matrix.

  • width (int) – The width of the original input matrix.

  • device (torch.device) – Choose CPU or GPU.

  • boundary – The method used for boundary filter treatment, see ptwt.constants.OrthogonalizeMethod. Defaults to ‘qr’.

  • dtype (torch.dtype, optional) – The data type of the sparse matrix, choose float32 or 64. Defaults to torch.float64.

Returns:

The synthesis matrix, used to compute the inverse fast wavelet transform.

ptwt.matmul_transform_3 module#

Implement 3D separable boundary transforms.

class ptwt.matmul_transform_3.MatrixWavedec3(wavelet: Wavelet | str, level: int | None = None, axes: tuple[int, int, int] = (-3, -2, -1), boundary: Literal['qr', 'gramschmidt'] = 'qr')[source]#

Bases: object

Compute 3d separable transforms.

Create a separable three-dimensional fast boundary wavelet transform.

Input signals should have the shape [batch_size, depth, height, width], this object transforms the last three dimensions.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • level (int, optional) – The desired decomposition level. Defaults to None.

  • boundary – The method used for boundary filter treatment, see ptwt.constants.OrthogonalizeMethod. Defaults to ‘qr’.

Raises:
  • NotImplementedError – If the chosen orthogonalization method is not implemented.

  • ValueError – If the analysis and synthesis filters do not have the same length.

__call__(input_signal: Tensor) ptwt.constants.WaveletCoeffNd[source]#

Compute a separable 3d-boundary wavelet transform.

Parameters:

input_signal (torch.Tensor) – An input signal. For example of shape [batch_size, depth, height, width].

Returns:

The resulting coefficients for each level are stored in a tuple, see ptwt.constants.WaveletCoeffNd.

Raises:

ValueError – If the input dimensions don’t work.

class ptwt.matmul_transform_3.MatrixWaverec3(wavelet: Wavelet | str, axes: tuple[int, int, int] = (-3, -2, -1), boundary: Literal['qr', 'gramschmidt'] = 'qr')[source]#

Bases: object

Reconstruct a signal from 3d-separable-fwt coefficients.

Compute a three-dimensional separable boundary wavelet synthesis transform.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • axes (tuple[int, int, int]) – Transform these axes instead of the last three. Defaults to (-3, -2, -1).

  • boundary – The method used for boundary filter treatment, see ptwt.constants.OrthogonalizeMethod. Defaults to ‘qr’.

Raises:
  • NotImplementedError – If the selected boundary mode is not supported.

  • ValueError – If the wavelet filters have different lengths.

__call__(coefficients: ptwt.constants.WaveletCoeffNd) Tensor[source]#

Reconstruct a batched 3d-signal from its coefficients.

Parameters:

coefficients (WaveletCoeffNd) – The output from the MatrixWavedec3 object, see ptwt.constants.WaveletCoeffNd.

Returns:

A reconstruction of the original signal.

Return type:

torch.Tensor

Raises:

ValueError – If the data structure is inconsistent.

ptwt.sparse_math module#

Efficiently construct fwt operations using sparse matrices.

ptwt.sparse_math.batch_mm(matrix: Tensor, matrix_batch: Tensor) Tensor[source]#

Calculate a batched matrix-matrix product using torch tensors.

This calculates the product of a matrix with a batch of dense matrices. The former can be dense or sparse.

Parameters:
  • matrix (torch.Tensor) – Sparse or dense matrix, of shape [m, n].

  • matrix_batch (torch.Tensor) – Batched dense matrices, of shape [b, n, k].

Returns:

The batched matrix-matrix product of shape [b, m, k].

Raises:

ValueError – If the matrices cannot be multiplied due to incompatible matrix shapes.

ptwt.sparse_math.cat_sparse_identity_matrix(sparse_matrix: Tensor, new_length: int) Tensor[source]#

Concatenate a sparse input matrix and a sparse identity matrix.

Parameters:
  • sparse_matrix (torch.Tensor) – The input matrix.

  • new_length (int) – The length up to which the diagonal should be elongated.

Returns:

Square [input, eye] matrix of size [new_length, new_length]

ptwt.sparse_math.construct_conv2d_matrix(filter: Tensor, input_rows: int, input_columns: int, mode: Literal['full', 'valid', 'same', 'sameshift'] = 'valid', fully_sparse: bool = True) Tensor[source]#

Create a two-dimensional sparse convolution matrix.

Convolving with this matrix should be equivalent to a call to scipy.signal.convolve2d and a reshape.

Parameters:
  • filter (torch.tensor) – A filter of shape [height, width] to convolve with.

  • input_rows (int) – The number of rows in the input matrix.

  • input_columns (int) – The number of columns in the input matrix.

  • mode – (str) = The desired padding method. Options are full, same and valid. Defaults to ‘valid’ or no padding.

  • fully_sparse (bool) – Use a sparse implementation of the Kronecker to save memory. Defaults to True.

Returns:

A sparse convolution matrix.

Raises:

ValueError – If the padding mode is neither full, same or valid.

ptwt.sparse_math.construct_conv_matrix(filter: Tensor, input_rows: int, *, mode: Literal['full', 'valid', 'same', 'sameshift'] = 'valid') Tensor[source]#

Construct a convolution matrix.

Full, valid and same, padding are supported. For reference see: RoyiAvital/StackExchangeCodes master/StackOverflow/Q2080835/CreateConvMtxSparse.m

Parameters:
  • filter (torch.tensor) – The 1D-filter to convolve with.

  • input_rows (int) – The number of columns in the input.

  • mode – String identifier for the desired padding. Choose ‘full’, ‘valid’ or ‘same’. Defaults to valid.

Returns:

The sparse convolution tensor.

Raises:

ValueError – If the padding is not ‘full’, ‘same’ or ‘valid’.

ptwt.sparse_math.construct_strided_conv2d_matrix(filter: Tensor, input_rows: int, input_columns: int, stride: int = 2, mode: Literal['full', 'valid', 'same', 'sameshift'] = 'full') Tensor[source]#

Create a strided sparse two-dimensional convolution matrix.

Parameters:
  • filter (torch.Tensor) – The two-dimensional convolution filter.

  • input_rows (int) – The number of rows in the 2d-input matrix.

  • input_columns (int) – The number of columns in the 2d- input matrix.

  • stride (int) – The stride between the filter positions. Defaults to 2.

  • mode – The convolution type. Defaults to ‘full’. Sameshift starts at 1 instead of 0.

Returns:

The sparse convolution tensor.

Raises:

ValueError – Raised if an unknown convolution string is provided.

ptwt.sparse_math.construct_strided_conv_matrix(filter: Tensor, input_rows: int, stride: int = 2, *, mode: Literal['full', 'valid', 'same', 'sameshift'] = 'valid') Tensor[source]#

Construct a strided convolution matrix.

Parameters:
  • filter (torch.Tensor) – The filter coefficients to convolve with.

  • input_rows (int) – The number of rows in the input vector.

  • stride (int) – The step size of the convolution. Defaults to two.

  • mode – Choose ‘valid’, ‘same’ or ‘sameshift’. Defaults to ‘valid’.

Returns:

The strided sparse convolution matrix.

ptwt.sparse_math.sparse_diag(diagonal: Tensor, col_offset: int, rows: int, cols: int) Tensor[source]#

Create a rows-by-cols sparse diagonal-matrix.

The matrix is constructed by taking the columns of the input and placing them along the diagonal specified by col_offset.

Parameters:
  • diagonal (torch.Tensor) – The values for the diagonal.

  • col_offset (int) – Move the diagonal to the right by offset columns.

  • rows (int) – The number of rows in the final matrix.

  • cols (int) – The number of columns in the final matrix.

Returns:

A sparse matrix with a shifted diagonal.

ptwt.sparse_math.sparse_kron(sparse_tensor_a: Tensor, sparse_tensor_b: Tensor) Tensor[source]#

Compute a sparse Kronecker product.

As defined at: https://en.wikipedia.org/wiki/Kronecker_product Adapted from: scipy/scipy

Parameters:
  • sparse_tensor_a (torch.Tensor) – Sparse 2d-Tensor a of shape [m, n].

  • sparse_tensor_b (torch.Tensor) – Sparse 2d-Tensor b of shape [p, q].

Returns:

The resulting tensor of shape [mp, nq].

ptwt.sparse_math.sparse_replace_row(matrix: Tensor, row_index: int, row: Tensor) Tensor[source]#

Replace a row within a sparse [rows, cols] matrix.

I.e. matrix[row_no, :] = row.

Parameters:
  • matrix (torch.Tensor) – A sparse two-dimensional matrix.

  • row_index (int) – The row to replace.

  • row (torch.Tensor) – The row to insert into the sparse matrix.

Returns:

A sparse matrix, with the new row inserted at row_index.

ptwt.wavelets_learnable module#

Experimental code for adaptive wavelet learning.

See https://arxiv.org/pdf/2004.09569.pdf for more information.

class ptwt.wavelets_learnable.ProductFilter(dec_lo: Tensor, dec_hi: Tensor, rec_lo: Tensor, rec_hi: Tensor)[source]#

Bases: WaveletFilter, Module

Learnable product filter implementation.

Create a Product filter object.

Parameters:
  • dec_lo (torch.Tensor) – Low pass analysis filter.

  • dec_hi (torch.Tensor) – High pass analysis filter.

  • rec_lo (torch.Tensor) – Low pass synthesis filter.

  • rec_hi (torch.Tensor) – High pass synthesis filter.

property filter_bank: tuple[Tensor, Tensor, Tensor, Tensor]#

All filters a a tuple.

product_filter_loss() Tensor[source]#

Get only the product filter loss.

Returns:

The loss scalar.

wavelet_loss() Tensor[source]#

Return the sum of all loss terms.

Returns:

The loss scalar.

class ptwt.wavelets_learnable.SoftOrthogonalWavelet(dec_lo: Tensor, dec_hi: Tensor, rec_lo: Tensor, rec_hi: Tensor)[source]#

Bases: ProductFilter, Module

Orthogonal wavelets with a soft orthogonality constraint.

Create a SoftOrthogonalWavelet object.

Parameters:
  • dec_lo (torch.Tensor) – Low pass analysis filter.

  • dec_hi (torch.Tensor) – High pass analysis filter.

  • rec_lo (torch.Tensor) – Low pass synthesis filter.

  • rec_hi (torch.Tensor) – High pass synthesis filter.

filt_bank_orthogonality_loss() Tensor[source]#

Return a Jensen+Harbo inspired soft orthogonality loss.

On Page 79 of the Book Ripples in Mathematics by Jensen la Cour-Harbo, the constraint g0[k] = h0[-k] and g1[k] = h1[-k] for orthogonal filters is presented. A measurement is implemented below.

Returns:

A tensor with the orthogonality constraint value.

rec_lo_orthogonality_loss() Tensor[source]#

Return a Strang inspired soft orthogonality loss.

See Strang p. 148/149 or Harbo p. 80. Since L is a convolution matrix, LL^T can be evaluated trough convolution.

Returns:

A tensor with the orthogonality constraint value.

wavelet_loss() Tensor[source]#

Return the sum of all terms.

class ptwt.wavelets_learnable.WaveletFilter[source]#

Bases: ABC

Interface for learnable wavelets.

Each wavelet has a filter bank loss function and comes with functionality that tests the perfect reconstruction and anti-aliasing conditions.

alias_cancellation_loss() tuple[Tensor, Tensor, Tensor][source]#

Return the alias cancellation loss.

Implementation of the ac-loss as described on page 104 of Strang+Nguyen. $$F_0(z)H_0(-z) + F_1(z)H_1(-z) = 0$$

Returns:

The numerical value of the alias cancellation loss, as well as both loss components for analysis.

abstract property filter_bank: tuple[Tensor, Tensor, Tensor, Tensor]#

Return dec_lo, dec_hi, rec_lo, rec_hi.

perfect_reconstruction_loss() tuple[Tensor, Tensor, Tensor][source]#

Return the perfect reconstruction loss.

Returns:

The numerical value of the alias cancellation loss, as well as both intermediate values for analysis.

pf_alias_cancellation_loss() tuple[Tensor, Tensor, Tensor][source]#

Return the product filter-alias cancellation loss.

See: Strang+Nguyen 105: $$F_0(z) = H_1(-z); F_1(z) = -H_0(-z)$$ Alternating sign convention from 0 to N see Strang overview on the back of the cover.

Returns:

The numerical value of the alias cancellation loss, as well as both loss components for analysis.

abstract wavelet_loss() Tensor[source]#

Return the sum of all loss terms.

ptwt.constants#

Constants and types used throughout the PyTorch Wavelet Toolbox.

ptwt.constants.BoundaryMode#

This is a type literal for the way of padding used at boundaries.

  • Refection padding mirrors samples along the border (reflect)

  • Zero padding pads zeros (zero)

  • Constant padding replicates border values (constant)

  • Periodic padding cyclically repeats samples (periodic)

  • Symmetric padding mirrors samples along the border (symmetric)

alias of Literal[‘constant’, ‘zero’, ‘reflect’, ‘periodic’, ‘symmetric’]

ptwt.constants.ExtendedBoundaryMode#

This is a type literal for the way of handling signal boundaries.

This is either a form of padding (see ptwt.constants.BoundaryMode for padding options) or boundary to use boundary wavelets.

alias of Literal[‘boundary’] | Literal[‘constant’, ‘zero’, ‘reflect’, ‘periodic’, ‘symmetric’]

ptwt.constants.OrthogonalizeMethod#

The method for orthogonalizing a matrix.

  1. qr relies on pytorch’s dense QR implementation, it is fast but memory hungry.

  2. gramschmidt option is sparse, memory efficient, and slow.

Choose gramschmidt if qr runs out of memory.

alias of Literal[‘qr’, ‘gramschmidt’]

ptwt.constants.PaddingMode#

The padding mode is used when construction convolution matrices.

alias of Literal[‘full’, ‘valid’, ‘same’, ‘sameshift’]

ptwt.constants.WaveletCoeff2d#

Type alias for 2d wavelet transform results.

This type alias represents the result of a 2d wavelet transform with \(n\) levels as a tuple (A, Tn, ..., T1) of length \(n + 1\). A denotes a tensor of approximation coefficients for the n-th level of decomposition. Tl is a tuple of detail coefficients for level l, see ptwt.constants.WaveletDetailTuple2d.

Note that this type always contains an approximation coefficient tensor but does not necesseraily contain any detail coefficients.

Alias of tuple[torch.Tensor, *tuple[WaveletDetailTuple2d, ...]]

alias of tuple[Tensor, Unpack[tuple[WaveletDetailTuple2d, …]]]

ptwt.constants.WaveletCoeff2dSeparable#

Type alias for separable 2d wavelet transform results.

This is an alias of ptwt.constants.WaveletCoeffNd. It is used to emphasize the use of ptwt.constants.WaveletDetailDict for detail coefficients in a 2d setting – in contrast to ptwt.constants.WaveletCoeff2d.

Alias of ptwt.constants.WaveletCoeffNd, i.e. of tuple[torch.Tensor, *tuple[WaveletDetailDict, ...]].

alias of tuple[Tensor, Unpack[tuple[dict[str, Tensor], …]]]

ptwt.constants.WaveletCoeffNd#

Type alias for wavelet transform results in any dimension.

This type alias represents the result of a Nd wavelet transform with \(n\) levels as a tuple (A, Dn, ..., D1) of length \(n + 1\). A denotes a tensor of approximation coefficients for the n-th level of decomposition. Dl is a dictionary of detail coefficients for level l, see ptwt.constants.WaveletDetailDict.

Note that this type always contains an approximation coefficient tensor but does not necesseraily contain any detail coefficients.

Alias of tuple[torch.Tensor, *tuple[WaveletDetailDict, ...]]

alias of tuple[Tensor, Unpack[tuple[dict[str, Tensor], …]]]

ptwt.constants.WaveletDetailDict#

Type alias for a dict containing detail coefficient for a given level.

This type alias represents the detail coefficient tensors of a given level for a wavelet transform in \(N\) dimensions as the values of a dictionary. Its keys are a string of length \(N\) describing the detail coefficient by the applied filter for each axis. The string consists only of chars ‘a’ and ‘d’ where ‘a’ denotes the low pass or approximation filter and ‘d’ the high-pass or detail filter. For a 3d transform, the dictionary thus uses the keys:

("aad", "ada", "add", "daa", "dad", "dda", "ddd")

Alias of dict[str, torch.Tensor]

alias of dict[str, Tensor]

class ptwt.constants.WaveletDetailTuple2d(horizontal: Tensor, vertical: Tensor, diagonal: Tensor)[source]#

Bases: NamedTuple

Detail coefficients of a 2d wavelet transform for a given level.

This is a type alias for a named tuple (H, V, D) of detail coefficient tensors where H denotes horizontal, V vertical and D diagonal coefficients.

Create new instance of WaveletDetailTuple2d(horizontal, vertical, diagonal)

diagonal: Tensor#

Alias for field number 2

horizontal: Tensor#

Alias for field number 0

vertical: Tensor#

Alias for field number 1

ptwt.version module#

Version information for ptwt.

Run with python -m ptwt.version

ptwt.version.get_version(with_git_hash: bool = False) str[source]#

Get the ptwt version string, including a git hash.