The ptwt package

Contents

The ptwt package#

ptwt.conv_transform module#

Fast wavelet transformations based on torch.nn.functional.conv1d and its transpose.

This module treats boundaries with edge-padding.

ptwt.conv_transform.wavedec(data: Tensor, wavelet: Wavelet | str, *, mode: Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', level: int | None = None, axis: int = -1) List[Tensor][source]#

Compute the analysis (forward) 1d fast wavelet transform.

The transformation relies on convolution operations with filter pairs.

\[x_s * h_k = c_{k,s+1}\]

Where \(x_s\) denotes the input at scale \(s\), with \(x_0\) equal to the original input. \(h_k\) denotes the convolution filter, with \(k \in {A, D}\), where \(A\) for approximation and \(D\) for detail. The processes uses approximation coefficients as inputs for higher scales. Set the level argument to choose the largest scale.

Parameters:
  • data (torch.Tensor) – The input time series, By default the last axis is transformed.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Please consider the output from pywt.wavelist(kind='discrete') for possible choices.

  • mode – The desired padding mode for extending the signal along the edges. Defaults to “reflect”. See ptwt.constants.BoundaryMode.

  • level (int) – The scale level to be computed. Defaults to None.

  • axis (int) – Compute the transform over this axis instead of the last one. Defaults to -1.

Returns:

A list:

[cA_s, cD_s, cD_s-1, …, cD2, cD1]

containing the wavelet coefficients. A denotes approximation and D detail coefficients.

Return type:

list

Raises:

ValueError – If the dtype of the input data tensor is unsupported or if more than one axis is provided.

Example

>>> import torch
>>> import ptwt, pywt
>>> import numpy as np
>>> # generate an input of even length.
>>> data = np.array([0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> data_torch = torch.from_numpy(data.astype(np.float32))
>>> # compute the forward fwt coefficients
>>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
>>>              mode='zero', level=2)
ptwt.conv_transform.waverec(coeffs: List[Tensor], wavelet: Wavelet | str, axis: int = -1) Tensor[source]#

Reconstruct a signal from wavelet coefficients.

Parameters:
  • coeffs (list) – The wavelet coefficient list produced by wavedec.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • axis (int) – Transform this axis instead of the last one. Defaults to -1.

Returns:

The reconstructed signal.

Return type:

torch.Tensor

Raises:

ValueError – If the dtype of the coeffs tensor is unsupported or if the coefficients have incompatible shapes, dtypes or devices or if more than one axis is provided.

Example

>>> import torch
>>> import ptwt, pywt
>>> import numpy as np
>>> # generate an input of even length.
>>> data = np.array([0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> data_torch = torch.from_numpy(data.astype(np.float32))
>>> # invert the fast wavelet transform.
>>> ptwt.waverec(ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
>>>                           mode='zero', level=2),
>>>              pywt.Wavelet('haar'))

ptwt.conv_transform_2 module#

This module implements two-dimensional padded wavelet transforms.

The implementation relies on torch.nn.functional.conv2d and torch.nn.functional.conv_transpose2d under the hood.

ptwt.conv_transform_2.wavedec2(data: Tensor, wavelet: Wavelet | str, *, mode: Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', level: int | None = None, axes: Tuple[int, int] = (-2, -1)) List[Tensor | Tuple[Tensor, Tensor, Tensor]][source]#

Run a two-dimensional wavelet transformation.

This function relies on two-dimensional convolutions. Outer products allow the construction of 2D-filters from 1D filter arrays (see fwt-intro). It transforms the last two axes by default. This function computes

\[\mathbf{x}_s *_2 \mathbf{h}_k = \mathbf{c}_{k, s+1}\]

with \(k \in [a, h, v, d]\) and \(s \in \mathbb{N}_0\) the set of natural numbers, where \(\mathbf{x}_0\) is equal to the original input image \(\mathbf{X}\). \(*_2\) indicates two dimensional-convolution. Computations at subsequent scales work exclusively with approximation coefficients \(c_{a, s}\) as inputs. Setting the level argument allows choosing the largest scale.

Parameters:
  • data (torch.Tensor) – The input data tensor with any number of dimensions. By default 2d inputs are interpreted as [height, width], 3d inputs are interpreted as [batch_size, height, width]. 4d inputs are interpreted as [batch_size, channels, height, width]. the axis argument allows other interpretations.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of pywt.wavelist(kind="discrete") for a list of possible choices.

  • mode – The desired padding mode for extending the signal along the edges. Defaults to “reflect”. See ptwt.constants.BoundaryMode.

  • level (int) – The number of desired scales. Defaults to None.

  • axes (Tuple[int, int]) – Compute the transform over these axes instead of the last two. Defaults to (-2, -1).

Returns:

A list containing the wavelet coefficients. The coefficients are in pywt order. That is:

[cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] .

A denotes approximation, H horizontal, V vertical and D diagonal coefficients.

Return type:

list

Raises:

ValueError – If the dimensionality or the dtype of the input data tensor is unsupported or if the provided axes input has a length other than two.

Example

>>> import torch
>>> import ptwt, pywt
>>> import numpy as np
>>> from scipy import datasets
>>> face = np.transpose(datasets.face(),
>>>                     [2, 0, 1]).astype(np.float64)
>>> pytorch_face = torch.tensor(face) # try unsqueeze(0)
>>> coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
>>>                              level=2, mode="zero")
ptwt.conv_transform_2.waverec2(coeffs: List[Tensor | Tuple[Tensor, Tensor, Tensor]], wavelet: Wavelet | str, axes: Tuple[int, int] = (-2, -1)) Tensor[source]#

Reconstruct a signal from wavelet coefficients.

This function undoes the effect of the analysis or forward transform by running transposed convolutions.

Parameters:
  • coeffs (list) –

    The wavelet coefficient list produced by wavedec2. The coefficients must be in pywt order. That is:

    [cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] .
    

    A denotes approximation, H horizontal, V vertical, and D diagonal coefficients.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • axes (Tuple[int, int]) – Compute the transform over these axes instead of the last two. Defaults to (-2, -1).

Returns:

The reconstructed signal of shape [batch, height, width] or [batch, channel, height, width] depending on the input to wavedec2.

Return type:

torch.Tensor

Raises:

ValueError – If coeffs is not in a shape as returned from wavedec2 or if the dtype is not supported or if the provided axes input has length other than two or if the same axes it repeated twice.

Example

>>> import ptwt, pywt, torch
>>> import numpy as np
>>> from scipy import datasets
>>> face = np.transpose(datasets.face(),
>>>                     [2, 0, 1]).astype(np.float64)
>>> pytorch_face = torch.tensor(face)
>>> coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
>>>                              level=2, mode="constant")
>>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))

ptwt.conv_transform_3 module#

Code for three dimensional padded transforms.

The functions here are based on torch.nn.functional.conv3d and it’s transpose.

ptwt.conv_transform_3.wavedec3(data: Tensor, wavelet: Wavelet | str, *, mode: Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'zero', level: int | None = None, axes: Tuple[int, int, int] = (-3, -2, -1)) List[Tensor | Dict[str, Tensor]][source]#

Compute a three-dimensional wavelet transform.

Parameters:
  • data (torch.Tensor) – The input data. For example of shape [batch_size, length, height, width]

  • wavelet (Union[Wavelet, str]) – The wavelet to transform with. pywt.wavelist(kind='discrete') lists possible choices.

  • mode – The desired padding mode for extending the signal along the edges. Defaults to “zero”. See ptwt.constants.BoundaryMode.

  • level (Optional[int]) – The maximum decomposition level. This argument defaults to None.

  • axes (Tuple[int, int, int]) – Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1).

Returns:

A list with the lll coefficients and dictionaries with the filter order strings:

("aad", "ada", "add", "daa", "dad", "dda", "ddd")

as keys. With a for the low pass or approximation filter and d for the high-pass or detail filter.

Return type:

list

Raises:

ValueError – If the input has fewer than three dimensions or if the dtype is not supported or if the provided axes input has length other than three.

Example

>>> import ptwt, torch
>>> data = torch.randn(5, 16, 16, 16)
>>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect")
ptwt.conv_transform_3.waverec3(coeffs: List[Tensor | Dict[str, Tensor]], wavelet: Wavelet | str, axes: Tuple[int, int, int] = (-3, -2, -1)) Tensor[source]#

Reconstruct a signal from wavelet coefficients.

Parameters:
  • coeffs (list) – The wavelet coefficient list produced by wavedec3.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • axes (Tuple[int, int, int]) – Transform these axes instead of the last three. Defaults to (-3, -2, -1).

Returns:

The reconstructed four-dimensional signal of shape [batch, depth, height, width].

Return type:

torch.Tensor

Raises:

ValueError – If coeffs is not in a shape as returned from wavedec3 or if the dtype is not supported or if the provided axes input has length other than three or if the same axes it repeated three.

Example

>>> import ptwt, torch
>>> data = torch.randn(5, 16, 16, 16)
>>> transformed = ptwt.wavedec3(data, "haar", level=2, mode="reflect")
>>> reconstruction = ptwt.waverec3(transformed, "haar")

ptwt.packets module#

Compute analysis wavelet packet representations.

class ptwt.packets.WaveletPacket(data: Tensor | None, wavelet: Wavelet | str, mode: Literal['boundary'] | Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', maxlevel: int | None = None, axis: int = -1, boundary_orthogonalization: Literal['qr', 'gramschmidt'] = 'qr')[source]#

Bases: UserDict

Implements a single-dimensional wavelet packets analysis transform.

Create a wavelet packet decomposition object.

The decompositions will rely on padded fast wavelet transforms.

Parameters:
  • data (torch.Tensor, optional) – The input data array of shape [time], [batch_size, time] or [batch_size, channels, time]. If None, the object is initialized without performing a decomposition. The time axis is transformed by default. Use the axis argument to choose another dimension.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • mode – The desired padding method. If you select ‘boundary’, the sparse matrix backend will be used. Defaults to ‘reflect’.

  • maxlevel (int, optional) – Value is passed on to transform. The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

  • axis (int) – The axis to transform. Defaults to -1.

  • boundary_orthogonalization – The orthogonalization method to use. Only used if mode equals ‘boundary’. Choose from ‘qr’ or ‘gramschmidt’. Defaults to ‘qr’.

Example

>>> import torch, pywt, ptwt
>>> import numpy as np
>>> import scipy.signal
>>> import matplotlib.pyplot as plt
>>> t = np.linspace(0, 10, 1500)
>>> w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
>>> wp = ptwt.WaveletPacket(data=torch.from_numpy(w.astype(np.float32)),
>>>     wavelet=pywt.Wavelet("db3"), mode="reflect")
>>> np_lst = []
>>> for node in wp.get_level(5):
>>>     np_lst.append(wp[node])
>>> viz = np.stack(np_lst).squeeze()
>>> plt.imshow(np.abs(viz))
>>> plt.show()
get_level(level: int) List[str][source]#

Return the graycode-ordered paths to the filter tree nodes.

Parameters:

level (int) – The depth of the tree.

Returns:

A list with the paths to each node.

Return type:

list

reconstruct() WaveletPacket[source]#

Recursively reconstruct the input starting from the leaf nodes.

Reconstruction replaces the input data originally assigned to this object.

Note

Only changes to leaf node data impact the results, since changes in all other nodes will be replaced with a reconstruction from the leaves.

Example

>>> import numpy as np
>>> import ptwt, torch
>>> signal = np.random.randn(1, 16)
>>> ptwp = ptwt.WaveletPacket(torch.from_numpy(signal), "haar",
>>>     mode="boundary", maxlevel=2)
>>> ptwp["aa"].data *= 0
>>> ptwp.reconstruct()
>>> print(ptwp[""])
transform(data: Tensor, maxlevel: int | None = None) WaveletPacket[source]#

Calculate the 1d wavelet packet transform for the input data.

Parameters:
  • data (torch.Tensor) – The input data array of shape [time] or [batch_size, time].

  • maxlevel (int, optional) – The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

class ptwt.packets.WaveletPacket2D(data: Tensor | None, wavelet: Wavelet | str, mode: Literal['boundary'] | Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', maxlevel: int | None = None, axes: Tuple[int, int] = (-2, -1), boundary_orthogonalization: Literal['qr', 'gramschmidt'] = 'qr', separable: bool = False)[source]#

Bases: UserDict

Two-dimensional wavelet packets.

Example code illustrating the use of this class is available at: v0lta/PyTorch-Wavelet-Toolbox

Create a 2D-Wavelet packet tree.

Parameters:
  • data (torch.tensor, optional) – The input data tensor. For example of shape [batch_size, height, width] or [batch_size, channels, height, width]. If None, the object is initialized without performing a decomposition.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • mode – A string indicating the desired padding mode. If you select ‘boundary’, the sparse matrix backend is used. Defaults to ‘reflect’

  • maxlevel (int, optional) – Value is passed on to transform. The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

  • axes ([int, int], optional) – The tensor axes that should be transformed. Defaults to (-2, -1).

  • boundary_orthogonalization – The orthogonalization method to use in the sparse matrix backend. Only used if mode equals ‘boundary’. Choose from ‘qr’ or ‘gramschmidt’. Defaults to ‘qr’.

  • separable (bool) – If true, a separable transform is performed, i.e. each image axis is transformed separately. Defaults to False.

get_natural_order(level: int) List[str][source]#

Get the natural ordering for a given decomposition level.

Parameters:

level (int) – The decomposition level.

Returns:

A list with the filter order strings.

Return type:

list

reconstruct() WaveletPacket2D[source]#

Recursively reconstruct the input starting from the leaf nodes.

Note

Only changes to leaf node data impact the results, since changes in all other nodes will be replaced with a reconstruction from the leaves.

transform(data: Tensor, maxlevel: int | None = None) WaveletPacket2D[source]#

Calculate the 2d wavelet packet transform for the input data.

The transform function allows reusing the same object.

Parameters:
  • data (torch.tensor) – The input data tensor of shape [batch_size, height, width]

  • maxlevel (int, optional) – The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

ptwt.packets.get_freq_order(level: int) List[List[Tuple[str, ...]]][source]#

Get the frequency order for a given packet decomposition level.

Use this code to create two-dimensional frequency orderings.

Parameters:

level (int) – The number of decomposition scales.

Returns:

A list with the tree nodes in frequency order.

Return type:

list

Note

Adapted from: PyWavelets/pywt

The code elements denote the filter application order. The filters are named following the pywt convention as: a - LL, low-low coefficients h - LH, low-high coefficients v - HL, high-low coefficients d - HH, high-high coefficients

ptwt.continuous_transform module#

PyTorch compatible cwt code.

This module is based on pywt’s cwt implementation.

ptwt.continuous_transform.cwt(data: Tensor, scales: ndarray | Tensor, wavelet: ContinuousWavelet | str, sampling_period: float = 1.0) Tuple[Tensor, ndarray][source]#

Compute the single-dimensional continuous wavelet transform.

This function is a PyTorch port of pywt.cwt as found at: PyWavelets/pywt

Parameters:
  • data (torch.Tensor) – The input tensor of shape [batch_size, time].

  • scales (torch.Tensor or np.array) – The wavelet scales to use. One can use f = pywt.scale2frequency(wavelet, scale)/sampling_period to determine what physical frequency, f. Here, f is in hertz when the sampling_period is given in seconds.

  • wavelet (ContinuousWavelet or str) – The continuous wavelet to work with.

  • sampling_period (float) – Sampling period for the frequencies output (optional). The values computed for coefs are independent of the choice of sampling_period (i.e. scales is not scaled by the sampling period).

Raises:

ValueError – If a scale is too small for the input signal.

Returns:

The first tuple-element contains

the transformation matrix of shape [scales, batch, time]. The second element contains an array with frequency information.

Return type:

Tuple[torch.Tensor, np.ndarray]

Example

>>> import torch, ptwt
>>> import numpy as np
>>> import scipy.signal as signal
>>> t = np.linspace(-2, 2, 800, endpoint=False)
>>> sig = signal.chirp(t, f0=1, f1=12, t1=2, method="linear")
>>> widths = np.arange(1, 31)
>>> cwtmatr, freqs = ptwt.cwt(
>>>     torch.from_numpy(sig), widths, "mexh", sampling_period=(4 / 800) * np.pi
>>> )

ptwt.separable_conv_transform module#

Compute separable convolution-based transforms.

This module takes multi-dimensional convolutions apart. It uses single-dimensional convolutions to transform axes individually. Under the hood, code in this module transforms all dimensions using torch.nn.functional.conv1d and it’s transpose.

ptwt.separable_conv_transform.fswavedec2(data: Tensor, wavelet: str | Wavelet, *, mode: Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', level: int | None = None, axes: Tuple[int, int] = (-2, -1)) List[Tensor | Dict[str, Tensor]][source]#

Compute a fully separable 2D-padded analysis wavelet transform.

Parameters:
  • data (torch.Tensor) – An data signal of shape [batch, height, width] or [batch, channels, height, width].

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of pywt.wavelist(kind="discrete") for a list of possible choices.

  • mode – The desired padding mode for extending the signal along the edges. Defaults to “reflect”. See ptwt.constants.BoundaryMode.

  • level (int) – The number of desired scales. Defaults to None.

  • axes ([int, int]) – The axes we want to transform, defaults to (-2, -1).

Raises:

ValueError – If the data is not a batched 2D signal.

Returns:

A list with the lll coefficients and dictionaries with the filter order strings:

("ad", "da", "dd")

as keys. With a for the low pass or approximation filter and d for the high-pass or detail filter.

Return type:

List[Union[torch.Tensor, Dict[str, torch.Tensor]]]

Example

>>> import torch
>>> import ptwt
>>> data = torch.randn(5, 10, 10)
>>> coeff = ptwt.fswavedec2(data, "haar", level=2)
ptwt.separable_conv_transform.fswavedec3(data: Tensor, wavelet: str | Wavelet, *, mode: Literal['constant', 'zero', 'reflect', 'periodic', 'symmetric'] = 'reflect', level: int | None = None, axes: Tuple[int, int, int] = (-3, -2, -1)) List[Tensor | Dict[str, Tensor]][source]#

Compute a fully separable 3D-padded analysis wavelet transform.

Parameters:
  • data (torch.Tensor) – An input signal of shape [batch, depth, height, width].

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output of pywt.wavelist(kind="discrete") for a list of possible choices.

  • mode – The desired padding mode for extending the signal along the edges. Defaults to “reflect”. See ptwt.constants.BoundaryMode.

  • level (int) – The number of desired scales. Defaults to None.

  • axes (Tuple[int, int, int]) – Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1).

Raises:

ValueError – If the input is not a batched 3D signal.

Returns:

A list with the lll coefficients and dictionaries with the filter order strings:

("aad", "ada", "add", "daa", "dad", "dda", "ddd")

as keys. With a for the low pass or approximation filter and d for the high-pass or detail filter.

Return type:

List[Union[torch.Tensor, Dict[str, torch.Tensor]]]

Example

>>> import torch
>>> import ptwt
>>> data = torch.randn(5, 10, 10, 10)
>>> coeff = ptwt.fswavedec3(data, "haar", level=2)
ptwt.separable_conv_transform.fswaverec2(coeffs: List[Tensor | Dict[str, Tensor]], wavelet: str | Wavelet, axes: Tuple[int, int] = (-2, -1)) Tensor[source]#

Compute a fully separable 2D-padded synthesis wavelet transform.

The function uses separate single-dimensional convolutions under the hood.

Parameters:
  • coeffs (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) – The wavelet coefficients as computed by fswavedec2.

  • wavelet (Union[str, pywt.Wavelet]) – The wavelet to use for the synthesis transform.

  • axes (Tuple[int, int]) – Compute the transform over these axes instead of the last two. Defaults to (-2, -1).

Returns:

A reconstruction of the signal encoded in the

wavelet coefficients.

Return type:

torch.Tensor

Raises:

ValueError – If the axes argument is not a tuple of two integers.

Example

>>> import torch
>>> import ptwt
>>> data = torch.randn(5, 10, 10)
>>> coeff = ptwt.fswavedec2(data, "haar", level=2)
>>> rec = ptwt.fswaverec2(coeff, "haar")
ptwt.separable_conv_transform.fswaverec3(coeffs: List[Tensor | Dict[str, Tensor]], wavelet: str | Wavelet, axes: Tuple[int, int, int] = (-3, -2, -1)) Tensor[source]#

Compute a fully separable 3D-padded synthesis wavelet transform.

Parameters:
  • coeffs (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]) – The wavelet coefficients as computed by fswavedec3.

  • wavelet (Union[str, pywt.Wavelet]) – The wavelet to use for the synthesis transform.

  • axes (Tuple[int, int, int]) – Compute the transform over these axes instead of the last three. Defaults to (-3, -2, -1).

Returns:

A reconstruction of the signal encoded in the

wavelet coefficients.

Return type:

torch.Tensor

Raises:

ValueError – If the axes argument is not a tuple with three ints.

Example

>>> import torch
>>> import ptwt
>>> data = torch.randn(5, 10, 10, 10)
>>> coeff = ptwt.fswavedec3(data, "haar", level=2)
>>> rec = ptwt.fswaverec3(coeff, "haar")

ptwt.matmul_transform module#

Implement matrix-based fwt and ifwt.

This module uses boundary filters instead of padding.

The implementation is based on the description in Strang Nguyen (p. 32), as well as the description of boundary filters in “Ripples in Mathematics” section 10.3 .

class ptwt.matmul_transform.BaseMatrixWaveDec[source]#

Bases: object

A base class for matrix wavedec.

class ptwt.matmul_transform.MatrixWavedec(wavelet: Wavelet | str, level: int | None = None, axis: int | None = -1, boundary: Literal['qr', 'gramschmidt'] = 'qr')[source]#

Bases: BaseMatrixWaveDec

Compute the sparse matrix fast wavelet transform.

Intermediate scale results must be divisible by two. A working third-level transform could use an input length of 128. This would lead to intermediate resolutions of 64 and 32. All are divisible by two.

Example

>>> import ptwt, torch, pywt
>>> import numpy as np
>>> # generate an input of even length.
>>> data = np.array([0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> data_torch = torch.from_numpy(data.astype(np.float32))
>>> matrix_wavedec = ptwt.MatrixWavedec(
>>>     pywt.Wavelet('haar'), level=2)
>>> coefficients = matrix_wavedec(data_torch)

Create a matrix-fwt object.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • level (int, optional) – The level up to which to compute the fwt. If None, the maximum level based on the signal length is chosen. Defaults to None.

  • axis (int, optional) – The axis we would like to transform. Defaults to -1.

  • boundary – The method used for boundary filter treatment. Choose ‘qr’ or ‘gramschmidt’. ‘qr’ relies on pytorch’s dense qr implementation, it is fast but memory hungry. The ‘gramschmidt’ option is sparse, memory efficient, and slow. Choose ‘gramschmidt’ if ‘qr’ runs out of memory. Defaults to ‘qr’.

Raises:
  • NotImplementedError – If the selected boundary mode is not supported.

  • ValueError – If the wavelet filters have different lengths or if axis is not an integer.

property sparse_fwt_operator: Tensor#

Return the sparse transformation operator.

If the input signal at all levels is divisible by two, the whole operation is padding-free and can be expressed as a single matrix multiply.

The operation torch.sparse.mm(sparse_fwt_operator, data.T) computes a batched fwt.

This property exists to make the operator matrix transparent. Calling the object will handle odd-length inputs properly.

Returns:

The sparse operator matrix.

Return type:

torch.Tensor

Raises:
  • NotImplementedError – if padding had to be used in the creation of the transformation matrices.

  • ValueError – If no level transformation matrices are stored (most likely since the object was not called yet).

class ptwt.matmul_transform.MatrixWaverec(wavelet: Wavelet | str, axis: int = -1, boundary: Literal['qr', 'gramschmidt'] = 'qr')[source]#

Bases: object

Matrix-based inverse fast wavelet transform.

Example

>>> import ptwt, torch, pywt
>>> import numpy as np
>>> # generate an input of even length.
>>> data = np.array([0, 1, 2, 3, 4, 5, 5, 4, 3, 2, 1, 0])
>>> data_torch = torch.from_numpy(data.astype(np.float32))
>>> matrix_wavedec = ptwt.MatrixWavedec(
>>>     pywt.Wavelet('haar'), level=2)
>>> coefficients = matrix_wavedec(data_torch)
>>> matrix_waverec = ptwt.MatrixWaverec(
>>>     pywt.Wavelet('haar'))
>>> reconstruction = matrix_waverec(coefficients)

Create the inverse matrix-based fast wavelet transformation.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • axis (int) – The axis transformed by the original decomposition defaults to -1 or the last axis.

  • boundary – The method used for boundary filter treatment. Choose ‘qr’ or ‘gramschmidt’. ‘qr’ relies on pytorch’s dense qr implementation, it is fast but memory hungry. The ‘gramschmidt’ option is sparse, memory efficient, and slow. Choose ‘gramschmidt’ if ‘qr’ runs out of memory. Defaults to ‘qr’.

Raises:
  • NotImplementedError – If the selected boundary mode is not supported.

  • ValueError – If the wavelet filters have different lengths or if axis is not an integer.

property sparse_ifwt_operator: Tensor#

Return the sparse transformation operator.

If the input signal at all levels is divisible by two, the whole operation is padding-free and can be expressed as a single matrix multiply.

Having concatenated the analysis coefficients, torch.sparse.mm(sparse_ifwt_operator, coefficients.T) to computes a batched iFWT.

This functionality is mainly here to make the operator-matrix transparent. Calling the object handles padding for odd inputs.

Returns:

The sparse operator matrix.

Return type:

torch.Tensor

Raises:
  • NotImplementedError – if padding had to be used in the creation of the transformation matrices.

  • ValueError – If no level transformation matrices are stored (most likely since the object was not called yet).

ptwt.matmul_transform.construct_boundary_a(wavelet: Wavelet | str, length: int, device: device | str = 'cpu', boundary: Literal['qr', 'gramschmidt'] = 'qr', dtype: dtype = torch.float64) Tensor[source]#

Construct a boundary-wavelet filter 1d-analysis matrix.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • length (int) – The number of entries in the input signal.

  • boundary – A string indicating the desired boundary treatment. Possible options are qr and gramschmidt. Defaults to qr.

  • device – Where to place the matrix. Choose cpu or cuda. Defaults to cpu.

  • dtype – Choose float32 or float64.

Returns:

The sparse analysis matrix.

Return type:

torch.Tensor

ptwt.matmul_transform.construct_boundary_s(wavelet: Wavelet | str, length: int, device: device | str = 'cpu', boundary: Literal['qr', 'gramschmidt'] = 'qr', dtype: dtype = torch.float64) Tensor[source]#

Construct a boundary-wavelet filter 1d-synthesis matarix.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • length (int) – The number of entries in the input signal.

  • device (torch.device) – Where to place the matrix. Choose cpu or cuda. Defaults to cpu.

  • boundary – A string indicating the desired boundary treatment. Possible options are qr and gramschmidt. Defaults to qr.

  • dtype – Choose torch.float32 or torch.float64. Defaults to torch.float64.

Returns:

The sparse synthesis matrix.

Return type:

torch.Tensor

ptwt.matmul_transform.orthogonalize(matrix: Tensor, filt_len: int, method: Literal['qr', 'gramschmidt'] = 'qr') Tensor[source]#

Orthogonalization for sparse filter matrices.

Parameters:
  • matrix (torch.Tensor) – The sparse filter matrix to orthogonalize.

  • filt_len (int) – The length of the wavelet filter coefficients.

  • method – The orthogonalization method to use. Choose qr or gramschmidt. The dense qr code will run much faster than sparse gramschidt. Choose gramschmidt if qr fails. Defaults to qr.

Returns:

Orthogonal sparse transformation matrix.

Return type:

torch.Tensor

Raises:

ValueError – If an invalid orthogonalization method is given

ptwt.matmul_transform_2 module#

Two-dimensional matrix based fast wavelet transform implementations.

This module uses boundary filters to minimize padding.

class ptwt.matmul_transform_2.MatrixWavedec2(wavelet: Wavelet | str, level: int | None = None, axes: Tuple[int, int] = (-2, -1), boundary: Literal['qr', 'gramschmidt'] = 'qr', separable: bool = True)[source]#

Bases: BaseMatrixWaveDec

Experimental sparse matrix 2d wavelet transform.

For a completely pad-free transform, input images are expected to be divisible by two. For multiscale transforms all intermediate scale dimensions should be divisible by two, i.e. 128, 128 -> 64, 64 -> 32, 32 would work well for a level three transform. In this case multiplication with the sparse_fwt_operator property is equivalent.

Note

Constructing the sparse fwt-matrix is expensive. For longer wavelets, high-level transforms, and large input images this may take a while. The matrix is therefore constructed only once. In the non-separable case, it can be accessed via the sparse_fwt_operator property.

Example

>>> import ptwt, torch, pywt
>>> import numpy as np
>>> from scipy import datasets
>>> face = datasets.face()[:256, :256, :].astype(np.float32)
>>> pt_face = torch.tensor(face).permute([2, 0, 1])
>>> matrixfwt = ptwt.MatrixWavedec2(pywt.Wavelet("haar"), level=2)
>>> mat_coeff = matrixfwt(pt_face)

Create a new matrix fwt object.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • level (int, optional) – The level up to which to compute the fwt. If None, the maximum level based on the signal length is chosen. Defaults to None.

  • axes (int, int) – A tuple with the axes to transform. Defaults to (-2, -1).

  • boundary – The method used for boundary filter treatment. Choose ‘qr’ or ‘gramschmidt’. ‘qr’ relies on Pytorch’s dense qr implementation, it is fast but memory hungry. The ‘gramschmidt’ option is sparse, memory efficient, and slow. Choose ‘gramschmidt’ if ‘qr’ runs out of memory. Defaults to ‘qr’.

  • separable (bool) – If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. Matrix construction is significantly faster for separable transformations since only a small constant-size part of the matrices must be orthogonalized. Defaults to True.

Raises:
  • NotImplementedError – If the selected boundary mode is not supported.

  • ValueError – If the wavelet filters have different lengths.

property sparse_fwt_operator: Tensor#

Compute the operator matrix for padding-free cases.

This property exists to make the transformation matrix available. To benefit from code handling odd-length levels call the object.

Returns:

The sparse 2d-fwt operator matrix.

Return type:

torch.Tensor

Raises:
  • NotImplementedError – if a separable transformation was used or if padding had to be used in the creation of the transformation matrices.

  • ValueError – If no level transformation matrices are stored (most likely since the object was not called yet).

class ptwt.matmul_transform_2.MatrixWaverec2(wavelet: Wavelet | str, axes: Tuple[int, int] = (-2, -1), boundary: Literal['qr', 'gramschmidt'] = 'qr', separable: bool = True)[source]#

Bases: object

Synthesis or inverse matrix based-wavelet transformation object.

Example

>>> import ptwt, torch, pywt
>>> import numpy as np
>>> from scipy import datasets
>>> face = datasets.face()[:256, :256, :].astype(np.float32)
>>> pt_face = torch.tensor(face).permute([2, 0, 1])
>>> matrixfwt = ptwt.MatrixWavedec2(pywt.Wavelet("haar"), level=2)
>>> mat_coeff = matrixfwt(pt_face)
>>> matrixifwt = ptwt.MatrixWaverec2(pywt.Wavelet("haar"))
>>> reconstruction = matrixifwt(mat_coeff)

Create the inverse matrix-based fast wavelet transformation.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • axes (int, int) – The axes transformed by waverec2. Defaults to (-2, -1).

  • boundary – The method used for boundary filter treatment. Choose ‘qr’ or ‘gramschmidt’. ‘qr’ relies on pytorch’s dense qr implementation, it is fast but memory hungry. The ‘gramschmidt’ option is sparse, memory efficient, and slow. Choose ‘gramschmidt’ if ‘qr’ runs out of memory. Defaults to ‘qr’.

  • separable (bool) – If this flag is set, a separable transformation is used, i.e. a 1d transformation along each axis. This is significantly faster than a non-separable transformation since only a small constant- size part of the matrices must be orthogonalized. For invertibility, the analysis and synthesis values must be identical! Defaults to True.

Raises:
  • NotImplementedError – If the selected boundary mode is not supported.

  • ValueError – If the wavelet filters have different lengths.

property sparse_ifwt_operator: Tensor#

Compute the ifwt operator matrix for pad-free cases.

Returns:

The sparse 2d ifwt operator matrix.

Return type:

torch.Tensor

Raises:
  • NotImplementedError – if a separable transformation was used or if padding had to be used in the creation of the transformation matrices.

  • ValueError – If no level transformation matrices are stored (most likely since the object was not called yet).

ptwt.matmul_transform_2.construct_boundary_a2(wavelet: Wavelet | str, height: int, width: int, device: device | str, boundary: Literal['qr', 'gramschmidt'] = 'qr', dtype: dtype = torch.float64) Tensor[source]#

Construct a boundary fwt matrix for the input wavelet.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • height (int) – The height of the input matrix. Should be divisible by two.

  • width (int) – The width of the input matrix. Should be divisible by two.

  • device (torch.device) – Where to place the matrix. Either on the CPU or GPU.

  • boundary – The method to use for matrix orthogonalization. Choose “qr” or “gramschmidt”. Defaults to “qr”.

  • dtype (torch.dtype, optional) – The desired data type for the matrix. Defaults to torch.float64.

Returns:

A sparse fwt matrix, with orthogonalized boundary

wavelets.

Return type:

torch.Tensor

ptwt.matmul_transform_2.construct_boundary_s2(wavelet: Wavelet | str, height: int, width: int, device: device | str, *, boundary: Literal['qr', 'gramschmidt'] = 'qr', dtype: dtype = torch.float64) Tensor[source]#

Construct a 2d-fwt matrix, with boundary wavelets.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • height (int) – The original height of the input matrix.

  • width (int) – The width of the original input matrix.

  • device (torch.device) – Choose CPU or GPU.

  • boundary – The method to use for matrix orthogonalization. Choose qr or gramschmidt. Defaults to qr.

  • dtype (torch.dtype, optional) – The data type of the sparse matrix, choose float32 or 64. Defaults to torch.float64.

Returns:

The synthesis matrix, used to compute the

inverse fast wavelet transform.

Return type:

torch.Tensor

ptwt.matmul_transform_3 module#

Implement 3D separable boundary transforms.

class ptwt.matmul_transform_3.MatrixWavedec3(wavelet: Wavelet | str, level: int | None = None, axes: Tuple[int, int, int] = (-3, -2, -1), boundary: Literal['qr', 'gramschmidt'] = 'qr')[source]#

Bases: object

Compute 3d separable transforms.

Create a separable three-dimensional fast boundary wavelet transform.

Input signals should have the shape [batch_size, depth, height, width], this object transforms the last three dimensions.

Parameters:
  • wavelet (Union[Wavelet, str]) – The wavelet to use.

  • level (Optional[int]) – The desired decomposition level. Defaults to None.

  • boundary – The matrix orthogonalization method. Defaults to “qr”.

Raises:
  • NotImplementedError – If the chosen orthogonalization method is not implemented.

  • ValueError – If the analysis and synthesis filters do not have the same length.

class ptwt.matmul_transform_3.MatrixWaverec3(wavelet: Wavelet | str, axes: Tuple[int, int, int] = (-3, -2, -1), boundary: Literal['qr', 'gramschmidt'] = 'qr')[source]#

Bases: object

Reconstruct a signal from 3d-separable-fwt coefficients.

Compute a three-dimensional separable boundary wavelet synthesis transform.

Parameters:
  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet.

  • axes (Tuple[int, int, int]) – Transform these axes instead of the last three. Defaults to (-3, -2, -1).

  • boundary – The method used for boundary filter treatment. Choose ‘qr’ or ‘gramschmidt’. ‘qr’ relies on Pytorch’s dense qr implementation, it is fast but memory hungry. The ‘gramschmidt’ option is sparse, memory efficient, and slow. Choose ‘gramschmidt’ if ‘qr’ runs out of memory. Defaults to ‘qr’.

Raises:
  • NotImplementedError – If the selected boundary mode is not supported.

  • ValueError – If the wavelet filters have different lengths.

ptwt.sparse_math module#

Efficiently construct fwt operations using sparse matrices.

ptwt.sparse_math.batch_mm(matrix: Tensor, matrix_batch: Tensor) Tensor[source]#

Calculate a batched matrix-matrix product using torch tensors.

This calculates the product of a matrix with a batch of dense matrices. The former can be dense or sparse.

Parameters:
  • matrix (torch.Tensor) – Sparse or dense matrix, size (m, n).

  • matrix_batch (torch.Tensor) – Batched dense matrices, size (b, n, k).

Returns

torch.Tensor: The batched matrix-matrix product, size (b, m, k).

Raises:

ValueError – If the matrices cannot be multiplied due to incompatible matrix shapes.

ptwt.sparse_math.cat_sparse_identity_matrix(sparse_matrix: Tensor, new_length: int) Tensor[source]#

Concatenate a sparse input matrix and a sparse identity matrix.

Parameters:
  • sparse_matrix (torch.Tensor) – The input matrix.

  • new_length (int) – The length up to which the diagonal should be elongated.

Returns:

Square [input, eye] matrix

of size [new_length, new_length]

Return type:

torch.Tensor

ptwt.sparse_math.construct_conv2d_matrix(filter: Tensor, input_rows: int, input_columns: int, mode: Literal['full', 'valid', 'same', 'sameshift'] = 'valid', fully_sparse: bool = True) Tensor[source]#

Create a two-dimensional sparse convolution matrix.

Convolving with this matrix should be equivalent to a call to scipy.signal.convolve2d and a reshape.

Parameters:
  • filter (torch.tensor) – A filter of shape [height, width] to convolve with.

  • input_rows (int) – The number of rows in the input matrix.

  • input_columns (int) – The number of columns in the input matrix.

  • mode – (str) = The desired padding method. Options are full, same and valid. Defaults to ‘valid’ or no padding.

  • fully_sparse (bool) – Use a sparse implementation of the Kronecker to save memory. Defaults to True.

Returns:

A sparse convolution matrix.

Return type:

torch.Tensor

Raises:

ValueError – If the padding mode is neither full, same or valid.

ptwt.sparse_math.construct_conv_matrix(filter: Tensor, input_rows: int, *, mode: Literal['full', 'valid', 'same', 'sameshift'] = 'valid') Tensor[source]#

Construct a convolution matrix.

Full, valid and same, padding are supported. For reference see: RoyiAvital/StackExchangeCodes master/StackOverflow/Q2080835/CreateConvMtxSparse.m

Parameters:
  • filter (torch.tensor) – The 1D-filter to convolve with.

  • input_rows (int) – The number of columns in the input.

  • mode – String identifier for the desired padding. Choose ‘full’, ‘valid’ or ‘same’. Defaults to valid.

Returns:

The sparse convolution tensor.

Return type:

torch.Tensor

Raises:

ValueError – If the padding is not ‘full’, ‘same’ or ‘valid’.

ptwt.sparse_math.construct_strided_conv2d_matrix(filter: Tensor, input_rows: int, input_columns: int, stride: int = 2, mode: Literal['full', 'valid', 'same', 'sameshift'] = 'full') Tensor[source]#

Create a strided sparse two-dimensional convolution matrix.

Parameters:
  • filter (torch.Tensor) – The two-dimensional convolution filter.

  • input_rows (int) – The number of rows in the 2d-input matrix.

  • input_columns (int) – The number of columns in the 2d- input matrix.

  • stride (int) – The stride between the filter positions. Defaults to 2.

  • mode – The convolution type. Defaults to ‘full’. Sameshift starts at 1 instead of 0.

Raises:

ValueError – Raised if an unknown convolution string is provided.

Returns:

The sparse convolution tensor.

Return type:

torch.Tensor

ptwt.sparse_math.construct_strided_conv_matrix(filter: Tensor, input_rows: int, stride: int = 2, *, mode: Literal['full', 'valid', 'same', 'sameshift'] = 'valid') Tensor[source]#

Construct a strided convolution matrix.

Parameters:
  • filter (torch.Tensor) – The filter coefficients to convolve with.

  • input_rows (int) – The number of rows in the input vector.

  • stride (int) – The step size of the convolution. Defaults to two.

  • mode – Choose ‘valid’, ‘same’ or ‘sameshift’. Defaults to ‘valid’.

Returns:

The strided sparse convolution matrix.

Return type:

torch.Tensor

ptwt.sparse_math.sparse_diag(diagonal: Tensor, col_offset: int, rows: int, cols: int) Tensor[source]#

Create a rows-by-cols sparse diagonal-matrix.

The matrix is constructed by taking the columns of the input and placing them along the diagonal specified by col_offset.

Parameters:
  • diagonal (torch.Tensor) – The values for the diagonal.

  • col_offset (int) – Move the diagonal to the right by offset columns.

  • rows (int) – The number of rows in the final matrix.

  • cols (int) – The number of columns in the final matrix.

Returns:

A sparse matrix with a shifted diagonal.

Return type:

torch.Tensor

ptwt.sparse_math.sparse_kron(sparse_tensor_a: Tensor, sparse_tensor_b: Tensor) Tensor[source]#

Compute a sparse Kronecker product.

As defined at: https://en.wikipedia.org/wiki/Kronecker_product Adapted from: scipy/scipy

Parameters:
  • sparse_tensor_a (torch.Tensor) – Sparse 2d-Tensor a of shape [m, n].

  • sparse_tensor_b (torch.Tensor) – Sparse 2d-Tensor b of shape [p, q].

Returns:

The resulting [mp, nq] tensor.

Return type:

torch.Tensor

ptwt.sparse_math.sparse_replace_row(matrix: Tensor, row_index: int, row: Tensor) Tensor[source]#

Replace a row within a sparse [rows, cols] matrix.

I.e. matrix[row_no, :] = row.

Parameters:
  • matrix (torch.Tensor) – A sparse two-dimensional matrix.

  • row_index (int) – The row to replace.

  • row (torch.Tensor) – The row to insert into the sparse matrix.

Returns:

A sparse matrix, with the new row inserted at row_index.

Return type:

torch.Tensor

ptwt.version module#

Version information for ptwt.

Run with python -m ptwt.version

ptwt.version.get_version(with_git_hash: bool = False) str[source]#

Get the ptwt version string, including a git hash.

ptwt.wavelets_learnable module#

Experimental code for adaptive wavelet learning.

See https://arxiv.org/pdf/2004.09569.pdf for more information.

class ptwt.wavelets_learnable.ProductFilter(dec_lo: Tensor, dec_hi: Tensor, rec_lo: Tensor, rec_hi: Tensor)[source]#

Bases: WaveletFilter, Module

Learnable product filter implementation.

Create a Product filter object.

Parameters:
  • dec_lo (torch.Tensor) – Low pass analysis filter.

  • dec_hi (torch.Tensor) – High pass analysis filter.

  • rec_lo (torch.Tensor) – Low pass synthesis filter.

  • rec_hi (torch.Tensor) – High pass synthesis filter.

property filter_bank: Tuple[Tensor, Tensor, Tensor, Tensor]#

Return all filters a a tuple.

product_filter_loss() Tensor[source]#

Get only the product filter loss.

Returns:

The loss scalar.

Return type:

torch.Tensor

wavelet_loss() Tensor[source]#

Return the sum of all loss terms.

Returns:

The loss scalar.

Return type:

torch.Tensor

class ptwt.wavelets_learnable.SoftOrthogonalWavelet(dec_lo: Tensor, dec_hi: Tensor, rec_lo: Tensor, rec_hi: Tensor)[source]#

Bases: ProductFilter, Module

Orthogonal wavelets with a soft orthogonality constraint.

Create a SoftOrthogonalWavelet object.

Parameters:
  • dec_lo (torch.Tensor) – Low pass analysis filter.

  • dec_hi (torch.Tensor) – High pass analysis filter.

  • rec_lo (torch.Tensor) – Low pass synthesis filter.

  • rec_hi (torch.Tensor) – High pass synthesis filter.

filt_bank_orthogonality_loss() Tensor[source]#

Return a Jensen+Harbo inspired soft orthogonality loss.

On Page 79 of the Book Ripples in Mathematics by Jensen la Cour-Harbo, the constraint g0[k] = h0[-k] and g1[k] = h1[-k] for orthogonal filters is presented. A measurement is implemented below.

Returns:

A tensor with the orthogonality constraint value.

Return type:

torch.Tensor

rec_lo_orthogonality_loss() Tensor[source]#

Return a Strang inspired soft orthogonality loss.

See Strang p. 148/149 or Harbo p. 80. Since L is a convolution matrix, LL^T can be evaluated trough convolution.

Returns:

A tensor with the orthogonality constraint value.

Return type:

torch.Tensor

wavelet_loss() Tensor[source]#

Return the sum of all terms.

class ptwt.wavelets_learnable.WaveletFilter[source]#

Bases: ABC

Interface for learnable wavelets.

Each wavelet has a filter bank loss function and comes with functionality that tests the perfect reconstruction and anti-aliasing conditions.

alias_cancellation_loss() Tuple[Tensor, Tensor, Tensor][source]#

Return the alias cancellation loss.

Implementation of the ac-loss as described on page 104 of Strang+Nguyen. F0(z)H0(-z) + F1(z)H1(-z) = 0

Returns:

The numerical value of the alias cancellation loss,

as well as both loss components for analysis.

Return type:

list

abstract property filter_bank: Tuple[Tensor, Tensor, Tensor, Tensor]#

Return dec_lo, dec_hi, rec_lo, rec_hi.

perfect_reconstruction_loss() Tuple[Tensor, Tensor, Tensor][source]#

Return the perfect reconstruction loss.

Returns:

The numerical value of the alias cancellation loss,

as well as both intermediate values for analysis.

Return type:

list

pf_alias_cancellation_loss() Tuple[Tensor, Tensor, Tensor][source]#

Return the product filter-alias cancellation loss.

See: Strang+Nguyen 105: F0(z) = H1(-z); F1(z) = -H0(-z) Alternating sign convention from 0 to N see Strang overview on the back of the cover.

Returns:

The numerical value of the alias cancellation loss,

as well as both loss components for analysis.

Return type:

list

abstract wavelet_loss() Tensor[source]#

Return the sum of all loss terms.

ptwt.constants#

Constants and types used throughout the PyTorch Wavelet Toolbox.

ptwt.constants.BoundaryMode#

This is a type literal for the way of padding.

  • Refection padding mirrors samples along the border.

  • Zero padding pads zeros.

  • Constant padding replicates border values.

  • Periodic padding cyclically repeats samples.

  • Symmetric padding mirrors samples along the border

alias of Literal[‘constant’, ‘zero’, ‘reflect’, ‘periodic’, ‘symmetric’]

ptwt.constants.OrthogonalizeMethod#

The method for orthogonalizing a matrix.

  1. ‘qr’ relies on pytorch’s dense qr implementation, it is fast but memory hungry.

  2. ‘gramschmidt’ option is sparse, memory efficient, and slow.

Choose ‘gramschmidt’ if ‘qr’ runs out of memory.

alias of Literal[‘qr’, ‘gramschmidt’]

ptwt.constants.PaddingMode#

The padding mode is used when construction convolution matrices.

alias of Literal[‘full’, ‘valid’, ‘same’, ‘sameshift’]