@@ -148,15 +148,16 @@ def _forward(
148148 s = torch .cumsum (q .double (), dim = - 1 )
149149 bias , _ = torch .cummax (s * ~ mask , dim = - 1 )
150150 phase = (s - bias ).to (p .dtype )
151- if isinstance (init_phase , str ):
152- if init_phase == "zeros" :
153- pass
154- elif init_phase == "random" :
155- phase += torch .rand_like (p [..., :1 ])
156- else :
157- raise ValueError (f"init_phase { init_phase } is not supported." )
151+ if not isinstance (init_phase , str ):
152+ shift = init_phase / TAU
153+ elif init_phase == "zeros" :
154+ shift = 0.0
155+ elif init_phase == "random" :
156+ shift = torch .rand_like (p [..., :1 ])
158157 else :
159- phase += init_phase / TAU
158+ raise ValueError (f"init_phase { init_phase } is not supported." )
159+ if isinstance (shift , torch .Tensor ) or shift != 0.0 :
160+ phase += shift
160161
161162 # Generate excitation signal using phase.
162163 if polarity == "auto" :
@@ -170,15 +171,19 @@ def _forward(
170171
171172 def get_pulse_pos (p ):
172173 r = torch .ceil (p )
173- r = F .pad (r , (1 , 0 ))
174174 return torch .ge (torch .diff (r ), 1 )
175175
176+ if isinstance (shift , float ):
177+ padded_phase = F .pad (phase , (1 , 0 ), value = shift )
178+ else :
179+ padded_phase = torch .cat ([shift , phase ], dim = - 1 )
180+
176181 if unipolar :
177- pulse_pos = get_pulse_pos (phase )
182+ pulse_pos = get_pulse_pos (padded_phase )
178183 e [pulse_pos ] = torch .sqrt (p [pulse_pos ])
179184 else :
180- pulse_pos1 = get_pulse_pos (phase )
181- pulse_pos2 = get_pulse_pos (0.5 * phase )
185+ pulse_pos1 = get_pulse_pos (padded_phase )
186+ pulse_pos2 = get_pulse_pos (0.5 * padded_phase )
182187 e [pulse_pos1 ] = torch .sqrt (p [pulse_pos1 ])
183188 e [pulse_pos1 & ~ pulse_pos2 ] *= - 1
184189 elif voiced_region == "sinusoidal" :
0 commit comments