"""This module implements two-dimensional padded wavelet transforms.
The implementation relies on torch.nn.functional.conv2d and
torch.nn.functional.conv_transpose2d under the hood.
"""
from functools import partial
from typing import List, Optional, Tuple, Union, cast
import pywt
import torch
from ._util import (
Wavelet,
_as_wavelet,
_check_axes_argument,
_check_if_tensor,
_fold_axes,
_get_len,
_is_dtype_supported,
_map_result,
_outer,
_pad_symmetric,
_swap_axes,
_undo_swap_axes,
_unfold_axes,
)
from .constants import BoundaryMode
from .conv_transform import (
_adjust_padding_at_reconstruction,
_get_filter_tensors,
_get_pad,
_translate_boundary_strings,
)
def _construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
"""Construct two-dimensional filters using outer products.
Args:
lo (torch.Tensor): Low-pass input filter.
hi (torch.Tensor): High-pass input filter
Returns:
torch.Tensor: Stacked 2d-filters of dimension
[filt_no, 1, height, width].
The four filters are ordered ll, lh, hl, hh.
"""
ll = _outer(lo, lo)
lh = _outer(hi, lo)
hl = _outer(lo, hi)
hh = _outer(hi, hi)
filt = torch.stack([ll, lh, hl, hh], 0)
filt = filt.unsqueeze(1)
return filt
def _fwt_pad2(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
*,
mode: Optional[BoundaryMode] = None,
) -> torch.Tensor:
"""Pad data for the 2d FWT.
This function pads along the last two axes.
Args:
data (torch.Tensor): Input data with 4 dimensions.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
mode :
The desired padding mode for extending the signal along the edges.
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
Returns:
The padded output tensor.
"""
if mode is None:
mode = cast(BoundaryMode, "reflect")
pytorch_mode = _translate_boundary_strings(mode)
wavelet = _as_wavelet(wavelet)
padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
if pytorch_mode == "symmetric":
data_pad = _pad_symmetric(data, [(padt, padb), (padl, padr)])
else:
data_pad = torch.nn.functional.pad(
data, [padl, padr, padt, padb], mode=pytorch_mode
)
return data_pad
def _waverec2d_fold_channels_2d_list(
coeffs: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
) -> Tuple[
List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
List[int],
]:
# fold the input coefficients for processing conv2d_transpose.
ds = list(_check_if_tensor(coeffs[0]).shape)
return _map_result(coeffs, lambda t: _fold_axes(t, 2)[0]), ds
def _preprocess_tensor_dec2d(
data: torch.Tensor,
) -> Tuple[torch.Tensor, Union[List[int], None]]:
# Preprocess multidimensional input.
ds = None
if len(data.shape) == 2:
data = data.unsqueeze(0).unsqueeze(0)
elif len(data.shape) == 3:
# add a channel dimension for torch.
data = data.unsqueeze(1)
elif len(data.shape) >= 4:
data, ds = _fold_axes(data, 2)
data = data.unsqueeze(1)
elif len(data.shape) == 1:
raise ValueError("More than one input dimension required.")
return data, ds
[docs]
def wavedec2(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
*,
mode: BoundaryMode = "reflect",
level: Optional[int] = None,
axes: Tuple[int, int] = (-2, -1),
) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
r"""Run a two-dimensional wavelet transformation.
This function relies on two-dimensional convolutions.
Outer products allow the construction of 2D-filters from 1D filter arrays
:ref:`(see fwt-intro) <sec-fwt-2d>`.
It transforms the last two axes by default.
This function computes
.. math::
\mathbf{x}_s *_2 \mathbf{h}_k = \mathbf{c}_{k, s+1}
with :math:`k \in [a, h, v, d]` and
:math:`s \in \mathbb{N}_0` the set of natural numbers,
where :math:`\mathbf{x}_0` is equal to
the original input image :math:`\mathbf{X}`.
:math:`*_2` indicates two dimensional-convolution.
Computations at subsequent scales work exclusively with
approximation coefficients :math:`c_{a, s}`
as inputs. Setting the `level` argument allows choosing the largest scale.
Args:
data (torch.Tensor): The input data tensor with any number of dimensions.
By default 2d inputs are interpreted as ``[height, width]``,
3d inputs are interpreted as ``[batch_size, height, width]``.
4d inputs are interpreted as ``[batch_size, channels, height, width]``.
the ``axis`` argument allows other interpretations.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet. Refer to the output of
``pywt.wavelist(kind="discrete")`` for a list of possible choices.
mode :
The desired padding mode for extending the signal along the edges.
Defaults to "reflect". See :data:`ptwt.constants.BoundaryMode`.
level (int): The number of desired scales.
Defaults to None.
axes (Tuple[int, int]): Compute the transform over these axes instead of the
last two. Defaults to (-2, -1).
Returns:
list: A list containing the wavelet coefficients.
The coefficients are in pywt order. That is::
[cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] .
A denotes approximation, H horizontal, V vertical
and D diagonal coefficients.
Raises:
ValueError: If the dimensionality or the dtype of the input data tensor
is unsupported or if the provided ``axes``
input has a length other than two.
Example:
>>> import torch
>>> import ptwt, pywt
>>> import numpy as np
>>> from scipy import datasets
>>> face = np.transpose(datasets.face(),
>>> [2, 0, 1]).astype(np.float64)
>>> pytorch_face = torch.tensor(face) # try unsqueeze(0)
>>> coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
>>> level=2, mode="zero")
"""
if not _is_dtype_supported(data.dtype):
raise ValueError(f"Input dtype {data.dtype} not supported")
if tuple(axes) != (-2, -1):
if len(axes) != 2:
raise ValueError("2D transforms work with two axes.")
else:
data = _swap_axes(data, list(axes))
wavelet = _as_wavelet(wavelet)
data, ds = _preprocess_tensor_dec2d(data)
dec_lo, dec_hi, _, _ = _get_filter_tensors(
wavelet, flip=True, device=data.device, dtype=data.dtype
)
dec_filt = _construct_2d_filt(lo=dec_lo, hi=dec_hi)
if level is None:
level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet)
result_lst: List[
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
] = []
res_ll = data
for _ in range(level):
res_ll = _fwt_pad2(res_ll, wavelet, mode=mode)
res = torch.nn.functional.conv2d(res_ll, dec_filt, stride=2)
res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1)
to_append = (res_lh.squeeze(1), res_hl.squeeze(1), res_hh.squeeze(1))
result_lst.append(to_append)
result_lst.append(res_ll.squeeze(1))
result_lst.reverse()
if ds:
_unfold_axes2 = partial(_unfold_axes, ds=ds, keep_no=2)
result_lst = _map_result(result_lst, _unfold_axes2)
if axes != (-2, -1):
undo_swap_fn = partial(_undo_swap_axes, axes=axes)
result_lst = _map_result(result_lst, undo_swap_fn)
return result_lst
[docs]
def waverec2(
coeffs: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
wavelet: Union[Wavelet, str],
axes: Tuple[int, int] = (-2, -1),
) -> torch.Tensor:
"""Reconstruct a signal from wavelet coefficients.
This function undoes the effect of the analysis
or forward transform by running transposed convolutions.
Args:
coeffs (list): The wavelet coefficient list produced by wavedec2.
The coefficients must be in pywt order. That is::
[cAs, (cHs, cVs, cDs), … (cH1, cV1, cD1)] .
A denotes approximation, H horizontal, V vertical,
and D diagonal coefficients.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
axes (Tuple[int, int]): Compute the transform over these axes instead of the
last two. Defaults to (-2, -1).
Returns:
torch.Tensor:
The reconstructed signal of shape ``[batch, height, width]`` or
``[batch, channel, height, width]`` depending on the input to `wavedec2`.
Raises:
ValueError: If coeffs is not in a shape as returned from wavedec2 or
if the dtype is not supported or if the provided axes input has length other
than two or if the same axes it repeated twice.
Example:
>>> import ptwt, pywt, torch
>>> import numpy as np
>>> from scipy import datasets
>>> face = np.transpose(datasets.face(),
>>> [2, 0, 1]).astype(np.float64)
>>> pytorch_face = torch.tensor(face)
>>> coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
>>> level=2, mode="constant")
>>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))
"""
if tuple(axes) != (-2, -1):
if len(axes) != 2:
raise ValueError("2D transforms work with two axes.")
else:
_check_axes_argument(list(axes))
swap_fn = partial(_swap_axes, axes=list(axes))
coeffs = _map_result(coeffs, swap_fn)
ds = None
wavelet = _as_wavelet(wavelet)
res_ll = _check_if_tensor(coeffs[0])
torch_device = res_ll.device
torch_dtype = res_ll.dtype
if res_ll.dim() >= 4:
# avoid the channel sum, fold the channels into batches.
coeffs, ds = _waverec2d_fold_channels_2d_list(coeffs)
res_ll = _check_if_tensor(coeffs[0])
if not _is_dtype_supported(torch_dtype):
raise ValueError(f"Input dtype {torch_dtype} not supported")
_, _, rec_lo, rec_hi = _get_filter_tensors(
wavelet, flip=False, device=torch_device, dtype=torch_dtype
)
filt_len = rec_lo.shape[-1]
rec_filt = _construct_2d_filt(lo=rec_lo, hi=rec_hi)
for c_pos, coeff_tuple in enumerate(coeffs[1:]):
if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3:
raise ValueError(
f"Unexpected detail coefficient type: {type(coeff_tuple)}. Detail "
"coefficients must be a 3-tuple of tensors as returned by "
"wavedec2."
)
curr_shape = res_ll.shape
for coeff in coeff_tuple:
if torch_device != coeff.device:
raise ValueError("coefficients must be on the same device")
elif torch_dtype != coeff.dtype:
raise ValueError("coefficients must have the same dtype")
elif coeff.shape != curr_shape:
raise ValueError(
"All coefficients on each level must have the same shape"
)
res_lh, res_hl, res_hh = coeff_tuple
res_ll = torch.stack([res_ll, res_lh, res_hl, res_hh], 1)
res_ll = torch.nn.functional.conv_transpose2d(
res_ll, rec_filt, stride=2
).squeeze(1)
# remove the padding
padl = (2 * filt_len - 3) // 2
padr = (2 * filt_len - 3) // 2
padt = (2 * filt_len - 3) // 2
padb = (2 * filt_len - 3) // 2
if c_pos < len(coeffs) - 2:
padr, padl = _adjust_padding_at_reconstruction(
res_ll.shape[-1], coeffs[c_pos + 2][0].shape[-1], padr, padl
)
padb, padt = _adjust_padding_at_reconstruction(
res_ll.shape[-2], coeffs[c_pos + 2][0].shape[-2], padb, padt
)
if padt > 0:
res_ll = res_ll[..., padt:, :]
if padb > 0:
res_ll = res_ll[..., :-padb, :]
if padl > 0:
res_ll = res_ll[..., padl:]
if padr > 0:
res_ll = res_ll[..., :-padr]
if ds:
res_ll = _unfold_axes(res_ll, list(ds), 2)
if axes != (-2, -1):
res_ll = _undo_swap_axes(res_ll, list(axes))
return res_ll