"""This module implements two-dimensional padded wavelet transforms.
The implementation relies on torch.nn.functional.conv2d and
torch.nn.functional.conv_transpose2d under the hood.
"""
from typing import List, Optional, Tuple, Union
import pywt
import torch
from ._util import Wavelet, _as_wavelet, _get_len, _is_dtype_supported, _outer
from .conv_transform import (
_adjust_padding_at_reconstruction,
_get_pad,
_translate_boundary_strings,
get_filter_tensors,
)
[docs]def construct_2d_filt(lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
"""Construct two-dimensional filters using outer products.
Args:
lo (torch.Tensor): Low-pass input filter.
hi (torch.Tensor): High-pass input filter
Returns:
torch.Tensor: Stacked 2d-filters of dimension
[filt_no, 1, height, width].
The four filters are ordered ll, lh, hl, hh.
"""
ll = _outer(lo, lo)
lh = _outer(hi, lo)
hl = _outer(lo, hi)
hh = _outer(hi, hi)
filt = torch.stack([ll, lh, hl, hh], 0)
filt = filt.unsqueeze(1)
return filt
def _fwt_pad2(
data: torch.Tensor, wavelet: Union[Wavelet, str], mode: str = "reflect"
) -> torch.Tensor:
"""Pad data for the 2d FWT.
This function pads along the last two axes.
Args:
data (torch.Tensor): Input data with 4 dimensions.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
mode (str): The padding mode.
Supported modes are::
"reflect", "zero", "constant", "periodic".
"reflect" is the default mode.
Returns:
The padded output tensor.
"""
mode = _translate_boundary_strings(mode)
wavelet = _as_wavelet(wavelet)
padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
data_pad = torch.nn.functional.pad(data, [padl, padr, padt, padb], mode=mode)
return data_pad
[docs]def wavedec2(
data: torch.Tensor,
wavelet: Union[Wavelet, str],
mode: str = "reflect",
level: Optional[int] = None,
) -> List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
"""Non separated two-dimensional wavelet transform.
Args:
data (torch.Tensor): The input data tensor with up to three dimensions.
2d inputs are interpreted as [height, width],
3d inputs are interpreted as [batch_size, height, width].
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet. Refer to the output of
``pywt.wavelist(kind="discrete")`` for a list of possible choices.
mode (str): The padding mode. Options are::
"reflect", "zero", "constant", "periodic".
This function defaults to "reflect".
level (int): The number of desired scales.
Defaults to None.
Returns:
list: A list containing the wavelet coefficients.
The coefficients are in pywt order. That is::
[cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)] .
A denotes approximation, H horizontal, V vertical
and D diagonal coefficients.
Raises:
ValueError: If the dimensionality or the dtype of the input data tensor
is unsupported.
Example:
>>> import torch
>>> import ptwt, pywt
>>> import numpy as np
>>> from scipy import datasets
>>> face = np.transpose(datasets.face(),
>>> [2, 0, 1]).astype(np.float64)
>>> pytorch_face = torch.tensor(face)
>>> coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
>>> level=2, mode="zero")
"""
if data.dim() == 2:
data = data.unsqueeze(0).unsqueeze(0)
elif data.dim() == 3:
# add a channel dimension for torch.
data = data.unsqueeze(1)
elif data.dim() == 4:
raise ValueError(
"Wavedec2 does not support four input dimensions. \
Optionally-batched two-dimensional inputs work."
)
elif data.dim() == 1:
raise ValueError("Wavedec2 needs more than one input dimension to work.")
if not _is_dtype_supported(data.dtype):
raise ValueError(f"Input dtype {data.dtype} not supported")
wavelet = _as_wavelet(wavelet)
dec_lo, dec_hi, _, _ = get_filter_tensors(
wavelet, flip=True, device=data.device, dtype=data.dtype
)
dec_filt = construct_2d_filt(lo=dec_lo, hi=dec_hi)
if level is None:
level = pywt.dwtn_max_level([data.shape[-1], data.shape[-2]], wavelet)
result_lst: List[
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
] = []
res_ll = data
for _ in range(level):
res_ll = _fwt_pad2(res_ll, wavelet, mode=mode)
res = torch.nn.functional.conv2d(res_ll, dec_filt, stride=2)
res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1)
to_append = (res_lh.squeeze(1), res_hl.squeeze(1), res_hh.squeeze(1))
result_lst.append(to_append)
result_lst.append(res_ll.squeeze(1))
return result_lst[::-1]
[docs]def waverec2(
coeffs: List[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]],
wavelet: Union[Wavelet, str],
) -> torch.Tensor:
"""Reconstruct a signal from wavelet coefficients.
Args:
coeffs (list): The wavelet coefficient list produced by wavedec2.
The coefficients must be in pywt order. That is::
[cAn, (cHn, cVn, cDn), … (cH1, cV1, cD1)] .
A denotes approximation, H horizontal, V vertical,
and D diagonal coefficients.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
Returns:
torch.Tensor: The reconstructed signal of shape [batch, height, width].
Raises:
ValueError: If `coeffs` is not in a shape as returned from `wavedec2` or
if the dtype is not supported.
Example:
>>> import ptwt, pywt, torch
>>> import numpy as np
>>> from scipy import datasets
>>> face = np.transpose(datasets.face(),
>>> [2, 0, 1]).astype(np.float64)
>>> pytorch_face = torch.tensor(face)
>>> coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
>>> level=2, mode="constant")
>>> reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))
"""
wavelet = _as_wavelet(wavelet)
res_ll = coeffs[0]
if not isinstance(res_ll, torch.Tensor):
raise ValueError(
"First element of coeffs must be the approximation coefficient tensor."
)
torch_device = res_ll.device
torch_dtype = res_ll.dtype
if not _is_dtype_supported(torch_dtype):
raise ValueError(f"Input dtype {torch_dtype} not supported")
_, _, rec_lo, rec_hi = get_filter_tensors(
wavelet, flip=False, device=torch_device, dtype=torch_dtype
)
filt_len = rec_lo.shape[-1]
rec_filt = construct_2d_filt(lo=rec_lo, hi=rec_hi)
for c_pos, coeff_tuple in enumerate(coeffs[1:]):
if not isinstance(coeff_tuple, tuple) or len(coeff_tuple) != 3:
raise ValueError(
f"Unexpected detail coefficient type: {type(coeff_tuple)}. Detail "
"coefficients must be a 3-tuple of tensors as returned by "
"wavedec2."
)
curr_shape = res_ll.shape
for coeff in coeff_tuple:
if torch_device != coeff.device:
raise ValueError("coefficients must be on the same device")
elif torch_dtype != coeff.dtype:
raise ValueError("coefficients must have the same dtype")
elif coeff.shape != curr_shape:
raise ValueError(
"All coefficients on each level must have the same shape"
)
res_lh, res_hl, res_hh = coeff_tuple
res_ll = torch.stack([res_ll, res_lh, res_hl, res_hh], 1)
res_ll = torch.nn.functional.conv_transpose2d(
res_ll, rec_filt, stride=2
).squeeze(1)
# remove the padding
padl = (2 * filt_len - 3) // 2
padr = (2 * filt_len - 3) // 2
padt = (2 * filt_len - 3) // 2
padb = (2 * filt_len - 3) // 2
if c_pos < len(coeffs) - 2:
padr, padl = _adjust_padding_at_reconstruction(
res_ll.shape[-1], coeffs[c_pos + 2][0].shape[-1], padr, padl
)
padb, padt = _adjust_padding_at_reconstruction(
res_ll.shape[-2], coeffs[c_pos + 2][0].shape[-2], padb, padt
)
if padt > 0:
res_ll = res_ll[..., padt:, :]
if padb > 0:
res_ll = res_ll[..., :-padb, :]
if padl > 0:
res_ll = res_ll[..., padl:]
if padr > 0:
res_ll = res_ll[..., :-padr]
return res_ll