Skip to content

Commit 75f9767

Browse files
small corrections regarding jitting and modifying images
1 parent 2fbeb8a commit 75f9767

1 file changed

Lines changed: 11 additions & 3 deletions

File tree

docs/index.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,32 @@ psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0)
5252
final = jax_galsim.Convolve([gal, psf])
5353
image = final.drawImage(scale=0.2)
5454

55-
# Add noise (changes underlying image array)
55+
# Add noise (overwrites underlying image array with new array)
5656
image.addNoise(jax_galsim.GaussianNoise(sigma=30.0))
5757
```
5858

5959
JAX-GalSim objects are JAX pytrees, so you can JIT-compile and differentiate the entire pipeline:
6060

6161
```python
6262
@jax.jit
63-
def simulate(flux, sigma):
63+
def simulate(flux, sigma, *, slen=21, fft_size=128):
64+
gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)
65+
6466
gal = jax_galsim.Gaussian(flux=flux, sigma=sigma)
6567
psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0)
66-
return jax_galsim.Convolve([gal, psf]).drawImage(scale=0.2).array.sum()
68+
gal_convolved = jax_galsim.Convolve([gal, psf]).withGSParams(gsparams)
69+
image = gal_convolved.drawImage(nx=slen, ny=slen, scale=0.2)
70+
return image.array.sum()
6771

6872
# Compute gradients with respect to galaxy parameters
6973
grad_fn = jax.grad(simulate, argnums=(0, 1))
7074
dflux, dsigma = grad_fn(1e5, 2.0)
7175
```
7276

77+
Note that the size of the image in real space (`slen`) and fourier space
78+
(`minimum_fft_size = maximum_fft_size`) need to be specified in advance for jitting. See the rest
79+
of the documentation for more details and examples.
80+
7381
---
7482

7583
## Next Steps

0 commit comments

Comments
 (0)