Skip to content

Commit 2b84690

Browse files
authored
Merge pull request #161 from sp-nitech/zerodf [skip ci]
Improve computational efficiency of FIR
2 parents cadec71 + 95656b6 commit 2b84690

11 files changed

Lines changed: 243 additions & 66 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ on:
77

88
pull_request:
99
branches:
10-
- '**'
10+
- "**"
1111

1212
jobs:
1313
test:

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
It provides various speech signal processing modules as PyTorch layers,
55
allowing users to integrate classic signal processing algorithms directly into neural network architectures and optimize them through backpropagation.
66

7-
[![Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/3.4.0/)
7+
[![Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/stable/)
88
[![Downloads](https://static.pepy.tech/badge/diffsptk)](https://pepy.tech/project/diffsptk)
99
[![ClickPy](https://img.shields.io/badge/downloads-clickpy-yellow.svg)](https://clickpy.clickhouse.com/dashboard/diffsptk)
1010
[![Python Version](https://img.shields.io/pypi/pyversions/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk)
@@ -22,7 +22,7 @@ allowing users to integrate classic signal processing algorithms directly into n
2222

2323
## Documentation
2424

25-
- [**Reference Manual**](https://sp-nitech.github.io/diffsptk/3.4.0/) - Detailed API documentation and module specifications.
25+
- [**Reference Manual**](https://sp-nitech.github.io/diffsptk/stable/) - Detailed API documentation and module specifications.
2626
- [**Interactive Tutorial**](https://colab.research.google.com/drive/1xAoUKqXadvJXJ7RzN0OceB6y7q5i7Sn6?usp=drive_link) (Google Colab) - Hands-on examples to get started with `diffsptk` in your browser.
2727
- [**Conference Paper**](https://www.isca-archive.org/ssw_2023/yoshimura23_ssw.html) - Technical background and implementation details available on the ISCA Archive.
2828

diffsptk/functional.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3248,7 +3248,12 @@ def zcross(
32483248

32493249

32503250
def zerodf(
3251-
x: Tensor, b: Tensor, frame_period: int = 80, ignore_gain: bool = False
3251+
x: Tensor,
3252+
b: Tensor,
3253+
frame_period: int = 80,
3254+
ignore_gain: bool = False,
3255+
zeroth_index: int = 0,
3256+
mode: str = "direct",
32523257
) -> Tensor:
32533258
"""Apply an all-zero digital filter.
32543259
@@ -3266,12 +3271,23 @@ def zerodf(
32663271
ignore_gain : bool
32673272
If True, perform filtering without the gain.
32683273
3274+
zeroth_index : int >= 0
3275+
The index of the zeroth coefficient in the filter coefficients.
3276+
3277+
mode : ['direct', 'efficient']
3278+
The implementation mode for time-varying convolution.
3279+
32693280
Returns
32703281
-------
32713282
out : Tensor [shape=(..., T)]
32723283
The output signal.
32733284
32743285
"""
32753286
return nn.AllZeroDigitalFilter._func(
3276-
x, b, frame_period=frame_period, ignore_gain=ignore_gain
3287+
x,
3288+
b,
3289+
frame_period=frame_period,
3290+
ignore_gain=ignore_gain,
3291+
zeroth_index=zeroth_index,
3292+
mode=mode,
32773293
)

diffsptk/modules/mglsadf.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .mgc2sp import MelGeneralizedCepstrumToSpectrum
3636
from .root_pol import PolynomialToRoots
3737
from .stft import ShortTimeFourierTransform
38+
from .zerodf import AllZeroDigitalFilter
3839

3940

4041
def is_array_like(x: Any) -> bool:
@@ -277,18 +278,18 @@ def __init__(
277278
if alpha == 0 and gamma == 0:
278279
cep_order = filter_order
279280

280-
# Prepare padding module.
281281
if self.phase == "minimum":
282-
padding = (cep_order, 0)
282+
cep_orders = (cep_order, 0)
283283
elif self.phase == "maximum":
284-
padding = (0, cep_order)
284+
cep_orders = (0, cep_order)
285285
elif self.phase == "zero":
286-
padding = (cep_order, cep_order)
286+
cep_orders = (cep_order, cep_order)
287287
elif self.phase == "mixed":
288-
padding = cep_order if is_array_like(cep_order) else (cep_order, cep_order)
288+
cep_orders = (
289+
cep_order if is_array_like(cep_order) else (cep_order, cep_order)
290+
)
289291
else:
290292
raise ValueError(f"phase {phase} is not supported.")
291-
self.pad = nn.ConstantPad1d(padding, 0)
292293

293294
# Prepare frequency transformation module.
294295
if self.phase == "mixed":
@@ -297,7 +298,7 @@ def __init__(
297298
self.mgc2c.append(
298299
MelGeneralizedCepstrumToMelGeneralizedCepstrum(
299300
filter_order[i],
300-
padding[i],
301+
cep_orders[i],
301302
in_alpha=alpha,
302303
in_gamma=gamma,
303304
n_fft=n_fft,
@@ -318,6 +319,16 @@ def __init__(
318319

319320
self.linear_intpl = LinearInterpolation(frame_period)
320321

322+
self.zerodf = AllZeroDigitalFilter(
323+
sum(cep_orders),
324+
frame_period,
325+
ignore_gain=False,
326+
zeroth_index=cep_orders[1],
327+
mode="efficient",
328+
device=device,
329+
dtype=dtype,
330+
)
331+
321332
cp = mp.taylor(mp.exp, 0, taylor_order)
322333
cp = np.array([float(x) for x in cp])
323334
weights = cp[1:] / cp[:-1]
@@ -341,29 +352,25 @@ def forward(
341352
c_min = self.mgc2c[0](mc_min)
342353
c_max = self.mgc2c[1](mc_max)
343354
c0 = c_min[..., :1] + c_max[..., :1]
344-
c1_min = c_min[..., 1:].flip(-1)
355+
c1_min = c_min[..., 1:]
345356
c0_dummy = torch.zeros_like(c0)
346-
c1_max = c_max[..., 1:]
347-
c = torch.cat([c1_min, c0_dummy, c1_max], dim=-1)
357+
c1_max = c_max[..., 1:].flip(-1)
358+
c = torch.cat([c1_max, c0_dummy, c1_min], dim=-1)
348359
else:
349360
c = self.mgc2c(mc)
350361
c0, c = remove_gain(c, value=0, return_gain=True)
351362
if self.phase == "minimum":
352-
c = c.flip(-1)
353-
elif self.phase == "maximum":
354363
pass
364+
elif self.phase == "maximum":
365+
c = c.flip(-1)
355366
elif self.phase == "zero":
356367
c = mirror(c, half=True)
357368
else:
358369
raise RuntimeError
359370

360-
c = self.linear_intpl(c)
361-
362371
y = x * self.a[0]
363372
for i in range(1, len(self.a)):
364-
x = self.pad(x)
365-
x = x.unfold(-1, c.size(-1), 1)
366-
x = (x * c).sum(-1) * self.weights[i]
373+
x = self.zerodf(x, c) * self.weights[i]
367374
y += x * self.a[i]
368375

369376
if not self.ignore_gain:
@@ -389,28 +396,26 @@ def __init__(
389396
) -> None:
390397
super().__init__()
391398

399+
self.frame_period = frame_period
392400
self.ignore_gain = ignore_gain
393401
self.phase = phase
394402
self.n_fft = n_fft
395403

396-
# Prepare padding module.
397-
taps = ir_length - 1
398404
if self.phase == "minimum":
399-
padding = (taps, 0)
405+
ir_orders = (ir_length - 1, 0)
400406
elif self.phase == "maximum":
401-
padding = (0, taps)
407+
ir_orders = (0, ir_length - 1)
402408
elif self.phase == "zero":
403-
padding = (taps, taps)
409+
ir_orders = (ir_length - 1, ir_length - 1)
404410
elif self.phase == "mixed":
405-
padding = (
411+
ir_orders = (
406412
(ir_length[0] - 1, ir_length[1] - 1)
407413
if is_array_like(ir_length)
408-
else (taps, taps)
414+
else (ir_length - 1, ir_length - 1)
409415
)
410416
else:
411417
raise ValueError(f"phase {phase} is not supported.")
412-
self.pad = nn.ConstantPad1d(padding, 0)
413-
self.padding = padding
418+
self.ir_orders = ir_orders
414419

415420
if self.phase in ("minimum", "maximum"):
416421
self.mgc2ir = MelGeneralizedCepstrumToMelGeneralizedCepstrum(
@@ -444,7 +449,7 @@ def __init__(
444449
self.mgc2c.append(
445450
MelGeneralizedCepstrumToMelGeneralizedCepstrum(
446451
filter_order[i],
447-
padding[i],
452+
ir_orders[i],
448453
in_alpha=alpha,
449454
in_gamma=gamma,
450455
n_fft=n_fft,
@@ -458,7 +463,15 @@ def __init__(
458463
else:
459464
raise ValueError(f"phase {phase} is not supported.")
460465

461-
self.linear_intpl = LinearInterpolation(frame_period)
466+
self.zerodf = AllZeroDigitalFilter(
467+
sum(ir_orders),
468+
frame_period,
469+
ignore_gain=False,
470+
zeroth_index=ir_orders[1],
471+
mode="efficient",
472+
device=device,
473+
dtype=dtype,
474+
)
462475

463476
def forward(
464477
self,
@@ -467,9 +480,13 @@ def forward(
467480
) -> torch.Tensor:
468481
if self.phase == "minimum":
469482
h = self.mgc2ir(mc)
470-
h = h.flip(-1)
483+
if self.ignore_gain:
484+
h = h / h[..., :1]
471485
elif self.phase == "maximum":
472486
h = self.mgc2ir(mc)
487+
if self.ignore_gain:
488+
h = h / h[..., :1]
489+
h = h.flip(-1)
473490
elif self.phase == "zero":
474491
c = self.mgc2c(mc)
475492
c[..., 1:] *= 0.5
@@ -485,25 +502,16 @@ def forward(
485502
c0 = torch.zeros_like(c_min[..., :1])
486503
else:
487504
c0 = c_min[..., :1] + c_max[..., :1]
488-
c = torch.cat([c_min[..., 1:].flip(-1), c0, c_max[..., 1:]], dim=-1)
505+
c = torch.cat([c_max[..., 1:].flip(-1), c0, c_min[..., 1:]], dim=-1)
489506
c = F.pad(c, (0, self.n_fft - c.size(-1)))
490-
c = torch.roll(c, -self.padding[0], dims=-1)
507+
shift = self.ir_orders[1]
508+
c = torch.roll(c, -shift, dims=-1)
491509
h = self.c2ir(c)
492-
h = torch.roll(h, self.padding[0], dims=-1)[..., : sum(self.padding) + 1]
510+
h = torch.roll(h, shift, dims=-1)[..., : sum(self.ir_orders) + 1]
493511
else:
494512
raise RuntimeError
495513

496-
h = self.linear_intpl(h)
497-
498-
if self.ignore_gain:
499-
if self.phase == "minimum":
500-
h = h / h[..., -1:]
501-
elif self.phase == "maximum":
502-
h = h / h[..., :1]
503-
504-
x = self.pad(x)
505-
x = x.unfold(-1, h.size(-1), 1)
506-
y = (x * h).sum(-1)
514+
y = self.zerodf(x, h)
507515
return y
508516

509517

0 commit comments

Comments
 (0)