Learnable adaptive wavelets#

Experimental code for adaptive wavelet learning.

See https://arxiv.org/pdf/2004.09569.pdf for more information.

class ptwt.wavelets_learnable.ProductFilter(dec_lo: Tensor, dec_hi: Tensor, rec_lo: Tensor, rec_hi: Tensor)[source]#

Bases: WaveletFilter, Module

Learnable product filter implementation.

Create a Product filter object.

Parameters:
  • dec_lo (torch.Tensor) – Low pass analysis filter.

  • dec_hi (torch.Tensor) – High pass analysis filter.

  • rec_lo (torch.Tensor) – Low pass synthesis filter.

  • rec_hi (torch.Tensor) – High pass synthesis filter.

property filter_bank: tuple[Tensor, Tensor, Tensor, Tensor]#

All filters a a tuple.

product_filter_loss() Tensor[source]#

Get only the product filter loss.

Returns:

The loss scalar.

wavelet_loss() Tensor[source]#

Return the sum of all loss terms.

Returns:

The loss scalar.

class ptwt.wavelets_learnable.SoftOrthogonalWavelet(dec_lo: Tensor, dec_hi: Tensor, rec_lo: Tensor, rec_hi: Tensor)[source]#

Bases: ProductFilter, Module

Orthogonal wavelets with a soft orthogonality constraint.

Create a SoftOrthogonalWavelet object.

Parameters:
  • dec_lo (torch.Tensor) – Low pass analysis filter.

  • dec_hi (torch.Tensor) – High pass analysis filter.

  • rec_lo (torch.Tensor) – Low pass synthesis filter.

  • rec_hi (torch.Tensor) – High pass synthesis filter.

filt_bank_orthogonality_loss() Tensor[source]#

Return a Jensen+Harbo inspired soft orthogonality loss.

On Page 79 of the Book Ripples in Mathematics by Jensen la Cour-Harbo, the constraint g0[k] = h0[-k] and g1[k] = h1[-k] for orthogonal filters is presented. A measurement is implemented below.

Returns:

A tensor with the orthogonality constraint value.

rec_lo_orthogonality_loss() Tensor[source]#

Return a Strang inspired soft orthogonality loss.

See Strang p. 148/149 or Harbo p. 80. Since L is a convolution matrix, LL^T can be evaluated trough convolution.

Returns:

A tensor with the orthogonality constraint value.

wavelet_loss() Tensor[source]#

Return the sum of all terms.

class ptwt.wavelets_learnable.WaveletFilter[source]#

Bases: ABC

Interface for learnable wavelets.

Each wavelet has a filter bank loss function and comes with functionality that tests the perfect reconstruction and anti-aliasing conditions.

alias_cancellation_loss() tuple[Tensor, Tensor, Tensor][source]#

Return the alias cancellation loss.

Implementation of the ac-loss as described on page 104 of Strang+Nguyen. $$F_0(z)H_0(-z) + F_1(z)H_1(-z) = 0$$

Returns:

The numerical value of the alias cancellation loss, as well as both loss components for analysis.

abstract property filter_bank: tuple[Tensor, Tensor, Tensor, Tensor]#

Return dec_lo, dec_hi, rec_lo, rec_hi.

perfect_reconstruction_loss() tuple[Tensor, Tensor, Tensor][source]#

Return the perfect reconstruction loss.

Returns:

The numerical value of the alias cancellation loss, as well as both intermediate values for analysis.

pf_alias_cancellation_loss() tuple[Tensor, Tensor, Tensor][source]#

Return the product filter-alias cancellation loss.

See: Strang+Nguyen 105: $$F_0(z) = H_1(-z); F_1(z) = -H_0(-z)$$ Alternating sign convention from 0 to N see Strang overview on the back of the cover.

Returns:

The numerical value of the alias cancellation loss, as well as both loss components for analysis.

abstractmethod wavelet_loss() Tensor[source]#

Return the sum of all loss terms.