@@ -52,24 +52,32 @@ psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0)
5252final = jax_galsim.Convolve([gal, psf])
5353image = final.drawImage(scale = 0.2 )
5454
55- # Add noise (changes underlying image array)
55+ # Add noise (overwrites underlying image array with new array)
5656image.addNoise(jax_galsim.GaussianNoise(sigma = 30.0 ))
5757```
5858
5959JAX-GalSim objects are JAX pytrees, so you can JIT-compile and differentiate the entire pipeline:
6060
6161``` python
6262@jax.jit
63- def simulate (flux , sigma ):
63+ def simulate (flux , sigma , * , slen = 21 , fft_size = 128 ):
64+ gsparams = GSParams(minimum_fft_size = fft_size, maximum_fft_size = fft_size)
65+
6466 gal = jax_galsim.Gaussian(flux = flux, sigma = sigma)
6567 psf = jax_galsim.Gaussian(flux = 1.0 , sigma = 1.0 )
66- return jax_galsim.Convolve([gal, psf]).drawImage(scale = 0.2 ).array.sum()
68+ gal_convolved = jax_galsim.Convolve([gal, psf]).withGSParams(gsparams)
69+ image = gal_convolved.drawImage(nx = slen, ny = slen, scale = 0.2 )
70+ return image.array.sum()
6771
6872# Compute gradients with respect to galaxy parameters
6973grad_fn = jax.grad(simulate, argnums = (0 , 1 ))
7074dflux, dsigma = grad_fn(1e5 , 2.0 )
7175```
7276
77+ Note that the size of the image in real space (` slen ` ) and fourier space
78+ (` minimum_fft_size = maximum_fft_size ` ) need to be specified in advance for jitting. See the rest
79+ of the documentation for more details and examples.
80+
7381---
7482
7583## Next Steps
0 commit comments