@@ -68,6 +68,38 @@ def f(x):
6868 jax .tree .all (jax .tree .map (assert_allclose , tree , expected_tree ))
6969
7070
71+ def test_fori_collect_no_recompilation ():
72+ def f (x ):
73+ return x + 1
74+
75+ result1 = fori_collect (0 , 10 , f , jnp .array ([0.0 ]), progbar = False )
76+ result2 = fori_collect (0 , 10 , f , jnp .array ([5.0 ]), progbar = False )
77+
78+ assert_allclose (result1 , np .arange (1 , 11 ).reshape (- 1 , 1 ))
79+ assert_allclose (result2 , np .arange (6 , 16 ).reshape (- 1 , 1 ))
80+
81+
82+ def test_fori_collect_repeated_mcmc_no_recompilation ():
83+ from numpyro .infer import MCMC , NUTS
84+
85+ def model ():
86+ numpyro .sample ("x" , dist .Normal (0 , 1 ))
87+
88+ mcmc = MCMC (
89+ NUTS (model ), num_warmup = 5 , num_samples = 10 , num_chains = 1 , progress_bar = False
90+ )
91+
92+ mcmc .run (random .PRNGKey (0 ))
93+ samples1 = mcmc .get_samples ()["x" ]
94+
95+ mcmc .run (random .PRNGKey (1 ))
96+ samples2 = mcmc .get_samples ()["x" ]
97+
98+ assert samples1 .shape == (10 ,)
99+ assert samples2 .shape == (10 ,)
100+ assert not np .allclose (samples1 , samples2 )
101+
102+
71103@pytest .mark .parametrize (
72104 "pytree" ,
73105 [
0 commit comments