Skip to content

Commit 4d58f52

Browse files
authored
Refactor FWT padding (#124)
1 parent 6beb81b commit 4d58f52

6 files changed

Lines changed: 48 additions & 27 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"PyWavelets",
3434
"numpy",
3535
"torch",
36+
"more-itertools",
3637
]
3738
license = {text = "EUPL-1.2"}
3839

src/ptwt/_util.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import numpy as np
1313
import pywt
1414
import torch
15+
from more_itertools import grouper
1516
from typing_extensions import ParamSpec, TypeVar
1617

1718
from .constants import (
@@ -209,7 +210,7 @@ def _get_pad(data_len: int, filt_len: int) -> tuple[int, int]:
209210
# pad to even singal length.
210211
padr += data_len % 2
211212

212-
return padr, padl
213+
return padl, padr
213214

214215

215216
def _adjust_padding_at_reconstruction(
@@ -789,3 +790,8 @@ def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> RetType:
789790
return wrapper
790791

791792
return deco
793+
794+
795+
def _group_for_symmetric(padding: tuple[int, ...]) -> list[tuple[int, int]]:
796+
"""Repack the padding tuple for symmetric padding."""
797+
return list(reversed(list(grouper(padding, 2)))) # type:ignore[arg-type]

src/ptwt/conv_transform.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_preprocess_coeffs,
2424
_preprocess_tensor,
2525
_translate_boundary_strings,
26+
_group_for_symmetric,
2627
)
2728
from .constants import BoundaryMode, Wavelet, WaveletCoeff1d
2829

@@ -60,13 +61,11 @@ def _fwt_pad(
6061
pytorch_mode = _translate_boundary_strings(mode)
6162

6263
if padding is None:
63-
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
64-
else:
65-
padl, padr = padding
64+
padding = _get_pad(data.shape[-1], _get_len(wavelet))
6665
if pytorch_mode == "symmetric":
67-
data_pad = _pad_symmetric(data, [(padl, padr)])
66+
data_pad = _pad_symmetric(data, _group_for_symmetric(padding))
6867
else:
69-
data_pad = torch.nn.functional.pad(data, [padl, padr], mode=pytorch_mode)
68+
data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode)
7069
return data_pad
7170

7271

src/ptwt/conv_transform_2.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_preprocess_coeffs,
2525
_preprocess_tensor,
2626
_translate_boundary_strings,
27+
_group_for_symmetric,
2728
)
2829
from .constants import BoundaryMode, Wavelet, WaveletCoeff2d, WaveletDetailTuple2d
2930

@@ -88,16 +89,15 @@ def _fwt_pad2(
8889
pytorch_mode = _translate_boundary_strings(mode)
8990

9091
if padding is None:
91-
padb, padt = _get_pad(data.shape[-2], _get_len(wavelet))
92-
padr, padl = _get_pad(data.shape[-1], _get_len(wavelet))
93-
else:
94-
padl, padr, padt, padb = padding
92+
_len_wavelet = _get_len(wavelet)
93+
padding = (
94+
*_get_pad(data.shape[-1], _len_wavelet),
95+
*_get_pad(data.shape[-2], _len_wavelet),
96+
)
9597
if pytorch_mode == "symmetric":
96-
data_pad = _pad_symmetric(data, [(padt, padb), (padl, padr)])
98+
data_pad = _pad_symmetric(data, _group_for_symmetric(padding))
9799
else:
98-
data_pad = torch.nn.functional.pad(
99-
data, [padl, padr, padt, padb], mode=pytorch_mode
100-
)
100+
data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode)
101101
return data_pad
102102

103103

src/ptwt/conv_transform_3.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_preprocess_coeffs,
2525
_preprocess_tensor,
2626
_translate_boundary_strings,
27+
_group_for_symmetric,
2728
)
2829
from .constants import BoundaryMode, Wavelet, WaveletCoeffNd, WaveletDetailDict
2930

@@ -91,21 +92,16 @@ def _fwt_pad3(
9192
pytorch_mode = _translate_boundary_strings(mode)
9293

9394
if padding is None:
94-
pad_back, pad_front = _get_pad(data.shape[-3], _get_len(wavelet))
95-
pad_bottom, pad_top = _get_pad(data.shape[-2], _get_len(wavelet))
96-
pad_right, pad_left = _get_pad(data.shape[-1], _get_len(wavelet))
97-
else:
98-
pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back = padding
99-
if pytorch_mode == "symmetric":
100-
data_pad = _pad_symmetric(
101-
data, [(pad_front, pad_back), (pad_top, pad_bottom), (pad_left, pad_right)]
95+
_len_wavelet = _get_len(wavelet)
96+
padding = (
97+
*_get_pad(data.shape[-1], _len_wavelet),
98+
*_get_pad(data.shape[-2], _len_wavelet),
99+
*_get_pad(data.shape[-3], _len_wavelet),
102100
)
101+
if pytorch_mode == "symmetric":
102+
data_pad = _pad_symmetric(data, _group_for_symmetric(padding))
103103
else:
104-
data_pad = torch.nn.functional.pad(
105-
data,
106-
[pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back],
107-
mode=pytorch_mode,
108-
)
104+
data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode)
109105
return data_pad
110106

111107

tests/test_util.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
_fold_axes,
1111
_pad_symmetric,
1212
_pad_symmetric_1d,
13+
_group_for_symmetric,
1314
_unfold_axes,
1415
)
1516

@@ -78,3 +79,21 @@ def test_fold(keep_no: int, size: list[int]) -> None:
7879
assert len(folded.shape) == keep_no + 1
7980
rec = _unfold_axes(folded, size, keep_no)
8081
np.allclose(array.numpy(), rec.numpy())
82+
83+
84+
def test_repack_symmetric() -> None:
85+
"""Ensure channel folding works as expected."""
86+
padl, padr = padding = tuple(range(2))
87+
assert _group_for_symmetric(padding) == [(padl, padr)]
88+
89+
padl, padr, padt, padb = padding = tuple(range(4))
90+
assert _group_for_symmetric(padding) == [(padt, padb), (padl, padr)]
91+
92+
pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back = padding = tuple(
93+
range(6)
94+
)
95+
assert _group_for_symmetric(padding) == [
96+
(pad_front, pad_back),
97+
(pad_top, pad_bottom),
98+
(pad_left, pad_right),
99+
]

0 commit comments

Comments
 (0)