Skip to content

Commit 58d1a0e

Browse files
authored
feat: investigate maxk differences and add galsim algorithm for maxk (#224)
* test: remove setting maxk since this now works too? * test: rename test to make purpose clearer * test: run new test suite * test: tighter tolerance * feat: add galsim maxk version * feat: use a scan op to match galsim exactly * fix: adjust indexing a bit * fix: off by one only if we reach the end of the array; galsim is fine * test: simplify tests * perf: use faster algorithm
1 parent e3ff71f commit 58d1a0e

5 files changed

Lines changed: 147 additions & 335 deletions

File tree

jax_galsim/interpolatedimage.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,7 +1214,78 @@ def _inner_comp_find_maxk(arr, thresh, kx, ky):
12141214
-jnp.inf,
12151215
)
12161216
)
1217-
return jnp.maximum(max_kx, max_ky)
1217+
# galsim adds one pixel at the end so that maxk is
1218+
# the k value where things do not pass the threshold,
1219+
# so we do that here too.
1220+
return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0]
1221+
1222+
1223+
# this version matches galsim's maxk operation exactly, but is
1224+
# more expensive to compute since it has a scan operation.
1225+
# I am leaving it here for posterity. - MRB
1226+
# @jax.jit
1227+
# def _inner_comp_find_maxk_scan(arr, thresh, kx, ky):
1228+
# val = (arr * arr.conjugate()).real
1229+
# msk_thresh = val > thresh * thresh
1230+
# akx = jnp.abs(kx)
1231+
# aky = jnp.abs(ky)
1232+
#
1233+
# def _func(carry, x):
1234+
# msk_kx = akx <= x
1235+
# msk_ky = aky <= x
1236+
# return carry, jnp.sum(msk_thresh & msk_kx & msk_ky)
1237+
#
1238+
# _, msk = jax.lax.scan(_func, None, xs=kx[0, :])
1239+
#
1240+
# # We are searching for the location of the first string of
1241+
# # five locations in a row in `msk` where the value stays the
1242+
# # same.
1243+
# # We do this by putting the array through jnp.diff, which
1244+
# # computes the difference of adjacent elements. Then we convolve
1245+
# # with a filter of ones of length five to sum groups of five
1246+
# # elements together. The first location where the result is
1247+
# # zero is the location we want. The tricky bit however is getting
1248+
# # the indexing right.
1249+
#
1250+
# # step 1. compute the diff of adjacent elements
1251+
# # The function jnp.diff returns an array of size one less than
1252+
# # the input. So we concatenate a zero at the front. This makes
1253+
# # sense since if the original array is all constant, then the
1254+
# # location of the first five zeros is at the start of the array.
1255+
# delta_msk = jnp.concatenate(
1256+
# [jnp.array([0], dtype=int), jnp.diff(msk)],
1257+
# axis=0,
1258+
# dtype=int,
1259+
# )
1260+
#
1261+
# # step 2. convolve with the filter
1262+
# # In the discrete convolution, you have to deal with edge
1263+
# # behavior where the filter only partially overlaps the arrays.
1264+
# # We use the mode `full` which returns an array containing
1265+
# # every possible combination with missing elements set to zero.
1266+
# # We cut the first `length of filter - 1` elements so that
1267+
# # index i of the result is the sum of the filter starting
1268+
# # at index i of the input.
1269+
# sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:]
1270+
#
1271+
# # step 3. find first location of zero in the convolution
1272+
# # Finally, we use jnp.argmin to find the location of the first
1273+
# # zero. Per the doc string, if there is more than one zero, this
1274+
# # function returns the first location (i.e., smallest index)
1275+
# # which is what we want.
1276+
# msk_zero = sums == 0
1277+
# sind, dk = jax.lax.cond(
1278+
# jnp.any(msk_zero),
1279+
# # if we find a set of zeros, the code computes the next pixel past
1280+
# # the pixels where |kval| > thresh. So we set dk = 0 since we don't
1281+
# # need to shift things.
1282+
# lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0),
1283+
# # if we get to the end of the array, we add one pixel spacing
1284+
# # so we match galsim
1285+
# lambda x: (-1, kx[0, -1] - kx[0, -2]),
1286+
# msk_zero,
1287+
# )
1288+
# return kx[0, sind] + dk
12181289

12191290

12201291
@jax.jit
@@ -1226,11 +1297,6 @@ def _find_maxk(kim, max_maxk, thresh):
12261297
# maxk from the image (computed by _inner_comp_find_maxk)
12271298
# by max_maxk from above
12281299
return jnp.minimum(
1229-
# jax-galsim tends to be less conservative for maxk
1230-
# since compared to galsim, it does NOT require 5 rows
1231-
# of pixels in a row below the threshold.
1232-
# thus we add pixels here to ensure the galsim tests pass.
1233-
# it turns out one worked ok so that is what we did. - MRB
1234-
_inner_comp_find_maxk(kim.array, thresh, kx, ky) + 1 * kim.scale,
1300+
_inner_comp_find_maxk(kim.array, thresh, kx, ky),
12351301
max_maxk,
12361302
)

tests/GalSim

0 commit comments

Comments
 (0)