|
23 | 23 | warnings.warn("Custom extension not loaded. Falling back to Numba implementation.") |
24 | 24 |
|
25 | 25 | from .core import LPC |
26 | | - |
27 | | -# from .parallel_scan import WARPSIZE |
28 | 26 | from .recurrence import Recurrence |
29 | 27 |
|
30 | | -__all__ = ["sample_wise_lpc"] |
| 28 | +__all__ = ["sample_wise_lpc", "linear_recurrence"] |
| 29 | + |
| 30 | + |
| 31 | +def linear_recurrence( |
| 32 | + a: torch.Tensor, x: torch.Tensor, zi: Optional[torch.Tensor] = None |
| 33 | +) -> torch.Tensor: |
| 34 | + """Compute linear recurrence using the recurrence relation: |
| 35 | + y[n] = x[n] + a[n] * (x[n-1] + a[n-1] * (x[n-2] + a[n-2] * (...(x[0] + a[0] * zi)...))) |
| 36 | +
|
| 37 | + Args: |
| 38 | + a (torch.Tensor): Coefficients of the recurrence relation. |
| 39 | + x (torch.Tensor): Input signal. |
| 40 | + zi (torch.Tensor, optional): Initial conditions. Defaults to zero if not provided. |
| 41 | +
|
| 42 | + Shape: |
| 43 | + - a: :math:`(B, T)` |
| 44 | + - x: :math:`(B, T)` |
| 45 | + - zi: :math:`(B,)` |
| 46 | +
|
| 47 | + Returns: |
| 48 | + Output signal with the same shape as x. |
| 49 | + """ |
| 50 | + assert a.shape == x.shape |
| 51 | + assert a.ndim == 2 |
| 52 | + assert x.ndim == 2 |
| 53 | + B, _ = a.shape |
| 54 | + |
| 55 | + if zi is None: |
| 56 | + zi = a.new_zeros(B) |
| 57 | + else: |
| 58 | + assert zi.shape == (B,) |
| 59 | + |
| 60 | + return Recurrence.apply(a, x, zi) # type: ignore |
31 | 61 |
|
32 | 62 |
|
33 | 63 | def sample_wise_lpc( |
@@ -64,10 +94,8 @@ def sample_wise_lpc( |
64 | 94 | else: |
65 | 95 | assert zi.shape == (B, order) |
66 | 96 |
|
67 | | - # if order == 1 and x.is_cuda and B * WARPSIZE < T: |
68 | | - # return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1)) |
69 | 97 | if order == 1: |
70 | | - y = Recurrence.apply(-a.squeeze(2), x, zi.squeeze(1)) |
| 98 | + y = linear_recurrence(-a.squeeze(2), x, zi.squeeze(1)) |
71 | 99 | else: |
72 | 100 | y = LPC.apply(x, a, zi) |
73 | 101 |
|
|
0 commit comments