diff --git a/examples/demo1.py b/examples/demo1.py new file mode 100644 index 00000000..987d3e75 --- /dev/null +++ b/examples/demo1.py @@ -0,0 +1,150 @@ +# Copyright (c) 2012-2026 by the GalSim developers team on GitHub +# https://github.com/GalSim-developers +# +# This file is part of GalSim: The modular galaxy image simulation toolkit. +# https://github.com/GalSim-developers/GalSim +# +# GalSim is free software: redistribution and use in source and binary forms, +# with or without modification, are permitted provided that the following +# conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions, and the disclaimer given in the accompanying LICENSE +# file. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions, and the disclaimer given in the documentation +# and/or other materials provided with the distribution. +# +""" +Demo #1 + +This is the first script in our tutorial about using JAX-GalSim in python scripts: examples/demo*.py. +(This file is designed to be viewed in a window 100 characters wide.) + +Each of these demo*.py files are designed to be equivalent to the corresponding demo*.yaml file +(or demo*.json -- found in the json directory). If you are new to python, you should probably +look at those files first as they will probably have a quicker learning curve for you. Then you +can look through these python scripts, which show how to do the same thing. Of course, experienced +pythonistas may prefer to start with these scripts and then look at the corresponding YAML files. + +To run this script, simply write: + + python demo1.py + + +This first script is about as simple as it gets. We draw an image of a single galaxy convolved +with a PSF and write it to disk. We use a circular Gaussian profile for both the PSF and the +galaxy, and add a constant level of Gaussian noise to the image. + +In each demo, we list the new features introduced in that demo file. These will differ somewhat +between the .py and .yaml (or .json) versions, since the two methods implement things in different +ways. (demo*.py are python scripts, while demo*.yaml and demo*.json are configuration files.) + +New features introduced in this demo: + +- obj = jax_galsim.Gaussian(flux, sigma) +- obj = jax_galsim.Convolve([list of objects]) +- image = obj.drawImage(scale) +- image.added_flux (Only present after a drawImage command.) +- noise = jax_galsim.GaussianNoise(sigma) +- image.addNoise(noise) +- image.write(file_name) +- image.FindAdaptiveMom() +""" + +import logging +import math +import os +import sys + +import jax_galsim + + +def main(argv): + """ + About as simple as it gets: + - Use a circular Gaussian profile for the galaxy. + - Convolve it by a circular Gaussian PSF. + - Add Gaussian noise to the image. + """ + # In non-script code, use getLogger(__name__) at module scope instead. + logging.basicConfig(format="%(message)s", level=logging.INFO, stream=sys.stdout) + logger = logging.getLogger("demo1") + + gal_flux = 1.0e5 # total counts on the image + gal_sigma = 2.0 # arcsec + psf_sigma = 1.0 # arcsec + pixel_scale = 0.2 # arcsec / pixel + noise = 30.0 # standard deviation of the counts in each pixel + + logger.info("Starting demo script 1 using:") + logger.info( + " - circular Gaussian galaxy (flux = %.1e, sigma = %.1f),", + gal_flux, + gal_sigma, + ) + logger.info(" - circular Gaussian PSF (sigma = %.1f),", psf_sigma) + logger.info(" - pixel scale = %.2f,", pixel_scale) + logger.info(" - Gaussian noise (sigma = %.2f).", noise) + + # Define the galaxy profile + gal = jax_galsim.Gaussian(flux=gal_flux, sigma=gal_sigma) + logger.debug("Made galaxy profile") + + # Define the PSF profile + psf = jax_galsim.Gaussian(flux=1.0, sigma=psf_sigma) # PSF flux should always = 1 + logger.debug("Made PSF profile") + + # Final profile is the convolution of these + # Can include any number of things in the list, all of which are convolved + # together to make the final flux profile. + final = jax_galsim.Convolve([gal, psf]) + logger.debug("Convolved components into final profile") + + # Draw the image with a particular pixel scale, given in arcsec/pixel. + # The returned image has a member, added_flux, which is gives the total flux actually added to + # the image. One could use this value to check if the image is large enough for some desired + # accuracy level. Here, we just ignore it. + image = final.drawImage(scale=pixel_scale) + logger.debug( + "Made image of the profile: flux = %f, added_flux = %f", + gal_flux, + image.added_flux, + ) + + # Add Gaussian noise to the image with specified sigma + image.addNoise(jax_galsim.GaussianNoise(sigma=noise)) + logger.debug("Added Gaussian noise") + + # Write the image to a file + if not os.path.isdir("output"): + os.mkdir("output") + file_name = os.path.join("output", "demo1.fits") + # Note: if the file already exists, this will overwrite it. + image.write(file_name) + logger.info( + "Wrote image to %r" % file_name + ) # using %r adds quotes around filename for us + + results = image.FindAdaptiveMom() + + logger.info("HSM reports that the image has observed shape and size:") + logger.info( + " e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)", + results.observed_shape.e1, + results.observed_shape.e2, + results.moments_sigma, + ) + logger.info( + "Expected values in the limit that pixel response and noise are negligible:" + ) + logger.info( + " e1 = %.3f, e2 = %.3f, sigma = %.3f", + 0.0, + 0.0, + math.sqrt(gal_sigma**2 + psf_sigma**2) / pixel_scale, + ) + + +if __name__ == "__main__": + main(sys.argv) diff --git a/examples/demo2.py b/examples/demo2.py new file mode 100644 index 00000000..4f64b436 --- /dev/null +++ b/examples/demo2.py @@ -0,0 +1,175 @@ +# Copyright (c) 2012-2026 by the GalSim developers team on GitHub +# https://github.com/GalSim-developers +# +# This file is part of GalSim: The modular galaxy image simulation toolkit. +# https://github.com/GalSim-developers/GalSim +# +# GalSim is free software: redistribution and use in source and binary forms, +# with or without modification, are permitted provided that the following +# conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions, and the disclaimer given in the accompanying LICENSE +# file. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions, and the disclaimer given in the documentation +# and/or other materials provided with the distribution. +# +""" +Demo #2 + +The second script in our tutorial about using JAX-GalSim in python scripts: examples/demo*.py. +(This file is designed to be viewed in a window 100 characters wide.) + +This script is a bit more sophisticated, but still pretty basic. We're still only making +a single image, but now the galaxy has an exponential radial profile and is sheared. +The PSF is a circular Moffat profile. The noise is drawn from a Poisson distribution +using the flux from both the object and a background sky level to determine the +variance in each pixel. + +New features introduced in this demo: + +- obj = jax_galsim.Exponential(flux, scale_radius) +- obj = jax_galsim.Moffat(beta, flux, half_light_radius) +- obj = obj.shear(g1, g2) -- with explanation of other ways to specify shear +- rng = jax_galsim.BaseDeviate(seed) +- noise = jax_galsim.PoissonNoise(rng, sky_level) +- galsim.hsm.EstimateShear(image, image_epsf) +""" + +import logging +import os +import sys + +import galsim + +import jax_galsim + + +def main(argv): + """ + A little bit more sophisticated, but still pretty basic: + - Use a sheared, exponential profile for the galaxy. + - Convolve it by a circular Moffat PSF. + - Add Poisson noise to the image. + """ + # In non-script code, use getLogger(__name__) at module scope instead. + logging.basicConfig(format="%(message)s", level=logging.INFO, stream=sys.stdout) + logger = logging.getLogger("demo2") + + gal_flux = 1.0e5 # counts + gal_r0 = 2.7 # arcsec + g1 = 0.1 # + g2 = 0.2 # + psf_beta = 5 # + psf_re = 1.0 # arcsec + pixel_scale = 0.2 # arcsec / pixel + sky_level = 2.5e3 # counts / arcsec^2 + + # This time use a particular seed, so the image is deterministic. + # This is the same seed that is used in demo2.yaml, which means the images + # produced by the two methods will be precisely identical. + random_seed = 1534225 + + # The first thing the config layer does with the random seed is to scramble + # it a bit. Specifically, it makes a random number generator (BaseDeviate) + # using that seed and asks for a raw value. This becomes the seed that + # actually gets used. + # The reason for this extra step is that eventually (cf. demo4) the config + # layer will want to increment these seed values when building multiple + # objects or images. If the user is likewise incrementing seed values for + # multiple runs of a given config file, these can interfere leading to + # surprising (and typically bad) results. + random_seed = jax_galsim.BaseDeviate(random_seed).raw() + + logger.info("Starting demo script 2 using:") + logger.info( + " - sheared (%.2f,%.2f) exponential galaxy (flux = %.1e, scale radius = %.2f),", + g1, + g2, + gal_flux, + gal_r0, + ) + logger.info(" - circular Moffat PSF (beta = %.1f, re = %.2f),", psf_beta, psf_re) + logger.info(" - pixel scale = %.2f,", pixel_scale) + logger.info(" - Poisson noise (sky level = %.1e).", sky_level) + + # Initialize the (pseudo-)random number generator that we will be using below. + # For a technical reason that will be explained later (demo9.py), we add 1 to the + # given random seed here. + rng = jax_galsim.BaseDeviate(random_seed + 1) + + # Define the galaxy profile. + gal = jax_galsim.Exponential(flux=gal_flux, scale_radius=gal_r0) + + # Shear the galaxy by some value. + # There are quite a few ways you can use to specify a shape. + # q, beta Axis ratio and position angle: q = b/a, 0 < q < 1 + # e, beta Ellipticity and position angle: |e| = (1-q^2)/(1+q^2) + # g, beta ("Reduced") Shear and position angle: |g| = (1-q)/(1+q) + # eta, beta Conformal shear and position angle: eta = ln(1/q) + # e1,e2 Ellipticity components: e1 = e cos(2 beta), e2 = e sin(2 beta) + # g1,g2 ("Reduced") shear components: g1 = g cos(2 beta), g2 = g sin(2 beta) + # eta1,eta2 Conformal shear components: eta1 = eta cos(2 beta), eta2 = eta sin(2 beta) + gal = gal.shear(g1=g1, g2=g2) + logger.debug("Made galaxy profile") + + # Define the PSF profile. + psf = jax_galsim.Moffat(beta=psf_beta, flux=1.0, half_light_radius=psf_re) + logger.debug("Made PSF profile") + + # Final profile is the convolution of these. + final = jax_galsim.Convolve([gal, psf]) + logger.debug("Convolved components into final profile") + + # Draw the image with a particular pixel scale. + image = final.drawImage(scale=pixel_scale) + # The "effective PSF" is the PSF as drawn on an image, which includes the convolution + # by the pixel response. We label it epsf here. + image_epsf = psf.drawImage(scale=pixel_scale) + logger.debug("Made image of the profile") + + # To get Poisson noise on the image, we will use a class called PoissonNoise. + # However, we want the noise to correspond to what you would get with a significant + # flux from tke sky. This is done by telling PoissonNoise to add noise from a + # sky level in addition to the counts currently in the image. + # + # One wrinkle here is that the PoissonNoise class needs the sky level in each pixel, + # while we have a sky_level in counts per arcsec^2. So we need to convert: + sky_level_pixel = sky_level * pixel_scale**2 + noise = jax_galsim.PoissonNoise(rng, sky_level=sky_level_pixel) + image.addNoise(noise) + logger.debug("Added Poisson noise") + + # Write the image to a file. + if not os.path.isdir("output"): + os.mkdir("output") + file_name = os.path.join("output", "demo2.fits") + file_name_epsf = os.path.join("output", "demo2_epsf.fits") + image.write(file_name) + image_epsf.write(file_name_epsf) + logger.info("Wrote image to %r", file_name) + logger.info("Wrote effective PSF image to %r", file_name_epsf) + + results = galsim.hsm.EstimateShear(image.to_galsim(), image_epsf.to_galsim()) + + logger.info("HSM reports that the image has observed shape and size:") + logger.info( + " e1 = %.3f, e2 = %.3f, sigma = %.3f (pixels)", + results.observed_shape.e1, + results.observed_shape.e2, + results.moments_sigma, + ) + logger.info( + "When carrying out Regaussianization PSF correction, HSM reports distortions" + ) + logger.info(" e1, e2 = %.3f, %.3f", results.corrected_e1, results.corrected_e2) + logger.info( + "Expected values in the limit that noise and non-Gaussianity are negligible:" + ) + exp_shear = galsim.Shear(g1=g1, g2=g2) + logger.info(" g1, g2 = %.3f, %.3f", exp_shear.e1, exp_shear.e2) + + +if __name__ == "__main__": + main(sys.argv) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index f8494a99..f6f0f518 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1094,15 +1094,45 @@ def tree_unflatten(cls, aux_data, children): @classmethod def from_galsim(cls, galsim_image): """Create a `Image` from a `galsim.Image` instance.""" + wcs = ( + BaseWCS.from_galsim(galsim_image.wcs) + if galsim_image.wcs is not None + else None + ) im = cls( array=galsim_image.array, - wcs=BaseWCS.from_galsim(galsim_image.wcs), + wcs=wcs, bounds=Bounds.from_galsim(galsim_image.bounds), ) if hasattr(galsim_image, "header"): im.header = galsim_image.header return im + def to_galsim(self): + """Create a galsim `Image` from a `jax_galsim.Image` object.""" + wcs = self.wcs.to_galsim() if self.wcs is not None else None + return _galsim.Image( + np.asarray(self.array), bounds=self.bounds.to_galsim(), wcs=wcs + ) + + @implements( + _galsim.Image.FindAdaptiveMom, + lax_description=( + "This method converts the current `jax_galsim.Image` to a native " + "`galsim.Image` and delegates the computation to " + "`galsim.hsm.FindAdaptiveMom`. The returned object is GalSim's " + "`ShapeData`." + ), + ) + def FindAdaptiveMom(self, *args, **kwargs): + args_ = [arg.to_galsim() if hasattr(arg, "to_galsim") else arg for arg in args] + kwargs_ = { + key: val.to_galsim() if hasattr(val, "to_galsim") else val + for key, val in kwargs.items() + } + gs_image = self.to_galsim() + return gs_image.FindAdaptiveMom(*args_, **kwargs_) + @implements( _galsim._Image, diff --git a/tests/GalSim b/tests/GalSim index 2ed86695..3251a393 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 2ed86695df3669c4ff4de4cd3154e6fd76e206da +Subproject commit 3251a393bf7ea94fe9ccda3508bc7db722eca1cf diff --git a/tests/SBProfile_comparison_images b/tests/SBProfile_comparison_images new file mode 120000 index 00000000..6e72d788 --- /dev/null +++ b/tests/SBProfile_comparison_images @@ -0,0 +1 @@ +GalSim/tests/SBProfile_comparison_images \ No newline at end of file diff --git a/tests/fits_file b/tests/fits_file new file mode 120000 index 00000000..5b03e34d --- /dev/null +++ b/tests/fits_file @@ -0,0 +1 @@ +GalSim/tests/fits_files \ No newline at end of file diff --git a/tests/galsim_tests_config.yaml b/tests/galsim_tests_config.yaml index 0efc2e1b..2d532570 100644 --- a/tests/galsim_tests_config.yaml +++ b/tests/galsim_tests_config.yaml @@ -86,7 +86,6 @@ allowed_failures: - "'Image' object has no attribute 'bin'" - "module 'jax_galsim' has no attribute 'InterpolatedKImage'" - "module 'jax_galsim' has no attribute 'CorrelatedNoise'" - - "'Image' object has no attribute 'FindAdaptiveMom'" - "CelestialCoord.precess is too slow" # cannot get jax to warmup but once it does it passes - "ValueError not raised by from_xyz" - "ValueError not raised by greatCirclePoint" diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index f7445a6b..65aa1b1c 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -604,6 +604,7 @@ def _reg_sfun(g1): def test_api_image(obj): _run_object_checks(obj, obj.__class__, "docs-methods") _run_object_checks(obj, obj.__class__, "pickle-eval-repr-img") + _run_object_checks(obj, obj.__class__, "to-from-galsim") # JAX tracing should be an identity assert obj.__class__.tree_unflatten(*((obj.tree_flatten())[::-1])) == obj