Skip to content

Commit 8c55beb

Browse files
authored
Merge pull request #156 from sp-nitech/excite
Fix excite and remove hilbert2
2 parents 32e4f7c + 37b02fc commit 8c55beb

8 files changed

Lines changed: 19 additions & 245 deletions

File tree

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ tool-clean:
8585

8686
update: tool
8787
. .venv/bin/activate && python -m pip install --upgrade pip
88-
@for package in $$(./tools/taplo/taplo get -f pyproject.toml project.optional-dependencies.dev); do \
89-
. .venv/bin/activate && python -m pip install --upgrade $$package; \
88+
@./tools/taplo/taplo get -f pyproject.toml project.optional-dependencies.dev | while read -r package; do \
89+
. .venv/bin/activate && python -m pip install --upgrade "$$package"; \
9090
done
9191

9292
clean: dist-clean doc-clean test-clean tool-clean

diffsptk/functional.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,33 +1138,6 @@ def hilbert(x: Tensor, fft_length: int | None = None, dim: int = -1) -> Tensor:
11381138
return nn.HilbertTransform._func(x, fft_length=fft_length, dim=dim)
11391139

11401140

1141-
def hilbert2(
1142-
x: Tensor,
1143-
fft_length: ArrayLike[int] | int | None = None,
1144-
dim: ArrayLike[int] = (-2, -1),
1145-
) -> Tensor:
1146-
"""Compute the analytic signal using the Hilbert transform.
1147-
1148-
Parameters
1149-
----------
1150-
x : Tensor [shape=(..., T1, T2, ...)]
1151-
The input signal.
1152-
1153-
fft_length : int, list[int], or None
1154-
The number of FFT bins. If None, set to (:math:`T1`, :math:`T2`).
1155-
1156-
dim : list[int]
1157-
The dimensions along which to take the Hilbert transform.
1158-
1159-
Returns
1160-
-------
1161-
out : Tensor [shape=(..., T1, T2, ...)]
1162-
The analytic signal.
1163-
1164-
"""
1165-
return nn.TwoDimensionalHilbertTransform._func(x, fft_length=fft_length, dim=dim)
1166-
1167-
11681141
def histogram(
11691142
x: Tensor,
11701143
n_bin: int = 10,

diffsptk/modules/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
from .griffin import GriffinLim
6363
from .grpdelay import GroupDelay
6464
from .hilbert import HilbertTransform
65-
from .hilbert2 import TwoDimensionalHilbertTransform
6665
from .histogram import Histogram
6766
from .ialaw import ALawExpansion
6867
from .ica import IndependentComponentAnalysis

diffsptk/modules/excite.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,16 @@ def _forward(
148148
s = torch.cumsum(q.double(), dim=-1)
149149
bias, _ = torch.cummax(s * ~mask, dim=-1)
150150
phase = (s - bias).to(p.dtype)
151-
if isinstance(init_phase, str):
152-
if init_phase == "zeros":
153-
pass
154-
elif init_phase == "random":
155-
phase += torch.rand_like(p[..., :1])
156-
else:
157-
raise ValueError(f"init_phase {init_phase} is not supported.")
151+
if not isinstance(init_phase, str):
152+
shift = init_phase / TAU
153+
elif init_phase == "zeros":
154+
shift = 0.0
155+
elif init_phase == "random":
156+
shift = torch.rand_like(p[..., :1])
158157
else:
159-
phase += init_phase / TAU
158+
raise ValueError(f"init_phase {init_phase} is not supported.")
159+
if isinstance(shift, torch.Tensor) or shift != 0.0:
160+
phase += shift
160161

161162
# Generate excitation signal using phase.
162163
if polarity == "auto":
@@ -170,15 +171,19 @@ def _forward(
170171

171172
def get_pulse_pos(p):
172173
r = torch.ceil(p)
173-
r = F.pad(r, (1, 0))
174174
return torch.ge(torch.diff(r), 1)
175175

176+
if isinstance(shift, float):
177+
padded_phase = F.pad(phase, (1, 0), value=shift)
178+
else:
179+
padded_phase = torch.cat([shift, phase], dim=-1)
180+
176181
if unipolar:
177-
pulse_pos = get_pulse_pos(phase)
182+
pulse_pos = get_pulse_pos(padded_phase)
178183
e[pulse_pos] = torch.sqrt(p[pulse_pos])
179184
else:
180-
pulse_pos1 = get_pulse_pos(phase)
181-
pulse_pos2 = get_pulse_pos(0.5 * phase)
185+
pulse_pos1 = get_pulse_pos(padded_phase)
186+
pulse_pos2 = get_pulse_pos(0.5 * padded_phase)
182187
e[pulse_pos1] = torch.sqrt(p[pulse_pos1])
183188
e[pulse_pos1 & ~pulse_pos2] *= -1
184189
elif voiced_region == "sinusoidal":

diffsptk/modules/hilbert2.py

Lines changed: 0 additions & 133 deletions
This file was deleted.

docs/source/modules/hilbert.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,3 @@ hilbert
77
:members:
88

99
.. autofunction:: diffsptk.functional.hilbert
10-
11-
.. seealso::
12-
13-
:ref:`hilbert2`

docs/source/modules/hilbert2.rst

Lines changed: 0 additions & 13 deletions
This file was deleted.

tests/test_hilbert2.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)