@@ -352,10 +352,10 @@ def slow_log_prob(amplitude, length_scale, noise):
352352 )(self .dtype (1.0 ), self .dtype (1.0 ), self .dtype (1e-3 ))
353353 np .testing .assert_allclose (value , slow_value , rtol = 1e-3 )
354354 slow_d_amp , slow_d_length_scale , slow_d_noise = slow_gradient
355- np .testing .assert_allclose (d_amp , slow_d_amp , rtol = 1e -4 )
356- np .testing .assert_allclose (d_length_scale , slow_d_length_scale , rtol = 1e -4 )
355+ np .testing .assert_allclose (d_amp , slow_d_amp , rtol = 2e -4 )
356+ np .testing .assert_allclose (d_length_scale , slow_d_length_scale , rtol = 2e -4 )
357357 # TODO(thomaswc): Investigate why the noise gradient is so noisy.
358- np .testing .assert_allclose (d_noise , slow_d_noise , rtol = 1e -4 )
358+ np .testing .assert_allclose (d_noise , slow_d_noise , rtol = 2e -4 )
359359
360360 def test_gaussian_process_log_prob_gradient_of_index_points (self ):
361361 samples = jnp .array ([
@@ -407,7 +407,7 @@ def slow_log_prob(pt1, pt2, pt3):
407407 fast_log_prob , argnums = [0 , 1 , 2 ]
408408 )(self .dtype (- 0.5 ), self .dtype (0.0 ), self .dtype (0.5 ))
409409 np .testing .assert_allclose (fast_value , slow_value , rtol = 3e-5 )
410- np .testing .assert_allclose (fast_gradient , slow_gradient , rtol = 1e -4 )
410+ np .testing .assert_allclose (fast_gradient , slow_gradient , rtol = 2e -4 )
411411
412412 def test_gaussian_process_mean (self ):
413413 mean_fn = lambda x : jnp .stack ([x [:, 0 ]** 2 , x [:, 0 ]** 3 ], axis = - 1 )
0 commit comments