Skip to content

Commit b5e5e2f

Browse files
some corrections in notable differences
1 parent 75f9767 commit b5e5e2f

1 file changed

Lines changed: 48 additions & 19 deletions

File tree

docs/notable-differences.md

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,40 @@ that you should understand before porting code or writing new simulations.
99

1010
## Immutability
1111

12-
JAX arrays are **immutable**. Any GalSim operation that modifies data in-place
13-
returns a new object in JAX-GalSim instead.
12+
JAX arrays are **immutable**. Any GalSim operation that originally modified data in-place, now
13+
instead creates a new array that overwrites the original one. Let's look at `__iadd__` as an example.
1414

1515
```python
1616
# GalSim — mutates the image in-place
17-
image.addNoise(noise)
18-
image.array[10, 10] = 0.0
17+
# i.e. no new numpy array is created
18+
image += 1.0
19+
# under the hood, some version of: `self.array[:,:] += a` does not create a new numpy array.
1920

20-
# JAX-GalSim — returns a new image each time
21-
image = image.addNoise(noise)
21+
# JAX-GalSim — creates a new array and overwrites original one
22+
image += 1.0
23+
# under the hood: `image._array = image._array + 1.0`. The RHS is a new JAX array.
24+
```
25+
26+
This could become a subtle source of bugs if you are used to numpy in place mutability. Here
27+
is another example with `__iadd__` that illustrates this:
28+
29+
```python
30+
# galsim
31+
image = galsim.ImageD(11, 11)
32+
arr1 = image.array
2233

23-
# Direct array element mutation is not supported.
24-
# Use jax.numpy operations to produce a new array:
25-
new_array = image.array.at[10, 10].set(0.0)
34+
image += 1.0
35+
arr1.sum(), image.array.sum() # -> 121.0, 121.0
36+
37+
# jax-galsim
38+
image = jax_galsim.ImageD(11, 11)
39+
arr1 = image.array
40+
41+
image += 1.0
42+
arr1.sum(), image.array.sum() # -> 0.0, 121.0, original image array was unmodified!
2643
```
2744

28-
This is the most common change when porting GalSim code. Every call that
29-
modifies an image, adds noise, or updates a value must capture the return value.
30-
If you forget the assignment, the original object is unchanged and no error is
31-
raised --- a subtle source of bugs.
45+
For more details on JAX immutability please see the [Sharp Bits page](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates) of the JAX documentation.
3246

3347
---
3448

@@ -64,7 +78,7 @@ user-facing interface looks the same:
6478

6579
```python
6680
noise = jax_galsim.GaussianNoise(sigma=30.0)
67-
image = image.addNoise(noise) # state is managed internally
81+
image.addNoise(noise) # state is managed internally
6882
```
6983

7084
**Different sequences**: Even with the same seed value, the actual random number
@@ -88,7 +102,7 @@ A PyTree splits each object into two parts:
88102
| **Children** (traced) | Values JAX differentiates through | `flux`, `sigma`, `half_light_radius` | Re-evaluation, not recompilation |
89103
| **Auxiliary data** (static) | Structure and configuration | `GSParams`, enum flags | Full recompilation under `jit` |
90104

91-
In practice, profile parameters live in a `_params` dict (children) and
105+
For `GSObject`, profile parameters live in a `_params` dict (children) and
92106
numerical configuration lives in `_gsparams` (auxiliary):
93107

94108
```python
@@ -104,12 +118,13 @@ calls when possible.
104118
```python
105119
import jax
106120

107-
gsparams = jax_galsim.GSParams(maximum_fft_size=8192)
121+
gsparams = jax_galsim.GSParams(minimum_fft_size=8192, maximum_fft_size=8192)
122+
slen = 21 # image size should also be constant for jit to work (see below for more details).
108123

109124
@jax.jit
110125
def simulate(flux, sigma):
111126
gal = jax_galsim.Gaussian(flux=flux, sigma=sigma, gsparams=gsparams)
112-
return gal.drawImage(scale=0.2).array.sum()
127+
return gal.drawImage(nx=slen, ny=slen, scale=0.2).array.sum()
113128

114129
# Changing gsparams here would cause recompilation on next call
115130
```
@@ -146,17 +161,31 @@ avoid problematic control flow in its own implementations.
146161

147162
Under `jit`, the **shape** of every array must be determinable at compile time.
148163
Operations whose output size depends on input values (e.g., adaptive image
149-
sizing based on a traced parameter) may not work. When using `jax.vmap`, you
164+
sizing based on a traced parameter) may not work. When using `jax.jit` or `jax.vmap`, you
150165
must specify fixed image dimensions:
151166

152167
```python
168+
@jax.jit
153169
@jax.vmap
154170
def batch(sigma):
155-
gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma)
171+
gsparams = GSParams(minimum_fft_size=256, maximum_fft_size=256)
172+
gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma).withGSParams(gsparams)
156173
# Must specify nx, ny so all images have the same shape
157174
return gal.drawImage(scale=0.2, nx=64, ny=64).array
158175
```
159176

177+
Importantly, the default (and most commonly used) drawing procedure in GalSim (and JAX-GalSim)
178+
transforms image to k-space via an FFT. The size of the "images" in Fourier space usually depends
179+
on traced galaxy profile paramers e.g. size, which makes this incompatible with `jit`. Thus, in JAX-GalSim
180+
we allow for this k-space image size to be fixed explicitly via `GSParams` as done above:
181+
182+
```python
183+
gsparams = GSParams(minimum_fft_size=256, maximum_fft_size=256)
184+
```
185+
186+
where both `minimum_fft_size` and `maximum_fft_size` need to be set to the same value.
187+
188+
160189
### The `__init__` gotcha
161190

162191
During `jit` tracing, JAX calls constructors with **tracer objects** rather than

0 commit comments

Comments
 (0)