Skip to content

Commit 51df286

Browse files
authored
feat: add benchmark for gradient w/ moffat (#187)
1 parent c0c4e3a commit 51df286

1 file changed

Lines changed: 24 additions & 0 deletions

File tree

tests/jax/test_benchmarks.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,27 @@ def test_benchmark_moffat_conv(benchmark, kind):
336336
benchmark, kind, lambda: _run_moffat_bench_conv_jit().block_until_ready()
337337
)
338338
print(f"time: {dt:0.4g} ms", end=" ")
339+
340+
341+
def _run_moffat_bench_conv_grad(scale_radius):
342+
obj = jgs.Spergel(nu=-0.6, scale_radius=scale_radius)
343+
psf = jgs.Moffat(beta=2.5, fwhm=0.9)
344+
obj = jgs.Convolve(
345+
[obj, psf],
346+
gsparams=jgs.GSParams(minimum_fft_size=2048, maximum_fft_size=2048),
347+
)
348+
return jnp.sum(obj.drawImage(nx=50, ny=50, scale=0.2).array ** 2)
349+
350+
351+
_run_moffat_bench_conv_grad_jit = jax.jit(jax.grad(_run_moffat_bench_conv_grad))
352+
353+
354+
@pytest.mark.parametrize("kind", ["run"])
355+
def test_benchmark_moffat_conv_grad(benchmark, kind):
356+
scale_radius = jnp.array(0.5)
357+
dt = _run_benchmarks(
358+
benchmark,
359+
kind,
360+
lambda: _run_moffat_bench_conv_grad_jit(scale_radius).block_until_ready(),
361+
)
362+
print(f"time: {dt:0.4g} ms", end=" ")

0 commit comments

Comments
 (0)