File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -336,3 +336,27 @@ def test_benchmark_moffat_conv(benchmark, kind):
336336 benchmark , kind , lambda : _run_moffat_bench_conv_jit ().block_until_ready ()
337337 )
338338 print (f"time: { dt :0.4g} ms" , end = " " )
339+
340+
341+ def _run_moffat_bench_conv_grad (scale_radius ):
342+ obj = jgs .Spergel (nu = - 0.6 , scale_radius = scale_radius )
343+ psf = jgs .Moffat (beta = 2.5 , fwhm = 0.9 )
344+ obj = jgs .Convolve (
345+ [obj , psf ],
346+ gsparams = jgs .GSParams (minimum_fft_size = 2048 , maximum_fft_size = 2048 ),
347+ )
348+ return jnp .sum (obj .drawImage (nx = 50 , ny = 50 , scale = 0.2 ).array ** 2 )
349+
350+
351+ _run_moffat_bench_conv_grad_jit = jax .jit (jax .grad (_run_moffat_bench_conv_grad ))
352+
353+
354+ @pytest .mark .parametrize ("kind" , ["run" ])
355+ def test_benchmark_moffat_conv_grad (benchmark , kind ):
356+ scale_radius = jnp .array (0.5 )
357+ dt = _run_benchmarks (
358+ benchmark ,
359+ kind ,
360+ lambda : _run_moffat_bench_conv_grad_jit (scale_radius ).block_until_ready (),
361+ )
362+ print (f"time: { dt :0.4g} ms" , end = " " )
You can’t perform that action at this time.
0 commit comments