diff --git a/diffsptk/functional.py b/diffsptk/functional.py index ffb9d29..d56feee 100644 --- a/diffsptk/functional.py +++ b/diffsptk/functional.py @@ -2064,6 +2064,42 @@ def mdst(x: Tensor, frame_length: int = 400, window: str = "sine") -> Tensor: ) +def medfilt( + x: Tensor, + filter_length: int = 3, + across_features: bool = False, + magic_number: float | None = None, +) -> Tensor: + """Apply median filtering to the input sequence. + + Parameters + ---------- + x : Tensor [shape=(B, N, D) or (N, D) or (N,)] + The input sequence. + + filter_length : int > 0 + The length of the median filter, :math:`L`. + + across_features : bool + If True, apply the filter across the feature dimension. + + magic_number : float or None + The magic number representing unvoiced frames. + + Returns + ------- + out : Tensor [shape=(B, N, D) or (B, N) or (N, D) or (N,)] + The filtered sequence. + + """ + return nn.MedianFilter._func( + x, + filter_length=filter_length, + across_features=across_features, + magic_number=magic_number, + ) + + def mfcc( x: Tensor, mfcc_order: int, diff --git a/diffsptk/modules/__init__.py b/diffsptk/modules/__init__.py index 3f8e1ec..077398a 100644 --- a/diffsptk/modules/__init__.py +++ b/diffsptk/modules/__init__.py @@ -116,6 +116,7 @@ from .mdct import ModifiedDiscreteCosineTransform as MDCT from .mdst import ModifiedDiscreteSineTransform from .mdst import ModifiedDiscreteSineTransform as MDST +from .medfilt import MedianFilter from .mfcc import MelFrequencyCepstralCoefficientsAnalysis from .mfcc import MelFrequencyCepstralCoefficientsAnalysis as MFCC from .mgc2mgc import MelGeneralizedCepstrumToMelGeneralizedCepstrum diff --git a/diffsptk/modules/medfilt.py b/diffsptk/modules/medfilt.py new file mode 100644 index 0000000..abdb827 --- /dev/null +++ b/diffsptk/modules/medfilt.py @@ -0,0 +1,144 @@ +# ------------------------------------------------------------------------ # +# Copyright 2022 SPTK Working Group # +# # +# Licensed under the Apache License, Version 2.0 (the "License"); # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http://www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +# ------------------------------------------------------------------------ # + +import torch +import torch.nn.functional as F + +from ..typing import Precomputed +from ..utils.private import filter_values +from .base import BaseFunctionalModule + + +class MedianFilter(BaseFunctionalModule): + """See `this page `_ + for details. + + Parameters + ---------- + filter_length : int > 0 + The length of the median filter, :math:`L`. + + across_features : bool + If True, apply the filter across the feature dimension. + + magic_number : float or None + The magic number representing unvoiced frames. + + """ + + def __init__( + self, + filter_length: int, + across_features: bool = False, + magic_number: float | None = None, + ) -> None: + super().__init__() + + self.values = self._precompute(**filter_values(locals())) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply median filtering to the input sequence. + + Parameters + ---------- + x : Tensor [shape=(B, N, D) or (N, D) or (N,)] + The input sequence. + + Returns + ------- + out : Tensor [shape=(B, N, D) or (B, N) or (N, D) or (N,)] + The filtered sequence. + + Examples + -------- + >>> import torch + >>> import diffsptk + >>> medfilt = diffsptk.MedianFilter(3) + >>> x = torch.tensor([0, 2, -2, 7, 4, 8]).float() + >>> y = medfilt(x) + >>> y + tensor([1., 0., 2., 4., 7., 6.]) + + """ + return self._forward(x, *self.values) + + @staticmethod + def _func(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + values = MedianFilter._precompute(*args, **kwargs) + return MedianFilter._forward(x, *values) + + @staticmethod + def _takes_input_size() -> bool: + return False + + @staticmethod + def _check(filter_length: int) -> None: + if filter_length <= 0: + raise ValueError("filter_length must be positive.") + + @staticmethod + def _precompute( + filter_length: int, across_features: bool, magic_number: float | None + ) -> Precomputed: + MedianFilter._check(filter_length) + if filter_length % 2 == 1: + padding = ((filter_length - 1) // 2, (filter_length - 1) // 2) + else: + padding = (filter_length // 2, (filter_length - 2) // 2) + padding = (0, 0) + padding # No padding for feature dimension + return (filter_length, padding, across_features, magic_number) + + @staticmethod + def _forward( + x: torch.Tensor, + filter_length: int, + padding: tuple[int, int], + across_features: bool, + magic_number: float | None, + ) -> torch.Tensor: + d = x.dim() + if d == 1: + x = x.reshape(1, -1, 1) + elif d == 2: + x = x.unsqueeze(0) + if x.dim() != 3: + raise ValueError("Input must be 1D, 2D, or 3D tensor.") + + if magic_number is not None: + mask = x == magic_number + x = x.masked_fill(mask, float("nan")) + + y = F.pad(x, padding, value=float("nan")) + y = y.unfold(1, filter_length, 1) + if across_features: + y = y.flatten(start_dim=-2) + y = y.nanquantile(0.5, dim=-1) + + if magic_number is not None: + m = F.pad(mask.float(), padding, value=float("nan")) + m = m.unfold(1, filter_length, 1) + if across_features: + m = m.flatten(start_dim=-2) + magic_count = m.nansum(dim=-1) + valid_count = (1 - m).nansum(dim=-1) + is_magic_dominant = magic_count > valid_count + y = torch.where(is_magic_dominant, torch.full_like(y, magic_number), y) + + if d == 1: + y = y.view(-1) + elif d == 2: + y = y.squeeze(0) + return y diff --git a/docs/source/modules/medfilt.rst b/docs/source/modules/medfilt.rst new file mode 100644 index 0000000..c46477d --- /dev/null +++ b/docs/source/modules/medfilt.rst @@ -0,0 +1,9 @@ +.. _medfilt: + +medfilt +======= + +.. autoclass:: diffsptk.MedianFilter + :members: + +.. autofunction:: diffsptk.functional.medfilt diff --git a/tests/test_medfilt.py b/tests/test_medfilt.py new file mode 100644 index 0000000..bde915d --- /dev/null +++ b/tests/test_medfilt.py @@ -0,0 +1,62 @@ +# ------------------------------------------------------------------------ # +# Copyright 2022 SPTK Working Group # +# # +# Licensed under the Apache License, Version 2.0 (the "License"); # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http://www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +# ------------------------------------------------------------------------ # + +import pytest + +import diffsptk +import tests.utils as U + + +@pytest.mark.parametrize("module", [False, True]) +@pytest.mark.parametrize("K", [1, 2, 3]) +@pytest.mark.parametrize("L", [1, 2]) +@pytest.mark.parametrize("across_features", [False, True]) +@pytest.mark.parametrize("magic_number", [None, 0]) +def test_compatibility( + device, dtype, module, K, L, across_features, magic_number, N=10 +): + medfilt = U.choice( + module, + diffsptk.MedianFilter, + diffsptk.functional.medfilt, + { + "filter_length": K, + "across_features": across_features, + "magic_number": magic_number, + }, + ) + + w = 1 if across_features else 0 + opt = "" if magic_number is None else f"--magic {magic_number}" + + U.check_compatibility( + device, + dtype, + medfilt, + [], + f"nrand -l {N * L} | sopr -ROUND", + f"medfilt -l {L} -K {K} -w {w} {opt}", + [], + dx=L, + dy=None if across_features else L, + ) + + U.check_differentiability(device, dtype, medfilt, [N]) + + +def test_various_shape(K=3, N=10): + medfilt = diffsptk.MedianFilter(K) + U.check_various_shape(medfilt, [(N,), (N, 1), (1, N, 1)])