Wavelet Packet Transform (WPT)#

Packets in 1d using WaveletPacket#

class ptwt.WaveletPacket(data: Tensor | None, wavelet: Wavelet | str, mode: ptwt.constants.ExtendedBoundaryMode = 'reflect', maxlevel: int | None = None, axis: int = -1, orthogonalization: ptwt.constants.OrthogonalizeMethod = 'qr')[source]#

Implements a single-dimensional wavelet packet transform.

Create a wavelet packet decomposition object.

The packet tree is initialized lazily, i.e. a coefficient is only calculated as it is retrieved. This allows for partial expansion of the wavelet packet tree.

Parameters:
  • data (torch.Tensor, optional) – The input time series to transform. By default the last axis is transformed. Use the axis argument to choose another dimension. If None, the object is initialized without performing a decomposition.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • mode – The desired mode to handle signal boundaries. Select either the the sparse-matrix backend (boundary) or a padding mode. See ptwt.constants.ExtendedBoundaryMode. Defaults to reflect.

  • maxlevel (int, optional) – Value is passed on to transform(). The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

  • axis (int) – The axis to transform. Defaults to -1.

  • orthogonalization – The orthogonalization method to use in the sparse matrix backend, see ptwt.constants.OrthogonalizeMethod. Only used if mode equals boundary. Defaults to qr.

Changed in version 1.10: The argument boundary_orthogonalization has been renamed to orthogonalization.

Raises:

NotImplementedError – If the selected orthogonalization mode is not supported.

Example

>>> import torch, pywt, ptwt
>>> import numpy as np
>>> import scipy.signal
>>> import matplotlib.pyplot as plt
>>> t = np.linspace(0, 10, 1500)
>>> w = scipy.signal.chirp(t, f0=1, f1=50, t1=10, method="linear")
>>> wp = ptwt.WaveletPacket(data=torch.from_numpy(w.astype(np.float32)),
>>>     wavelet=pywt.Wavelet("db3"), mode="reflect")
>>> np_lst = [wp[node] for node in wp.get_level(5)]
>>> viz = np.stack(np_lst).squeeze()
>>> plt.imshow(np.abs(viz))
>>> plt.show()
__getitem__(key: str) Tensor[source]#

Access the coefficients in the wavelet packets tree.

Parameters:

key (str) – The key of the accessed coefficients. The string may only consist of the chars a and d where a denotes the low pass or approximation filter and d the high-pass or detail filter.

Returns:

The accessed wavelet packet coefficients.

Raises:
  • ValueError – If the wavelet packet tree is not initialized.

  • KeyError – If no wavelet coefficients are indexed by the specified key and a lazy initialization fails.

static get_level(level: int, order: Literal['freq', 'natural'] = 'freq') list[str][source]#

Return the paths to the filter tree nodes.

Parameters:
  • level (int) – The depth of the tree.

  • order – The order the paths are in. See ptwt.constants.PacketNodeOrder. Choose from frequency order (freq) and natural order (natural). Defaults to freq.

Returns:

A list with the paths to each node.

Raises:

ValueError – If order is neither freq nor natural.

initialize(keys: Iterable[str]) None[source]#

Initialize the wavelet packet tree partially.

Parameters:

keys (Iterable[str]) – An iterable yielding the keys of the tree nodes to initialize.

reconstruct() WaveletPacket[source]#

Recursively reconstruct the input starting from the leaf nodes.

Reconstruction replaces the input data originally assigned to this object.

Note

Only changes to leaf node data impact the results, since changes in all other nodes will be replaced with a reconstruction from the leaves.

Example

>>> import numpy as np
>>> import ptwt, torch
>>> signal = np.random.randn(1, 16)
>>> ptwp = ptwt.WaveletPacket(torch.from_numpy(signal), "haar",
>>>     mode="boundary", maxlevel=2)
>>> # initialize other leaf nodes
>>> ptwp.initialize(["ad", "da", "dd"])
>>> ptwp["aa"] = torch.zeros_like(ptwp["ad"])
>>> ptwp.reconstruct()
>>> print(ptwp[""])
Raises:

KeyError – if any leaf node data is not present.

transform(data: Tensor, maxlevel: int | None = None) WaveletPacket[source]#

Lazily calculate the 1d wavelet packet transform for the input data.

The packet tree is initialized lazily, i.e. a coefficient is only calculated as it is retrieved. This allows for partial expansion of the wavelet packet tree.

The transform function allows reusing the same object.

Parameters:
  • data (torch.Tensor) – The input time series to transform. By default the last axis is transformed. Use the axis argument to choose another dimension.

  • maxlevel (int, optional) – The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

Returns:

This wavelet packet object (to allow call chaining).

Packets in 2d using WaveletPacket2D#

class ptwt.WaveletPacket2D(data: Tensor | None, wavelet: Wavelet | str, mode: ptwt.constants.ExtendedBoundaryMode = 'reflect', maxlevel: int | None = None, axes: tuple[int, int] = (-2, -1), orthogonalization: ptwt.constants.OrthogonalizeMethod = 'qr', separable: bool = False)[source]#

Two-dimensional wavelet packets.

Example code illustrating the use of this class is available at: v0lta/PyTorch-Wavelet-Toolbox

Create a 2D-Wavelet packet tree.

The packet tree is initialized lazily, i.e. a coefficient is only calculated as it is retrieved. This allows for partial expansion of the wavelet packet tree.

Parameters:
  • data (torch.tensor, optional) – The input data tensor. If None, the object is initialized without performing a decomposition.

  • wavelet (Wavelet or str) – A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from pywt.wavelist(kind='discrete') for possible choices.

  • mode – The desired mode to handle signal boundaries. Select either the the sparse-matrix backend (boundary) or a padding mode. See ptwt.constants.ExtendedBoundaryMode. Defaults to reflect.

  • maxlevel (int, optional) – Value is passed on to transform(). The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

  • axes ([int, int], optional) – The tensor axes that should be transformed. Defaults to (-2, -1).

  • orthogonalization – The orthogonalization method to use in the sparse matrix backend, see ptwt.constants.OrthogonalizeMethod. Only used if mode equals boundary. Defaults to qr.

  • separable (bool) – If true, a separable transform is performed, i.e. each image axis is transformed separately. Defaults to False.

Changed in version 1.10: The argument boundary_orthogonalization has been renamed to orthogonalization.

Raises:

NotImplementedError – If the selected orthogonalization mode is not supported.

__getitem__(key: str) Tensor[source]#

Access the coefficients in the wavelet packets tree.

Parameters:

key (str) – The key of the accessed coefficients. The string may only consist of the following chars: a, h, v, d The chars correspond to the selected coefficients for a level where a denotes the approximation coefficients and h horizontal, v vertical and d diagonal details coefficients.

Returns:

The accessed wavelet packet coefficients.

Raises:
  • ValueError – If the wavelet packet tree is not initialized.

  • KeyError – If no wavelet coefficients are indexed by the specified key and a lazy initialization fails.

static get_freq_order(level: int) list[list[str]][source]#

Get the frequency order for a given packet decomposition level.

Use this code to create two-dimensional frequency orderings.

Parameters:

level (int) – The number of decomposition scales.

Returns:

A list with the tree nodes in frequency order.

Note

Adapted from: PyWavelets/pywt

The code elements denote the filter application order. The filters are named following the pywt convention as: a - LL, low-low coefficients h - LH, low-high coefficients v - HL, high-low coefficients d - HH, high-high coefficients

static get_level(level: int, order: Literal['freq']) list[list[str]][source]#
static get_level(level: int, order: Literal['natural']) list[str]

Return the paths to the filter tree nodes.

Parameters:
  • level (int) – The depth of the tree.

  • order – The order the paths are in. See ptwt.constants.PacketNodeOrder. Choose from frequency order (freq) and natural order (natural). Defaults to freq.

Returns:

A list with the paths to each node.

Raises:

ValueError – If order is neither freq nor natural.

static get_natural_order(level: int) list[str][source]#

Get the natural ordering for a given decomposition level.

Parameters:

level (int) – The decomposition level.

Returns:

A list with the filter order strings.

initialize(keys: Iterable[str]) None[source]#

Initialize the wavelet packet tree partially.

Parameters:

keys (Iterable[str]) – An iterable yielding the keys of the tree nodes to initialize.

reconstruct() WaveletPacket2D[source]#

Recursively reconstruct the input starting from the leaf nodes.

Note

Only changes to leaf node data impact the results, since changes in all other nodes will be replaced with a reconstruction from the leaves.

Raises:

KeyError – if any leaf node data is not present.

transform(data: Tensor, maxlevel: int | None = None) WaveletPacket2D[source]#

Lazily calculate the 2d wavelet packet transform for the input data.

The packet tree is initialized lazily, i.e. a coefficient is only calculated as it is retrieved. This allows for partial expansion of the wavelet packet tree.

The transform function allows reusing the same object.

Parameters:
  • data (torch.tensor) – The input data tensor of at least two dimensions. By default, the last two axes are transformed. The axes class attribute allows other choices.

  • maxlevel (int, optional) – The highest decomposition level to compute. If None, the maximum level is determined from the input data shape. Defaults to None.

Returns:

This wavelet packet object (to allow call chaining).

Node ordering#

ptwt.constants.PacketNodeOrder#

This is a type literal for the order of wavelet packet tree nodes.

  • frequency order (freq)

  • natural order (natural)