1919from jax_galsim import fits
2020from jax_galsim .bounds import BoundsI
2121from jax_galsim .core .utils import (
22- compute_major_minor_from_jacobian ,
2322 ensure_hashable ,
2423 implements ,
2524)
@@ -175,8 +174,6 @@ def __init__(
175174 depixelize = depixelize ,
176175 offset = offset ,
177176 gsparams = GSParams .check (gsparams ),
178- _force_stepk = _force_stepk ,
179- _force_maxk = _force_maxk ,
180177 hdu = hdu ,
181178 _recenter_image = _recenter_image ,
182179 )
@@ -222,14 +219,19 @@ def _maxk(self):
222219 if self ._jax_aux_data ["_force_maxk" ] > 0 :
223220 return self ._jax_aux_data ["_force_maxk" ]
224221 else :
225- return super ()._maxk
222+ # galsim uses a different way to handle the WCS effects on maxk
223+ # for interpolated images. IDK why. - MRB
224+ return self ._original .maxk / self ._original ._wcs ._maxScale ()
226225
227226 @property
228227 def _stepk (self ):
229228 if self ._jax_aux_data ["_force_stepk" ] > 0 :
230229 return self ._jax_aux_data ["_force_stepk" ]
231230 else :
232- return super ()._stepk
231+ # galsim uses a different way to handle the WCS effects on stepk
232+ # for interpolated images. IDK why. - MRB
233+ # super()._stepk
234+ return self ._original .stepk / self ._original ._wcs ._minScale ()
233235
234236 @property
235237 @implements (_galsim .interpolatedimage .InterpolatedImage .x_interpolant )
@@ -367,6 +369,16 @@ def _zeropad_image(arr, npad):
367369
368370@register_pytree_node_class
369371class _InterpolatedImageImpl (GSObject ):
372+ """Internal class for handling interpolated images.
373+
374+ An interpolated image carries an intrinsic WCS with it that can be anything
375+ from a pixel-scale to a full Jacobian.
376+
377+ We use this internal class to separate the underlying image bits from the
378+ WCS handling bits. For those, we inherit from the Transform class so that
379+ we can reuse its methods.
380+ """
381+
370382 _cache_noise_pad = {}
371383
372384 _has_hard_edges = False
@@ -395,15 +407,12 @@ def __init__(
395407 depixelize = False ,
396408 offset = None ,
397409 gsparams = None ,
398- _force_stepk = 0.0 ,
399- _force_maxk = 0.0 ,
400410 hdu = None ,
401411 _recenter_image = True ,
402412 ):
403413 # this class does a ton of munging of the inputs that I don't want to reconstruct when
404414 # flattening and unflattening the class.
405415 # thus I am going to make some refs here so we have it when we need it
406- self ._workspace = {}
407416 self ._jax_children = (
408417 image ,
409418 dict (
@@ -428,8 +437,6 @@ def __init__(
428437 use_true_center = use_true_center ,
429438 depixelize = depixelize ,
430439 gsparams = gsparams ,
431- _force_stepk = _force_stepk ,
432- _force_maxk = _force_maxk ,
433440 _recenter_image = _recenter_image ,
434441 hdu = hdu ,
435442 )
@@ -530,13 +537,10 @@ def tree_unflatten(cls, aux_data, children):
530537 return ret
531538
532539 def __getstate__ (self ):
533- d = self .__dict__ .copy ()
534- d .pop ("_workspace" )
535- return d
540+ return self .__dict__ .copy ()
536541
537542 def __setstate__ (self , d ):
538543 self .__dict__ = d
539- self ._workspace = {}
540544
541545 @property
542546 def x_interpolant (self ):
@@ -683,31 +687,15 @@ def _kim(self):
683687
684688 @property
685689 def _maxk (self ):
686- if self ._jax_aux_data ["_force_maxk" ]:
687- _ , minor = compute_major_minor_from_jacobian (self ._jac_arr .reshape ((2 , 2 )))
688- return self ._jax_aux_data ["_force_maxk" ] * minor
689- else :
690- return self ._getMaxK (self ._jax_aux_data ["calculate_maxk" ])
690+ return self ._getMaxK (self ._jax_aux_data ["calculate_maxk" ])
691691
692692 @property
693693 def _stepk (self ):
694- if self ._jax_aux_data ["_force_stepk" ]:
695- _ , minor = compute_major_minor_from_jacobian (self ._jac_arr .reshape ((2 , 2 )))
696- return self ._jax_aux_data ["_force_stepk" ] * minor
697- else :
698- return self ._getStepK (self ._jax_aux_data ["calculate_stepk" ])
694+ return self ._getStepK (self ._jax_aux_data ["calculate_stepk" ])
699695
700696 def _getStepK (self , calculate_stepk ):
701697 # GalSim cannot automatically know what stepK and maxK are appropriate for the
702698 # input image. So it is usually worth it to do a manual calculation (below).
703- #
704- # However, there is also a hidden option to force it to use specific values of stepK and
705- # maxK (caveat user!). The values of _force_stepk and _force_maxk should be provided in
706- # terms of physical scale, e.g., for images that have a scale length of 0.1 arcsec, the
707- # stepK and maxK should be provided in units of 1/arcsec. Then we convert to the 1/pixel
708- # units required by the C++ layer below. Also note that profile recentering for even-sized
709- # images (see the ._adjust_offset step below) leads to automatic reduction of stepK slightly
710- # below what is provided here, while maxK is preserved.
711699 if calculate_stepk :
712700 if calculate_stepk is True :
713701 im = self .image
@@ -1173,9 +1161,9 @@ def _flux_frac(a, x, y, cenx, ceny):
11731161 dx = jnp .reshape (dx , (a .shape [0 ], a .shape [1 ], 1 ))
11741162 dy = y - ceny
11751163 dy = jnp .reshape (dy , (a .shape [0 ], a .shape [1 ], 1 ))
1176- d = jnp .arange (a .shape [0 ])
1164+ d = jnp .arange (min ( a .shape [0 ], a . shape [ 1 ]) )
11771165 d = jnp .reshape (d , (1 , 1 , - 1 ))
1178- msk = (jnp .abs (dx ) <= d ) & (jnp .abs (dx ) <= d )
1166+ msk = (jnp .abs (dx ) <= d ) & (jnp .abs (dy ) <= d )
11791167 res = jnp .sum (
11801168 jnp .where (
11811169 msk ,
@@ -1184,7 +1172,6 @@ def _flux_frac(a, x, y, cenx, ceny):
11841172 ),
11851173 axis = (0 , 1 ),
11861174 )
1187- res = jnp .where (res > 0 , res , - jnp .inf )
11881175 return res
11891176
11901177
@@ -1193,28 +1180,20 @@ def _calculate_size_containing_flux(image, thresh):
11931180 cenx , ceny = image .center .x , image .center .y
11941181 x , y = image .get_pixel_centers ()
11951182 fluxes = _flux_frac (image .array , x , y , cenx , ceny )
1196- msk = fluxes >= - jnp .inf
1197- fluxes = jnp .where (msk , fluxes , jnp .max (fluxes ))
1198- d = jnp .arange (image .array .shape [0 ]) + 1.0
1199- # below we use a linear interpolation table to find the maximum size
1200- # in pixels that contains a given flux (called thresh here)
1201- # expfac controls how much we oversample the interpolation table
1202- # in order to return a more accurate result
1203- # we have it hard coded at 4 to compromise between speed and accuracy
1204- expfac = 4.0
1205- dint = jnp .arange (image .array .shape [0 ] * expfac ) / expfac + 1.0
1206- fluxes = jnp .interp (dint , d , fluxes )
1207- msk = fluxes <= thresh
1183+ # we add 1 since the flux fraction computation above starts at
1184+ # one pixel and jnp.arange starts at zero
1185+ d = jnp .arange (min (image .array .shape [0 ], image .array .shape [1 ])) + 1.0
1186+ p = jnp .sign (thresh )
1187+ msk = (p * fluxes ) >= (p * thresh )
12081188 return (
1209- jnp .argmax (
1189+ jnp .argmin (
12101190 jnp .where (
12111191 msk ,
1212- dint ,
1213- - jnp .inf ,
1192+ d ,
1193+ jnp .inf ,
12141194 )
12151195 )
1216- / expfac
1217- + 1.0
1196+ + 0.5
12181197 )
12191198
12201199
@@ -1247,6 +1226,11 @@ def _find_maxk(kim, max_maxk, thresh):
12471226 # maxk from the image (computed by _inner_comp_find_maxk)
12481227 # by max_maxk from above
12491228 return jnp .minimum (
1250- _inner_comp_find_maxk (kim .array , thresh , kx , ky ),
1229+ # jax-galsim tends to be less conservative for maxk
1230+ # since compared to galsim, it does NOT require 5 rows
1231+ # of pixels in a row below the threshold.
1232+ # thus we add pixels here to ensure the galsim tests pass.
1233+ # it turns out one worked ok so that is what we did. - MRB
1234+ _inner_comp_find_maxk (kim .array , thresh , kx , ky ) + 1 * kim .scale ,
12511235 max_maxk ,
12521236 )
0 commit comments