"""Implement matrix-based FWT and iFWT.
This module uses boundary filters instead of padding.
The implementation is based on the description
in Strang Nguyen (p. 32), as well as the description
of boundary filters in "Ripples in Mathematics" section 10.3 .
"""
import sys
from typing import Optional, Union
import numpy as np
import torch
from ._util import (
_as_wavelet,
_check_same_device_dtype,
_deprecated_alias,
_get_filter_tensors,
_is_orthogonalize_method_supported,
_postprocess_coeffs,
_postprocess_tensor,
_preprocess_coeffs,
_preprocess_tensor,
)
from .constants import (
BoundaryMode,
OrthogonalizeMethod,
PaddingMode,
Wavelet,
WaveletCoeff1d,
)
from .conv_transform import _fwt_pad
from .sparse_math import (
_orth_by_gram_schmidt,
_orth_by_qr,
cat_sparse_identity_matrix,
construct_strided_conv_matrix,
)
__all__ = ["orthogonalize", "MatrixWavedec", "MatrixWaverec"]
def _construct_a(
wavelet: Union[Wavelet, str],
length: int,
device: Union[torch.device, str] = "cpu",
dtype: torch.dtype = torch.float64,
*,
mode: PaddingMode = "sameshift",
) -> torch.Tensor:
"""Construct a raw analysis matrix.
The resulting matrix will only be orthogonal in the Haar case,
in most cases, you will want to use construct_boundary_a instead.
Args:
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
length (int): The length of the input signal to transform.
device (torch.device or str, optional): Where to create the matrix.
Choose a torch device or device name. Defaults to "cpu".
dtype (torch.dtype): The desired torch datatype. Choose torch.float32
or torch.float64. Defaults to torch.float64.
mode: The padding mode to use.
See :data:`ptwt.constants.PaddingMode`. Defaults to 'sameshift'.
Returns:
The sparse raw analysis matrix.
"""
wavelet = _as_wavelet(wavelet)
dec_lo, dec_hi, _, _ = _get_filter_tensors(
wavelet, flip=False, device=device, dtype=dtype
)
analysis_lo = construct_strided_conv_matrix(dec_lo.squeeze(), length, 2, mode=mode)
analysis_hi = construct_strided_conv_matrix(dec_hi.squeeze(), length, 2, mode=mode)
analysis = torch.cat([analysis_lo, analysis_hi])
return analysis
def _construct_s(
wavelet: Union[Wavelet, str],
length: int,
device: Union[torch.device, str] = "cpu",
dtype: torch.dtype = torch.float64,
*,
mode: PaddingMode = "sameshift",
) -> torch.Tensor:
"""Create a raw synthesis matrix.
The construced matrix is NOT necessary orthogonal.
In most cases construct_boundary_s should be used instead.
Args:
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
length (int): The length of the originally transformed signal.
device (torch.device, optional): Choose cuda or cpu.
Defaults to torch.device("cpu").
dtype (torch.dtype): The desired data type. Choose torch.float32
or torch.float64. Defaults to torch.float64.
mode: The padding mode to use.
See :data:`ptwt.constants.PaddingMode`. Defaults to 'sameshift'.
Returns:
The raw sparse synthesis matrix.
"""
wavelet = _as_wavelet(wavelet)
_, _, rec_lo, rec_hi = _get_filter_tensors(
wavelet, flip=True, device=device, dtype=dtype
)
synthesis_lo = construct_strided_conv_matrix(rec_lo.squeeze(), length, 2, mode=mode)
synthesis_hi = construct_strided_conv_matrix(rec_hi.squeeze(), length, 2, mode=mode)
synthesis = torch.cat([synthesis_lo, synthesis_hi])
return synthesis.transpose(0, 1)
def _get_to_orthogonalize(matrix: torch.Tensor, filt_len: int) -> torch.Tensor:
"""Find matrix rows with fewer entries than filt_len.
The returned rows will need to be orthogonalized.
Args:
matrix (torch.Tensor): The wavelet matrix under consideration.
filt_len (int): The number of entries we would expect per row.
Returns:
The row indices with too few entries.
"""
unique, count = torch.unique_consecutive(
matrix.coalesce().indices()[0, :], return_counts=True
)
return unique[count != filt_len]
def orthogonalize(
matrix: torch.Tensor, filt_len: int, method: OrthogonalizeMethod = "qr"
) -> torch.Tensor:
"""Orthogonalization for sparse filter matrices.
Args:
matrix (torch.Tensor): The sparse filter matrix to orthogonalize.
filt_len (int): The length of the wavelet filter coefficients.
method : The orthogonalization method to use. Choose qr
or gramschmidt. The dense qr code will run much faster
than sparse gramschidt. Choose gramschmidt if qr fails.
Defaults to qr.
Returns:
Orthogonal sparse transformation matrix.
Raises:
ValueError: If an invalid orthogonalization method is given
"""
to_orthogonalize = _get_to_orthogonalize(matrix, filt_len)
if len(to_orthogonalize) == 0:
return matrix
if method == "qr":
return _orth_by_qr(matrix, to_orthogonalize)
elif method == "gramschmidt":
return _orth_by_gram_schmidt(matrix, to_orthogonalize)
raise ValueError(f"Invalid orthogonalization method: {method}")
[docs]
class BaseMatrixWaveDec:
"""A base class for matrix wavedec."""
[docs]
class MatrixWavedec(BaseMatrixWaveDec):
"""Compute the 1d fast wavelet transform using sparse matrices.
This transform is the sparse matrix correspondant to
:func:`ptwt.wavedec`. The convolution operations are
implemented as a matrix-vector product between a
sparse transformation matrix and the input signal.
This transform uses boundary wavelets instead of padding to
handle the signal boundaries, see the
:ref:`boundary wavelet docs <modes.boundary wavelets>`.
Note:
Constructing the sparse FWT matrix can be expensive.
For longer wavelets, high-level transforms, and large
input images this may take a while.
The matrix is therefore constructed only once and reused
in further calls.
The sparse transformation matrix can be accessed
via the :attr:`sparse_fwt_operator` property.
Note:
On each level of the transform the convolved signal
is required to be of even length. This transform uses
padding to transform coefficients with an odd length,
with the padding mode specified by `odd_coeff_padding_mode`.
To avoid padding consider transforming signals
with a length divisable by :math:`2^L`
for a :math:`L`-level transform.
Example:
>>> import ptwt, torch
>>> # generate an input of even length.
>>> data = torch.arange(8, dtype=torch.float32)
>>> # First, construct the transformation object
>>> matrix_wavedec = ptwt.MatrixWavedec('haar', level=2)
>>> # Then, the FWT is computed by calling the object.
>>> coefficients = matrix_wavedec(data)
"""
@_deprecated_alias(boundary="orthogonalization")
def __init__(
self,
wavelet: Union[Wavelet, str],
level: Optional[int] = None,
*,
axis: int = -1,
orthogonalization: OrthogonalizeMethod = "qr",
odd_coeff_padding_mode: BoundaryMode = "zero",
) -> None:
"""Create a sparse matrix fast wavelet transform object.
Args:
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
Refer to the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
level (int, optional): The level up to which to compute the FWT. If None,
the maximum level based on the signal length is chosen. Defaults to
None.
axis (int): The axis we would like to transform. Defaults to -1.
orthogonalization: The method used to orthogonalize
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
Defaults to ``qr``.
odd_coeff_padding_mode: The constructed FWT matrices require inputs
with even lengths. Thus, any odd-length approximation coefficients
are padded to an even length using this mode,
see :data:`ptwt.constants.BoundaryMode`.
Defaults to ``zero``.
.. versionchanged:: 1.10
The argument `boundary` has been renamed to `orthogonalization`.
Raises:
NotImplementedError: If the selected `orthogonalization` mode
is not supported.
ValueError: If the wavelet filters have different lengths or
if axis is not an integer.
"""
self.wavelet = _as_wavelet(wavelet)
self.level = level
self.odd_coeff_padding_mode = odd_coeff_padding_mode
self.orthogonalization = orthogonalization
if isinstance(axis, int):
self.axis = axis
else:
raise ValueError("MatrixWavedec transforms a single axis only.")
self.input_length: Optional[int] = None
self.fwt_matrix_list: list[torch.Tensor] = []
self.pad_list: list[bool] = []
self.padded = False
self.size_list: list[int] = []
if not _is_orthogonalize_method_supported(self.orthogonalization):
raise NotImplementedError
if self.wavelet.dec_len != self.wavelet.rec_len:
raise ValueError("All filters must have the same length")
@property
def sparse_fwt_operator(self) -> torch.Tensor:
"""The sparse transformation operator.
If the input signal at all levels is divisible by two,
the whole operation is padding-free and can be expressed
as a single matrix multiply.
The operation
.. code-block:: python
torch.sparse.mm(sparse_fwt_operator, data.T)
computes a batched FWT.
This property exists to make the operator matrix transparent.
Calling the object will handle odd-length inputs properly.
Raises:
NotImplementedError: if padding had to be used in the creation of the
transformation matrices.
ValueError: If no level transformation matrices are stored (most likely
since the object was not called yet).
"""
if len(self.fwt_matrix_list) == 1:
return self.fwt_matrix_list[0]
elif len(self.fwt_matrix_list) > 1:
if self.padded:
raise NotImplementedError
fwt_matrix = self.fwt_matrix_list[0]
for scale_mat in self.fwt_matrix_list[1:]:
scale_mat = cat_sparse_identity_matrix(scale_mat, fwt_matrix.shape[0])
fwt_matrix = torch.sparse.mm(scale_mat, fwt_matrix)
return fwt_matrix
else:
raise ValueError(
"Call this object first to create the transformation matrices for each "
"level."
)
def _construct_analysis_matrices(
self, device: Union[torch.device, str], dtype: torch.dtype
) -> None:
if self.level is None or self.input_length is None:
raise AssertionError
self.fwt_matrix_list = []
self.size_list = []
self.pad_list = []
self.padded = False
filt_len = self.wavelet.dec_len
curr_length = self.input_length
for curr_level in range(1, self.level + 1):
if curr_length < filt_len:
sys.stderr.write(
f"Warning: The selected number of decomposition levels {self.level}"
f" is too large for the given input size {self.input_length}. At "
f"level {curr_level}, the current signal length {curr_length} is "
f"smaller than the filter length {filt_len}. Therefore, the "
"transformation is only computed up to the decomposition level "
f"{curr_level-1}.\n"
)
break
if curr_length % 2 != 0:
# padding
curr_length += 1
self.padded = True
self.pad_list.append(True)
else:
self.pad_list.append(False)
self.size_list.append(curr_length)
an = construct_boundary_a(
self.wavelet,
curr_length,
orthogonalization=self.orthogonalization,
device=device,
dtype=dtype,
)
self.fwt_matrix_list.append(an)
curr_length = curr_length // 2
self.size_list.append(curr_length)
[docs]
def __call__(self, input_signal: torch.Tensor) -> list[torch.Tensor]:
"""Compute the matrix FWT for the given input signal.
Matrix FWTs are used to avoid padding.
Args:
input_signal (torch.Tensor): Input data to transform.
This transform affects the last axis by default.
Use the `axis` argument in the constructor to choose
another axis.
Returns:
A list with the coefficient tensor for each scale.
Raises:
ValueError: If the decomposition level is not a positive integer
or if the input signal has not the expected shape.
"""
input_signal, ds = _preprocess_tensor(
input_signal,
ndim=1,
axes=self.axis,
add_channel_dim=False,
)
if input_signal.shape[-1] % 2 != 0:
# odd length input
input_signal = _fwt_pad(
input_signal,
wavelet=self.wavelet,
mode=self.odd_coeff_padding_mode,
padding=(0, 1),
)
_, length = input_signal.shape
re_build = False
if self.input_length != length:
self.input_length = length
re_build = True
if self.level is None:
wlen = len(self.wavelet)
self.level = int(np.log2(length / (wlen - 1)))
re_build = True
elif self.level <= 0:
raise ValueError("level must be a positive integer.")
if not self.fwt_matrix_list or re_build:
self._construct_analysis_matrices(
device=input_signal.device, dtype=input_signal.dtype
)
lo = input_signal.T
split_list = []
for scale, fwt_matrix in enumerate(self.fwt_matrix_list):
if self.pad_list[scale]:
# fix odd coefficients lengths for the conv matrix to work.
lo = lo.T.unsqueeze(1)
lo = _fwt_pad(
lo,
wavelet=self.wavelet,
mode=self.odd_coeff_padding_mode,
padding=(0, 1),
)
lo = lo.squeeze(1).T
coefficients = torch.sparse.mm(fwt_matrix, lo)
lo, hi = torch.split(coefficients, coefficients.shape[0] // 2, dim=0)
split_list.append(hi)
split_list.append(lo)
# undo the transpose we used to handle the batch dimension.
result_list = [s.T for s in split_list[::-1]]
# unfold if necessary
return _postprocess_coeffs(result_list, ndim=1, ds=ds, axes=self.axis)
@_deprecated_alias(boundary="orthogonalization")
def construct_boundary_a(
wavelet: Union[Wavelet, str],
length: int,
device: Union[torch.device, str] = "cpu",
orthogonalization: OrthogonalizeMethod = "qr",
dtype: torch.dtype = torch.float64,
) -> torch.Tensor:
"""Construct a boundary-wavelet filter 1d-analysis matrix.
Args:
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
length (int): The number of entries in the input signal.
orthogonalization: The method used to orthogonalize
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
Defaults to 'qr'.
device: Where to place the matrix. Choose cpu or cuda.
Defaults to cpu.
dtype: Choose float32 or float64.
.. versionchanged:: 1.10
The argument `boundary` has been renamed to `orthogonalization`.
Returns:
The sparse analysis matrix.
"""
wavelet = _as_wavelet(wavelet)
a_full = _construct_a(wavelet, length, dtype=dtype, device=device)
a_orth = orthogonalize(a_full, wavelet.dec_len, method=orthogonalization)
return a_orth
@_deprecated_alias(boundary="orthogonalization")
def construct_boundary_s(
wavelet: Union[Wavelet, str],
length: int,
device: Union[torch.device, str] = "cpu",
orthogonalization: OrthogonalizeMethod = "qr",
dtype: torch.dtype = torch.float64,
) -> torch.Tensor:
"""Construct a boundary-wavelet filter 1d-synthesis matarix.
Args:
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
length (int): The number of entries in the input signal.
device (torch.device): Where to place the matrix.
Choose cpu or cuda. Defaults to cpu.
orthogonalization: The method used to orthogonalize
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
Defaults to 'qr'.
dtype: Choose torch.float32 or torch.float64.
Defaults to torch.float64.
.. versionchanged:: 1.10
The argument `boundary` has been renamed to `orthogonalization`.
Returns:
The sparse synthesis matrix.
"""
wavelet = _as_wavelet(wavelet)
s_full = _construct_s(wavelet, length, dtype=dtype, device=device)
s_orth = orthogonalize(
s_full.transpose(1, 0), wavelet.rec_len, method=orthogonalization
)
return s_orth.transpose(1, 0)
[docs]
class MatrixWaverec(object):
"""Matrix-based inverse fast wavelet transform.
Example:
>>> import ptwt, torch
>>> # generate an input of even length.
>>> data = torch.arange(8, dtype=torch.float32)
>>> matrix_wavedec = ptwt.MatrixWavedec('haar', level=2)
>>> coefficients = matrix_wavedec(data)
>>> matrix_waverec = ptwt.MatrixWaverec('haar')
>>> reconstruction = matrix_waverec(coefficients)
"""
@_deprecated_alias(boundary="orthogonalization")
def __init__(
self,
wavelet: Union[Wavelet, str],
*,
axis: int = -1,
orthogonalization: OrthogonalizeMethod = "qr",
) -> None:
"""Create the inverse matrix-based fast wavelet transformation.
Args:
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
Refer to the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
axis (int): The axis transformed by the original decomposition
defaults to -1 or the last axis.
orthogonalization: The method used to orthogonalize
boundary filters, see :data:`ptwt.constants.OrthogonalizeMethod`.
Defaults to ``qr``.
.. versionchanged:: 1.10
The argument `boundary` has been renamed to `orthogonalization`.
Raises:
NotImplementedError: If the selected `orthogonalization` mode
is not supported.
ValueError: If the wavelet filters have different lengths or if
axis is not an integer.
"""
self.wavelet = _as_wavelet(wavelet)
self.orthogonalization = orthogonalization
if isinstance(axis, int):
self.axis = axis
else:
raise ValueError("MatrixWaverec transforms a single axis only.")
self.ifwt_matrix_list: list[torch.Tensor] = []
self.level: Optional[int] = None
self.input_length: Optional[int] = None
self.padded = False
if not _is_orthogonalize_method_supported(self.orthogonalization):
raise NotImplementedError
if self.wavelet.dec_len != self.wavelet.rec_len:
raise ValueError("All filters must have the same length")
@property
def sparse_ifwt_operator(self) -> torch.Tensor:
"""The sparse transformation operator.
If the input signal at all levels is divisible by two,
the whole operation is padding-free and can be expressed
as a single matrix multiply.
Having concatenated the analysis coefficients,
.. code-block:: python
torch.sparse.mm(sparse_ifwt_operator, coefficients.T)
to computes a batched iFWT.
This functionality is mainly here to make the operator-matrix
transparent. Calling the object handles padding for odd inputs.
Raises:
NotImplementedError: if padding had to be used in the creation of the
transformation matrices.
ValueError: If no level transformation matrices are stored (most likely
since the object was not called yet).
"""
if len(self.ifwt_matrix_list) == 1:
return self.ifwt_matrix_list[0]
elif len(self.ifwt_matrix_list) > 1:
if self.padded:
raise NotImplementedError
ifwt_matrix = self.ifwt_matrix_list[-1]
for scale_matrix in self.ifwt_matrix_list[:-1][::-1]:
ifwt_matrix = cat_sparse_identity_matrix(
ifwt_matrix, scale_matrix.shape[0]
)
ifwt_matrix = torch.sparse.mm(scale_matrix, ifwt_matrix)
return ifwt_matrix
else:
raise ValueError(
"Call this object first to create the transformation matrices for each "
"level."
)
def _construct_synthesis_matrices(
self, device: Union[torch.device, str], dtype: torch.dtype
) -> None:
self.ifwt_matrix_list = []
self.size_list = []
self.padded = False
if self.level is None or self.input_length is None:
raise AssertionError
filt_len = self.wavelet.rec_len
curr_length = self.input_length
for curr_level in range(1, self.level + 1):
if curr_length < filt_len:
sys.stderr.write(
f"Warning: The selected number of decomposition levels {self.level}"
f" is too large for the given input size {self.input_length}. At "
f"level {curr_level}, the current signal length {curr_length} is "
f"smaller than the filter length {filt_len}. Therefore, the "
"transformation is only computed up to the decomposition level "
f"{curr_level-1}.\n"
)
break
if curr_length % 2 != 0:
# padding
curr_length += 1
self.padded = True
self.size_list.append(curr_length)
sn = construct_boundary_s(
self.wavelet,
curr_length,
orthogonalization=self.orthogonalization,
device=device,
dtype=dtype,
)
self.ifwt_matrix_list.append(sn)
curr_length = curr_length // 2
[docs]
def __call__(self, coefficients: WaveletCoeff1d) -> torch.Tensor:
"""Run the synthesis or inverse matrix FWT.
Args:
coefficients: The coefficients produced by the forward transform
:class:`MatrixWavedec`. See :data:`ptwt.constants.WaveletCoeff1d`.
Returns:
The input signal reconstruction.
Raises:
ValueError: If the decomposition level is not a positive integer or if the
coefficients are not in the shape as it is returned from a
:class:`MatrixWavedec` object.
"""
if not isinstance(coefficients, list):
coefficients = list(coefficients)
coefficients, ds = _preprocess_coeffs(coefficients, ndim=1, axes=self.axis)
torch_device, torch_dtype = _check_same_device_dtype(coefficients)
level = len(coefficients) - 1
input_length = coefficients[-1].shape[-1] * 2
re_build = False
if self.level != level or self.input_length != input_length:
self.level = level
self.input_length = input_length
re_build = True
if not self.ifwt_matrix_list or re_build:
self._construct_synthesis_matrices(
device=torch_device,
dtype=torch_dtype,
)
# transpose the coefficients to handle the batch dimension efficiently.
coefficients = [c.T for c in coefficients]
lo = coefficients[0]
for c_pos, hi in enumerate(coefficients[1:]):
if lo.shape != hi.shape:
raise ValueError("coefficients must have the same shape")
lo = torch.cat([lo, hi], 0)
lo = torch.sparse.mm(self.ifwt_matrix_list[::-1][c_pos], lo)
# remove padding
if c_pos < len(coefficients) - 2:
pred_len = lo.shape[0]
next_len = coefficients[c_pos + 2].shape[0]
if next_len != pred_len:
lo = lo[:-1, :]
pred_len = lo.shape[0]
assert (
pred_len == next_len
), "padding error, please open an issue on github"
res_lo = lo.T
return _postprocess_tensor(res_lo, ndim=1, ds=ds, axes=self.axis)