So our approach so far has been to emulate as best as possible the GalSim APIs. In particular, these APIs assume you can
- modify views of an array that modify the parent array
- do dynamic indexing into arrays
- have global RNG state
and many other things I am for sure leaving off the list.
However, JAX at its core is fundamentally not compatible with these things. My work on issue #190 is making it clear to me at least that as we try and do more complicated things, we are going to encounter more and more difficulty.
For example the code bit
b = image.bounds & stamp.bounds
if b.isDefined():
image[b] += stamp[b]
is a GIANT footgun in JAX for numerous reasons. The more obvious one is that the if statement may not be compatible with tracing within JIT. The much more subtle reason is how the various calls get ordered. JAX always returns copies, not views. So in reality the code is doing the following
sub_image = image[b] # this is a COPY
sub_stamp = stamp[b] # this is a COPY
sub_image._array = sub_image._array + sub_stamp._array # what sub_image.__iadd__(sub_stamp) does
So the end result is buggy code that fails for incredibly confusing reasons related to us layering galsim-like APIs on top of JAX when we cannot support the same semantics.
We have a few non-exclusive choices, none of which are fantastic.
We can muddle through with minimal changes w/ a "sharp bits" document that people have to read and understand.
This is what we do now and it is hurting my brain.
We can declare certain things about JAX-GalSim objects/images immutable to better conform with JAX.
For example, we could disallow operations on JAX-Galsim images that do things like set data in-place for swaths of the array. While JAX allows this, it can only happen under certain conditions that don't seem general enough to be worth accounting for.
So for example, the code bit
would raise an error as opposed to silently producing buggy data.
We could also do things like set flags on images produced via indexing with bounds so that we can emit warnings if people attempt to set data in them in place.
I expect that this last option will serve us better in the long run since I expect if we pre-conform to JAX's constraints, we'll be able to avoid the complicated stack traces that come out when we break JAX's rules.
So our approach so far has been to emulate as best as possible the GalSim APIs. In particular, these APIs assume you can
and many other things I am for sure leaving off the list.
However, JAX at its core is fundamentally not compatible with these things. My work on issue #190 is making it clear to me at least that as we try and do more complicated things, we are going to encounter more and more difficulty.
For example the code bit
is a GIANT footgun in JAX for numerous reasons. The more obvious one is that the if statement may not be compatible with tracing within JIT. The much more subtle reason is how the various calls get ordered. JAX always returns copies, not views. So in reality the code is doing the following
So the end result is buggy code that fails for incredibly confusing reasons related to us layering galsim-like APIs on top of JAX when we cannot support the same semantics.
We have a few non-exclusive choices, none of which are fantastic.
We can muddle through with minimal changes w/ a "sharp bits" document that people have to read and understand.
This is what we do now and it is hurting my brain.
We can declare certain things about JAX-GalSim objects/images immutable to better conform with JAX.
For example, we could disallow operations on JAX-Galsim images that do things like set data in-place for swaths of the array. While JAX allows this, it can only happen under certain conditions that don't seem general enough to be worth accounting for.
So for example, the code bit
would raise an error as opposed to silently producing buggy data.
We could also do things like set flags on images produced via indexing with bounds so that we can emit warnings if people attempt to set data in them in place.
I expect that this last option will serve us better in the long run since I expect if we pre-conform to JAX's constraints, we'll be able to avoid the complicated stack traces that come out when we break JAX's rules.