3535from .mgc2sp import MelGeneralizedCepstrumToSpectrum
3636from .root_pol import PolynomialToRoots
3737from .stft import ShortTimeFourierTransform
38+ from .zerodf import AllZeroDigitalFilter
3839
3940
4041def is_array_like (x : Any ) -> bool :
@@ -277,18 +278,18 @@ def __init__(
277278 if alpha == 0 and gamma == 0 :
278279 cep_order = filter_order
279280
280- # Prepare padding module.
281281 if self .phase == "minimum" :
282- padding = (cep_order , 0 )
282+ cep_orders = (cep_order , 0 )
283283 elif self .phase == "maximum" :
284- padding = (0 , cep_order )
284+ cep_orders = (0 , cep_order )
285285 elif self .phase == "zero" :
286- padding = (cep_order , cep_order )
286+ cep_orders = (cep_order , cep_order )
287287 elif self .phase == "mixed" :
288- padding = cep_order if is_array_like (cep_order ) else (cep_order , cep_order )
288+ cep_orders = (
289+ cep_order if is_array_like (cep_order ) else (cep_order , cep_order )
290+ )
289291 else :
290292 raise ValueError (f"phase { phase } is not supported." )
291- self .pad = nn .ConstantPad1d (padding , 0 )
292293
293294 # Prepare frequency transformation module.
294295 if self .phase == "mixed" :
@@ -297,7 +298,7 @@ def __init__(
297298 self .mgc2c .append (
298299 MelGeneralizedCepstrumToMelGeneralizedCepstrum (
299300 filter_order [i ],
300- padding [i ],
301+ cep_orders [i ],
301302 in_alpha = alpha ,
302303 in_gamma = gamma ,
303304 n_fft = n_fft ,
@@ -318,6 +319,16 @@ def __init__(
318319
319320 self .linear_intpl = LinearInterpolation (frame_period )
320321
322+ self .zerodf = AllZeroDigitalFilter (
323+ sum (cep_orders ),
324+ frame_period ,
325+ ignore_gain = False ,
326+ zeroth_index = cep_orders [1 ],
327+ mode = "efficient" ,
328+ device = device ,
329+ dtype = dtype ,
330+ )
331+
321332 cp = mp .taylor (mp .exp , 0 , taylor_order )
322333 cp = np .array ([float (x ) for x in cp ])
323334 weights = cp [1 :] / cp [:- 1 ]
@@ -341,29 +352,25 @@ def forward(
341352 c_min = self .mgc2c [0 ](mc_min )
342353 c_max = self .mgc2c [1 ](mc_max )
343354 c0 = c_min [..., :1 ] + c_max [..., :1 ]
344- c1_min = c_min [..., 1 :]. flip ( - 1 )
355+ c1_min = c_min [..., 1 :]
345356 c0_dummy = torch .zeros_like (c0 )
346- c1_max = c_max [..., 1 :]
347- c = torch .cat ([c1_min , c0_dummy , c1_max ], dim = - 1 )
357+ c1_max = c_max [..., 1 :]. flip ( - 1 )
358+ c = torch .cat ([c1_max , c0_dummy , c1_min ], dim = - 1 )
348359 else :
349360 c = self .mgc2c (mc )
350361 c0 , c = remove_gain (c , value = 0 , return_gain = True )
351362 if self .phase == "minimum" :
352- c = c .flip (- 1 )
353- elif self .phase == "maximum" :
354363 pass
364+ elif self .phase == "maximum" :
365+ c = c .flip (- 1 )
355366 elif self .phase == "zero" :
356367 c = mirror (c , half = True )
357368 else :
358369 raise RuntimeError
359370
360- c = self .linear_intpl (c )
361-
362371 y = x * self .a [0 ]
363372 for i in range (1 , len (self .a )):
364- x = self .pad (x )
365- x = x .unfold (- 1 , c .size (- 1 ), 1 )
366- x = (x * c ).sum (- 1 ) * self .weights [i ]
373+ x = self .zerodf (x , c ) * self .weights [i ]
367374 y += x * self .a [i ]
368375
369376 if not self .ignore_gain :
@@ -389,28 +396,26 @@ def __init__(
389396 ) -> None :
390397 super ().__init__ ()
391398
399+ self .frame_period = frame_period
392400 self .ignore_gain = ignore_gain
393401 self .phase = phase
394402 self .n_fft = n_fft
395403
396- # Prepare padding module.
397- taps = ir_length - 1
398404 if self .phase == "minimum" :
399- padding = (taps , 0 )
405+ ir_orders = (ir_length - 1 , 0 )
400406 elif self .phase == "maximum" :
401- padding = (0 , taps )
407+ ir_orders = (0 , ir_length - 1 )
402408 elif self .phase == "zero" :
403- padding = (taps , taps )
409+ ir_orders = (ir_length - 1 , ir_length - 1 )
404410 elif self .phase == "mixed" :
405- padding = (
411+ ir_orders = (
406412 (ir_length [0 ] - 1 , ir_length [1 ] - 1 )
407413 if is_array_like (ir_length )
408- else (taps , taps )
414+ else (ir_length - 1 , ir_length - 1 )
409415 )
410416 else :
411417 raise ValueError (f"phase { phase } is not supported." )
412- self .pad = nn .ConstantPad1d (padding , 0 )
413- self .padding = padding
418+ self .ir_orders = ir_orders
414419
415420 if self .phase in ("minimum" , "maximum" ):
416421 self .mgc2ir = MelGeneralizedCepstrumToMelGeneralizedCepstrum (
@@ -444,7 +449,7 @@ def __init__(
444449 self .mgc2c .append (
445450 MelGeneralizedCepstrumToMelGeneralizedCepstrum (
446451 filter_order [i ],
447- padding [i ],
452+ ir_orders [i ],
448453 in_alpha = alpha ,
449454 in_gamma = gamma ,
450455 n_fft = n_fft ,
@@ -458,7 +463,15 @@ def __init__(
458463 else :
459464 raise ValueError (f"phase { phase } is not supported." )
460465
461- self .linear_intpl = LinearInterpolation (frame_period )
466+ self .zerodf = AllZeroDigitalFilter (
467+ sum (ir_orders ),
468+ frame_period ,
469+ ignore_gain = False ,
470+ zeroth_index = ir_orders [1 ],
471+ mode = "efficient" ,
472+ device = device ,
473+ dtype = dtype ,
474+ )
462475
463476 def forward (
464477 self ,
@@ -467,9 +480,13 @@ def forward(
467480 ) -> torch .Tensor :
468481 if self .phase == "minimum" :
469482 h = self .mgc2ir (mc )
470- h = h .flip (- 1 )
483+ if self .ignore_gain :
484+ h = h / h [..., :1 ]
471485 elif self .phase == "maximum" :
472486 h = self .mgc2ir (mc )
487+ if self .ignore_gain :
488+ h = h / h [..., :1 ]
489+ h = h .flip (- 1 )
473490 elif self .phase == "zero" :
474491 c = self .mgc2c (mc )
475492 c [..., 1 :] *= 0.5
@@ -485,25 +502,16 @@ def forward(
485502 c0 = torch .zeros_like (c_min [..., :1 ])
486503 else :
487504 c0 = c_min [..., :1 ] + c_max [..., :1 ]
488- c = torch .cat ([c_min [..., 1 :].flip (- 1 ), c0 , c_max [..., 1 :]], dim = - 1 )
505+ c = torch .cat ([c_max [..., 1 :].flip (- 1 ), c0 , c_min [..., 1 :]], dim = - 1 )
489506 c = F .pad (c , (0 , self .n_fft - c .size (- 1 )))
490- c = torch .roll (c , - self .padding [0 ], dims = - 1 )
507+ shift = self .ir_orders [1 ]
508+ c = torch .roll (c , - shift , dims = - 1 )
491509 h = self .c2ir (c )
492- h = torch .roll (h , self . padding [ 0 ] , dims = - 1 )[..., : sum (self .padding ) + 1 ]
510+ h = torch .roll (h , shift , dims = - 1 )[..., : sum (self .ir_orders ) + 1 ]
493511 else :
494512 raise RuntimeError
495513
496- h = self .linear_intpl (h )
497-
498- if self .ignore_gain :
499- if self .phase == "minimum" :
500- h = h / h [..., - 1 :]
501- elif self .phase == "maximum" :
502- h = h / h [..., :1 ]
503-
504- x = self .pad (x )
505- x = x .unfold (- 1 , h .size (- 1 ), 1 )
506- y = (x * h ).sum (- 1 )
514+ y = self .zerodf (x , h )
507515 return y
508516
509517
0 commit comments