"""Implement 3D separable boundary transforms."""
import sys
from functools import partial
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np
import torch
from ._util import (
Wavelet,
_as_wavelet,
_check_axes_argument,
_check_if_tensor,
_fold_axes,
_is_boundary_mode_supported,
_is_dtype_supported,
_map_result,
_swap_axes,
_undo_swap_axes,
_unfold_axes,
)
from .constants import OrthogonalizeMethod
from .conv_transform_3 import _waverec3d_fold_channels_3d_list
from .matmul_transform import construct_boundary_a, construct_boundary_s
from .sparse_math import _batch_dim_mm
class _PadTuple(NamedTuple):
"""Replaces _PadTuple = namedtuple("_PadTuple", ("depth", "height", "width"))."""
depth: bool
height: bool
width: bool
def _matrix_pad_3(
depth: int, height: int, width: int
) -> Tuple[int, int, int, _PadTuple]:
pad_depth, pad_height, pad_width = (False, False, False)
if height % 2 != 0:
height += 1
pad_height = True
if width % 2 != 0:
width += 1
pad_width = True
if depth % 2 != 0:
depth += 1
pad_depth = True
return depth, height, width, _PadTuple(pad_depth, pad_height, pad_width)
[docs]
class MatrixWavedec3(object):
"""Compute 3d separable transforms."""
def __init__(
self,
wavelet: Union[Wavelet, str],
level: Optional[int] = None,
axes: Tuple[int, int, int] = (-3, -2, -1),
boundary: OrthogonalizeMethod = "qr",
):
"""Create a *separable* three-dimensional fast boundary wavelet transform.
Input signals should have the shape [batch_size, depth, height, width],
this object transforms the last three dimensions.
Args:
wavelet (Union[Wavelet, str]): The wavelet to use.
level (Optional[int]): The desired decomposition level.
Defaults to None.
boundary: The matrix orthogonalization method.
Defaults to "qr".
Raises:
NotImplementedError: If the chosen orthogonalization method
is not implemented.
ValueError: If the analysis and synthesis filters do not have
the same length.
"""
self.wavelet = _as_wavelet(wavelet)
self.level = level
self.boundary = boundary
if len(axes) != 3:
raise ValueError("3D transforms work with three axes.")
else:
_check_axes_argument(list(axes))
self.axes = axes
self.input_signal_shape: Optional[Tuple[int, int, int]] = None
self.fwt_matrix_list: List[List[torch.Tensor]] = []
if not _is_boundary_mode_supported(self.boundary):
raise NotImplementedError
if self.wavelet.dec_len != self.wavelet.rec_len:
raise ValueError("All filters must have the same length")
def _construct_analysis_matrices(
self,
device: Union[torch.device, str],
dtype: torch.dtype,
) -> None:
if self.level is None or self.input_signal_shape is None:
raise AssertionError
self.fwt_matrix_list = []
self.size_list = []
self.pad_list = []
self.padded = False
filt_len = self.wavelet.dec_len
current_depth, current_height, current_width = self.input_signal_shape
for curr_level in range(1, self.level + 1):
if (
current_height < filt_len
or current_width < filt_len
or current_depth < filt_len
):
# we have reached the max decomposition depth.
sys.stderr.write(
f"Warning: The selected number of decomposition levels {self.level}"
f" is too large for the given input shape {self.input_signal_shape}"
f". At level {curr_level}, at least one of the current signal "
f"depth, height, and width ({current_depth}, {current_height},"
f"{current_width}) is smaller "
f"then the filter length {filt_len}. Therefore, the transformation "
f"is only computed up to the decomposition level {curr_level-1}.\n"
)
break
# the conv matrices require even length inputs.
current_depth, current_height, current_width, pad_tuple = _matrix_pad_3(
depth=current_depth, height=current_height, width=current_width
)
if any(pad_tuple):
self.padded = True
self.pad_list.append(pad_tuple)
self.size_list.append((current_depth, current_height, current_width))
matrix_construction_fun = partial(
construct_boundary_a,
wavelet=self.wavelet,
boundary=self.boundary,
device=device,
dtype=dtype,
)
analysis_matrics = [
matrix_construction_fun(length=dimension_length)
for dimension_length in (current_depth, current_height, current_width)
]
self.fwt_matrix_list.append(analysis_matrics)
current_depth, current_height, current_width = (
current_depth // 2,
current_height // 2,
current_width // 2,
)
self.size_list.append((current_depth, current_height, current_width))
def __call__(
self, input_signal: torch.Tensor
) -> List[Union[torch.Tensor, Dict[str, torch.Tensor]]]:
"""Compute a separable 3d-boundary wavelet transform.
Args:
input_signal (torch.Tensor): An input signal. For example
of shape [batch_size, depth, height, width].
Raises:
ValueError: If the input dimensions don't work.
Returns:
List[Union[torch.Tensor, TypedDict[str, torch.Tensor]]]:
A list with the approximation coefficients,
and a coefficient dict for each scale.
"""
if self.axes != (-3, -2, -1):
input_signal = _swap_axes(input_signal, list(self.axes))
ds = None
if input_signal.dim() < 3:
raise ValueError("At least three dimensions are required for 3d wavedec.")
elif len(input_signal.shape) == 3:
input_signal = input_signal.unsqueeze(1)
else:
input_signal, ds = _fold_axes(input_signal, 3)
_, depth, height, width = input_signal.shape
if not _is_dtype_supported(input_signal.dtype):
raise ValueError(f"Input dtype {input_signal.dtype} not supported")
re_build = False
if (
self.input_signal_shape is None
or self.input_signal_shape[0] != depth
or self.input_signal_shape[1] != height
or self.input_signal_shape[2] != width
):
self.input_signal_shape = depth, height, width
re_build = True
if self.level is None:
wlen = len(self.wavelet)
self.level = int(
np.min(
[
np.log2(depth / (wlen - 1)),
np.log2(height / (wlen - 1)),
np.log2(width / (wlen - 1)),
]
)
)
re_build = True
elif self.level <= 0:
raise ValueError("level must be a positive integer.")
if not self.fwt_matrix_list or re_build:
self._construct_analysis_matrices(
device=input_signal.device, dtype=input_signal.dtype
)
split_list: List[Union[torch.Tensor, Dict[str, torch.Tensor]]] = []
lll = input_signal
for scale, fwt_mats in enumerate(self.fwt_matrix_list):
# fwt_depth_matrix, fwt_row_matrix, fwt_col_matrix = fwt_mats
pad_tuple = self.pad_list[scale]
# current_depth, current_height, current_width = self.size_list[scale]
if pad_tuple.width:
lll = torch.nn.functional.pad(lll, [0, 1, 0, 0, 0, 0])
if pad_tuple.height:
lll = torch.nn.functional.pad(lll, [0, 0, 0, 1, 0, 0])
if pad_tuple.depth:
lll = torch.nn.functional.pad(lll, [0, 0, 0, 0, 0, 1])
for dim, mat in enumerate(fwt_mats[::-1]):
lll = _batch_dim_mm(mat, lll, dim=(-1) * (dim + 1))
def _split_rec(
tensor: torch.Tensor,
key: str,
depth: int,
dict: Dict[str, torch.Tensor],
) -> None:
if key:
dict[key] = tensor
if len(key) < depth:
dim = len(key) + 1
ca, cd = torch.split(tensor, tensor.shape[-dim] // 2, dim=-dim)
_split_rec(ca, "a" + key, depth, dict)
_split_rec(cd, "d" + key, depth, dict)
coeff_dict: Dict[str, torch.Tensor] = {}
_split_rec(lll, "", 3, coeff_dict)
lll = coeff_dict["aaa"]
result_keys = list(
filter(lambda x: len(x) == 3 and not x == "aaa", coeff_dict.keys())
)
coeff_dict = {
key: tensor for key, tensor in coeff_dict.items() if key in result_keys
}
split_list.append(coeff_dict)
split_list.append(lll)
if ds:
_unfold_axes_fn = partial(_unfold_axes, ds=ds, keep_no=3)
split_list = _map_result(split_list, _unfold_axes_fn)
if self.axes != (-3, -2, -1):
undo_swap_fn = partial(_undo_swap_axes, axes=self.axes)
split_list = _map_result(split_list, undo_swap_fn)
return split_list[::-1]
[docs]
class MatrixWaverec3(object):
"""Reconstruct a signal from 3d-separable-fwt coefficients."""
def __init__(
self,
wavelet: Union[Wavelet, str],
axes: Tuple[int, int, int] = (-3, -2, -1),
boundary: OrthogonalizeMethod = "qr",
):
"""Compute a three-dimensional separable boundary wavelet synthesis transform.
Args:
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
axes (Tuple[int, int, int]): Transform these axes instead of the
last three. Defaults to (-3, -2, -1).
boundary : The method used for boundary filter treatment.
Choose 'qr' or 'gramschmidt'. 'qr' relies on Pytorch's dense qr
implementation, it is fast but memory hungry. The 'gramschmidt' option
is sparse, memory efficient, and slow. Choose 'gramschmidt' if 'qr' runs
out of memory. Defaults to 'qr'.
Raises:
NotImplementedError: If the selected `boundary` mode is not supported.
ValueError: If the wavelet filters have different lengths.
"""
self.wavelet = _as_wavelet(wavelet)
if len(axes) != 3:
raise ValueError("3D transforms work with three axes")
else:
_check_axes_argument(list(axes))
self.axes = axes
self.boundary = boundary
self.ifwt_matrix_list: List[List[torch.Tensor]] = []
self.input_signal_shape: Optional[Tuple[int, int, int]] = None
self.level: Optional[int] = None
if not _is_boundary_mode_supported(self.boundary):
raise NotImplementedError
if self.wavelet.dec_len != self.wavelet.rec_len:
raise ValueError("All filters must have the same length")
def _construct_synthesis_matrices(
self,
device: Union[torch.device, str],
dtype: torch.dtype,
) -> None:
self.ifwt_matrix_list = []
self.padded = False
if self.level is None or self.input_signal_shape is None:
raise AssertionError
current_depth, current_height, current_width = self.input_signal_shape
filt_len = self.wavelet.rec_len
for curr_level in range(1, self.level + 1):
if (
current_depth < filt_len
or current_height < filt_len
or current_width < filt_len
):
sys.stderr.write(
f"Warning: The selected number of decomposition levels {self.level}"
f" is too large for the given input shape {self.input_signal_shape}"
f". At level {curr_level}, at least one of the current signal "
f" depth, height and width ({current_depth}, {current_height}, "
f"{current_width}) is smaller than the filter length {filt_len}."
f" Therefore, the transformation "
f"is only computed up to the decomposition level {curr_level-1}.\n"
)
break
# the conv matrices require even length inputs.
current_depth, current_height, current_width, pad_tuple = _matrix_pad_3(
depth=current_depth, height=current_height, width=current_width
)
if any(pad_tuple):
self.padded = True
matrix_construction_fun = partial(
construct_boundary_s,
wavelet=self.wavelet,
boundary=self.boundary,
device=device,
dtype=dtype,
)
synthesis_matrices = [
matrix_construction_fun(length=dimension_length)
for dimension_length in (current_depth, current_height, current_width)
]
self.ifwt_matrix_list.append(synthesis_matrices)
current_depth, current_height, current_width = (
current_depth // 2,
current_height // 2,
current_width // 2,
)
def _cat_coeff_recursive(self, input_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
done_dict = {}
a_initial_keys = list(filter(lambda x: x[0] == "a", input_dict.keys()))
for a_key in a_initial_keys:
d_key = "d" + a_key[1:]
cat_d = input_dict[d_key]
d_shape = cat_d.shape
# undo any analysis padding.
cat_a = input_dict[a_key][:, : d_shape[1], : d_shape[2], : d_shape[3]]
cat_tensor = torch.cat([cat_a, cat_d], dim=-len(a_key))
if a_key[1:]:
done_dict[a_key[1:]] = cat_tensor
else:
return cat_tensor
return self._cat_coeff_recursive(done_dict)
def __call__(
self, coefficients: List[Union[torch.Tensor, Dict[str, torch.Tensor]]]
) -> torch.Tensor:
"""Reconstruct a batched 3d-signal from its coefficients.
Args:
coefficients (List[Union[torch.Tensor, Dict[str, torch.Tensor]]]):
The output from MatrixWavedec3.
Returns:
torch.Tensor: A reconstruction of the original signal.
Raises:
ValueError: If the data structure is inconsistent.
"""
if self.axes != (-3, -2, -1):
swap_axes_fn = partial(_swap_axes, axes=list(self.axes))
coefficients = _map_result(coefficients, swap_axes_fn)
ds = None
# the Union[tensor, dict] idea is coming from pywt. We don't change it here.
res_lll = _check_if_tensor(coefficients[0])
if res_lll.dim() < 3:
raise ValueError(
"Three dimensional transforms require at least three dimensions."
)
elif res_lll.dim() >= 5:
coefficients, ds = _waverec3d_fold_channels_3d_list(coefficients)
res_lll = _check_if_tensor(coefficients[0])
level = len(coefficients) - 1
if type(coefficients[-1]) is dict:
depth, height, width = tuple(
c * 2 for c in coefficients[-1]["ddd"].shape[-3:]
)
else:
raise ValueError("Waverec3 expects dicts of tensors.")
re_build = False
if (
self.input_signal_shape is None
or self.input_signal_shape[0] != depth
or self.input_signal_shape[1] != height
or self.input_signal_shape[2] != width
):
self.input_signal_shape = depth, height, width
re_build = True
if self.level != level:
self.level = level
re_build = True
lll = coefficients[0]
if not isinstance(lll, torch.Tensor):
raise ValueError(
"First element of coeffs must be the approximation coefficient tensor."
)
torch_device = lll.device
torch_dtype = lll.dtype
if not _is_dtype_supported(torch_dtype):
if not _is_dtype_supported(torch_dtype):
raise ValueError(f"Input dtype {torch_dtype} not supported")
if not self.ifwt_matrix_list or re_build:
self._construct_synthesis_matrices(
device=torch_device,
dtype=torch_dtype,
)
for c_pos, coeff_dict in enumerate(coefficients[1:]):
if not isinstance(coeff_dict, dict) or len(coeff_dict) != 7:
raise ValueError(
f"Unexpected detail coefficient type: {type(coeff_dict)}. Detail "
"coefficients must be a dict containing 7 tensors as returned by "
"MatrixWavedec3."
)
test_shape = None
for coeff in coeff_dict.values():
if test_shape is None:
test_shape = coeff.shape
if torch_device != coeff.device:
raise ValueError("coefficients must be on the same device")
elif torch_dtype != coeff.dtype:
raise ValueError("coefficients must have the same dtype")
elif test_shape != coeff.shape:
raise ValueError(
"All coefficients on each level must have the same shape"
)
coeff_dict["a" * len(list(coeff_dict.keys())[-1])] = lll
lll = self._cat_coeff_recursive(coeff_dict)
for dim, mat in enumerate(self.ifwt_matrix_list[level - 1 - c_pos][::-1]):
lll = _batch_dim_mm(mat, lll, dim=(-1) * (dim + 1))
if ds:
lll = _unfold_axes(lll, ds, 3)
if self.axes != (-3, -2, -1):
lll = _undo_swap_axes(lll, list(self.axes))
return lll