@@ -389,7 +389,18 @@ def g(t, y, args):
389389 or saveat .subs .ts is not None
390390 or saveat .subs .steps
391391 ):
392- assert tree_allclose (sol1 .ts , - cast (Array , sol2 .ts ), equal_nan = True )
392+ assert sol1 .ts is not None
393+ assert sol2 .ts is not None
394+ assert tree_allclose (
395+ sol1 .ts [~ jnp .isinf (sol1 .ts )],
396+ - cast (Array , sol2 .ts [~ jnp .isinf (sol2 .ts )]),
397+ equal_nan = True ,
398+ )
399+ assert tree_allclose (
400+ sol1 .ts [jnp .isinf (sol1 .ts )],
401+ cast (Array , sol2 .ts [jnp .isinf (sol2 .ts )]),
402+ equal_nan = True ,
403+ )
393404 assert tree_allclose (sol1 .ys , sol2 .ys , equal_nan = True )
394405 if saveat .dense :
395406 t = jnp .linspace (0.3 , 4 , 20 )
@@ -398,6 +409,36 @@ def g(t, y, args):
398409 assert tree_allclose (sol1 .derivative (ti ), - sol2 .derivative (- ti ))
399410
400411
412+ @pytest .mark .parametrize (
413+ "saveat" ,
414+ (
415+ diffrax .SaveAt (t0 = True , fn = lambda t , y , args : t ),
416+ diffrax .SaveAt (t1 = True , fn = lambda t , y , args : t ),
417+ diffrax .SaveAt (dense = True , fn = lambda t , y , args : t ),
418+ diffrax .SaveAt (steps = True , fn = lambda t , y , args : t ),
419+ diffrax .SaveAt (ts = jnp .linspace (3.0 , 1.0 , 5 ), fn = lambda t , y , args : t ),
420+ ),
421+ )
422+ def test_reverse_time_saveat (saveat ):
423+ def f (t , y , args ):
424+ return - y
425+
426+ t0 = 4
427+ t1 = 0.3
428+ dt0 = - 1 / 50
429+ y0 = 1.0
430+ sol1 = diffrax .diffeqsolve (
431+ diffrax .ODETerm (f ),
432+ diffrax .Euler (),
433+ t0 ,
434+ t1 ,
435+ dt0 ,
436+ y0 ,
437+ saveat = saveat ,
438+ )
439+ assert tree_allclose (sol1 .ys , sol1 .ts )
440+
441+
401442def test_semi_implicit_euler ():
402443 term1 = diffrax .ODETerm (lambda t , y , args : - y )
403444 term2 = diffrax .ODETerm (lambda t , y , args : y )
0 commit comments