diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d98dcbe..1e99a8a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: pull_request: branches: - - '**' + - "**" jobs: test: diff --git a/README.md b/README.md index 985ef44..93a7379 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ It provides various speech signal processing modules as PyTorch layers, allowing users to integrate classic signal processing algorithms directly into neural network architectures and optimize them through backpropagation. -[![Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/3.4.0/) +[![Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/stable/) [![Downloads](https://static.pepy.tech/badge/diffsptk)](https://pepy.tech/project/diffsptk) [![ClickPy](https://img.shields.io/badge/downloads-clickpy-yellow.svg)](https://clickpy.clickhouse.com/dashboard/diffsptk) [![Python Version](https://img.shields.io/pypi/pyversions/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk) @@ -22,7 +22,7 @@ allowing users to integrate classic signal processing algorithms directly into n ## Documentation -- [**Reference Manual**](https://sp-nitech.github.io/diffsptk/3.4.0/) - Detailed API documentation and module specifications. +- [**Reference Manual**](https://sp-nitech.github.io/diffsptk/stable/) - Detailed API documentation and module specifications. - [**Interactive Tutorial**](https://colab.research.google.com/drive/1xAoUKqXadvJXJ7RzN0OceB6y7q5i7Sn6?usp=drive_link) (Google Colab) - Hands-on examples to get started with `diffsptk` in your browser. - [**Conference Paper**](https://www.isca-archive.org/ssw_2023/yoshimura23_ssw.html) - Technical background and implementation details available on the ISCA Archive. diff --git a/diffsptk/functional.py b/diffsptk/functional.py index d56feee..6ff61f6 100644 --- a/diffsptk/functional.py +++ b/diffsptk/functional.py @@ -3248,7 +3248,12 @@ def zcross( def zerodf( - x: Tensor, b: Tensor, frame_period: int = 80, ignore_gain: bool = False + x: Tensor, + b: Tensor, + frame_period: int = 80, + ignore_gain: bool = False, + zeroth_index: int = 0, + mode: str = "direct", ) -> Tensor: """Apply an all-zero digital filter. @@ -3266,6 +3271,12 @@ def zerodf( ignore_gain : bool If True, perform filtering without the gain. + zeroth_index : int >= 0 + The index of the zeroth coefficient in the filter coefficients. + + mode : ['direct', 'efficient'] + The implementation mode for time-varying convolution. + Returns ------- out : Tensor [shape=(..., T)] @@ -3273,5 +3284,10 @@ def zerodf( """ return nn.AllZeroDigitalFilter._func( - x, b, frame_period=frame_period, ignore_gain=ignore_gain + x, + b, + frame_period=frame_period, + ignore_gain=ignore_gain, + zeroth_index=zeroth_index, + mode=mode, ) diff --git a/diffsptk/modules/mglsadf.py b/diffsptk/modules/mglsadf.py index 85ae910..24f1b38 100644 --- a/diffsptk/modules/mglsadf.py +++ b/diffsptk/modules/mglsadf.py @@ -35,6 +35,7 @@ from .mgc2sp import MelGeneralizedCepstrumToSpectrum from .root_pol import PolynomialToRoots from .stft import ShortTimeFourierTransform +from .zerodf import AllZeroDigitalFilter def is_array_like(x: Any) -> bool: @@ -277,18 +278,18 @@ def __init__( if alpha == 0 and gamma == 0: cep_order = filter_order - # Prepare padding module. if self.phase == "minimum": - padding = (cep_order, 0) + cep_orders = (cep_order, 0) elif self.phase == "maximum": - padding = (0, cep_order) + cep_orders = (0, cep_order) elif self.phase == "zero": - padding = (cep_order, cep_order) + cep_orders = (cep_order, cep_order) elif self.phase == "mixed": - padding = cep_order if is_array_like(cep_order) else (cep_order, cep_order) + cep_orders = ( + cep_order if is_array_like(cep_order) else (cep_order, cep_order) + ) else: raise ValueError(f"phase {phase} is not supported.") - self.pad = nn.ConstantPad1d(padding, 0) # Prepare frequency transformation module. if self.phase == "mixed": @@ -297,7 +298,7 @@ def __init__( self.mgc2c.append( MelGeneralizedCepstrumToMelGeneralizedCepstrum( filter_order[i], - padding[i], + cep_orders[i], in_alpha=alpha, in_gamma=gamma, n_fft=n_fft, @@ -318,6 +319,16 @@ def __init__( self.linear_intpl = LinearInterpolation(frame_period) + self.zerodf = AllZeroDigitalFilter( + sum(cep_orders), + frame_period, + ignore_gain=False, + zeroth_index=cep_orders[1], + mode="efficient", + device=device, + dtype=dtype, + ) + cp = mp.taylor(mp.exp, 0, taylor_order) cp = np.array([float(x) for x in cp]) weights = cp[1:] / cp[:-1] @@ -341,29 +352,25 @@ def forward( c_min = self.mgc2c[0](mc_min) c_max = self.mgc2c[1](mc_max) c0 = c_min[..., :1] + c_max[..., :1] - c1_min = c_min[..., 1:].flip(-1) + c1_min = c_min[..., 1:] c0_dummy = torch.zeros_like(c0) - c1_max = c_max[..., 1:] - c = torch.cat([c1_min, c0_dummy, c1_max], dim=-1) + c1_max = c_max[..., 1:].flip(-1) + c = torch.cat([c1_max, c0_dummy, c1_min], dim=-1) else: c = self.mgc2c(mc) c0, c = remove_gain(c, value=0, return_gain=True) if self.phase == "minimum": - c = c.flip(-1) - elif self.phase == "maximum": pass + elif self.phase == "maximum": + c = c.flip(-1) elif self.phase == "zero": c = mirror(c, half=True) else: raise RuntimeError - c = self.linear_intpl(c) - y = x * self.a[0] for i in range(1, len(self.a)): - x = self.pad(x) - x = x.unfold(-1, c.size(-1), 1) - x = (x * c).sum(-1) * self.weights[i] + x = self.zerodf(x, c) * self.weights[i] y += x * self.a[i] if not self.ignore_gain: @@ -389,28 +396,26 @@ def __init__( ) -> None: super().__init__() + self.frame_period = frame_period self.ignore_gain = ignore_gain self.phase = phase self.n_fft = n_fft - # Prepare padding module. - taps = ir_length - 1 if self.phase == "minimum": - padding = (taps, 0) + ir_orders = (ir_length - 1, 0) elif self.phase == "maximum": - padding = (0, taps) + ir_orders = (0, ir_length - 1) elif self.phase == "zero": - padding = (taps, taps) + ir_orders = (ir_length - 1, ir_length - 1) elif self.phase == "mixed": - padding = ( + ir_orders = ( (ir_length[0] - 1, ir_length[1] - 1) if is_array_like(ir_length) - else (taps, taps) + else (ir_length - 1, ir_length - 1) ) else: raise ValueError(f"phase {phase} is not supported.") - self.pad = nn.ConstantPad1d(padding, 0) - self.padding = padding + self.ir_orders = ir_orders if self.phase in ("minimum", "maximum"): self.mgc2ir = MelGeneralizedCepstrumToMelGeneralizedCepstrum( @@ -444,7 +449,7 @@ def __init__( self.mgc2c.append( MelGeneralizedCepstrumToMelGeneralizedCepstrum( filter_order[i], - padding[i], + ir_orders[i], in_alpha=alpha, in_gamma=gamma, n_fft=n_fft, @@ -458,7 +463,15 @@ def __init__( else: raise ValueError(f"phase {phase} is not supported.") - self.linear_intpl = LinearInterpolation(frame_period) + self.zerodf = AllZeroDigitalFilter( + sum(ir_orders), + frame_period, + ignore_gain=False, + zeroth_index=ir_orders[1], + mode="efficient", + device=device, + dtype=dtype, + ) def forward( self, @@ -467,9 +480,13 @@ def forward( ) -> torch.Tensor: if self.phase == "minimum": h = self.mgc2ir(mc) - h = h.flip(-1) + if self.ignore_gain: + h = h / h[..., :1] elif self.phase == "maximum": h = self.mgc2ir(mc) + if self.ignore_gain: + h = h / h[..., :1] + h = h.flip(-1) elif self.phase == "zero": c = self.mgc2c(mc) c[..., 1:] *= 0.5 @@ -485,25 +502,16 @@ def forward( c0 = torch.zeros_like(c_min[..., :1]) else: c0 = c_min[..., :1] + c_max[..., :1] - c = torch.cat([c_min[..., 1:].flip(-1), c0, c_max[..., 1:]], dim=-1) + c = torch.cat([c_max[..., 1:].flip(-1), c0, c_min[..., 1:]], dim=-1) c = F.pad(c, (0, self.n_fft - c.size(-1))) - c = torch.roll(c, -self.padding[0], dims=-1) + shift = self.ir_orders[1] + c = torch.roll(c, -shift, dims=-1) h = self.c2ir(c) - h = torch.roll(h, self.padding[0], dims=-1)[..., : sum(self.padding) + 1] + h = torch.roll(h, shift, dims=-1)[..., : sum(self.ir_orders) + 1] else: raise RuntimeError - h = self.linear_intpl(h) - - if self.ignore_gain: - if self.phase == "minimum": - h = h / h[..., -1:] - elif self.phase == "maximum": - h = h / h[..., :1] - - x = self.pad(x) - x = x.unfold(-1, h.size(-1), 1) - y = (x * h).sum(-1) + y = self.zerodf(x, h) return y diff --git a/diffsptk/modules/zerodf.py b/diffsptk/modules/zerodf.py index 1772513..a2b94fb 100644 --- a/diffsptk/modules/zerodf.py +++ b/diffsptk/modules/zerodf.py @@ -17,8 +17,8 @@ import torch import torch.nn.functional as F -from ..typing import Precomputed -from ..utils.private import check_size, filter_values +from ..typing import Callable, Precomputed +from ..utils.private import check_size, filter_values, to from .base import BaseFunctionalModule from .linear_intpl import LinearInterpolation @@ -38,16 +38,47 @@ class AllZeroDigitalFilter(BaseFunctionalModule): ignore_gain : bool If True, perform filtering without the gain. + zeroth_index : int >= 0 + The index of the zeroth coefficient in the filter coefficients. If 0, the filter + is assumed to be minimum-phase. If `M`, the filter is assumed to be + maximum-phase. + + mode : ['direct', 'efficient'] + The implementation mode for time-varying convolution. 'direct' applies + convolution at the sample level, linearly interpolating the filter coefficients + to match the length of the input signal. This approach is simple to understand + but requires substantial memory. 'efficient' instead performs two separate + convolutions - one with the original filter coefficients and one with the + shifted coefficients - and interpolates between their outputs. This avoids the + need to interpolate the filter coefficients themselves, resulting in lower + memory usage. Both modes are mathematically equivalent, though 'efficient' may + produce slight numerical differences owing to the different order of operations. + + device : torch.device or None + The device of this module. + + dtype : torch.dtype or None + The data type of this module. + """ def __init__( - self, filter_order: int, frame_period: int, ignore_gain: bool = False + self, + filter_order: int, + frame_period: int, + ignore_gain: bool = False, + zeroth_index: int = 0, + mode: str = "direct", + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> None: super().__init__() self.in_dim = filter_order + 1 - self.values = self._precompute(**filter_values(locals())) + self.values, _, tensors = self._precompute(**filter_values(locals())) + if len(tensors) > 0: + self.register_buffer("ramp", tensors[0]) def forward(self, x: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Apply an all-zero digital filter. @@ -77,7 +108,7 @@ def forward(self, x: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ check_size(b.size(-1), self.in_dim, "dimension of impulse response") - return self._forward(x, b, *self.values) + return self._forward(x, b, *self.values, **self._buffers) @staticmethod def _takes_input_size() -> bool: @@ -85,34 +116,121 @@ def _takes_input_size() -> bool: @staticmethod def _func(x: torch.Tensor, b: torch.Tensor, *args, **kwargs) -> torch.Tensor: - values = AllZeroDigitalFilter._precompute(b.size(-1) - 1, *args, **kwargs) - return AllZeroDigitalFilter._forward(x, b, *values) + values, _, tensors = AllZeroDigitalFilter._precompute( + b.size(-1) - 1, *args, **kwargs, device=b.device, dtype=b.dtype + ) + return AllZeroDigitalFilter._forward(x, b, *values, *tensors) @staticmethod - def _check(filter_order: int, frame_period: int) -> None: + def _check( + filter_order: int, + frame_period: int, + ignore_gain: bool, + zeroth_index: int, + ) -> None: if filter_order < 0: raise ValueError("filter_order must be non-negative.") if frame_period <= 0: raise ValueError("frame_period must be positive.") + if ignore_gain and zeroth_index not in (0, filter_order): + raise ValueError( + "zeroth_index must be 0 or filter_order when ignore_gain is True." + ) + if zeroth_index < 0 or zeroth_index > filter_order: + raise ValueError("zeroth_index must be in [0, filter_order].") @staticmethod def _precompute( - filter_order: int, frame_period: int, ignore_gain: bool = False + filter_order: int, + frame_period: int, + ignore_gain: bool = False, + zeroth_index: int = 0, + mode: str = "direct", + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> Precomputed: - AllZeroDigitalFilter._check(filter_order, frame_period) - return (frame_period, ignore_gain) + AllZeroDigitalFilter._check( + filter_order, frame_period, ignore_gain, zeroth_index + ) + + padding = (filter_order - zeroth_index, zeroth_index) + + if mode == "direct": + impl = AllZeroDigitalFilter._forward_direct + tensors = () + elif mode == "efficient": + impl = AllZeroDigitalFilter._forward_efficient + ramp = torch.arange(frame_period, device=device) / frame_period + ramp = ramp.view(1, 1, -1) + tensors = (to(ramp, dtype=dtype),) + else: + raise ValueError("mode must be 'direct' or 'efficient'.") + + return (frame_period, ignore_gain, padding, impl), None, tensors @staticmethod def _forward( - x: torch.Tensor, b: torch.Tensor, frame_period: int, ignore_gain: bool + x: torch.Tensor, + b: torch.Tensor, + frame_period: int, + ignore_gain: bool, + padding: tuple[int, int], + impl: Callable, + *args, + **kwargs, ) -> torch.Tensor: check_size(x.size(-1), b.size(-2) * frame_period, "sequence length") + return impl(x, b, frame_period, ignore_gain, padding, *args, **kwargs) + @staticmethod + def _forward_direct( + x: torch.Tensor, + b: torch.Tensor, + frame_period: int, + ignore_gain: bool, + padding: tuple[int, int], + ) -> torch.Tensor: M = b.size(-1) - 1 - x = F.pad(x, (M, 0)) + x = F.pad(x, padding) x = x.unfold(-1, M + 1, 1) h = LinearInterpolation._func(b.flip(-1), frame_period) if ignore_gain: - h = h / h[..., -1:] + h = h / (h[..., :1] if padding[0] == 0 else h[..., -1:]) y = (x * h).sum(-1) return y + + @staticmethod + def _forward_efficient( + x: torch.Tensor, + b: torch.Tensor, + frame_period: int, + ignore_gain: bool, + padding: tuple[int, int], + ramp: torch.Tensor, + ) -> torch.Tensor: + x_org_shape = x.shape + x = x.view(-1, x.size(-1)) + b = b.view(-1, b.size(-2), b.size(-1)) + B, N, L = b.size() + BN = B * N + + b1 = b.flip(-1) + b2 = F.pad(b1[:, 1:], (0, 0, 0, 1), mode="replicate") + weight1 = b1.view(BN, 1, L) + weight2 = b2.view(BN, 1, L) + + x = F.pad(x, padding) + x = x.unfold(-1, L - 1 + frame_period, frame_period) + x = x.reshape(1, BN, L - 1 + frame_period) + + y1 = F.conv1d(x, weight1, groups=BN) + y2 = F.conv1d(x, weight2, groups=BN) + y = torch.lerp(y1, y2, ramp) + y = y.view(*x_org_shape) + + if ignore_gain: + b0 = b1[..., :1] if padding[0] == 0 else b1[..., -1:] + g = LinearInterpolation._func(b0, frame_period) + g = g.reshape(*x_org_shape) + y = y / g + return y diff --git a/diffsptk/version.py b/diffsptk/version.py index 6cd661e..dcbfb52 100644 --- a/diffsptk/version.py +++ b/diffsptk/version.py @@ -1 +1 @@ -__version__ = "3.4.1.dev0" +__version__ = "3.5.0" diff --git a/pyproject.toml b/pyproject.toml index 43f625b..75f6ec3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ dev = [ [project.urls] Homepage = "https://sp-tk.sourceforge.net/" -Documentation = "https://sp-nitech.github.io/diffsptk/latest/" +Documentation = "https://sp-nitech.github.io/diffsptk/stable/" Source = "https://github.com/sp-nitech/diffsptk" [tool.hatch.build.targets.sdist] diff --git a/tests/test_zerodf.py b/tests/test_zerodf.py index 9fd1e6a..21428db 100644 --- a/tests/test_zerodf.py +++ b/tests/test_zerodf.py @@ -15,6 +15,7 @@ # ------------------------------------------------------------------------ # import pytest +import torch import diffsptk import tests.utils as U @@ -45,3 +46,30 @@ def test_compatibility(device, dtype, module, ignore_gain, M=3, T=100, P=10): ) U.check_differentiability(device, dtype, zerodf, [(T,), (T // P, M + 1)]) + + +@pytest.mark.parametrize("ignore_gain", [False, True]) +def test_efficient_mode(device, dtype, ignore_gain, M=3, T=100, P=10, B=2): + zerodf_direct = diffsptk.AllZeroDigitalFilter( + filter_order=M, + frame_period=P, + ignore_gain=ignore_gain, + mode="direct", + device=device, + dtype=dtype, + ) + + zerodf_efficient = diffsptk.AllZeroDigitalFilter( + filter_order=M, + frame_period=P, + ignore_gain=ignore_gain, + mode="efficient", + device=device, + dtype=dtype, + ) + + x = torch.randn(B, T, device=device, dtype=dtype) + b = torch.randn(B, T // P, M + 1, device=device, dtype=dtype) + y_direct = zerodf_direct(x, b).cpu().numpy() + y_efficient = zerodf_efficient(x, b).cpu().numpy() + assert U.allclose(y_direct, y_efficient, dtype=dtype, factor=10) diff --git a/tests/utils.py b/tests/utils.py index 7c916b0..78f3bf3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -172,7 +172,7 @@ def check_compatibility( if eq is None: assert allclose(y_hat, y, dtype=dtype, **kwargs), ( - f"Output: {y_hat}\nTarget: {y}" + f"Output: {y_hat}\nTarget: {y}\nError: {np.abs(y_hat - y).max()}" ) else: assert eq(y_hat, y, **kwargs), f"Output: {y_hat}\nTarget: {y}" @@ -237,6 +237,8 @@ def check_differentiability( for i in range(load): if i == 1: + if device == "cuda": + torch.cuda.synchronize() s = time.process_time() y = module(*xs, **opt) optimizer.zero_grad() @@ -244,6 +246,9 @@ def check_differentiability( loss.backward() optimizer.step() + if device == "cuda": + torch.cuda.synchronize() + if 1 < load: e = time.process_time() print(f"time: {e - s}") diff --git a/tools/Makefile b/tools/Makefile index fdf1a8e..beec1b7 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -15,7 +15,7 @@ # ------------------------------------------------------------------------ # TAPLO_VERSION := 0.10.0 -YAMLFMT_VERSION := 0.20.0 +YAMLFMT_VERSION := 0.21.0 all: SPTK taplo yamlfmt diff --git a/yamlfmt.yml b/yamlfmt.yml index 0d7c56d..ba0a2b7 100644 --- a/yamlfmt.yml +++ b/yamlfmt.yml @@ -1,3 +1,5 @@ formatter: type: basic retain_line_breaks_single: true + force_array_style: block + force_quote_style: double