1212
1313
1414if TYPE_CHECKING :
15- from typing import Any , Literal
15+ from typing import Literal
1616
1717pytestmark = pytest .mark .skipif (not find_spec ("jax" ), reason = "jax not installed" )
1818
2121 # problem as mean_var passes dtype= np.float64 internally, which crashes without this fix
2222 import jax
2323
24- jax .config .update ("jax_enable_x64" , True ) # noqa: FBT003
24+ jax .config .update ("jax_enable_x64" , True ) # type: ignore[no-untyped-call] # noqa: FBT003
2525
2626
2727@pytest .fixture
28- def jax_arr () -> Any : # noqa: ANN401
28+ def jax_arr () -> jax . Array :
2929 import jax .numpy as jnp
3030
3131 return jnp .array ([[1 , 0 ], [2 , 0 ], [3 , 0 ]], dtype = jnp .float32 )
3232
3333
3434@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
35- def test_sum (jax_arr : Any , axis : Literal [0 , 1 ] | None ) -> None : # noqa: ANN401
35+ def test_sum (jax_arr : jax . Array , axis : Literal [0 , 1 ] | None ) -> None :
3636 import jax .numpy as jnp
3737
3838 result = stats .sum (jax_arr , axis = axis )
@@ -41,7 +41,7 @@ def test_sum(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
4141
4242
4343@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
44- def test_min (jax_arr : Any , axis : Literal [0 , 1 ] | None ) -> None : # noqa: ANN401
44+ def test_min (jax_arr : jax . Array , axis : Literal [0 , 1 ] | None ) -> None :
4545 import jax .numpy as jnp
4646
4747 result = stats .min (jax_arr , axis = axis )
@@ -50,7 +50,7 @@ def test_min(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
5050
5151
5252@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
53- def test_max (jax_arr : Any , axis : Literal [0 , 1 ] | None ) -> None : # noqa: ANN401
53+ def test_max (jax_arr : jax . Array , axis : Literal [0 , 1 ] | None ) -> None :
5454 import jax .numpy as jnp
5555
5656 result = stats .max (jax_arr , axis = axis )
@@ -59,7 +59,7 @@ def test_max(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: ANN401
5959
6060
6161@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
62- def test_mean (jax_arr : Any , axis : Literal [0 , 1 ] | None ) -> None : # noqa: ANN401
62+ def test_mean (jax_arr : jax . Array , axis : Literal [0 , 1 ] | None ) -> None :
6363 import jax .numpy as jnp
6464
6565 result = stats .mean (jax_arr , axis = axis )
@@ -95,7 +95,7 @@ def test_is_constant(axis: Literal[0, 1] | None) -> None:
9595
9696
9797@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
98- def test_mean_var (jax_arr : Any , axis : Literal [0 , 1 ] | None ) -> None : # noqa: ANN401
98+ def test_mean_var (jax_arr : jax . Array , axis : Literal [0 , 1 ] | None ) -> None :
9999 import jax .numpy as jnp
100100
101101 mean , var = stats .mean_var (jax_arr , axis = axis , correction = 1 )
@@ -108,14 +108,14 @@ def test_mean_var(jax_arr: Any, axis: Literal[0, 1] | None) -> None: # noqa: AN
108108 assert jnp .allclose (var , var_expected )
109109
110110
111- def test_to_dense (jax_arr : Any ) -> None : # noqa: ANN401
111+ def test_to_dense (jax_arr : jax . Array ) -> None :
112112 import jax .numpy as jnp
113113
114114 result = to_dense (jax_arr )
115115 assert jnp .array_equal (result , jax_arr )
116116
117117
118- def test_to_dense_to_cpu (jax_arr : Any ) -> None : # noqa: ANN401
118+ def test_to_dense_to_cpu (jax_arr : jax . Array ) -> None :
119119 result = to_dense (jax_arr , to_cpu_memory = True )
120120 assert isinstance (result , np .ndarray )
121121 np .testing .assert_array_equal (result , np .asarray (jax_arr ))
0 commit comments