Skip to content

Commit 1bfde4a

Browse files
committed
feat: expose linear recurrence function and update __all__ exports
1 parent 4488dc6 commit 1bfde4a

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

torchlpc/__init__.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,41 @@
2323
warnings.warn("Custom extension not loaded. Falling back to Numba implementation.")
2424

2525
from .core import LPC
26-
27-
# from .parallel_scan import WARPSIZE
2826
from .recurrence import Recurrence
2927

30-
__all__ = ["sample_wise_lpc"]
28+
__all__ = ["sample_wise_lpc", "linear_recurrence"]
29+
30+
31+
def linear_recurrence(
32+
a: torch.Tensor, x: torch.Tensor, zi: Optional[torch.Tensor] = None
33+
) -> torch.Tensor:
34+
"""Compute linear recurrence using the recurrence relation:
35+
y[n] = x[n] + a[n] * (x[n-1] + a[n-1] * (x[n-2] + a[n-2] * (...(x[0] + a[0] * zi)...)))
36+
37+
Args:
38+
a (torch.Tensor): Coefficients of the recurrence relation.
39+
x (torch.Tensor): Input signal.
40+
zi (torch.Tensor, optional): Initial conditions. Defaults to zero if not provided.
41+
42+
Shape:
43+
- a: :math:`(B, T)`
44+
- x: :math:`(B, T)`
45+
- zi: :math:`(B,)`
46+
47+
Returns:
48+
Output signal with the same shape as x.
49+
"""
50+
assert a.shape == x.shape
51+
assert a.ndim == 2
52+
assert x.ndim == 2
53+
B, _ = a.shape
54+
55+
if zi is None:
56+
zi = a.new_zeros(B)
57+
else:
58+
assert zi.shape == (B,)
59+
60+
return Recurrence.apply(a, x, zi) # type: ignore
3161

3262

3363
def sample_wise_lpc(
@@ -64,10 +94,8 @@ def sample_wise_lpc(
6494
else:
6595
assert zi.shape == (B, order)
6696

67-
# if order == 1 and x.is_cuda and B * WARPSIZE < T:
68-
# return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1))
6997
if order == 1:
70-
y = Recurrence.apply(-a.squeeze(2), x, zi.squeeze(1))
98+
y = linear_recurrence(-a.squeeze(2), x, zi.squeeze(1))
7199
else:
72100
y = LPC.apply(x, a, zi)
73101

0 commit comments

Comments
 (0)