Skip to content

Commit 592b0bb

Browse files
authored
Add type hints to _binom and _b_basis_jvp functions
1 parent d967890 commit 592b0bb

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

uncertaintyx/b/jax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from jax.scipy.special import gammaln
77

88

9-
def _binom(i, k) -> Array:
9+
def _binom(i: Array, k: int) -> Array:
1010
"""
1111
Returns the binomial coefficients for the Bernstein basis
1212
of degree :math:`k`.
@@ -26,7 +26,7 @@ def _binom(i, k) -> Array:
2626
)
2727

2828

29-
def _b_basis_jvp(k, inputs: tuple[Array], perturbations: tuple[Array]):
29+
def _b_basis_jvp(k: int, inputs: tuple[Array], perturbations: tuple[Array]) -> tuple[Array, Array]:
3030
r"""
3131
Custom forward-mode differentiation (JVP) for the Bernstein basis.
3232

0 commit comments

Comments
 (0)