You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/notable-differences.md
+48-19Lines changed: 48 additions & 19 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -9,26 +9,40 @@ that you should understand before porting code or writing new simulations.
9
9
10
10
## Immutability
11
11
12
-
JAX arrays are **immutable**. Any GalSim operation that modifies data in-place
13
-
returns a new object in JAX-GalSim instead.
12
+
JAX arrays are **immutable**. Any GalSim operation that originally modified data in-place, now
13
+
instead creates a new array that overwrites the original one. Let's look at `__iadd__` as an example.
14
14
15
15
```python
16
16
# GalSim — mutates the image in-place
17
-
image.addNoise(noise)
18
-
image.array[10, 10] =0.0
17
+
# i.e. no new numpy array is created
18
+
image +=1.0
19
+
# under the hood, some version of: `self.array[:,:] += a` does not create a new numpy array.
19
20
20
-
# JAX-GalSim — returns a new image each time
21
-
image = image.addNoise(noise)
21
+
# JAX-GalSim — creates a new array and overwrites original one
22
+
image +=1.0
23
+
# under the hood: `image._array = image._array + 1.0`. The RHS is a new JAX array.
24
+
```
25
+
26
+
This could become a subtle source of bugs if you are used to numpy in place mutability. Here
27
+
is another example with `__iadd__` that illustrates this:
28
+
29
+
```python
30
+
# galsim
31
+
image = galsim.ImageD(11, 11)
32
+
arr1 = image.array
22
33
23
-
# Direct array element mutation is not supported.
24
-
# Use jax.numpy operations to produce a new array:
25
-
new_array = image.array.at[10, 10].set(0.0)
34
+
image +=1.0
35
+
arr1.sum(), image.array.sum() # -> 121.0, 121.0
36
+
37
+
# jax-galsim
38
+
image = jax_galsim.ImageD(11, 11)
39
+
arr1 = image.array
40
+
41
+
image +=1.0
42
+
arr1.sum(), image.array.sum() # -> 0.0, 121.0, original image array was unmodified!
26
43
```
27
44
28
-
This is the most common change when porting GalSim code. Every call that
29
-
modifies an image, adds noise, or updates a value must capture the return value.
30
-
If you forget the assignment, the original object is unchanged and no error is
31
-
raised --- a subtle source of bugs.
45
+
For more details on JAX immutability please see the [Sharp Bits page](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates) of the JAX documentation.
32
46
33
47
---
34
48
@@ -64,7 +78,7 @@ user-facing interface looks the same:
64
78
65
79
```python
66
80
noise = jax_galsim.GaussianNoise(sigma=30.0)
67
-
image= image.addNoise(noise) # state is managed internally
81
+
image.addNoise(noise) # state is managed internally
68
82
```
69
83
70
84
**Different sequences**: Even with the same seed value, the actual random number
@@ -88,7 +102,7 @@ A PyTree splits each object into two parts:
88
102
|**Children** (traced) | Values JAX differentiates through |`flux`, `sigma`, `half_light_radius`| Re-evaluation, not recompilation |
89
103
|**Auxiliary data** (static) | Structure and configuration |`GSParams`, enum flags | Full recompilation under `jit`|
90
104
91
-
In practice, profile parameters live in a `_params` dict (children) and
105
+
For `GSObject`, profile parameters live in a `_params` dict (children) and
92
106
numerical configuration lives in `_gsparams` (auxiliary):
0 commit comments