@@ -191,9 +191,12 @@ def test_saveat_solution():
191191
192192
193193def test_saveat_solution_skip_steps ():
194- def _step_integrate (saveat : diffrax .SaveAt ):
194+ def _step_integrate (saveat : diffrax .SaveAt , with_7 : bool ):
195195 term = diffrax .ODETerm (lambda t , y , args : - 0.5 * y )
196- ts = jnp .array ([0.0 , 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ])
196+ if with_7 :
197+ ts = jnp .array ([0.0 , 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 ])
198+ else :
199+ ts = jnp .array ([0.0 , 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ])
197200 sol_ts = diffrax .diffeqsolve (
198201 term ,
199202 t0 = ts [0 ],
@@ -208,24 +211,35 @@ def _step_integrate(saveat: diffrax.SaveAt):
208211 assert sol_ts is not None
209212 return sol_ts [jnp .isfinite (sol_ts )]
210213
211- saveat = diffrax .SaveAt (steps = 2 )
212- ts = _step_integrate (saveat )
213- assert jnp .allclose (ts , jnp .array ([1.0 , 3.0 , 5.0 ]))
214- saveat = diffrax .SaveAt (steps = 2 , t1 = True )
215- ts = _step_integrate (saveat )
216- assert jnp .allclose (ts , jnp .array ([1.0 , 3.0 , 5.0 , 6.0 ]))
217- saveat = diffrax .SaveAt (steps = 2 , t1 = True , t0 = True )
218- ts = _step_integrate (saveat )
219- assert jnp .allclose (ts , jnp .array ([0.0 , 1.0 , 3.0 , 5.0 , 6.0 ]))
220- saveat = diffrax .SaveAt (steps = 3 )
221- ts = _step_integrate (saveat )
222- assert jnp .allclose (ts , jnp .array ([1.0 , 4.0 ]))
223- saveat = diffrax .SaveAt (steps = 3 , t1 = True )
224- ts = _step_integrate (saveat )
225- assert jnp .allclose (ts , jnp .array ([1.0 , 4.0 , 6.0 ]))
226- saveat = diffrax .SaveAt (steps = 3 , t1 = True , t0 = True )
227- ts = _step_integrate (saveat )
228- assert jnp .allclose (ts , jnp .array ([0.0 , 1.0 , 4.0 , 6.0 ]))
214+ ts = _step_integrate (diffrax .SaveAt (steps = 2 ), with_7 = True )
215+ assert jnp .allclose (ts , jnp .array ([2.0 , 4.0 , 6.0 ]))
216+ ts = _step_integrate (diffrax .SaveAt (steps = 2 ), with_7 = False )
217+ assert jnp .allclose (ts , jnp .array ([2.0 , 4.0 , 6.0 ]))
218+
219+ ts = _step_integrate (diffrax .SaveAt (steps = 2 , t1 = True ), with_7 = True )
220+ assert jnp .allclose (ts , jnp .array ([2.0 , 4.0 , 6.0 , 7.0 ]))
221+ ts = _step_integrate (diffrax .SaveAt (steps = 2 , t1 = True ), with_7 = False )
222+ assert jnp .allclose (ts , jnp .array ([2.0 , 4.0 , 6.0 ]))
223+
224+ ts = _step_integrate (diffrax .SaveAt (steps = 2 , t1 = True , t0 = True ), with_7 = True )
225+ assert jnp .allclose (ts , jnp .array ([0.0 , 2.0 , 4.0 , 6.0 , 7.0 ]))
226+ ts = _step_integrate (diffrax .SaveAt (steps = 2 , t1 = True , t0 = True ), with_7 = False )
227+ assert jnp .allclose (ts , jnp .array ([0.0 , 2.0 , 4.0 , 6.0 ]))
228+
229+ ts = _step_integrate (diffrax .SaveAt (steps = 3 ), with_7 = True )
230+ assert jnp .allclose (ts , jnp .array ([3.0 , 6.0 ]))
231+ ts = _step_integrate (diffrax .SaveAt (steps = 3 ), with_7 = False )
232+ assert jnp .allclose (ts , jnp .array ([3.0 , 6.0 ]))
233+
234+ ts = _step_integrate (diffrax .SaveAt (steps = 3 , t1 = True ), with_7 = True )
235+ assert jnp .allclose (ts , jnp .array ([3.0 , 6.0 , 7.0 ]))
236+ ts = _step_integrate (diffrax .SaveAt (steps = 3 , t1 = True ), with_7 = False )
237+ assert jnp .allclose (ts , jnp .array ([3.0 , 6.0 ]))
238+
239+ ts = _step_integrate (diffrax .SaveAt (steps = 3 , t1 = True , t0 = True ), with_7 = True )
240+ assert jnp .allclose (ts , jnp .array ([0.0 , 3.0 , 6.0 , 7.0 ]))
241+ ts = _step_integrate (diffrax .SaveAt (steps = 3 , t1 = True , t0 = True ), with_7 = False )
242+ assert jnp .allclose (ts , jnp .array ([0.0 , 3.0 , 6.0 ]))
229243
230244
231245def test_saveat_solution_skip_vs_saveat ():
0 commit comments