1616
1717from typing import Any
1818
19+ import mpmath as mp
1920import numpy as np
2021import torch
2122import torch .nn .functional as F
2223from torch import nn
2324
24- from ..utils .private import Lambda , check_size , get_gamma , remove_gain
25+ from ..utils .private import Lambda , check_size , get_gamma , remove_gain , to
2526from .b2mc import MLSADigitalFilterCoefficientsToMelCepstrum
2627from .base import BaseNonFunctionalModule
2728from .c2mpir import CepstrumToMinimumPhaseImpulseResponse
29+ from .frame import Frame
2830from .gnorm import GeneralizedCepstrumGainNormalization
2931from .istft import InverseShortTimeFourierTransform
3032from .linear_intpl import LinearInterpolation
3133from .mc2b import MelCepstrumToMLSADigitalFilterCoefficients
3234from .mgc2mgc import MelGeneralizedCepstrumToMelGeneralizedCepstrum
3335from .mgc2sp import MelGeneralizedCepstrumToSpectrum
36+ from .root_pol import PolynomialToRoots
3437from .stft import ShortTimeFourierTransform
3538
3639
@@ -74,12 +77,15 @@ class PseudoMGLSADigitalFilter(BaseNonFunctionalModule):
7477 phase : ['minimum', 'maximum', 'zero', 'mixed']
7578 The filter type.
7679
77- mode : ['multi-stage', 'single-stage', 'freq-domain']
80+ mode : ['multi-stage', 'single-stage', 'freq-domain', 'pade-approx' ]
7881 'multi-stage' approximates the MLSA filter by cascading FIR filters based on the
7982 Taylor series expansion. 'single-stage' uses an FIR filter with the coefficients
8083 derived from the impulse response converted from the input mel-cepstral
8184 coefficients using FFT. 'freq-domain' performs filtering in the frequency domain
82- rather than the time domain.
85+ rather than the time domain. 'pade-approx' implements the MLSA filter by
86+ cascading all-zero and all-pole filters derived from the factorization. While
87+ this approach is not computationally efficient, it allows for the optimization
88+ of the Pade approximation coefficients.
8389
8490 n_fft : int >= 1
8591 The number of FFT bins used for conversion. Higher values result in increased
@@ -89,12 +95,20 @@ class PseudoMGLSADigitalFilter(BaseNonFunctionalModule):
8995 The order of the Taylor series expansion (valid only if **mode** is
9096 'multi-stage').
9197
98+ pade_order : int >= 3
99+ The order of Pade approximation (valid only if **mode** is 'pade-approx').
100+
92101 cep_order : int >= 0 or tuple[int, int]
93- The order of the linear cepstrum (valid only if **mode** is 'multi-stage').
102+ The order of the linear cepstrum (valid only if **mode** is 'multi-stage' or
103+ 'pade-approx').
94104
95105 ir_length : int >= 1 or tuple[int, int]
96106 The length of the impulse response (valid only if **mode** is 'single-stage').
97107
108+ learnable : bool
109+ If True, the polynomial coefficients used in the approximation are learnable
110+ (valid only if **mode** is 'multi-stage' or 'pade-approx').
111+
98112 device : torch.device or None
99113 The device of this module.
100114
@@ -181,6 +195,16 @@ def flip(x):
181195 phase = phase ,
182196 ** modified_kwargs ,
183197 )
198+ elif mode == "pade-approx" :
199+ self .mglsadf = MultiStageIIRFilter (
200+ flipped_filter_order ,
201+ frame_period ,
202+ alpha = alpha ,
203+ gamma = gamma ,
204+ ignore_gain = ignore_gain ,
205+ phase = phase ,
206+ ** modified_kwargs ,
207+ )
184208 else :
185209 raise ValueError (f"mode { mode } is not supported." )
186210
@@ -238,6 +262,7 @@ def __init__(
238262 taylor_order : int = 20 ,
239263 cep_order : tuple [int , int ] | int = 199 ,
240264 n_fft : int = 512 ,
265+ learnable : bool = False ,
241266 device : torch .device | None = None ,
242267 dtype : torch .dtype | None = None ,
243268 ) -> None :
@@ -248,7 +273,6 @@ def __init__(
248273
249274 self .ignore_gain = ignore_gain
250275 self .phase = phase
251- self .taylor_order = taylor_order
252276
253277 if alpha == 0 and gamma == 0 :
254278 cep_order = filter_order
@@ -294,6 +318,19 @@ def __init__(
294318
295319 self .linear_intpl = LinearInterpolation (frame_period )
296320
321+ cp = mp .taylor (mp .exp , 0 , taylor_order )
322+ cp = np .array ([float (x ) for x in cp ])
323+ weights = cp [1 :] / cp [:- 1 ]
324+ weights = np .insert (weights , 0 , 1 )
325+ self .register_buffer ("weights" , to (weights , device = device , dtype = dtype ))
326+
327+ a = np .ones (taylor_order + 1 )
328+ a = to (a , device = device , dtype = dtype )
329+ if learnable :
330+ self .a = nn .Parameter (a )
331+ else :
332+ self .register_buffer ("a" , a )
333+
297334 def forward (
298335 self ,
299336 x : torch .Tensor ,
@@ -322,12 +359,12 @@ def forward(
322359
323360 c = self .linear_intpl (c )
324361
325- y = x . clone ()
326- for a in range (1 , self .taylor_order + 1 ):
362+ y = x * self . a [ 0 ]
363+ for i in range (1 , len ( self .a ) ):
327364 x = self .pad (x )
328365 x = x .unfold (- 1 , c .size (- 1 ), 1 )
329- x = (x * c ).sum (- 1 ) / a
330- y += x
366+ x = (x * c ).sum (- 1 ) * self . weights [ i ]
367+ y += x * self . a [ i ]
331368
332369 if not self .ignore_gain :
333370 K = torch .exp (self .linear_intpl (c0 ))
@@ -586,3 +623,193 @@ def forward(
586623 Y = H * X
587624 y = self .istft (Y , out_length = x .size (- 1 ))
588625 return y
626+
627+
628+ class MultiStageIIRFilter (nn .Module ):
629+ def __init__ (
630+ self ,
631+ filter_order : tuple [int , int ] | int ,
632+ frame_period : int ,
633+ * ,
634+ alpha : float = 0 ,
635+ gamma : float = 0 ,
636+ ignore_gain : bool = False ,
637+ phase : str = "minimum" ,
638+ pade_order : int = 5 ,
639+ cep_order : tuple [int , int ] | int = 199 ,
640+ n_fft : int = 512 ,
641+ chunk_length : int | None = None ,
642+ warmup_length : int | None = None ,
643+ learnable : bool = False ,
644+ per_stage_pade_coefficients : bool = False ,
645+ device : torch .device | None = None ,
646+ dtype : torch .dtype | None = None ,
647+ ) -> None :
648+ super ().__init__ ()
649+
650+ if phase != "minimum" or is_array_like (filter_order ):
651+ raise ValueError ("Only minimum-phase filter is supported." )
652+
653+ self .ignore_gain = ignore_gain
654+
655+ self .mgc2c = MelGeneralizedCepstrumToMelGeneralizedCepstrum (
656+ filter_order ,
657+ cep_order ,
658+ in_alpha = alpha ,
659+ in_gamma = gamma ,
660+ n_fft = n_fft ,
661+ device = device ,
662+ dtype = dtype ,
663+ )
664+ self .linear_intpl = LinearInterpolation (frame_period )
665+ self .root_pol = PolynomialToRoots (pade_order , device = device , dtype = dtype )
666+
667+ from torchlpc import sample_wise_lpc
668+
669+ self .sample_wise_lpc = sample_wise_lpc
670+
671+ if chunk_length is None :
672+ self .chuking = False
673+ else :
674+ self .chuking = True
675+ self .warmup_length = (
676+ warmup_length if warmup_length is not None else cep_order
677+ )
678+ if chunk_length <= 0 :
679+ raise ValueError ("chunk_length must be positive." )
680+ if self .warmup_length < 0 :
681+ raise ValueError ("warmup_length must be non-negative." )
682+ frame_period = chunk_length - self .warmup_length
683+ self .frame_x = Frame (chunk_length , frame_period , center = False )
684+ self .frame_c = Frame (
685+ cep_order * chunk_length , cep_order * frame_period , center = False
686+ )
687+
688+ cr = mp .taylor (mp .exp , 0 , pade_order * 2 )
689+ cp , cq = mp .pade (cr , pade_order , pade_order )
690+ cp = np .array ([float (x ) for x in cp ])
691+ weights = cp [1 :] / cp [:- 1 ]
692+ weights = np .insert (weights , 0 , 1 )
693+ self .register_buffer ("weights" , to (weights , device = device , dtype = dtype ))
694+
695+ if pade_order == 3 :
696+ a1 = np .linspace (1.0 , 0.4 , pade_order + 1 )
697+ elif pade_order == 4 :
698+ a1 = np .linspace (1.0 , 0.6 , pade_order + 1 )
699+ elif 5 <= pade_order <= 14 :
700+ a1 = np .ones (pade_order + 1 )
701+ else :
702+ raise ValueError ("pade_order must be in [3, 14]." )
703+
704+ if learnable and per_stage_pade_coefficients :
705+ a2 = a1
706+ a1 = np .ones (pade_order + 1 )
707+ a1 = to (a1 , device = device , dtype = dtype )
708+ a2 = to (a2 , device = device , dtype = dtype )
709+ self .a1 = nn .Parameter (a1 )
710+ self .a2 = nn .Parameter (a2 )
711+ else :
712+ a1 = to (a1 , device = device , dtype = dtype )
713+ if learnable :
714+ self .a1 = nn .Parameter (a1 )
715+ else :
716+ self .register_buffer ("a1" , a1 )
717+ self .a2 = self .a1
718+
719+ def forward (
720+ self , x : torch .Tensor , mc : torch .Tensor , return_roots : bool = False
721+ ) -> torch .Tensor :
722+ if x .dim () == 1 :
723+ x = x .unsqueeze (0 )
724+ mc = mc .unsqueeze (0 )
725+ unsqueezed = True
726+ else :
727+ unsqueezed = False
728+
729+ if x .dim () != 2 or mc .dim () != 3 :
730+ raise ValueError ("x and mc must be 2-D and 3-D tensors, respectively." )
731+
732+ c = self .mgc2c (mc )
733+ c0 , c1 = torch .split (c , [1 , c .size (- 1 ) - 1 ], dim = - 1 )
734+ c_b = self .linear_intpl (c1 .flip (- 1 ))
735+ c_a = self .linear_intpl (c1 )
736+
737+ T = x .size (- 1 )
738+ B , _ , M = c_a .size ()
739+
740+ a1 = torch .clip (self .a1 , min = 1e-1 , max = 1e1 )
741+ a1 [0 ] = 1.0
742+ a2 = torch .clip (self .a2 , min = 1e-1 , max = 1e1 )
743+ a2 [0 ] = 1.0
744+
745+ c_b2 , c_b1 = torch .split (c_b , [c_b .size (- 1 ) - 1 , 1 ], dim = - 1 )
746+ c_b1 = c_b1 .squeeze (- 1 )
747+
748+ # Numerator, 1st stage:
749+ y = x * a1 [0 ]
750+ for i in range (1 , len (a1 )):
751+ x = F .pad (x [..., :- 1 ], (1 , 0 ))
752+ x = x * c_b1 * self .weights [i ]
753+ y += x * a1 [i ]
754+
755+ # Numerator, 2nd stage:
756+ x = y
757+ y = x * a2 [0 ]
758+ for i in range (1 , len (a2 )):
759+ x = F .pad (x , (M , 0 ))
760+ x = x .unfold (- 1 , M + 1 , 1 )
761+ x = (x [..., :- 2 ] * c_b2 ).sum (- 1 ) * self .weights [i ]
762+ y += x * a2 [i ]
763+
764+ if self .chuking :
765+ y = F .pad (y , (self .warmup_length , 0 ))
766+ y = self .frame_x (y )
767+ y = y .reshape (- 1 , y .size (- 1 ))
768+
769+ c_a = c_a .reshape (B , - 1 )
770+ c_a = F .pad (c_a , (M * self .warmup_length , 0 ))
771+ c_a = self .frame_c (c_a )
772+ c_a = c_a .reshape (y .size (0 ), y .size (1 ), M )
773+
774+ c_a1 , c_a2 = torch .split (c_a , [1 , c_a .size (- 1 ) - 1 ], dim = - 1 )
775+ c_a2 = F .pad (c_a2 , (1 , 0 ))
776+
777+ def compute_roots (a : torch .Tensor ) -> torch .Tensor :
778+ pade_coefficients = torch .cumprod (self .weights , 0 ) * a
779+ roots = self .root_pol (pade_coefficients .flip (0 ).double ())
780+ roots = roots .to (
781+ torch .complex64 if a .dtype == torch .float32 else torch .complex128
782+ )
783+ return roots
784+
785+ roots1 = compute_roots (a1 )
786+ roots2 = compute_roots (a2 )
787+ roots = torch .stack ([roots1 , roots2 ], dim = 0 )
788+
789+ # Denominator, 1st stage:
790+ y = y .to (roots .dtype )
791+ p1 = torch .reciprocal (roots1 )
792+ for i in range (len (p1 )):
793+ y = self .sample_wise_lpc (y , (p1 [i ] * c_a1 ))
794+
795+ # Denominator, 2nd stage:
796+ p2 = torch .reciprocal (roots2 )
797+ for i in range (len (p2 )):
798+ y = self .sample_wise_lpc (y , (p2 [i ] * c_a2 ))
799+ y = y .real
800+
801+ if self .chuking :
802+ y = y [..., self .warmup_length :]
803+ y = y .reshape (B , - 1 )
804+ y = y [..., :T ]
805+
806+ if not self .ignore_gain :
807+ K = torch .exp (self .linear_intpl (c0 ))
808+ y = y * K .squeeze (- 1 )
809+
810+ if unsqueezed :
811+ y = y .squeeze (0 )
812+
813+ if return_roots :
814+ return y , roots
815+ return y
0 commit comments