@@ -33,39 +33,18 @@ def jax_arr() -> jax.Array:
3333
3434
3535@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
36- def test_sum (jax_arr : jax .Array , axis : Literal [0 , 1 ] | None ) -> None :
36+ @pytest .mark .parametrize ("func" , ["sum" , "min" , "max" , "mean" ])
37+ def test_simple_stat (jax_arr : jax .Array , func : Literal ["sum" , "min" , "max" , "mean" ], axis : Literal [0 , 1 ] | None ) -> None :
3738 import jax .numpy as jnp
3839
39- result = stats .sum (jax_arr , axis = axis )
40- expected = jnp .sum (jax_arr , axis = axis )
41- assert jnp .array_equal (result , expected )
40+ result = getattr (stats , func )(jax_arr , axis = axis )
41+ expected = getattr (jnp , func )(jax_arr , axis = axis )
4242
43-
44- @pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
45- def test_min (jax_arr : jax .Array , axis : Literal [0 , 1 ] | None ) -> None :
46- import jax .numpy as jnp
47-
48- result = stats .min (jax_arr , axis = axis )
49- expected = jnp .min (jax_arr , axis = axis )
50- assert jnp .array_equal (result , expected )
51-
52-
53- @pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
54- def test_max (jax_arr : jax .Array , axis : Literal [0 , 1 ] | None ) -> None :
55- import jax .numpy as jnp
56-
57- result = stats .max (jax_arr , axis = axis )
58- expected = jnp .max (jax_arr , axis = axis )
59- assert jnp .array_equal (result , expected )
60-
61-
62- @pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
63- def test_mean (jax_arr : jax .Array , axis : Literal [0 , 1 ] | None ) -> None :
64- import jax .numpy as jnp
65-
66- result = stats .mean (jax_arr , axis = axis )
67- expected = jnp .mean (jax_arr , axis = axis )
68- assert jnp .allclose (result , expected )
43+ assert type (result ) is type (expected )
44+ if func == "mean" :
45+ assert jnp .allclose (result , expected )
46+ else :
47+ assert jnp .array_equal (result , expected )
6948
7049
7150@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
@@ -86,37 +65,43 @@ def test_is_constant(axis: Literal[0, 1] | None) -> None:
8665 result = stats .is_constant (x , axis = axis )
8766
8867 if axis is None :
89- assert bool ( result ) is False
68+ assert not result
9069 elif axis == 0 :
9170 expected = jnp .array ([True , True , False , False ])
71+ assert type (result ) is type (expected )
9272 assert jnp .array_equal (result , expected )
9373 else :
9474 expected = jnp .array ([False , False , True , True , False , True ])
75+ assert type (result ) is type (expected )
9576 assert jnp .array_equal (result , expected )
9677
9778
9879@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
99- def test_mean_var (jax_arr : jax .Array , axis : Literal [0 , 1 ] | None ) -> None :
80+ def test_mean_var (subtests : pytest . Subtests , jax_arr : jax .Array , axis : Literal [0 , 1 ] | None ) -> None :
10081 import jax .numpy as jnp
10182
10283 mean , var = stats .mean_var (jax_arr , axis = axis , correction = 1 )
10384
104- mean_expected = jnp .mean (jax_arr , axis = axis )
105- n = jax_arr .size if axis is None else jax_arr .shape [axis ]
106- var_expected = jnp .var (jax_arr , axis = axis ) * n / (n - 1 )
85+ for name , result in dict (mean = mean , var = var ).items ():
86+ if name == "mean" :
87+ expected = jnp .mean (jax_arr , axis = axis )
88+ else :
89+ n = jax_arr .size if axis is None else jax_arr .shape [axis ]
90+ expected = jnp .var (jax_arr , axis = axis ) * n / (n - 1 )
10791
108- assert jnp .allclose (mean , mean_expected )
109- assert jnp .allclose (var , var_expected )
92+ with subtests .test (name ):
93+ assert type (result ) is type (expected )
94+ assert jnp .allclose (result , expected )
11095
11196
112- def test_to_dense (jax_arr : jax .Array ) -> None :
97+ @pytest .mark .parametrize ("to_cpu_memory" , [True , False ], ids = ["to_cpu_memory" , "not_to_cpu_memory" ])
98+ def test_to_dense (* , jax_arr : jax .Array , to_cpu_memory : bool ) -> None :
11399 import jax .numpy as jnp
114100
115- result = to_dense (jax_arr )
116- assert jnp .array_equal (result , jax_arr )
117-
101+ result = to_dense (jax_arr , to_cpu_memory = to_cpu_memory )
118102
119- def test_to_dense_to_cpu (jax_arr : jax .Array ) -> None :
120- result = to_dense (jax_arr , to_cpu_memory = True )
121- assert isinstance (result , np .ndarray )
122- np .testing .assert_array_equal (result , np .asarray (jax_arr ))
103+ if to_cpu_memory :
104+ assert isinstance (result , np .ndarray )
105+ else :
106+ assert isinstance (result , jax .Array )
107+ assert jnp .array_equal (result , jax_arr )
0 commit comments