Source code for ptwt.wavelets_learnable

"""Experimental code for adaptive wavelet learning.

See https://arxiv.org/pdf/2004.09569.pdf for more information.
"""

# Inspired by Ripples in Mathematics, Jensen and La Cour-Harbo, Chapter 7.7
from abc import ABC, abstractmethod
from typing import Tuple

import torch


[docs] class WaveletFilter(ABC): """Interface for learnable wavelets. Each wavelet has a filter bank loss function and comes with functionality that tests the perfect reconstruction and anti-aliasing conditions. """ @property @abstractmethod def filter_bank( self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Return dec_lo, dec_hi, rec_lo, rec_hi.""" raise NotImplementedError
[docs] @abstractmethod def wavelet_loss(self) -> torch.Tensor: """Return the sum of all loss terms.""" return self.alias_cancellation_loss()[0] + self.perfect_reconstruction_loss()[0]
@abstractmethod def __len__(self) -> int: """Return the filter length.""" raise NotImplementedError # @abstractmethod # def parameters(self): # raise NotImplementedError
[docs] def pf_alias_cancellation_loss( self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return the product filter-alias cancellation loss. See: Strang+Nguyen 105: F0(z) = H1(-z); F1(z) = -H0(-z) Alternating sign convention from 0 to N see Strang overview on the back of the cover. Returns: list: The numerical value of the alias cancellation loss, as well as both loss components for analysis. """ dec_lo, dec_hi, rec_lo, rec_hi = self.filter_bank m1 = torch.tensor([-1], device=dec_lo.device, dtype=dec_lo.dtype) length = dec_lo.shape[0] mask = torch.tensor( [torch.pow(m1, n) for n in range(length)][::-1], device=dec_lo.device, dtype=dec_lo.dtype, ) err1 = rec_lo - mask * dec_hi err1s = torch.sum(err1 * err1) length = dec_lo.shape[0] mask = torch.tensor( [torch.pow(m1, n) for n in range(length)][::-1], device=dec_lo.device, dtype=dec_lo.dtype, ) err2 = rec_hi - m1 * mask * dec_lo err2s = torch.sum(err2 * err2) return err1s + err2s, err1, err2
[docs] def alias_cancellation_loss( self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return the alias cancellation loss. Implementation of the ac-loss as described on page 104 of Strang+Nguyen. F0(z)H0(-z) + F1(z)H1(-z) = 0 Returns: list: The numerical value of the alias cancellation loss, as well as both loss components for analysis. """ dec_lo, dec_hi, rec_lo, rec_hi = self.filter_bank m1 = torch.tensor([-1], device=dec_lo.device, dtype=dec_lo.dtype) length = dec_lo.shape[0] mask = torch.tensor( [torch.pow(m1, n) for n in range(length)][::-1], device=dec_lo.device, dtype=dec_lo.dtype, ) # polynomial multiplication is convolution, compute p(z): pad = dec_lo.shape[0] - 1 p_lo = torch.nn.functional.conv1d( dec_lo.unsqueeze(0).unsqueeze(0) * mask, torch.flip(rec_lo, [-1]).unsqueeze(0).unsqueeze(0), padding=pad, ) pad = dec_hi.shape[0] - 1 p_hi = torch.nn.functional.conv1d( dec_hi.unsqueeze(0).unsqueeze(0) * mask, torch.flip(rec_hi, [-1]).unsqueeze(0).unsqueeze(0), padding=pad, ) p_test = p_lo + p_hi zeros = torch.zeros(p_test.shape, device=p_test.device, dtype=p_test.dtype) errs = (p_test - zeros) * (p_test - zeros) return torch.sum(errs), p_test, zeros
[docs] def perfect_reconstruction_loss( self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return the perfect reconstruction loss. Returns: list: The numerical value of the alias cancellation loss, as well as both intermediate values for analysis. """ # Strang 107: Assuming alias cancellation holds: # P(z) = F(z)H(z) # Product filter P(z) + P(-z) = 2. # However, since alias cancellation is implemented as a soft constraint: # P_0 + P_1 = 2 # Somehow NumPy and PyTorch implement convolution differently. # For some reason, the machine learning people call cross-correlation # convolution. # https://discuss.pytorch.org/t/numpy-convolve-and-conv1d-in-pytorch/12172 # Therefore for true convolution, one element needs to be flipped. dec_lo, dec_hi, rec_lo, rec_hi = self.filter_bank # polynomial multiplication is convolution, compute p(z): pad = dec_lo.shape[0] - 1 p_lo = torch.nn.functional.conv1d( dec_lo.unsqueeze(0).unsqueeze(0), torch.flip(rec_lo, [-1]).unsqueeze(0).unsqueeze(0), padding=pad, ) pad = dec_hi.shape[0] - 1 p_hi = torch.nn.functional.conv1d( dec_hi.unsqueeze(0).unsqueeze(0), torch.flip(rec_hi, [-1]).unsqueeze(0).unsqueeze(0), padding=pad, ) p_test = p_lo + p_hi two_at_power_zero = torch.zeros( p_test.shape, device=p_test.device, dtype=p_test.dtype ) # numpy comparison for debugging. # np.convolve(self.init_wavelet.filter_bank[0], # self.init_wavelet.filter_bank[2]) # np.convolve(self.init_wavelet.filter_bank[1], # self.init_wavelet.filter_bank[3]) two_at_power_zero[..., p_test.shape[-1] // 2] = 2 # square the error errs = (p_test - two_at_power_zero) * (p_test - two_at_power_zero) return torch.sum(errs), p_test, two_at_power_zero
[docs] class ProductFilter(WaveletFilter, torch.nn.Module): """Learnable product filter implementation.""" def __init__( self, dec_lo: torch.Tensor, dec_hi: torch.Tensor, rec_lo: torch.Tensor, rec_hi: torch.Tensor, ): """Create a Product filter object. Args: dec_lo (torch.Tensor): Low pass analysis filter. dec_hi (torch.Tensor): High pass analysis filter. rec_lo (torch.Tensor): Low pass synthesis filter. rec_hi (torch.Tensor): High pass synthesis filter. """ super().__init__() self.dec_lo = torch.nn.Parameter(dec_lo) self.dec_hi = torch.nn.Parameter(dec_hi) self.rec_lo = torch.nn.Parameter(rec_lo) self.rec_hi = torch.nn.Parameter(rec_hi) @property def filter_bank( self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Return all filters a a tuple.""" return self.dec_lo, self.dec_hi, self.rec_lo, self.rec_hi # def parameters(self): # return [self.dec_lo, self.dec_hi, self.rec_lo, self.rec_hi] def __len__(self) -> int: """Return the length of all filter arrays.""" return self.dec_lo.shape[-1]
[docs] def product_filter_loss(self) -> torch.Tensor: """Get only the product filter loss. Returns: torch.Tensor: The loss scalar. """ return self.perfect_reconstruction_loss()[0] + self.alias_cancellation_loss()[0]
[docs] def wavelet_loss(self) -> torch.Tensor: """Return the sum of all loss terms. Returns: torch.Tensor: The loss scalar. """ return self.product_filter_loss()
[docs] class SoftOrthogonalWavelet(ProductFilter, torch.nn.Module): """Orthogonal wavelets with a soft orthogonality constraint.""" def __init__( self, dec_lo: torch.Tensor, dec_hi: torch.Tensor, rec_lo: torch.Tensor, rec_hi: torch.Tensor, ): """Create a SoftOrthogonalWavelet object. Args: dec_lo (torch.Tensor): Low pass analysis filter. dec_hi (torch.Tensor): High pass analysis filter. rec_lo (torch.Tensor): Low pass synthesis filter. rec_hi (torch.Tensor): High pass synthesis filter. """ super().__init__(dec_lo, dec_hi, rec_lo, rec_hi)
[docs] def rec_lo_orthogonality_loss(self) -> torch.Tensor: """Return a Strang inspired soft orthogonality loss. See Strang p. 148/149 or Harbo p. 80. Since L is a convolution matrix, LL^T can be evaluated trough convolution. Returns: torch.Tensor: A tensor with the orthogonality constraint value. """ filt_len = self.dec_lo.shape[-1] pad_dec_lo = torch.cat( [ self.dec_lo, torch.zeros( [ filt_len, ], device=self.dec_lo.device, ), ], -1, ) res = torch.nn.functional.conv1d( pad_dec_lo.unsqueeze(0).unsqueeze(0), self.dec_lo.unsqueeze(0).unsqueeze(0), stride=2, ) test = torch.zeros_like(res.squeeze(0).squeeze(0)) test[0] = 1 err = res - test return torch.sum(err * err)
[docs] def filt_bank_orthogonality_loss(self) -> torch.Tensor: """Return a Jensen+Harbo inspired soft orthogonality loss. On Page 79 of the Book Ripples in Mathematics by Jensen la Cour-Harbo, the constraint g0[k] = h0[-k] and g1[k] = h1[-k] for orthogonal filters is presented. A measurement is implemented below. Returns: torch.Tensor: A tensor with the orthogonality constraint value. """ eq0 = self.dec_lo - self.rec_lo.flip(-1) eq1 = self.dec_hi - self.rec_hi.flip(-1) seq0 = torch.sum(eq0 * eq0) seq1 = torch.sum(eq1 * eq1) # print(eq0, eq1) return seq0 + seq1
[docs] def wavelet_loss(self) -> torch.Tensor: """Return the sum of all terms.""" return self.product_filter_loss() + self.filt_bank_orthogonality_loss()