1414# limitations under the License. #
1515# ------------------------------------------------------------------------ #
1616
17+ import numpy as np
1718import torch
18- import torch .nn .functional as F
1919
2020from ..typing import Precomputed
21- from ..utils .private import check_size , filter_values , remove_gain
21+ from ..utils .private import check_size , filter_values , remove_gain , to
2222from .base import BaseFunctionalModule
2323
2424
@@ -31,6 +31,9 @@ class ReverseLevinsonDurbin(BaseFunctionalModule):
3131 lpc_order : int >= 0
3232 The order of the LPC coefficients, :math:`M`.
3333
34+ n_fft : int >> M
35+ The number of FFT bins. Accurate conversion requires a large value.
36+
3437 device : torch.device or None
3538 The device of this module.
3639
@@ -42,6 +45,7 @@ class ReverseLevinsonDurbin(BaseFunctionalModule):
4245 def __init__ (
4346 self ,
4447 lpc_order : int ,
48+ n_fft : int = 1024 ,
4549 device : torch .device | None = None ,
4650 dtype : torch .dtype | None = None ,
4751 ) -> None :
@@ -50,7 +54,7 @@ def __init__(
5054 self .in_dim = lpc_order + 1
5155
5256 _ , _ , tensors = self ._precompute (** filter_values (locals ()))
53- self .register_buffer ("eye " , tensors [0 ])
57+ self .register_buffer ("phase_factors " , tensors [0 ])
5458
5559 def forward (self , a : torch .Tensor ) -> torch .Tensor :
5660 """Solve a Yule-Walker linear system given the LPC coefficients.
@@ -95,39 +99,30 @@ def _takes_input_size() -> bool:
9599 return True
96100
97101 @staticmethod
98- def _check (lpc_order : int ) -> None :
102+ def _check (lpc_order : int , n_fft : int ) -> None :
99103 if lpc_order < 0 :
100104 raise ValueError ("lpc_order must be non-negative." )
105+ if n_fft <= lpc_order + 1 :
106+ raise ValueError ("n_fft must be much larger than lpc_order." )
101107
102108 @staticmethod
103109 def _precompute (
104110 lpc_order : int ,
111+ n_fft : int ,
105112 device : torch .device | None ,
106113 dtype : torch .dtype | None ,
107114 ) -> Precomputed :
108- ReverseLevinsonDurbin ._check (lpc_order )
109- eye = torch .eye (lpc_order + 1 , device = device , dtype = dtype )
110- return None , None , (eye ,)
115+ ReverseLevinsonDurbin ._check (lpc_order , n_fft )
116+ n_freq = n_fft // 2 + 1
117+ omega = torch .linspace (0 , np .pi , n_freq , device = device , dtype = torch .double )
118+ m = torch .arange (lpc_order + 1 , device = device , dtype = torch .double )
119+ phase_factors = torch .exp (- 1j * omega * m .unsqueeze (- 1 ))
120+ return None , None , (to (phase_factors , dtype = dtype ),)
111121
112122 @staticmethod
113- def _forward (a : torch .Tensor , eye : torch .Tensor ) -> torch .Tensor :
123+ def _forward (a : torch .Tensor , phase_factors : torch .Tensor ) -> torch .Tensor :
114124 M = a .size (- 1 ) - 1
115125 K , a = remove_gain (a , return_gain = True )
116-
117- U = [a .flip (- 1 )]
118- E = [K ** 2 ]
119- for m in range (M ):
120- u0 = U [- 1 ][..., :1 ]
121- u1 = U [- 1 ][..., 1 : M - m ]
122- t = 1 / (1 - u0 ** 2 )
123- u = (u1 - u0 * u1 .flip (- 1 )) * t
124- u = F .pad (u , (0 , m + 2 ))
125- e = E [- 1 ] * t
126- U .append (u )
127- E .append (e )
128- U = torch .stack (U [::- 1 ], dim = - 1 )
129- E = torch .stack (E [::- 1 ], dim = - 1 )
130-
131- V = torch .linalg .solve_triangular (U , eye , upper = True , unitriangular = True )
132- r = torch .matmul (V [..., :1 ].transpose (- 2 , - 1 ) * E , V ).squeeze (- 2 )
126+ A = torch .sum (a .unsqueeze (- 1 ) * phase_factors , dim = - 2 )
127+ r = torch .fft .irfft ((K / A .abs ()) ** 2 )[..., : M + 1 ]
133128 return r
0 commit comments