Skip to content

Commit 15bf34a

Browse files
authored
Merge pull request #159 from sp-nitech/pade
Add new mode for Pade coefficients optimization
2 parents 3dd8ca6 + 583f547 commit 15bf34a

7 files changed

Lines changed: 310 additions & 21 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ test: tool
7171
test-example: tool
7272
[ -n "$(MODULE)" ] && module=modules/$(MODULE).py || module=; \
7373
. .venv/bin/activate && export NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS=0 && \
74-
python -m pytest --doctest-modules --no-cov --ignore=diffsptk/third_party diffsptk/$$module
74+
python -m pytest --doctest-modules --no-cov --ignore=$(PROJECT)/third_party $(PROJECT)/$$module
7575

7676
test-clean:
7777
rm -rf tests/__pycache__

diffsptk/modules/acorr.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,9 @@ def _precompute(
101101
elif out_format in (2, "biased"):
102102
formatter = lambda x: x / frame_length
103103
elif out_format in (3, "unbiased"):
104-
formatter = lambda x: x / (
105-
torch.arange(
106-
frame_length, frame_length - acr_order - 1, -1, device=x.device
107-
)
104+
n = frame_length - acr_order - 1
105+
formatter = lambda x: (
106+
x / (torch.arange(frame_length, n, -1, device=x.device))
108107
)
109108
else:
110109
raise ValueError(f"out_format {out_format} is not supported.")

diffsptk/modules/mcep.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ def forward(self, x: torch.Tensor):
9999
... )
100100
>>> x = diffsptk.ramp(19)
101101
>>> mc = mcep(stft(x))
102-
>>> mc
103-
tensor([[-0.8851, 0.7917, -0.1737, 0.0175],
104-
[-0.3522, 4.4222, -1.0882, -0.0510]])
102+
>>> mc.round(decimals=3)
103+
tensor([[-0.8850, 0.7920, -0.1740, 0.0170],
104+
[-0.3520, 4.4220, -1.0880, -0.0510]])
105105
106106
"""
107107
check_size(x.size(-1), self.in_dim, "dimension of spectrum")

diffsptk/modules/mglsadf.py

Lines changed: 236 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,24 @@
1616

1717
from typing import Any
1818

19+
import mpmath as mp
1920
import numpy as np
2021
import torch
2122
import torch.nn.functional as F
2223
from torch import nn
2324

24-
from ..utils.private import Lambda, check_size, get_gamma, remove_gain
25+
from ..utils.private import Lambda, check_size, get_gamma, remove_gain, to
2526
from .b2mc import MLSADigitalFilterCoefficientsToMelCepstrum
2627
from .base import BaseNonFunctionalModule
2728
from .c2mpir import CepstrumToMinimumPhaseImpulseResponse
29+
from .frame import Frame
2830
from .gnorm import GeneralizedCepstrumGainNormalization
2931
from .istft import InverseShortTimeFourierTransform
3032
from .linear_intpl import LinearInterpolation
3133
from .mc2b import MelCepstrumToMLSADigitalFilterCoefficients
3234
from .mgc2mgc import MelGeneralizedCepstrumToMelGeneralizedCepstrum
3335
from .mgc2sp import MelGeneralizedCepstrumToSpectrum
36+
from .root_pol import PolynomialToRoots
3437
from .stft import ShortTimeFourierTransform
3538

3639

@@ -74,12 +77,15 @@ class PseudoMGLSADigitalFilter(BaseNonFunctionalModule):
7477
phase : ['minimum', 'maximum', 'zero', 'mixed']
7578
The filter type.
7679
77-
mode : ['multi-stage', 'single-stage', 'freq-domain']
80+
mode : ['multi-stage', 'single-stage', 'freq-domain', 'pade-approx']
7881
'multi-stage' approximates the MLSA filter by cascading FIR filters based on the
7982
Taylor series expansion. 'single-stage' uses an FIR filter with the coefficients
8083
derived from the impulse response converted from the input mel-cepstral
8184
coefficients using FFT. 'freq-domain' performs filtering in the frequency domain
82-
rather than the time domain.
85+
rather than the time domain. 'pade-approx' implements the MLSA filter by
86+
cascading all-zero and all-pole filters derived from the factorization. While
87+
this approach is not computationally efficient, it allows for the optimization
88+
of the Pade approximation coefficients.
8389
8490
n_fft : int >= 1
8591
The number of FFT bins used for conversion. Higher values result in increased
@@ -89,12 +95,20 @@ class PseudoMGLSADigitalFilter(BaseNonFunctionalModule):
8995
The order of the Taylor series expansion (valid only if **mode** is
9096
'multi-stage').
9197
98+
pade_order : int >= 3
99+
The order of Pade approximation (valid only if **mode** is 'pade-approx').
100+
92101
cep_order : int >= 0 or tuple[int, int]
93-
The order of the linear cepstrum (valid only if **mode** is 'multi-stage').
102+
The order of the linear cepstrum (valid only if **mode** is 'multi-stage' or
103+
'pade-approx').
94104
95105
ir_length : int >= 1 or tuple[int, int]
96106
The length of the impulse response (valid only if **mode** is 'single-stage').
97107
108+
learnable : bool
109+
If True, the polynomial coefficients used in the approximation are learnable
110+
(valid only if **mode** is 'multi-stage' or 'pade-approx').
111+
98112
device : torch.device or None
99113
The device of this module.
100114
@@ -181,6 +195,16 @@ def flip(x):
181195
phase=phase,
182196
**modified_kwargs,
183197
)
198+
elif mode == "pade-approx":
199+
self.mglsadf = MultiStageIIRFilter(
200+
flipped_filter_order,
201+
frame_period,
202+
alpha=alpha,
203+
gamma=gamma,
204+
ignore_gain=ignore_gain,
205+
phase=phase,
206+
**modified_kwargs,
207+
)
184208
else:
185209
raise ValueError(f"mode {mode} is not supported.")
186210

@@ -238,6 +262,7 @@ def __init__(
238262
taylor_order: int = 20,
239263
cep_order: tuple[int, int] | int = 199,
240264
n_fft: int = 512,
265+
learnable: bool = False,
241266
device: torch.device | None = None,
242267
dtype: torch.dtype | None = None,
243268
) -> None:
@@ -248,7 +273,6 @@ def __init__(
248273

249274
self.ignore_gain = ignore_gain
250275
self.phase = phase
251-
self.taylor_order = taylor_order
252276

253277
if alpha == 0 and gamma == 0:
254278
cep_order = filter_order
@@ -294,6 +318,19 @@ def __init__(
294318

295319
self.linear_intpl = LinearInterpolation(frame_period)
296320

321+
cp = mp.taylor(mp.exp, 0, taylor_order)
322+
cp = np.array([float(x) for x in cp])
323+
weights = cp[1:] / cp[:-1]
324+
weights = np.insert(weights, 0, 1)
325+
self.register_buffer("weights", to(weights, device=device, dtype=dtype))
326+
327+
a = np.ones(taylor_order + 1)
328+
a = to(a, device=device, dtype=dtype)
329+
if learnable:
330+
self.a = nn.Parameter(a)
331+
else:
332+
self.register_buffer("a", a)
333+
297334
def forward(
298335
self,
299336
x: torch.Tensor,
@@ -322,12 +359,12 @@ def forward(
322359

323360
c = self.linear_intpl(c)
324361

325-
y = x.clone()
326-
for a in range(1, self.taylor_order + 1):
362+
y = x * self.a[0]
363+
for i in range(1, len(self.a)):
327364
x = self.pad(x)
328365
x = x.unfold(-1, c.size(-1), 1)
329-
x = (x * c).sum(-1) / a
330-
y += x
366+
x = (x * c).sum(-1) * self.weights[i]
367+
y += x * self.a[i]
331368

332369
if not self.ignore_gain:
333370
K = torch.exp(self.linear_intpl(c0))
@@ -586,3 +623,193 @@ def forward(
586623
Y = H * X
587624
y = self.istft(Y, out_length=x.size(-1))
588625
return y
626+
627+
628+
class MultiStageIIRFilter(nn.Module):
629+
def __init__(
630+
self,
631+
filter_order: tuple[int, int] | int,
632+
frame_period: int,
633+
*,
634+
alpha: float = 0,
635+
gamma: float = 0,
636+
ignore_gain: bool = False,
637+
phase: str = "minimum",
638+
pade_order: int = 5,
639+
cep_order: tuple[int, int] | int = 199,
640+
n_fft: int = 512,
641+
chunk_length: int | None = None,
642+
warmup_length: int | None = None,
643+
learnable: bool = False,
644+
per_stage_pade_coefficients: bool = False,
645+
device: torch.device | None = None,
646+
dtype: torch.dtype | None = None,
647+
) -> None:
648+
super().__init__()
649+
650+
if phase != "minimum" or is_array_like(filter_order):
651+
raise ValueError("Only minimum-phase filter is supported.")
652+
653+
self.ignore_gain = ignore_gain
654+
655+
self.mgc2c = MelGeneralizedCepstrumToMelGeneralizedCepstrum(
656+
filter_order,
657+
cep_order,
658+
in_alpha=alpha,
659+
in_gamma=gamma,
660+
n_fft=n_fft,
661+
device=device,
662+
dtype=dtype,
663+
)
664+
self.linear_intpl = LinearInterpolation(frame_period)
665+
self.root_pol = PolynomialToRoots(pade_order, device=device, dtype=dtype)
666+
667+
from torchlpc import sample_wise_lpc
668+
669+
self.sample_wise_lpc = sample_wise_lpc
670+
671+
if chunk_length is None:
672+
self.chuking = False
673+
else:
674+
self.chuking = True
675+
self.warmup_length = (
676+
warmup_length if warmup_length is not None else cep_order
677+
)
678+
if chunk_length <= 0:
679+
raise ValueError("chunk_length must be positive.")
680+
if self.warmup_length < 0:
681+
raise ValueError("warmup_length must be non-negative.")
682+
frame_period = chunk_length - self.warmup_length
683+
self.frame_x = Frame(chunk_length, frame_period, center=False)
684+
self.frame_c = Frame(
685+
cep_order * chunk_length, cep_order * frame_period, center=False
686+
)
687+
688+
cr = mp.taylor(mp.exp, 0, pade_order * 2)
689+
cp, cq = mp.pade(cr, pade_order, pade_order)
690+
cp = np.array([float(x) for x in cp])
691+
weights = cp[1:] / cp[:-1]
692+
weights = np.insert(weights, 0, 1)
693+
self.register_buffer("weights", to(weights, device=device, dtype=dtype))
694+
695+
if pade_order == 3:
696+
a1 = np.linspace(1.0, 0.4, pade_order + 1)
697+
elif pade_order == 4:
698+
a1 = np.linspace(1.0, 0.6, pade_order + 1)
699+
elif 5 <= pade_order <= 14:
700+
a1 = np.ones(pade_order + 1)
701+
else:
702+
raise ValueError("pade_order must be in [3, 14].")
703+
704+
if learnable and per_stage_pade_coefficients:
705+
a2 = a1
706+
a1 = np.ones(pade_order + 1)
707+
a1 = to(a1, device=device, dtype=dtype)
708+
a2 = to(a2, device=device, dtype=dtype)
709+
self.a1 = nn.Parameter(a1)
710+
self.a2 = nn.Parameter(a2)
711+
else:
712+
a1 = to(a1, device=device, dtype=dtype)
713+
if learnable:
714+
self.a1 = nn.Parameter(a1)
715+
else:
716+
self.register_buffer("a1", a1)
717+
self.a2 = self.a1
718+
719+
def forward(
720+
self, x: torch.Tensor, mc: torch.Tensor, return_roots: bool = False
721+
) -> torch.Tensor:
722+
if x.dim() == 1:
723+
x = x.unsqueeze(0)
724+
mc = mc.unsqueeze(0)
725+
unsqueezed = True
726+
else:
727+
unsqueezed = False
728+
729+
if x.dim() != 2 or mc.dim() != 3:
730+
raise ValueError("x and mc must be 2-D and 3-D tensors, respectively.")
731+
732+
c = self.mgc2c(mc)
733+
c0, c1 = torch.split(c, [1, c.size(-1) - 1], dim=-1)
734+
c_b = self.linear_intpl(c1.flip(-1))
735+
c_a = self.linear_intpl(c1)
736+
737+
T = x.size(-1)
738+
B, _, M = c_a.size()
739+
740+
a1 = torch.clip(self.a1, min=1e-1, max=1e1)
741+
a1[0] = 1.0
742+
a2 = torch.clip(self.a2, min=1e-1, max=1e1)
743+
a2[0] = 1.0
744+
745+
c_b2, c_b1 = torch.split(c_b, [c_b.size(-1) - 1, 1], dim=-1)
746+
c_b1 = c_b1.squeeze(-1)
747+
748+
# Numerator, 1st stage:
749+
y = x * a1[0]
750+
for i in range(1, len(a1)):
751+
x = F.pad(x[..., :-1], (1, 0))
752+
x = x * c_b1 * self.weights[i]
753+
y += x * a1[i]
754+
755+
# Numerator, 2nd stage:
756+
x = y
757+
y = x * a2[0]
758+
for i in range(1, len(a2)):
759+
x = F.pad(x, (M, 0))
760+
x = x.unfold(-1, M + 1, 1)
761+
x = (x[..., :-2] * c_b2).sum(-1) * self.weights[i]
762+
y += x * a2[i]
763+
764+
if self.chuking:
765+
y = F.pad(y, (self.warmup_length, 0))
766+
y = self.frame_x(y)
767+
y = y.reshape(-1, y.size(-1))
768+
769+
c_a = c_a.reshape(B, -1)
770+
c_a = F.pad(c_a, (M * self.warmup_length, 0))
771+
c_a = self.frame_c(c_a)
772+
c_a = c_a.reshape(y.size(0), y.size(1), M)
773+
774+
c_a1, c_a2 = torch.split(c_a, [1, c_a.size(-1) - 1], dim=-1)
775+
c_a2 = F.pad(c_a2, (1, 0))
776+
777+
def compute_roots(a: torch.Tensor) -> torch.Tensor:
778+
pade_coefficients = torch.cumprod(self.weights, 0) * a
779+
roots = self.root_pol(pade_coefficients.flip(0).double())
780+
roots = roots.to(
781+
torch.complex64 if a.dtype == torch.float32 else torch.complex128
782+
)
783+
return roots
784+
785+
roots1 = compute_roots(a1)
786+
roots2 = compute_roots(a2)
787+
roots = torch.stack([roots1, roots2], dim=0)
788+
789+
# Denominator, 1st stage:
790+
y = y.to(roots.dtype)
791+
p1 = torch.reciprocal(roots1)
792+
for i in range(len(p1)):
793+
y = self.sample_wise_lpc(y, (p1[i] * c_a1))
794+
795+
# Denominator, 2nd stage:
796+
p2 = torch.reciprocal(roots2)
797+
for i in range(len(p2)):
798+
y = self.sample_wise_lpc(y, (p2[i] * c_a2))
799+
y = y.real
800+
801+
if self.chuking:
802+
y = y[..., self.warmup_length :]
803+
y = y.reshape(B, -1)
804+
y = y[..., :T]
805+
806+
if not self.ignore_gain:
807+
K = torch.exp(self.linear_intpl(c0))
808+
y = y * K.squeeze(-1)
809+
810+
if unsqueezed:
811+
y = y.squeeze(0)
812+
813+
if return_roots:
814+
return y, roots
815+
return y

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ classifiers = [
2626
dependencies = [
2727
"numpy >= 1.23.0",
2828
"scipy >= 1.12.0",
29+
"mpmath >= 0.17.0",
2930
"tqdm >= 4.63.0",
3031
"soundfile >= 0.10.2",
3132
"torch >= 2.3.1",

0 commit comments

Comments
 (0)