Skip to content

Commit 14b600a

Browse files
Merge branch 'main' into documentation
2 parents 2f5bd10 + ed84a05 commit 14b600a

12 files changed

Lines changed: 1133 additions & 168 deletions

jax_galsim/bounds.py

Lines changed: 421 additions & 69 deletions
Large diffs are not rendered by default.

jax_galsim/core/wrap_image.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _block_reduce_loop(sim, nx, ny, nxwrap, nywrap):
5555
return fim
5656

5757

58-
@partial(jax.jit, static_argnames=("xmin", "ymin", "nxwrap", "nywrap"))
58+
@partial(jax.jit, static_argnames=("nxwrap", "nywrap"))
5959
def wrap_nonhermitian(im, xmin, ymin, nxwrap, nywrap):
6060
# these bits compute how many total blocks we need to cover the image
6161
nx = im.shape[1] // nxwrap
@@ -81,7 +81,11 @@ def wrap_nonhermitian(im, xmin, ymin, nxwrap, nywrap):
8181
else:
8282
fim = _block_reduce_loop(sim, nx, ny, nxwrap, nywrap)
8383

84-
im = im.at[ymin : ymin + nywrap, xmin : xmin + nxwrap].set(fim)
84+
im = jax.lax.dynamic_update_slice(
85+
im,
86+
fim,
87+
(ymin, xmin),
88+
)
8589
return im
8690

8791

@@ -98,10 +102,6 @@ def contract_hermitian_x(im):
98102
@partial(
99103
jax.jit,
100104
static_argnames=[
101-
"im_xmin",
102-
"im_ymin",
103-
"wrap_xmin",
104-
"wrap_ymin",
105105
"wrap_nx",
106106
"wrap_ny",
107107
],
@@ -127,10 +127,6 @@ def contract_hermitian_y(im):
127127
@partial(
128128
jax.jit,
129129
static_argnames=[
130-
"im_xmin",
131-
"im_ymin",
132-
"wrap_xmin",
133-
"wrap_ymin",
134130
"wrap_nx",
135131
"wrap_ny",
136132
],

jax_galsim/gsobject.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def _setup_image(
380380
N = self.getGoodImageSize(1.0)
381381
if odd:
382382
N += 1
383-
bounds = BoundsI(1, N, 1, N)
383+
bounds = BoundsI(xmin=1, deltax=N, ymin=1, deltay=N)
384384
image.resize(bounds)
385385
# Else use the given image as is
386386

@@ -486,7 +486,7 @@ def _get_new_bounds(self, image, nx, ny, bounds, center):
486486
if image is not None and image.bounds.isDefined():
487487
return image.bounds
488488
elif nx is not None and ny is not None:
489-
b = BoundsI(1, nx, 1, ny)
489+
b = BoundsI(xmin=1, deltax=nx, ymin=1, deltay=ny)
490490
if center is not None:
491491
# this code has to match the code in _setup_image
492492
# for the same branch of the if statement block
@@ -853,7 +853,14 @@ def drawFFT_makeKImage(self, image):
853853
image_N = jnp.max(
854854
jnp.array(
855855
[
856-
jnp.max(jnp.abs(jnp.array(image.bounds._getinitargs()))) * 2,
856+
jnp.max(
857+
jnp.abs(
858+
jnp.array(
859+
[image.xmin, image.xmax, image.ymin, image.ymax]
860+
)
861+
)
862+
)
863+
* 2,
857864
jnp.max(jnp.array(image.bounds.numpyShape())),
858865
]
859866
)
@@ -880,7 +887,9 @@ def drawFFT_makeKImage(self, image):
880887
"drawFFT requires an FFT that is too large.", Nk
881888
)
882889

883-
bounds = BoundsI(0, Nk // 2, -Nk // 2, Nk // 2)
890+
bounds = BoundsI(
891+
xmin=0, deltax=Nk // 2 + 1, ymin=-Nk // 2, deltay=2 * (Nk // 2) + 1
892+
)
884893
if image.dtype in (np.complex128, np.float64, np.int32, np.uint32):
885894
kimage = ImageCD(bounds=bounds, scale=dk)
886895
else:
@@ -895,12 +904,20 @@ def drawFFT_finish(self, image, kimage, wrap_size, add_to_image):
895904
# Wrap the full image to the size we want for the FT.
896905
# Even if N == Nk, this is useful to make this portion properly Hermitian in the
897906
# N/2 column and N/2 row.
898-
bwrap = BoundsI(0, wrap_size // 2, -wrap_size // 2, wrap_size // 2 - 1)
899-
kimage_wrap = kimage._wrap(bwrap, True, False)
907+
bwrap = BoundsI(
908+
xmin=0,
909+
deltax=wrap_size // 2 + 1,
910+
ymin=-wrap_size // 2,
911+
deltay=2 * (wrap_size // 2),
912+
)
913+
kimage_wrap = kimage._wrap(bwrap, True, False, wrap_size)
900914

901915
# Perform the fourier transform.
902916
breal = BoundsI(
903-
-wrap_size // 2, wrap_size // 2 - 1, -wrap_size // 2, wrap_size // 2 - 1
917+
xmin=-wrap_size // 2,
918+
deltax=2 * (wrap_size // 2),
919+
ymin=-wrap_size // 2,
920+
deltay=2 * (wrap_size // 2),
904921
)
905922
kimg_shift = jnp.fft.ifftshift(kimage_wrap.array, axes=(-2,))
906923
real_image_arr = jnp.fft.fftshift(

0 commit comments

Comments
 (0)