|
24 | 24 | _preprocess_coeffs, |
25 | 25 | _preprocess_tensor, |
26 | 26 | _translate_boundary_strings, |
| 27 | + _group_for_symmetric, |
27 | 28 | ) |
28 | 29 | from .constants import BoundaryMode, Wavelet, WaveletCoeffNd, WaveletDetailDict |
29 | 30 |
|
@@ -91,21 +92,16 @@ def _fwt_pad3( |
91 | 92 | pytorch_mode = _translate_boundary_strings(mode) |
92 | 93 |
|
93 | 94 | if padding is None: |
94 | | - pad_back, pad_front = _get_pad(data.shape[-3], _get_len(wavelet)) |
95 | | - pad_bottom, pad_top = _get_pad(data.shape[-2], _get_len(wavelet)) |
96 | | - pad_right, pad_left = _get_pad(data.shape[-1], _get_len(wavelet)) |
97 | | - else: |
98 | | - pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back = padding |
99 | | - if pytorch_mode == "symmetric": |
100 | | - data_pad = _pad_symmetric( |
101 | | - data, [(pad_front, pad_back), (pad_top, pad_bottom), (pad_left, pad_right)] |
| 95 | + _len_wavelet = _get_len(wavelet) |
| 96 | + padding = ( |
| 97 | + *_get_pad(data.shape[-1], _len_wavelet), |
| 98 | + *_get_pad(data.shape[-2], _len_wavelet), |
| 99 | + *_get_pad(data.shape[-3], _len_wavelet), |
102 | 100 | ) |
| 101 | + if pytorch_mode == "symmetric": |
| 102 | + data_pad = _pad_symmetric(data, _group_for_symmetric(padding)) |
103 | 103 | else: |
104 | | - data_pad = torch.nn.functional.pad( |
105 | | - data, |
106 | | - [pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back], |
107 | | - mode=pytorch_mode, |
108 | | - ) |
| 104 | + data_pad = torch.nn.functional.pad(data, padding, mode=pytorch_mode) |
109 | 105 | return data_pad |
110 | 106 |
|
111 | 107 |
|
|
0 commit comments