@@ -1214,7 +1214,78 @@ def _inner_comp_find_maxk(arr, thresh, kx, ky):
12141214 - jnp .inf ,
12151215 )
12161216 )
1217- return jnp .maximum (max_kx , max_ky )
1217+ # galsim adds one pixel at the end so that maxk is
1218+ # the k value where things do not pass the threshold,
1219+ # so we do that here too.
1220+ return jnp .maximum (max_kx , max_ky ) + kx [0 , 1 ] - kx [0 , 0 ]
1221+
1222+
1223+ # this version matches galsim's maxk operation exactly, but is
1224+ # more expensive to compute since it has a scan operation.
1225+ # I am leaving it here for posterity. - MRB
1226+ # @jax.jit
1227+ # def _inner_comp_find_maxk_scan(arr, thresh, kx, ky):
1228+ # val = (arr * arr.conjugate()).real
1229+ # msk_thresh = val > thresh * thresh
1230+ # akx = jnp.abs(kx)
1231+ # aky = jnp.abs(ky)
1232+ #
1233+ # def _func(carry, x):
1234+ # msk_kx = akx <= x
1235+ # msk_ky = aky <= x
1236+ # return carry, jnp.sum(msk_thresh & msk_kx & msk_ky)
1237+ #
1238+ # _, msk = jax.lax.scan(_func, None, xs=kx[0, :])
1239+ #
1240+ # # We are searching for the location of the first string of
1241+ # # five locations in a row in `msk` where the value stays the
1242+ # # same.
1243+ # # We do this by putting the array through jnp.diff, which
1244+ # # computes the difference of adjacent elements. Then we convolve
1245+ # # with a filter of ones of length five to sum groups of five
1246+ # # elements together. The first location where the result is
1247+ # # zero is the location we want. The tricky bit however is getting
1248+ # # the indexing right.
1249+ #
1250+ # # step 1. compute the diff of adjacent elements
1251+ # # The function jnp.diff returns an array of size one less than
1252+ # # the input. So we concatenate a zero at the front. This makes
1253+ # # sense since if the original array is all constant, then the
1254+ # # location of the first five zeros is at the start of the array.
1255+ # delta_msk = jnp.concatenate(
1256+ # [jnp.array([0], dtype=int), jnp.diff(msk)],
1257+ # axis=0,
1258+ # dtype=int,
1259+ # )
1260+ #
1261+ # # step 2. convolve with the filter
1262+ # # In the discrete convolution, you have to deal with edge
1263+ # # behavior where the filter only partially overlaps the arrays.
1264+ # # We use the mode `full` which returns an array containing
1265+ # # every possible combination with missing elements set to zero.
1266+ # # We cut the first `length of filter - 1` elements so that
1267+ # # index i of the result is the sum of the filter starting
1268+ # # at index i of the input.
1269+ # sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:]
1270+ #
1271+ # # step 3. find first location of zero in the convolution
1272+ # # Finally, we use jnp.argmin to find the location of the first
1273+ # # zero. Per the doc string, if there is more than one zero, this
1274+ # # function returns the first location (i.e., smallest index)
1275+ # # which is what we want.
1276+ # msk_zero = sums == 0
1277+ # sind, dk = jax.lax.cond(
1278+ # jnp.any(msk_zero),
1279+ # # if we find a set of zeros, the code computes the next pixel past
1280+ # # the pixels where |kval| > thresh. So we set dk = 0 since we don't
1281+ # # need to shift things.
1282+ # lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0),
1283+ # # if we get to the end of the array, we add one pixel spacing
1284+ # # so we match galsim
1285+ # lambda x: (-1, kx[0, -1] - kx[0, -2]),
1286+ # msk_zero,
1287+ # )
1288+ # return kx[0, sind] + dk
12181289
12191290
12201291@jax .jit
@@ -1226,11 +1297,6 @@ def _find_maxk(kim, max_maxk, thresh):
12261297 # maxk from the image (computed by _inner_comp_find_maxk)
12271298 # by max_maxk from above
12281299 return jnp .minimum (
1229- # jax-galsim tends to be less conservative for maxk
1230- # since compared to galsim, it does NOT require 5 rows
1231- # of pixels in a row below the threshold.
1232- # thus we add pixels here to ensure the galsim tests pass.
1233- # it turns out one worked ok so that is what we did. - MRB
1234- _inner_comp_find_maxk (kim .array , thresh , kx , ky ) + 1 * kim .scale ,
1300+ _inner_comp_find_maxk (kim .array , thresh , kx , ky ),
12351301 max_maxk ,
12361302 )
0 commit comments