diff --git a/roman_imsim/__init__.py b/roman_imsim/__init__.py index f7bccd8f..d2fc010c 100644 --- a/roman_imsim/__init__.py +++ b/roman_imsim/__init__.py @@ -23,6 +23,7 @@ from .psf import * from .sca import * from .skycat import * +from .roman_coadd import * from .stamp import * from .wcs import * diff --git a/roman_imsim/config/sim_coadd.yaml b/roman_imsim/config/sim_coadd.yaml new file mode 100644 index 00000000..227c1bef --- /dev/null +++ b/roman_imsim/config/sim_coadd.yaml @@ -0,0 +1,189 @@ +# Default settings for roman simulation +# Includes creation of noisless oversampled images (including PSF) +# -- processing of other detector and instrument effects are still handled in the +# python postprocessing layer to enable things not currently in galsim.roman + +modules: + + # Including galsim.roman in the list of modules to import will add a number of Roman-specific + # functions and classes that we will use here. + - roman_imsim + - galsim.roman + + # We need this for one of our Eval items. GalSim does not by default import datetime into + # the globals dict it uses when evaluating Eval items, so we can tell it to import it here. + - datetime + +# Define some other information about the images +image: + + # A special Image type that knows all the Roman SCA geometry, WCS, gain, etc. + # It also by default applies a number of detector effects, but these can be turned + # off if desired by setting some parameters (given below) to False. + type: roman_coadd + coadd_file: '/hpc/group/cosmology/yuedong/H158_Row00/prod_H_00_00_map.fits' + xsize: 2688 + ysize: 2688 + white_noise_weight: 0.7 + pink_noise_weight: 0.3 + pixel_scale: 0.0390625 + + wcs: + type: RomanWCS + # SCA: '@image.SCA' + # ra: { type: ObSeqData, field: ra } + # dec: { type: ObSeqData, field: dec } + # pa: { type: ObSeqData, field: pa } + # mjd: { type: ObSeqData, field: mjd } + # coadd_file: '/cwork/cmh215/OU24-Roman-Coadds/H158/prod_H_00_00_map.fits.gz' + + bandpass: + type: RomanBandpass + # name: { type: ObSeqData, field: filter } + name: "H158" + + # When you want to have multiple images generate the same random galaxies, then + # you can set up multiple random number generators with different update cadences + # by making random_seed a list. + # The default behavior is just to have the random seeds for each object go in order by + # object number across all images, but this shows how to set it up so we use two separate + # cadences. + # The first one behaves normally, which will be used for things like noise on the image. + # The second one sets the initial seed for each object to repeat to the same starting value + # at the start of each filter. If we were doing more than 3 total files, it would then + # move on to another sequence for the next 3 and so on. + random_seed: + # Used for noise and nobjects. + - { type: ObSeqData, field: visit } + + # Used for objects. Repeats sequence for each filter + # Note: Don't use $ shorthand here, since that will implicitly be evaluated once and then + # treated the same way as an integer (i.e. making a regular sequence starting from that + # value). Using an explicit dict with an Eval type means GalSim will leave it alone and + # evaluate it as is for each object. + + + # We're just doing one SCA here. + # If you wanted to do all of them in each of three filters (given below), you could use: + # + # SCA: + # type: Sequence + # first: 1 + # last: 18 + # repeat: 3 # repeat each SCA num 3 times before moving on, for the 3 filters. + # + + SCA: 10 + mjd: { type: ObSeqData, field: mjd } + filter: { type: ObSeqData, field: filter } + exptime: { type: ObSeqData, field: exptime } + + draw_method: 'fft' + # Photon shooting is way faster for chromatic objects than fft, especially when most of them + # are fairly faint. The cross-over point for achromatic objects is generally of order + # flux=1.e6 or so (depending on the profile). Most of our objects here are much fainter than + # that. The fft rendering for chromatic is a factor of 10 or so slower still, whereas + # chromatic photon shooting is only slighly slower than achromatic, so the difference + # is even more pronounced in this case. + use_fft_bright: True + + # These are all by default turned on, but you can turn any of them off if desired: + # ignore_noise: True + # stray_light: False + # thermal_background: False + # reciprocity_failure: False + # dark_current: False + # nonlinearity: False + # ipc: False + # read_noise: False + # sky_subtract: False + + # ignore_noise: False + # stray_light: True + # thermal_background: True + # reciprocity_failure: True + # dark_current: True + # nonlinearity: True + # ipc: True + # read_noise: True + # sky_subtract: False + + # nobjects: 500 + +stamp: + type: Roman_stamp + world_pos: + type: SkyCatWorldPos + exptime: { type: ObSeqData, field: exptime } + skip_failures: True + photon_ops: + - + type: ChargeDiff + +# psf: +# type: roman_psf +# # If omitted, it would figure this out automatically, because we are using the RomanSCA image +# # type. But if we weren't, you'd have to tell it which SCA to build the PSF for. +# SCA: '@image.SCA' +# # n_waves defines how finely to sample the PSF profile over the bandpass. +# # Using 10 wavelengths usually gives decent accuracy. +# n_waves: 10 + +psf: + type: RomanPSF + interpolator: + type: RomanPSFInterpolator + +# Define the galaxy type and positions to use +gal: + type: SkyCatObj + +input: + obseq_data: + file_name: /hpc/group/cosmology/OpenUniverse2024/RomanWAS/Roman_WAS_obseq_11_1_23.fits + visit: 4907 + SCA: '@image.SCA' + # roman_psf: + # SCA: '@image.SCA' + # n_waves: 5 + RomanPSFInterpolator: + kind: corners + n_waves: 5 + sky_catalog: + file_name: /hpc/home/yf194/Work/test_coadd_sims/roman_imsim/config/skyCatalog.yaml + edge_pix: 512 + mjd: { type: ObSeqData, field: mjd } + exptime: { type: ObSeqData, field: exptime } + obj_types: ['diffsky_galaxy','star','snana'] + +output: + + nfiles: 1 + dir: /hpc/group/cosmology/yuedong/roman_output/coadd_sim_new/images/truth + file_name: + type: FormattedStr + format: "Roman_WAS_truth_%s_%i_%i.fits.gz" + items: + - { type: ObSeqData, field: filter } + - { type: ObSeqData, field: visit } + - '@image.SCA' + + truth: + dir: /hpc/group/cosmology/yuedong/roman_output/coadd_sim_new/truth + file_name: + type: FormattedStr + format: "Roman_WAS_index_%s_%i_%i.txt" + items: + - { type: ObSeqData, field: filter } + - { type: ObSeqData, field: visit } + - '@image.SCA' + columns: + object_id: "@object_id" + ra: "$sky_pos.ra.deg" + dec: "$sky_pos.dec.deg" + x: "$image_pos.x" + y: "$image_pos.y" + realized_flux: "@realized_flux" + flux: "@flux" + mag: "@mag" + obj_type: "@object_type" diff --git a/roman_imsim/psf.py b/roman_imsim/psf.py index 033dbd10..c6066187 100644 --- a/roman_imsim/psf.py +++ b/roman_imsim/psf.py @@ -83,6 +83,18 @@ def _parse_pupil_bin(self, pupil_bin): else: return pupil_bin + def _psf_call_coadd(self, bpass, n_waves, logger): + # Currently only implementing a Gaussian PSF for each band + fwhm_dict = { + "Y106": 0.220, + "J129": 0.231, + "H158": 0.242, + "F184": 0.253, + "K213": 0.264, + } + psf = galsim.Gaussian(fwhm=fwhm_dict[bpass.name]) + return psf.withGSParams(maximum_fft_size=16384) + def _psf_call(self, SCA, bpass, SCA_pos, WCS, pupil_bin, n_waves, logger, extra_aberrations): if pupil_bin == 8: @@ -227,7 +239,9 @@ def initPSF( SCA, bandpass, cc, WCS, pupil_bin, n_waves, logger, self._extra_aberrations ) - def getPSF(self, pupil_bin, pos): + self.PSF_coadd = self._psf_call_coadd(bandpass, n_waves, logger) + + def getPSF(self, pupil_bin, pos, is_coadd=False): """ Return a PSF to be convolved with sources. @@ -256,6 +270,9 @@ def getPSF(self, pupil_bin, pos): # psf = self.PSF[pupil_bin]['cc'] # return psf + if is_coadd: + return self.PSF_coadd + psf = self.PSF[pupil_bin] if pupil_bin != 8: return psf diff --git a/roman_imsim/roman_coadd.py b/roman_imsim/roman_coadd.py new file mode 100644 index 00000000..f969b572 --- /dev/null +++ b/roman_imsim/roman_coadd.py @@ -0,0 +1,238 @@ +import galsim +import galsim.roman as roman +import galsim.config +from galsim.config import RegisterImageType +from galsim.config.image_scattered import ScatteredImageBuilder +from galsim.image import Image +from astropy.time import Time +from astropy.io import fits +import numpy as np + + +class RomanCoaddImageBuilder(ScatteredImageBuilder): + + def setup(self, config, base, image_num, obj_num, ignore, logger): + """Do the initialization and setup for building the image. + + This figures out the size that the image will be, but doesn't actually build it yet. + + Parameters: + config: The configuration dict for the image field. + base: The base configuration dict. + image_num: The current image number. + obj_num: The first object number in the image. + ignore: A list of parameters that are allowed to be in config that we can + ignore here. i.e. it won't be an error if these parameters are present. + logger: If given, a logger object to log progress. + + Returns: + xsize, ysize + """ + # import os, psutil + # process = psutil.Process() + # print('sca setup 1',process.memory_info().rss) + logger.debug("image %d: Building RomanSCA: image, obj = %d,%d", image_num, image_num, obj_num) + + self.nobjects = self.getNObj(config, base, image_num, logger=logger) + logger.debug("image %d: nobj = %d", image_num, self.nobjects) + + # These are allowed for Scattered, but we don't use them here. + extra_ignore = ["image_pos", "world_pos", "stamp_size", "stamp_xsize", "stamp_ysize", "nobjects"] + req = { + "SCA": int, + "filter": str, + "mjd": float, + "exptime": float, + "coadd_file": str, + "white_noise_weight": float, + "pink_noise_weight": float, + "pixel_scale": float, + } + opt = { + "draw_method": str, + "ignore_noise": bool, + "use_fft_bright": bool, + # 'sca_filepath': str, + "xsize": int, + "ysize": int, + } + params = galsim.config.GetAllParams(config, base, req=req, opt=opt, ignore=ignore + extra_ignore)[0] + + self.sca = params["SCA"] + base["SCA"] = self.sca + self.filter = params["filter"] + self.mjd = params["mjd"] + self.exptime = params["exptime"] + + self.ignore_noise = params.get("ignore_noise", False) + # self.exptime = params.get('exptime', roman.exptime) # Default is roman standard exposure time. + # self.stray_light = params.get('stray_light', False) + # self.thermal_background = params.get('thermal_background', False) + # self.reciprocity_failure = params.get('reciprocity_failure', False) + # self.dark_current = params.get('dark_current', False) + # self.nonlinearity = params.get('nonlinearity', False) + # self.ipc = params.get('ipc', False) + # self.read_noise = params.get('read_noise', False) + # self.sky_subtract = params.get('sky_subtract', False) + + # If draw_method isn't in image field, it may be in stamp. Check. + self.draw_method = params.get("draw_method", base.get("stamp", {}).get("draw_method", "auto")) + + # pointing = CelestialCoord(ra=params['ra'], dec=params['dec']) + # wcs = roman.getWCS(world_pos = pointing, + # PA = params['pa']*galsim.degrees, + # date = params['date'], + # SCAs = self.sca, + # PA_is_FPA = True + # )[self.sca] + + # # GalSim expects a wcs in the image field. + # config['wcs'] = wcs + + self.rng = galsim.config.GetRNG(config, base) + self.visit = int(base["input"]["obseq_data"]["visit"]) + + # If user hasn't overridden the bandpass to use, get the standard one. + if "bandpass" not in config: + base["bandpass"] = galsim.config.BuildBandpass(base["image"], "bandpass", base, logger=logger) + + self.coadd_hdu = fits.open(params["coadd_file"]) + self.white_noise_weight = params["white_noise_weight"] + self.pink_noise_weight = params["pink_noise_weight"] + + # return roman.n_pix, roman.n_pix + return int(self.coadd_hdu[0].header["NAXIS1"]), int(self.coadd_hdu[0].header["NAXIS2"]) + + # def getBandpass(self, filter_name): + # if not hasattr(self, 'all_roman_bp'): + # self.all_roman_bp = roman.getBandpasses() + # return self.all_roman_bp[filter_name] + + def buildImage(self, config, base, image_num, obj_num, logger): + """Build an Image containing multiple objects placed at arbitrary locations. + + Parameters: + config: The configuration dict for the image field. + base: The base configuration dict. + image_num: The current image number. + obj_num: The first object number in the image. + logger: If given, a logger object to log progress. + + Returns: + the final image and the current noise variance in the image as a tuple + """ + full_xsize = base["image_xsize"] + full_ysize = base["image_ysize"] + wcs = base["wcs"] + + full_image = Image(full_xsize, full_ysize, dtype=float) + full_image.setOrigin(base["image_origin"]) + full_image.wcs = wcs + full_image.setZero() + + full_image.header = galsim.FitsHeader() + full_image.header["EXPTIME"] = self.exptime + full_image.header["MJD-OBS"] = self.mjd + full_image.header["DATE-OBS"] = Time(self.mjd, format="mjd").datetime.isoformat() + full_image.header["FILTER"] = self.filter + full_image.header["ZPTMAG"] = 2.5 * np.log10(self.exptime * roman.collecting_area) + + base["current_image"] = full_image + + if "image_pos" in config and "world_pos" in config: + raise galsim.GalSimConfigValueError( + "Both image_pos and world_pos specified for Scattered image.", + (config["image_pos"], config["world_pos"]), + ) + + if "image_pos" not in config and "world_pos" not in config: + xmin = base["image_origin"].x + xmax = xmin + full_xsize - 1 + ymin = base["image_origin"].y + ymax = ymin + full_ysize - 1 + config["image_pos"] = { + "type": "XY", + "x": {"type": "Random", "min": xmin, "max": xmax}, + "y": {"type": "Random", "min": ymin, "max": ymax}, + } + + nbatch = self.nobjects // 1000 + 1 + for batch in range(nbatch): + start_obj_num = self.nobjects * batch // nbatch + end_obj_num = self.nobjects * (batch + 1) // nbatch + nobj_batch = end_obj_num - start_obj_num + if nbatch > 1: + logger.warning( + "Start batch %d/%d with %d objects [%d, %d)", + batch + 1, + nbatch, + nobj_batch, + start_obj_num, + end_obj_num, + ) + stamps, current_vars = galsim.config.BuildStamps( + nobj_batch, base, logger=logger, obj_num=start_obj_num, do_noise=False + ) + base["index_key"] = "image_num" + + for k in range(nobj_batch): + # This is our signal that the object was skipped. + if stamps[k] is None: + continue + bounds = stamps[k].bounds & full_image.bounds + if not bounds.isDefined(): # pragma: no cover + # These noramlly show up as stamp==None, but technically it is possible + # to get a stamp that is off the main image, so check for that here to + # avoid an error. But this isn't covered in the imsim test suite. + continue + + logger.debug("image %d: full bounds = %s", image_num, str(full_image.bounds)) + logger.debug( + "image %d: stamp %d bounds = %s", image_num, k + start_obj_num, str(stamps[k].bounds) + ) + logger.debug("image %d: Overlap = %s", image_num, str(bounds)) + full_image[bounds] += stamps[k][bounds] + stamps = None + + # # [TODO] + # break + + # # Bring the image so far up to a flat noise variance + # current_var = FlattenNoiseVariance( + # base, full_image, stamps, current_vars, logger) + + logger.info("roman pixel scale: %.5f" % (roman.pixel_scale)) + full_image /= (0.0390625 / 0.11) ** 2 + + return full_image, None + + def addNoise(self, image, config, base, image_num, obj_num, current_var, logger): + """Add the final noise to a Scattered image + + Parameters: + image: The image onto which to add the noise. + config: The configuration dict for the image field. + base: The base configuration dict. + image_num: The current image number. + obj_num: The first object number in the image. + current_var: The current noise variance in each postage stamps. + logger: If given, a logger object to log progress. + """ + # check ignore noise + if self.ignore_noise: + return + + base["current_noise_image"] = base["current_image"] + # wcs = base["wcs"] + # bp = base["bandpass"] + # rng = galsim.config.GetRNG(config, base) + logger.info("image %d: Start RomanSCA detector effects", base.get("image_num", 0)) + + noise_white = self.coadd_hdu[0].data[0][11] + noise_pink = self.coadd_hdu[0].data[0][10] + image += noise_white * self.white_noise_weight + image += noise_pink * self.pink_noise_weight + + +# Register this as a valid type +RegisterImageType("roman_coadd", RomanCoaddImageBuilder()) diff --git a/roman_imsim/skycat.py b/roman_imsim/skycat.py index 0b2b8c9b..b7b1a03f 100644 --- a/roman_imsim/skycat.py +++ b/roman_imsim/skycat.py @@ -17,9 +17,7 @@ class SkyCatalogInterface: """Interface to skyCatalogs package.""" _trivial_sed = galsim.SED( - galsim.LookupTable([100, 2600], [1, 1], interpolant="linear"), - wave_type="nm", - flux_type="fphotons", + galsim.LookupTable([100, 2600], [1, 1], interpolant="linear"), wave_type="nm", flux_type="fphotons" ) def __init__( @@ -260,6 +258,7 @@ def getKwargs(self, config, base, logger): } kwargs, safe = galsim.config.GetAllParams(config, base, req=req, opt=opt) wcs = galsim.config.BuildWCS(base["image"], "wcs", base, logger=logger) + kwargs["wcs"] = wcs kwargs["logger"] = logger @@ -267,6 +266,9 @@ def getKwargs(self, config, base, logger): base["bandpass"] = galsim.config.BuildBandpass(base["image"], "bandpass", base, logger=logger)[0] kwargs["bandpass"] = base["bandpass"] + if base["image"]["type"] == "roman_coadd": + kwargs["xsize"] = base["image"]["xsize"] + kwargs["ysize"] = base["image"]["ysize"] # Sky catalog object lists are created per CCD, so they are # not safe to reuse. safe = False diff --git a/roman_imsim/stamp.py b/roman_imsim/stamp.py index 2554d180..bd0a9900 100644 --- a/roman_imsim/stamp.py +++ b/roman_imsim/stamp.py @@ -55,6 +55,14 @@ def setup(self, config, base, xsize, ysize, ignore, logger): raise galsim.config.SkipThisObject("gal is None (invalid parameters)") base["object_type"] = getattr(gal, "object_type", "") bandpass = base["bandpass"] + + if base["image"]["type"] == "roman_coadd": + self.is_coadd = True + self.pixel_scale = float(base["image"]["pixel_scale"]) + else: + self.is_coadd = False + self.pixel_scale = roman.pixel_scale + if not hasattr(gal, "flux"): # In this case, the object flux has not been precomputed # or cached by the skyCatalogs code. @@ -100,7 +108,7 @@ def setup(self, config, base, xsize, ysize, ignore, logger): # psf = galsim.config.BuildGSObject(base, 'psf', logger=logger)[0]['achromatic'] # obj = galsim.Convolve(gal_achrom, psf).withFlux(self.flux) obj = gal_achrom.withGSParams(galsim.GSParams(stepk_minimum_hlr=20)) - image_size = obj.getGoodImageSize(roman.pixel_scale) + image_size = obj.getGoodImageSize(self.pixel_scale) # print('stamp setup3',process.memory_info().rss) base["pupil_bin"] = self.pupil_bin @@ -121,6 +129,29 @@ def setup(self, config, base, xsize, ysize, ignore, logger): return image_size, image_size, image_pos, world_pos + def buildPSF(self, config, base, gsparams, logger): + """Build the PSF object. + + For the Basic stamp type, this builds a PSF from the base['psf'] dict, if present, + else returns None. + + Parameters: + config: The configuration dict for the stamp field. + base: The base configuration dict. + gsparams: A dict of kwargs to use for a GSParams. More may be added to this + list by the galaxy object. + logger: A logger object to log progress. + + Returns: + the PSF + """ + if base.get("psf", {}).get("type", "roman_psf") != "roman_psf": + return galsim.config.BuildGSObject(base, "psf", gsparams=gsparams, logger=logger)[0] + + roman_psf = galsim.config.GetInputObj("roman_psf", config, base, "buildPSF") + psf = roman_psf.getPSF(self.pupil_bin, base["image_pos"], is_coadd=self.is_coadd) + return psf + def getDrawMethod(self, config, base, logger): """Determine the draw method to use. @@ -317,7 +348,7 @@ def draw(self, prof, image, method, offset, config, base, logger): raise # Check if we need to add photon noise for bright objects drawn # with FFT because we switched from phot to fft above. - if self.use_fft_bright: + if self.use_fft_bright and (not self.is_coadd): self.add_poisson_noise(fft_image) # In case we had to make a bigger image, just copy the part we need. image += fft_image[image.bounds] diff --git a/roman_imsim/wcs.py b/roman_imsim/wcs.py index 94700318..bde980d5 100644 --- a/roman_imsim/wcs.py +++ b/roman_imsim/wcs.py @@ -10,27 +10,37 @@ class RomanWCS(WCSBuilder): def buildWCS(self, config, base, logger): - req = { - "SCA": int, - "ra": Angle, - "dec": Angle, - "pa": Angle, - "mjd": float, - } - opt = {"max_sun_angle": float, "force_cvz": bool} - - kwargs, safe = galsim.config.GetAllParams(config, base, req=req, opt=opt) - if "max_sun_angle" in kwargs: - roman.max_sun_angle = kwargs["max_sun_angle"] - roman.roman_wcs.max_sun_angle = kwargs["max_sun_angle"] - pointing = CelestialCoord(ra=kwargs["ra"], dec=kwargs["dec"]) - wcs = roman.getWCS( - world_pos=pointing, - PA=kwargs["pa"], - date=Time(kwargs["mjd"], format="mjd").datetime, - SCAs=kwargs["SCA"], - PA_is_FPA=True, - )[kwargs["SCA"]] + if base["image"]["type"] == "roman_sca": + req = { + "SCA": int, + "ra": Angle, + "dec": Angle, + "pa": Angle, + "mjd": float, + } + opt = {"max_sun_angle": float, "force_cvz": bool} + + kwargs, safe = galsim.config.GetAllParams(config, base, req=req, opt=opt) + if "max_sun_angle" in kwargs: + roman.max_sun_angle = kwargs["max_sun_angle"] + roman.roman_wcs.max_sun_angle = kwargs["max_sun_angle"] + pointing = CelestialCoord(ra=kwargs["ra"], dec=kwargs["dec"]) + wcs = roman.getWCS( + world_pos=pointing, + PA=kwargs["pa"], + date=Time(kwargs["mjd"], format="mjd").datetime, + SCAs=kwargs["SCA"], + PA_is_FPA=True, + )[kwargs["SCA"]] + elif base["image"]["type"] == "roman_coadd": + # req = {'coadd_file': str, + # } + # opt = {} + # kwargs, safe = galsim.config.GetAllParams( + # config, base, req=req, opt=opt) + # wcs = galsim.FitsWCS(kwargs['coadd_file']) + wcs = galsim.GSFitsWCS(file_name=base["image"]["coadd_file"], hdu=0) + return wcs