Skip to content

Commit 6168b98

Browse files
committed
fix bug when init_phase is not zero
1 parent 32e4f7c commit 6168b98

1 file changed

Lines changed: 17 additions & 12 deletions

File tree

diffsptk/modules/excite.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,16 @@ def _forward(
148148
s = torch.cumsum(q.double(), dim=-1)
149149
bias, _ = torch.cummax(s * ~mask, dim=-1)
150150
phase = (s - bias).to(p.dtype)
151-
if isinstance(init_phase, str):
152-
if init_phase == "zeros":
153-
pass
154-
elif init_phase == "random":
155-
phase += torch.rand_like(p[..., :1])
156-
else:
157-
raise ValueError(f"init_phase {init_phase} is not supported.")
151+
if not isinstance(init_phase, str):
152+
shift = init_phase / TAU
153+
elif init_phase == "zeros":
154+
shift = 0.0
155+
elif init_phase == "random":
156+
shift = torch.rand_like(p[..., :1])
158157
else:
159-
phase += init_phase / TAU
158+
raise ValueError(f"init_phase {init_phase} is not supported.")
159+
if isinstance(shift, torch.Tensor) or shift != 0.0:
160+
phase += shift
160161

161162
# Generate excitation signal using phase.
162163
if polarity == "auto":
@@ -170,15 +171,19 @@ def _forward(
170171

171172
def get_pulse_pos(p):
172173
r = torch.ceil(p)
173-
r = F.pad(r, (1, 0))
174174
return torch.ge(torch.diff(r), 1)
175175

176+
if isinstance(shift, float):
177+
padded_phase = F.pad(phase, (1, 0), value=shift)
178+
else:
179+
padded_phase = torch.cat([shift, phase], dim=-1)
180+
176181
if unipolar:
177-
pulse_pos = get_pulse_pos(phase)
182+
pulse_pos = get_pulse_pos(padded_phase)
178183
e[pulse_pos] = torch.sqrt(p[pulse_pos])
179184
else:
180-
pulse_pos1 = get_pulse_pos(phase)
181-
pulse_pos2 = get_pulse_pos(0.5 * phase)
185+
pulse_pos1 = get_pulse_pos(padded_phase)
186+
pulse_pos2 = get_pulse_pos(0.5 * padded_phase)
182187
e[pulse_pos1] = torch.sqrt(p[pulse_pos1])
183188
e[pulse_pos1 & ~pulse_pos2] *= -1
184189
elif voiced_region == "sinusoidal":

0 commit comments

Comments
 (0)